├── .gitignore ├── README.md ├── configs ├── inference │ └── config.json └── training │ └── config.json ├── doc ├── Data_Preparation.png └── System_Overview.png ├── n2m ├── __init__.py ├── data │ └── dataset.py ├── model │ ├── N2Mnet.py │ └── pointbert │ │ ├── checkpoint.py │ │ ├── dvae.py │ │ ├── logger.py │ │ ├── misc.py │ │ └── point_encoder.py ├── module │ └── N2Mmodule.py └── utils │ ├── AverageMeter.py │ ├── checkpoint.py │ ├── config.py │ ├── dist_utils.py │ ├── logger.py │ ├── loss.py │ ├── metrics.py │ ├── misc.py │ ├── parser.py │ ├── point_cloud.py │ ├── prediction.py │ ├── registry.py │ ├── sample_utils.py │ └── visualizer.py ├── pyproject.toml ├── scripts ├── process_dataset.sh ├── render │ ├── CMakeLists.txt │ ├── README.md │ ├── fpv_render.cpp │ └── include │ │ └── json.hpp ├── sample_camera_poses.py ├── train.py └── visualize_attention.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | datasets 3 | /training/ 4 | build 5 | **.egg-info 6 | **__pycache__** 7 | wandb 8 | **.DS_Store 9 | test_codes -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # N2M: Bridging Navigation and Manipulation by Learning Pose Preference from Rollout 2 |

Kaixin Chai*, Hyunjun Lee*, Joseph J. Lim

3 | 4 | ![System Overview](doc/System_Overview.png) 5 | 6 | This is an official implementation of N2M. We provided detailed instructions to train and inference N2M. For examples in simulation and real world, please refer to `sim` and `real` branches. 7 | 8 | ## TODOs 9 | - Organize N2M training / inference code ✅ 10 | - Organize N2M real world code ✅ 11 | - Organize N2M simulation code 12 | 13 | ## Installation 14 | Clone and install necessary packages 15 | ```bash 16 | git clone --single-branch --branch main https://github.com/clvrai/N2M.git 17 | cd N2M 18 | 19 | # install mamba environment 20 | mamba create -n n2m python==3.11 21 | mamba activate n2m 22 | pip install -r requirements.txt 23 | pip install -e . 24 | 25 | # compile c++ file 26 | cd scripts/render 27 | mkdir build && cd build 28 | cmake .. 29 | make 30 | ``` 31 | 32 | ## 📊 Data preparation 33 | 34 | ![Data Preparation](doc/Data_Preparation.png) 35 | 36 | You should first prepare raw data with pairs of local scene and preferable initial pose. Local scene is a point cloud of a scene and you may stitch point clouds using multiple calibrated cameras. In this repo, we do not provide code for capturing the local scene. 37 | 38 | The format of raw data should be placed under `datasets` folder in the format below 39 | ``` 40 | datasets/ 41 | └── {dataset name}/ 42 | ├── pcl/ 43 | │ ├── 0.pcd 44 | │ ├── 1.pcd 45 | │ └── ... 46 | └── meta.json 47 | 48 | ``` 49 | Replace `{dataset name}` with your own dataset name. `pcl/` folder should include point clouds of your local scene and `meta.json` should include the information of each local scene and the label of the preferable initial pose of each scene. `meta.json` should be in the format as below. 50 | ```json 51 | { 52 | "meta": { 53 | "T_base_to_cam": [ 54 | [-8.25269110e-02, -5.73057816e-01, 8.15348841e-01, 6.05364230e-04], 55 | [-9.95784041e-01, 1.45464862e-02, -9.05661474e-02, -3.94417736e-02], 56 | [ 4.00391906e-02, -8.19385767e-01, -5.71842485e-01, 1.64310488e-00], 57 | [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00] 58 | ], 59 | "camera_intrinsic": [ 60 | 100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240 61 | ] 62 | }, 63 | "episodes": [ 64 | { 65 | "id": 0, 66 | "file_path": "pcl/0.pcd", 67 | "pose": [ 68 | 3.07151444879416, 69 | -0.9298766226100992, 70 | 1.5782995419534618 71 | ], 72 | "object_position": [ 73 | 3.2, 74 | -0.2, 75 | 1.85 76 | ] 77 | }, 78 | ... 79 | ] 80 | } 81 | ``` 82 | `T_base_to_cam`: A 4 x 4 transformation matrix that transforms SE(3) pose of base to SE(3) pose of camera. SE(3) pose of base is an SE(2) pose transformed into SE(3) format. This should be pre-calculated and provided to the dataset. 83 | 84 | `camera_intrinsic`: The intrinsic of the camera used to capture the point cloud. This should also be provided in the dataset. 85 | 86 | `file_path`: Relative path of the corresponding local scene point cloud. 87 | 88 | `pose`: Preferable initial pose where you mark the rollout success. Above example is SE(2) pose but you can change it to (x, y, theta, z) information depending on your setting. 89 | 90 | `object position`: Position of the object of interest. We later used this to check the visibility of the object when we sample viewpoints. This element is not mandatory as we estimate it as a position 0.5m infront of `pose` with height 1m if it is empty. 91 | 92 | Example dataset can be downloaded from this link 93 | 94 | ## 🛠️ Data Processing 95 | Now we are ready to process the data. Run the following command to process the data. 96 | ```bash 97 | sh scripts/process_dataset.sh "path/to/dataset" 98 | ``` 99 | This will apply viewpoint augmentation and generate augmented point clouds with new transformed labels corresponding to them. The file structure of the dataset will now look like this: 100 | ``` 101 | datasets/ 102 | └── {dataset name}/ 103 | ├── camera_poses/ 104 | ├── camera_poses_vis/ 105 | ├── pcl/ 106 | ├── pcl_aug/ 107 | ├── meta_aug.json 108 | └── meta.json 109 | ``` 110 | You will find the visualization of sampled camera poses per each scene in `camera_poses_vis/`. Augmented point cloud and corresponding labels will be saved in `pcl_aug/` and `meta_aug.json` respectively. 111 | 112 | ## 🚀 Training 113 | You will be using `configs/training/config.json` as training configuration. Change `dataset_path: "datasets/{dataset name}"` and additional training settings related to your taste 114 | 115 | Before running training, download pretrained PointBERT weight from this link and save it under `models/PointBERT` folder. 116 | ```bash 117 | python scripts/train.py --config configs/training/config.json 118 | ``` 119 | Your training log will saved under `training/{dataset name}` 120 | 121 | ## 🏃🏻‍♂️‍➡️ Inference 122 | To use N2M, you have to import N2Mmodule from `n2m/module/N2Mmodule.py`. This is a wrapper for N2Mnet with data pre-processing and post-processing. This also contains collision checking for sampling valid initial pose from predicted distribution. Example code is as follows 123 | ```python 124 | import json 125 | import o3d 126 | import numpy as np 127 | 128 | from n2m.module import N2Mmodule 129 | 130 | # initialize n2m module 131 | config = json.load("configs/inference/config.json") 132 | n2m = N2Mmodule(config) 133 | 134 | # load pcd 135 | pcd = o3d.io.read_point_cloud("example.pcd") 136 | pcd_numpy = np.concatenate([np.asarray(pcd.points), np.asarray(pcd.colors)], axis=1) 137 | 138 | # predict initial pose. If it fails to sample valid points within certain number of trial, is_valid will return False. Otherwise, is_valid will be True 139 | initial_pose, is_valid = n2m.predict(pcd_numpy) 140 | ``` -------------------------------------------------------------------------------- /configs/inference/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n2mnet": { 3 | "encoder": { 4 | "name": "PointBERT", 5 | "config": { 6 | "NAME": "PointTransformer", 7 | "trans_dim": 384, 8 | "depth": 12, 9 | "drop_path_rate": 0.1, 10 | "cls_dim": 40, 11 | "num_heads": 6, 12 | "group_size": 32, 13 | "num_group": 512, 14 | "encoder_dims": 256 15 | } 16 | }, 17 | "decoder": { 18 | "name": "mlp", 19 | "config": { 20 | "dropout": 0.1 21 | }, 22 | "layers": [ 23 | 512, 24 | 512, 25 | 512, 26 | 512, 27 | 512, 28 | 512, 29 | 512, 30 | 512 31 | ], 32 | "num_gaussians": 2, 33 | "output_dim": 3 34 | }, 35 | "ckpt": "models/N2M/N2Mnet.pth" 36 | }, 37 | "preprocess": { 38 | "pointnum": 8192 39 | }, 40 | "postprocess": { 41 | "num_samples": 100, 42 | "collision_checker": { 43 | "ground_z": 0.05, 44 | "filter_noise": true, 45 | "resolution": 0.02, 46 | "robot_width": 0.5, 47 | "robot_length": 0.63 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /configs/training/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "N2M_training", 3 | "model": { 4 | "encoder": { 5 | "name": "PointBERT", 6 | "config": { 7 | "NAME": "PointTransformer", 8 | "trans_dim": 384, 9 | "depth": 12, 10 | "drop_path_rate": 0.1, 11 | "cls_dim": 40, 12 | "num_heads": 6, 13 | "group_size": 32, 14 | "num_group": 512, 15 | "encoder_dims": 256 16 | }, 17 | "ckpt": "models/PointBERT/PointTransformer_ModelNet8192points.pth", 18 | "freeze": false 19 | }, 20 | "decoder": { 21 | "name": "mlp", 22 | "config": { 23 | "dropout": 0.1 24 | }, 25 | "num_gaussians": 1, 26 | "output_dim": 3 27 | } 28 | }, 29 | "dataset": { 30 | "dataset_path": "datasets/test", 31 | "anno_path": "meta_aug.json", 32 | "pointnum": 8192, 33 | "train_val_ratio": 0.9, 34 | "augmentations": { 35 | "rotation_se2": { 36 | "min_angle": -180, 37 | "max_angle": 180 38 | }, 39 | "translation_xy": { 40 | "radius": 1 41 | } 42 | } 43 | }, 44 | "train": { 45 | "batch_size": 32, 46 | "num_epochs": 300, 47 | "learning_rate": 1e-4, 48 | "num_workers": 4, 49 | "val_freq": 50, 50 | "loss": { 51 | "name": "mle", 52 | "config": { 53 | "lam_weight": 0.1, 54 | "lam_dist": 0.1 55 | } 56 | }, 57 | "wandb": { 58 | "name": null, 59 | "project": null, 60 | "entity": null 61 | }, 62 | "output_dir": "training/test" 63 | } 64 | } -------------------------------------------------------------------------------- /doc/Data_Preparation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/N2M/76ab68b8e75debd3c190bb6db6399b2cee179093/doc/Data_Preparation.png -------------------------------------------------------------------------------- /doc/System_Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/N2M/76ab68b8e75debd3c190bb6db6399b2cee179093/doc/System_Overview.png -------------------------------------------------------------------------------- /n2m/__init__.py: -------------------------------------------------------------------------------- 1 | from .model.N2Mnet import N2Mnet 2 | from .data.dataset import N2MDataset 3 | -------------------------------------------------------------------------------- /n2m/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import random 6 | from torch.utils.data import Dataset 7 | import open3d as o3d 8 | 9 | from n2m.utils.point_cloud import apply_augmentations, fix_point_cloud_size 10 | 11 | def make_data_module(config): 12 | """ 13 | Make training dataset for SIR predictor. 14 | 15 | input(config): 16 | data_path: str 17 | anno_path: str 18 | use_color: bool 19 | pointnum: int 20 | 21 | return: 22 | train_dataset: Dataset 23 | val_dataset: Dataset 24 | """ 25 | if isinstance(config['anno_path'], str): 26 | anno_path = os.path.join(config['dataset_path'], config['anno_path']) 27 | with open(anno_path, 'r') as f: 28 | anno = json.load(f)['episodes'] 29 | train_num = int(len(anno) * config['train_val_ratio']) 30 | random.shuffle(anno) 31 | anno_train = anno[:train_num] 32 | anno_val = anno[train_num:] 33 | 34 | print(f"Train set size: {len(anno_train)}") 35 | print(f"Val set size: {len(anno_val)}") 36 | 37 | train_dataset = N2MDataset( 38 | config=config, 39 | anno=anno_train, 40 | ) 41 | val_dataset = N2MDataset( 42 | config=config, 43 | anno=anno_val, 44 | ) 45 | 46 | return train_dataset, val_dataset 47 | else: 48 | raise ValueError(f"Invalid config: {config}") 49 | 50 | class N2MDataset(Dataset): 51 | """Dataset for SIR predictor.""" 52 | def __init__(self, config, anno): 53 | self.anno = anno 54 | self.pointnum = config['pointnum'] 55 | self.config = config 56 | 57 | self.settings = config['settings'] if 'settings' in config else None 58 | 59 | if 'dataset_path' in config: 60 | self.dataset_path = config['dataset_path'] 61 | else: 62 | raise ValueError(f"Invalid config: {config}") 63 | 64 | def _load_point_cloud(self, file_path): 65 | if os.path.exists(file_path): 66 | pcd = o3d.io.read_point_cloud(file_path) 67 | if 'augmentations' in self.config and 'hpr' in self.config['augmentations']: 68 | pcd = self._apply_hpr(pcd, self.config['augmentations']['hpr']) 69 | point_cloud = np.asarray(pcd.points) 70 | colors = np.asarray(pcd.colors) 71 | point_cloud = np.concatenate([point_cloud, colors], axis=1) 72 | return point_cloud 73 | 74 | raise FileNotFoundError(f"No point cloud file found for object {file_path}") 75 | 76 | def __len__(self): 77 | """ 78 | Return number of samples in the dataset 79 | """ 80 | return len(self.anno) 81 | 82 | def __getitem__(self, index): 83 | """ 84 | Get a sample from the dataset 85 | """ 86 | data = self.anno[index] 87 | file_path = os.path.join(self.dataset_path, data['file_path']) 88 | 89 | # Load point cloud and target point 90 | point_cloud = self._load_point_cloud(file_path) 91 | target_point = np.array(data['pose'], dtype=np.float32) 92 | label = 1 93 | 94 | # Ensure point cloud has consistent size 95 | point_cloud = fix_point_cloud_size(point_cloud, self.pointnum) 96 | 97 | # Apply augmentations 98 | if 'augmentations' in self.config: 99 | point_cloud, target_point = apply_augmentations(point_cloud, target_point, self.config['augmentations']) 100 | 101 | # Convert to torch tensors 102 | point_cloud = torch.from_numpy(point_cloud.astype(np.float32)) 103 | target_point = torch.from_numpy(target_point) 104 | label = torch.tensor(label, dtype=torch.long) 105 | 106 | return { 107 | 'point_cloud': point_cloud, 108 | 'target_point': target_point, 109 | 'label': label, 110 | } 111 | -------------------------------------------------------------------------------- /n2m/model/N2Mnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from n2m.model.pointbert.point_encoder import PointTransformer 5 | import os 6 | from typing import Tuple, Dict 7 | 8 | 9 | class N2Mnet(nn.Module): 10 | """ 11 | N2Mnet predicts a Gaussian Mixture Model (GMM) from an input point cloud. 12 | """ 13 | 14 | def __init__(self, config: Dict): 15 | """ 16 | Initializes the N2Mnet model. 17 | 18 | Args: 19 | config (Dict): A configuration dictionary containing model, encoder, 20 | and decoder settings. 21 | """ 22 | super().__init__() 23 | self.encoder_config = config['encoder'] 24 | self.decoder_config = config['decoder'] 25 | 26 | self.num_gaussians = self.decoder_config['num_gaussians'] 27 | self.output_dim = self.decoder_config['output_dim'] 28 | 29 | encoder_output_dim = self._build_encoder() 30 | self._build_decoder(encoder_output_dim) 31 | 32 | if 'ckpt' in config and config['ckpt']: 33 | if not os.path.exists(config['ckpt']): 34 | raise FileNotFoundError(f"Checkpoint file not found: {config['ckpt']}") 35 | print(f"Loading model weights from {config['ckpt']}") 36 | self.load_state_dict(torch.load(config['ckpt'])['model_state_dict']) 37 | 38 | def _build_encoder(self) -> int: 39 | """ 40 | Builds the point cloud encoder based on the configuration. 41 | 42 | Returns: 43 | int: The output dimension of the encoder. 44 | """ 45 | if self.encoder_config['name'] == 'PointBERT': 46 | self.encoder = PointTransformer(self.encoder_config['config']) 47 | # Output dimension is doubled due to max and mean pooling in PointTransformer 48 | encoder_output_dim = self.encoder_config['config']['trans_dim'] * 2 49 | else: 50 | raise ValueError(f"Unsupported encoder: {self.encoder_config['name']}") 51 | 52 | # Load pretrained encoder weights if specified 53 | if 'ckpt' in self.encoder_config and self.encoder_config['ckpt']: 54 | ckpt_path = self.encoder_config['ckpt'] 55 | if not os.path.exists(ckpt_path): 56 | raise FileNotFoundError(f"Encoder checkpoint not found: {ckpt_path}") 57 | print(f"Loading encoder weights from {ckpt_path}") 58 | self.encoder.load_checkpoint(ckpt_path) 59 | 60 | # Freeze encoder weights if specified 61 | if self.encoder_config.get('freeze', False): 62 | for param in self.encoder.parameters(): 63 | param.requires_grad = False 64 | 65 | return encoder_output_dim 66 | 67 | def _build_decoder(self, decoder_input_dim: int): 68 | """ 69 | Builds the GMM parameter decoder based on the configuration. 70 | 71 | Args: 72 | decoder_input_dim (int): The input dimension for the decoder. 73 | """ 74 | # Each Gaussian component requires parameters for mean, covariance, and mixing weight 75 | gmm_params_per_gaussian = self.output_dim + (self.output_dim ** 2) + 1 76 | decoder_output_dim = self.num_gaussians * gmm_params_per_gaussian 77 | 78 | if self.decoder_config['name'] == 'mlp': 79 | layers = self.decoder_config.get('layers', [512, 256]) 80 | 81 | decoder_layers = [] 82 | prev_dim = decoder_input_dim 83 | for layer_dim in layers: 84 | decoder_layers.extend([ 85 | nn.Linear(prev_dim, layer_dim), 86 | nn.ReLU(), 87 | nn.Dropout(self.decoder_config['config']['dropout']) 88 | ]) 89 | prev_dim = layer_dim 90 | 91 | decoder_layers.append(nn.Linear(prev_dim, decoder_output_dim)) 92 | self.decoder = nn.Sequential(*decoder_layers) 93 | else: 94 | raise ValueError(f"Unsupported decoder: {self.decoder_config['name']}") 95 | 96 | def _construct_covariance_matrices(self, sigma_params: torch.Tensor) -> torch.Tensor: 97 | """ 98 | Constructs positive semidefinite covariance matrices from raw network outputs. 99 | 100 | This method ensures symmetry and positive definiteness using the matrix 101 | exponential. A small identity matrix is added for numerical stability before 102 | the exponential. 103 | 104 | Args: 105 | sigma_params (torch.Tensor): Raw covariance parameters from the decoder, 106 | with shape (B, K, D, D), where B is batch 107 | size, K is the number of Gaussians, and D 108 | is the output dimension. 109 | 110 | Returns: 111 | torch.Tensor: Valid, positive semidefinite covariance matrices of shape 112 | (B, K, D, D). 113 | """ 114 | B, K, D, _ = sigma_params.shape 115 | 116 | # Enforce symmetry for the input to matrix exponential 117 | sigma_params = 0.5 * (sigma_params + sigma_params.transpose(-2, -1)) 118 | 119 | # Add a small diagonal epsilon for numerical stability 120 | eye = torch.eye(D, device=sigma_params.device).unsqueeze(0).unsqueeze(0) 121 | sigma_params = sigma_params + 1e-6 * eye 122 | 123 | # The matrix exponential of a symmetric matrix is symmetric positive definite 124 | covs = torch.matrix_exp(sigma_params) 125 | 126 | return covs 127 | 128 | def forward(self, point_cloud: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 129 | """ 130 | Performs the forward pass to predict GMM parameters from a point cloud. 131 | 132 | Args: 133 | point_cloud (torch.Tensor): Input point cloud of shape (B, N, C), where B 134 | is batch size, N is the number of points, 135 | and C is feature dimension. 136 | 137 | Returns: 138 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: 139 | - means (torch.Tensor): GMM means (B, K, D). 140 | - covs (torch.Tensor): GMM covariance matrices (B, K, D, D). 141 | - weights (torch.Tensor): GMM mixing weights (B, K). 142 | """ 143 | # Encode point cloud to a global feature vector 144 | features, _ = self.encoder(point_cloud) 145 | features = features.squeeze(1) # (B, encoder_output_dim) 146 | 147 | # Decode features into a flat tensor of GMM parameters 148 | gmm_params = self.decoder(features) 149 | 150 | # Reshape the flat tensor to extract means, covariance parameters, and weights 151 | batch_size = gmm_params.size(0) 152 | 153 | # Extract means 154 | means_end = self.num_gaussians * self.output_dim 155 | means = gmm_params[:, :means_end].view( 156 | batch_size, self.num_gaussians, self.output_dim 157 | ) 158 | 159 | # Extract raw covariance parameters 160 | cov_end = means_end + self.num_gaussians * (self.output_dim ** 2) 161 | sigma_params = gmm_params[:, means_end:cov_end].view( 162 | batch_size, self.num_gaussians, self.output_dim, self.output_dim 163 | ) 164 | covs = self._construct_covariance_matrices(sigma_params) 165 | 166 | # Extract and normalize mixing weights 167 | weights = gmm_params[:, -self.num_gaussians:].view(batch_size, self.num_gaussians) 168 | weights = torch.softmax(weights, dim=-1) 169 | 170 | return means, covs, weights 171 | 172 | def sample(self, point_cloud: torch.Tensor, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 173 | """ 174 | Samples points from the predicted GMM for a given point cloud. 175 | 176 | Args: 177 | point_cloud (torch.Tensor): Input point cloud of shape (B, N, C). 178 | num_samples (int): Number of points to sample for each item in the batch. 179 | 180 | Returns: 181 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of: 182 | - samples (torch.Tensor): Sampled points (B, num_samples, D). 183 | - means (torch.Tensor): GMM means (B, K, D). 184 | - covs (torch.Tensor): GMM covariances (B, K, D, D). 185 | - weights (torch.Tensor): GMM weights (B, K). 186 | """ 187 | means, covs, weights = self.forward(point_cloud) 188 | batch_size = means.size(0) 189 | 190 | samples = torch.zeros(batch_size, num_samples, self.output_dim, device=means.device) 191 | 192 | # Process each item in the batch independently 193 | for b in range(batch_size): 194 | # 1. Sample component indices based on the mixture weights 195 | component_indices = torch.multinomial(weights[b], num_samples, replacement=True) 196 | 197 | # 2. Sample from the corresponding Gaussian for each chosen component 198 | for i in range(self.num_gaussians): 199 | # Find which samples belong to the current component 200 | mask = (component_indices == i) 201 | num_component_samples = mask.sum().item() 202 | 203 | if num_component_samples > 0: 204 | dist = torch.distributions.MultivariateNormal( 205 | loc=means[b, i], 206 | covariance_matrix=covs[b, i] 207 | ) 208 | samples[b, mask] = dist.sample((num_component_samples,)) 209 | 210 | return samples, means, covs, weights -------------------------------------------------------------------------------- /n2m/model/pointbert/checkpoint.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch.nn as nn 3 | 4 | from typing import Any 5 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable 6 | 7 | from termcolor import colored 8 | 9 | def get_missing_parameters_message(keys: List[str]) -> str: 10 | """ 11 | Get a logging-friendly message to report parameter names (keys) that are in 12 | the model but not found in a checkpoint. 13 | Args: 14 | keys (list[str]): List of keys that were not found in the checkpoint. 15 | Returns: 16 | str: message. 17 | """ 18 | groups = _group_checkpoint_keys(keys) 19 | msg = "Some model parameters or buffers are not found in the checkpoint:\n" 20 | msg += "\n".join( 21 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() 22 | ) 23 | return msg 24 | 25 | 26 | def get_unexpected_parameters_message(keys: List[str]) -> str: 27 | """ 28 | Get a logging-friendly message to report parameter names (keys) that are in 29 | the checkpoint but not found in the model. 30 | Args: 31 | keys (list[str]): List of keys that were not found in the model. 32 | Returns: 33 | str: message. 34 | """ 35 | groups = _group_checkpoint_keys(keys) 36 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n" 37 | msg += "\n".join( 38 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() 39 | ) 40 | return msg 41 | 42 | 43 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: 44 | """ 45 | Strip the prefix in metadata, if any. 46 | Args: 47 | state_dict (OrderedDict): a state-dict to be loaded to the model. 48 | prefix (str): prefix. 49 | """ 50 | keys = sorted(state_dict.keys()) 51 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys): 52 | return 53 | 54 | for key in keys: 55 | newkey = key[len(prefix):] 56 | state_dict[newkey] = state_dict.pop(key) 57 | 58 | # also strip the prefix in metadata, if any.. 59 | try: 60 | metadata = state_dict._metadata # pyre-ignore 61 | except AttributeError: 62 | pass 63 | else: 64 | for key in list(metadata.keys()): 65 | # for the metadata dict, the key can be: 66 | # '': for the DDP module, which we want to remove. 67 | # 'module': for the actual model. 68 | # 'module.xx.xx': for the rest. 69 | 70 | if len(key) == 0: 71 | continue 72 | newkey = key[len(prefix):] 73 | metadata[newkey] = metadata.pop(key) 74 | 75 | 76 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: 77 | """ 78 | Group keys based on common prefixes. A prefix is the string up to the final 79 | "." in each key. 80 | Args: 81 | keys (list[str]): list of parameter names, i.e. keys in the model 82 | checkpoint dict. 83 | Returns: 84 | dict[list]: keys with common prefixes are grouped into lists. 85 | """ 86 | groups = defaultdict(list) 87 | for key in keys: 88 | pos = key.rfind(".") 89 | if pos >= 0: 90 | head, tail = key[:pos], [key[pos + 1:]] 91 | else: 92 | head, tail = key, [] 93 | groups[head].extend(tail) 94 | return groups 95 | 96 | 97 | def _group_to_str(group: List[str]) -> str: 98 | """ 99 | Format a group of parameter name suffixes into a loggable string. 100 | Args: 101 | group (list[str]): list of parameter name suffixes. 102 | Returns: 103 | str: formated string. 104 | """ 105 | if len(group) == 0: 106 | return "" 107 | 108 | if len(group) == 1: 109 | return "." + group[0] 110 | 111 | return ".{" + ", ".join(group) + "}" 112 | 113 | 114 | def _named_modules_with_dup( 115 | model: nn.Module, prefix: str = "" 116 | ) -> Iterable[Tuple[str, nn.Module]]: 117 | """ 118 | The same as `model.named_modules()`, except that it includes 119 | duplicated modules that have more than one name. 120 | """ 121 | yield prefix, model 122 | for name, module in model._modules.items(): # pyre-ignore 123 | if module is None: 124 | continue 125 | submodule_prefix = prefix + ("." if prefix else "") + name 126 | yield from _named_modules_with_dup(module, submodule_prefix) -------------------------------------------------------------------------------- /n2m/model/pointbert/dvae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from . import misc 5 | 6 | # from knn_cuda import KNN 7 | 8 | # knn = KNN(k=4, transpose_mode=False) 9 | 10 | 11 | class DGCNN(nn.Module): 12 | def __init__(self, encoder_channel, output_channel): 13 | super().__init__() 14 | ''' 15 | K has to be 16 16 | ''' 17 | self.input_trans = nn.Conv1d(encoder_channel, 128, 1) 18 | 19 | self.layer1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False), 20 | nn.GroupNorm(4, 256), 21 | nn.LeakyReLU(negative_slope=0.2) 22 | ) 23 | 24 | self.layer2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, bias=False), 25 | nn.GroupNorm(4, 512), 26 | nn.LeakyReLU(negative_slope=0.2) 27 | ) 28 | 29 | self.layer3 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=1, bias=False), 30 | nn.GroupNorm(4, 512), 31 | nn.LeakyReLU(negative_slope=0.2) 32 | ) 33 | 34 | self.layer4 = nn.Sequential(nn.Conv2d(1024, 1024, kernel_size=1, bias=False), 35 | nn.GroupNorm(4, 1024), 36 | nn.LeakyReLU(negative_slope=0.2) 37 | ) 38 | 39 | self.layer5 = nn.Sequential(nn.Conv1d(2304, output_channel, kernel_size=1, bias=False), 40 | nn.GroupNorm(4, output_channel), 41 | nn.LeakyReLU(negative_slope=0.2) 42 | ) 43 | 44 | @staticmethod 45 | def get_graph_feature(coor_q, x_q, coor_k, x_k): 46 | # coor: bs, 3, np, x: bs, c, np 47 | 48 | k = 4 49 | batch_size = x_k.size(0) 50 | num_points_k = x_k.size(2) 51 | num_points_q = x_q.size(2) 52 | 53 | with torch.no_grad(): 54 | _, idx = knn(coor_k, coor_q) # bs k np 55 | assert idx.shape[1] == k 56 | idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k 57 | idx = idx + idx_base 58 | idx = idx.view(-1) 59 | num_dims = x_k.size(1) 60 | x_k = x_k.transpose(2, 1).contiguous() 61 | feature = x_k.view(batch_size * num_points_k, -1)[idx, :] 62 | feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous() 63 | x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k) 64 | feature = torch.cat((feature - x_q, x_q), dim=1) 65 | return feature 66 | 67 | def forward(self, f, coor): 68 | # f: B G C 69 | # coor: B G 3 70 | 71 | # bs 3 N bs C N 72 | feature_list = [] 73 | coor = coor.transpose(1, 2).contiguous() # B 3 N 74 | f = f.transpose(1, 2).contiguous() # B C N 75 | f = self.input_trans(f) # B 128 N 76 | 77 | f = self.get_graph_feature(coor, f, coor, f) # B 256 N k 78 | f = self.layer1(f) # B 256 N k 79 | f = f.max(dim=-1, keepdim=False)[0] # B 256 N 80 | feature_list.append(f) 81 | 82 | f = self.get_graph_feature(coor, f, coor, f) # B 512 N k 83 | f = self.layer2(f) # B 512 N k 84 | f = f.max(dim=-1, keepdim=False)[0] # B 512 N 85 | feature_list.append(f) 86 | 87 | f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k 88 | f = self.layer3(f) # B 512 N k 89 | f = f.max(dim=-1, keepdim=False)[0] # B 512 N 90 | feature_list.append(f) 91 | 92 | f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k 93 | f = self.layer4(f) # B 1024 N k 94 | f = f.max(dim=-1, keepdim=False)[0] # B 1024 N 95 | feature_list.append(f) 96 | 97 | f = torch.cat(feature_list, dim=1) # B 2304 N 98 | 99 | f = self.layer5(f) # B C' N 100 | 101 | f = f.transpose(-1, -2) 102 | 103 | return f 104 | 105 | 106 | ### ref https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py ### 107 | def knn_point(nsample, xyz, new_xyz): 108 | """ 109 | Input: 110 | nsample: max sample number in local region 111 | xyz: all points, [B, N, C] 112 | new_xyz: query points, [B, S, C] 113 | Return: 114 | group_idx: grouped points index, [B, S, nsample] 115 | """ 116 | sqrdists = square_distance(new_xyz, xyz) 117 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) 118 | return group_idx 119 | 120 | 121 | def square_distance(src, dst): 122 | """ 123 | Calculate Euclid distance between each two points. 124 | src^T * dst = xn * xm + yn * ym + zn * zm; 125 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 126 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 127 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 128 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 129 | Input: 130 | src: source points, [B, N, C] 131 | dst: target points, [B, M, C] 132 | Output: 133 | dist: per-point square distance, [B, N, M] 134 | """ 135 | B, N, _ = src.shape 136 | _, M, _ = dst.shape 137 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 138 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 139 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 140 | return dist 141 | 142 | 143 | class Group(nn.Module): 144 | def __init__(self, num_group, group_size): 145 | super().__init__() 146 | self.num_group = num_group 147 | self.group_size = group_size 148 | # self.knn = KNN(k=self.group_size, transpose_mode=True) 149 | 150 | def forward(self, xyz): 151 | ''' 152 | input: B N 3 153 | --------------------------- 154 | output: B G M 3 155 | center : B G 3 156 | ''' 157 | B, N, C = xyz.shape 158 | if C > 3: 159 | data = xyz 160 | xyz = data[:,:,:3] 161 | rgb = data[:, :, 3:] 162 | batch_size, num_points, _ = xyz.shape 163 | # fps the centers out 164 | center = misc.fps(xyz, self.num_group) # B G 3 165 | 166 | # knn to get the neighborhood 167 | # _, idx = self.knn(xyz, center) # B G M 168 | idx = knn_point(self.group_size, xyz, center) # B G M 169 | assert idx.size(1) == self.num_group 170 | assert idx.size(2) == self.group_size 171 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 172 | idx = idx + idx_base 173 | idx = idx.view(-1) 174 | 175 | neighborhood_xyz = xyz.view(batch_size * num_points, -1)[idx, :] 176 | neighborhood_xyz = neighborhood_xyz.view(batch_size, self.num_group, self.group_size, 3).contiguous() 177 | if C > 3: 178 | neighborhood_rgb = rgb.view(batch_size * num_points, -1)[idx, :] 179 | neighborhood_rgb = neighborhood_rgb.view(batch_size, self.num_group, self.group_size, -1).contiguous() 180 | 181 | # normalize xyz 182 | neighborhood_xyz = neighborhood_xyz - center.unsqueeze(2) 183 | if C > 3: 184 | neighborhood = torch.cat((neighborhood_xyz, neighborhood_rgb), dim=-1) 185 | else: 186 | neighborhood = neighborhood_xyz 187 | return neighborhood, center 188 | 189 | class Encoder(nn.Module): 190 | def __init__(self, encoder_channel, point_input_dims=3): 191 | super().__init__() 192 | self.encoder_channel = encoder_channel 193 | self.point_input_dims = point_input_dims 194 | self.first_conv = nn.Sequential( 195 | nn.Conv1d(self.point_input_dims, 128, 1), 196 | nn.BatchNorm1d(128), 197 | nn.ReLU(inplace=True), 198 | nn.Conv1d(128, 256, 1) 199 | ) 200 | self.second_conv = nn.Sequential( 201 | nn.Conv1d(512, 512, 1), 202 | nn.BatchNorm1d(512), 203 | nn.ReLU(inplace=True), 204 | nn.Conv1d(512, self.encoder_channel, 1) 205 | ) 206 | 207 | def forward(self, point_groups): 208 | ''' 209 | point_groups : B G N 3 210 | ----------------- 211 | feature_global : B G C 212 | ''' 213 | bs, g, n, c = point_groups.shape 214 | point_groups = point_groups.reshape(bs * g, n, c) 215 | # encoder 216 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n 217 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1 218 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n 219 | feature = self.second_conv(feature) # BG 1024 n 220 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 221 | return feature_global.reshape(bs, g, self.encoder_channel) 222 | 223 | 224 | class Decoder(nn.Module): 225 | def __init__(self, encoder_channel, num_fine): 226 | super().__init__() 227 | self.num_fine = num_fine 228 | self.grid_size = 2 229 | self.num_coarse = self.num_fine // 4 230 | assert num_fine % 4 == 0 231 | 232 | self.mlp = nn.Sequential( 233 | nn.Linear(encoder_channel, 1024), 234 | nn.ReLU(inplace=True), 235 | nn.Linear(1024, 1024), 236 | nn.ReLU(inplace=True), 237 | nn.Linear(1024, 3 * self.num_coarse) 238 | ) 239 | self.final_conv = nn.Sequential( 240 | nn.Conv1d(encoder_channel + 3 + 2, 512, 1), 241 | nn.BatchNorm1d(512), 242 | nn.ReLU(inplace=True), 243 | nn.Conv1d(512, 512, 1), 244 | nn.BatchNorm1d(512), 245 | nn.ReLU(inplace=True), 246 | nn.Conv1d(512, 3, 1) 247 | ) 248 | a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand( 249 | self.grid_size, self.grid_size).reshape(1, -1) 250 | b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand( 251 | self.grid_size, self.grid_size).reshape(1, -1) 252 | self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2) # 1 2 S 253 | 254 | def forward(self, feature_global): 255 | ''' 256 | feature_global : B G C 257 | ------- 258 | coarse : B G M 3 259 | fine : B G N 3 260 | 261 | ''' 262 | bs, g, c = feature_global.shape 263 | feature_global = feature_global.reshape(bs * g, c) 264 | 265 | coarse = self.mlp(feature_global).reshape(bs * g, self.num_coarse, 3) # BG M 3 266 | 267 | point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3 268 | point_feat = point_feat.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N 269 | 270 | seed = self.folding_seed.unsqueeze(2).expand(bs * g, -1, self.num_coarse, -1) # BG 2 M (S) 271 | seed = seed.reshape(bs * g, -1, self.num_fine).to(feature_global.device) # BG 2 N 272 | 273 | feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_fine) # BG 1024 N 274 | feat = torch.cat([feature_global, seed, point_feat], dim=1) # BG C N 275 | 276 | center = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3 277 | center = center.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N 278 | 279 | fine = self.final_conv(feat) + center # BG 3 N 280 | fine = fine.reshape(bs, g, 3, self.num_fine).transpose(-1, -2) 281 | coarse = coarse.reshape(bs, g, self.num_coarse, 3) 282 | return coarse, fine 283 | 284 | 285 | class DiscreteVAE(nn.Module): 286 | def __init__(self, config, **kwargs): 287 | super().__init__() 288 | self.group_size = config.group_size 289 | self.num_group = config.num_group 290 | self.encoder_dims = config.encoder_dims 291 | self.tokens_dims = config.tokens_dims 292 | 293 | self.decoder_dims = config.decoder_dims 294 | self.num_tokens = config.num_tokens 295 | 296 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 297 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 298 | self.dgcnn_1 = DGCNN(encoder_channel=self.encoder_dims, output_channel=self.num_tokens) 299 | self.codebook = nn.Parameter(torch.randn(self.num_tokens, self.tokens_dims)) 300 | 301 | self.dgcnn_2 = DGCNN(encoder_channel=self.tokens_dims, output_channel=self.decoder_dims) 302 | self.decoder = Decoder(encoder_channel=self.decoder_dims, num_fine=self.group_size) 303 | # self.build_loss_func() 304 | 305 | # def build_loss_func(self): 306 | # self.loss_func_cdl1 = ChamferDistanceL1().cuda() 307 | # self.loss_func_cdl2 = ChamferDistanceL2().cuda() 308 | # self.loss_func_emd = emd().cuda() 309 | 310 | def recon_loss(self, ret, gt): 311 | whole_coarse, whole_fine, coarse, fine, group_gt, _ = ret 312 | 313 | bs, g, _, _ = coarse.shape 314 | 315 | coarse = coarse.reshape(bs * g, -1, 3).contiguous() 316 | fine = fine.reshape(bs * g, -1, 3).contiguous() 317 | group_gt = group_gt.reshape(bs * g, -1, 3).contiguous() 318 | 319 | loss_coarse_block = self.loss_func_cdl1(coarse, group_gt) 320 | loss_fine_block = self.loss_func_cdl1(fine, group_gt) 321 | 322 | loss_recon = loss_coarse_block + loss_fine_block 323 | 324 | return loss_recon 325 | 326 | def get_loss(self, ret, gt): 327 | # reconstruction loss 328 | loss_recon = self.recon_loss(ret, gt) 329 | # kl divergence 330 | logits = ret[-1] # B G N 331 | softmax = F.softmax(logits, dim=-1) 332 | mean_softmax = softmax.mean(dim=1) 333 | log_qy = torch.log(mean_softmax) 334 | log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=gt.device)) 335 | loss_klv = F.kl_div(log_qy, log_uniform.expand(log_qy.size(0), log_qy.size(1)), None, None, 'batchmean', 336 | log_target=True) 337 | 338 | return loss_recon, loss_klv 339 | 340 | def forward(self, inp, temperature=1., hard=False, **kwargs): 341 | neighborhood, center = self.group_divider(inp) 342 | logits = self.encoder(neighborhood) # B G C 343 | logits = self.dgcnn_1(logits, center) # B G N 344 | soft_one_hot = F.gumbel_softmax(logits, tau=temperature, dim=2, hard=hard) # B G N 345 | sampled = torch.einsum('b g n, n c -> b g c', soft_one_hot, self.codebook) # B G C 346 | feature = self.dgcnn_2(sampled, center) 347 | coarse, fine = self.decoder(feature) 348 | 349 | with torch.no_grad(): 350 | whole_fine = (fine + center.unsqueeze(2)).reshape(inp.size(0), -1, 3) 351 | whole_coarse = (coarse + center.unsqueeze(2)).reshape(inp.size(0), -1, 3) 352 | 353 | assert fine.size(2) == self.group_size 354 | ret = (whole_coarse, whole_fine, coarse, fine, neighborhood, logits) 355 | return ret -------------------------------------------------------------------------------- /n2m/model/pointbert/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /n2m/model/pointbert/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from collections import abc 10 | # from pointnet2_ops import pointnet2_utils 11 | 12 | 13 | # def fps(data, number): 14 | # ''' 15 | # data B N 3 16 | # number int 17 | # ''' 18 | # fps_idx = pointnet2_utils.furthest_point_sample(data, number) 19 | # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 20 | # return fps_data 21 | 22 | def index_points(points, idx): 23 | """ 24 | Input: 25 | points: input points data, [B, N, C] 26 | idx: sample index data, [B, S] 27 | Return: 28 | new_points:, indexed points data, [B, S, C] 29 | """ 30 | device = points.device 31 | B = points.shape[0] 32 | view_shape = list(idx.shape) 33 | view_shape[1:] = [1] * (len(view_shape) - 1) 34 | repeat_shape = list(idx.shape) 35 | repeat_shape[0] = 1 36 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 37 | new_points = points[batch_indices, idx, :] 38 | return new_points 39 | 40 | def fps(xyz, npoint): 41 | """ 42 | Input: 43 | xyz: pointcloud data, [B, N, 3] 44 | npoint: number of samples 45 | Return: 46 | centroids: sampled pointcloud index, [B, npoint] 47 | """ 48 | device = xyz.device 49 | B, N, C = xyz.shape 50 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 51 | distance = torch.ones(B, N).to(device) * 1e10 52 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 53 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 54 | for i in range(npoint): 55 | centroids[:, i] = farthest 56 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 57 | dist = torch.sum((xyz - centroid) ** 2, -1) 58 | distance = torch.min(distance, dist) 59 | farthest = torch.max(distance, -1)[1] 60 | return index_points(xyz, centroids) 61 | 62 | def worker_init_fn(worker_id): 63 | np.random.seed(np.random.get_state()[1][0] + worker_id) 64 | 65 | def build_lambda_sche(opti, config): 66 | if config.get('decay_step') is not None: 67 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) 68 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) 69 | else: 70 | raise NotImplementedError() 71 | return scheduler 72 | 73 | def build_lambda_bnsche(model, config): 74 | if config.get('decay_step') is not None: 75 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay) 76 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) 77 | else: 78 | raise NotImplementedError() 79 | return bnm_scheduler 80 | 81 | def set_random_seed(seed, deterministic=False): 82 | """Set random seed. 83 | Args: 84 | seed (int): Seed to be used. 85 | deterministic (bool): Whether to set the deterministic option for 86 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 87 | to True and `torch.backends.cudnn.benchmark` to False. 88 | Default: False. 89 | 90 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 91 | if cuda_deterministic: # slower, more reproducible 92 | cudnn.deterministic = True 93 | cudnn.benchmark = False 94 | else: # faster, less reproducible 95 | cudnn.deterministic = False 96 | cudnn.benchmark = True 97 | 98 | """ 99 | random.seed(seed) 100 | np.random.seed(seed) 101 | torch.manual_seed(seed) 102 | torch.cuda.manual_seed_all(seed) 103 | if deterministic: 104 | torch.backends.cudnn.deterministic = True 105 | torch.backends.cudnn.benchmark = False 106 | 107 | 108 | def is_seq_of(seq, expected_type, seq_type=None): 109 | """Check whether it is a sequence of some type. 110 | Args: 111 | seq (Sequence): The sequence to be checked. 112 | expected_type (type): Expected type of sequence items. 113 | seq_type (type, optional): Expected sequence type. 114 | Returns: 115 | bool: Whether the sequence is valid. 116 | """ 117 | if seq_type is None: 118 | exp_seq_type = abc.Sequence 119 | else: 120 | assert isinstance(seq_type, type) 121 | exp_seq_type = seq_type 122 | if not isinstance(seq, exp_seq_type): 123 | return False 124 | for item in seq: 125 | if not isinstance(item, expected_type): 126 | return False 127 | return True 128 | 129 | 130 | def set_bn_momentum_default(bn_momentum): 131 | def fn(m): 132 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 133 | m.momentum = bn_momentum 134 | return fn 135 | 136 | class BNMomentumScheduler(object): 137 | 138 | def __init__( 139 | self, model, bn_lambda, last_epoch=-1, 140 | setter=set_bn_momentum_default 141 | ): 142 | if not isinstance(model, nn.Module): 143 | raise RuntimeError( 144 | "Class '{}' is not a PyTorch nn Module".format( 145 | type(model).__name__ 146 | ) 147 | ) 148 | 149 | self.model = model 150 | self.setter = setter 151 | self.lmbd = bn_lambda 152 | 153 | self.step(last_epoch + 1) 154 | self.last_epoch = last_epoch 155 | 156 | def step(self, epoch=None): 157 | if epoch is None: 158 | epoch = self.last_epoch + 1 159 | 160 | self.last_epoch = epoch 161 | self.model.apply(self.setter(self.lmbd(epoch))) 162 | 163 | def get_momentum(self, epoch=None): 164 | if epoch is None: 165 | epoch = self.last_epoch + 1 166 | return self.lmbd(epoch) 167 | 168 | 169 | 170 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False): 171 | ''' 172 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. 173 | ''' 174 | _,n,c = xyz.shape 175 | 176 | assert n == num_points 177 | assert c == 3 178 | if crop == num_points: 179 | return xyz, None 180 | 181 | INPUT = [] 182 | CROP = [] 183 | for points in xyz: 184 | if isinstance(crop,list): 185 | num_crop = random.randint(crop[0],crop[1]) 186 | else: 187 | num_crop = crop 188 | 189 | points = points.unsqueeze(0) 190 | 191 | if fixed_points is None: 192 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda() 193 | else: 194 | if isinstance(fixed_points,list): 195 | fixed_point = random.sample(fixed_points,1)[0] 196 | else: 197 | fixed_point = fixed_points 198 | center = fixed_point.reshape(1,1,3).cuda() 199 | 200 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048 201 | 202 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048 203 | 204 | if padding_zeros: 205 | input_data = points.clone() 206 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0 207 | 208 | else: 209 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 210 | 211 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) 212 | 213 | if isinstance(crop,list): 214 | INPUT.append(fps(input_data,2048)) 215 | CROP.append(fps(crop_data,2048)) 216 | else: 217 | INPUT.append(input_data) 218 | CROP.append(crop_data) 219 | 220 | input_data = torch.cat(INPUT,dim=0)# B N 3 221 | crop_data = torch.cat(CROP,dim=0)# B M 3 222 | 223 | return input_data.contiguous(), crop_data.contiguous() 224 | 225 | def get_ptcloud_img(ptcloud): 226 | fig = plt.figure(figsize=(8, 8)) 227 | 228 | x, z, y = ptcloud.transpose(1, 0) 229 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 230 | ax.axis('off') 231 | # ax.axis('scaled') 232 | ax.view_init(30, 45) 233 | max, min = np.max(ptcloud), np.min(ptcloud) 234 | ax.set_xbound(min, max) 235 | ax.set_ybound(min, max) 236 | ax.set_zbound(min, max) 237 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet') 238 | 239 | fig.canvas.draw() 240 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 241 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 242 | return img 243 | 244 | 245 | 246 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y', 247 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ): 248 | fig = plt.figure(figsize=(6*len(data_list),6)) 249 | cmax = data_list[-1][:,0].max() 250 | 251 | for i in range(len(data_list)): 252 | data = data_list[i][:-2048] if i == 1 else data_list[i] 253 | color = data[:,0] /cmax 254 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d') 255 | ax.view_init(30, -120) 256 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black') 257 | ax.set_title(titles[i]) 258 | 259 | ax.set_axis_off() 260 | ax.set_xlim(xlim) 261 | ax.set_ylim(ylim) 262 | ax.set_zlim(zlim) 263 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) 264 | if not os.path.exists(path): 265 | os.makedirs(path) 266 | 267 | pic_path = path + '.png' 268 | fig.savefig(pic_path) 269 | 270 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy()) 271 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy()) 272 | plt.close(fig) 273 | 274 | 275 | def random_dropping(pc, e): 276 | up_num = max(64, 768 // (e//50 + 1)) 277 | pc = pc 278 | random_num = torch.randint(1, up_num, (1,1))[0,0] 279 | pc = fps(pc, random_num) 280 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) 281 | pc = torch.cat([pc, padding], dim = 1) 282 | return pc 283 | 284 | 285 | def random_scale(partial, scale_range=[0.8, 1.2]): 286 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] 287 | return partial * scale 288 | -------------------------------------------------------------------------------- /n2m/model/pointbert/point_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath 4 | from .dvae import Group 5 | from .dvae import Encoder 6 | from .logger import print_log 7 | from collections import OrderedDict 8 | 9 | from .checkpoint import get_missing_parameters_message, get_unexpected_parameters_message 10 | 11 | class Mlp(nn.Module): 12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 13 | super().__init__() 14 | out_features = out_features or in_features 15 | hidden_features = hidden_features or in_features 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.act = act_layer() 18 | self.fc2 = nn.Linear(hidden_features, out_features) 19 | self.drop = nn.Dropout(drop) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.drop(x) 25 | x = self.fc2(x) 26 | x = self.drop(x) 27 | return x 28 | 29 | 30 | class Attention(nn.Module): 31 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 32 | super().__init__() 33 | self.num_heads = num_heads 34 | head_dim = dim // num_heads 35 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 36 | self.scale = qk_scale or head_dim ** -0.5 37 | 38 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 39 | self.attn_drop = nn.Dropout(attn_drop) 40 | self.proj = nn.Linear(dim, dim) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | 43 | def forward(self, x): 44 | B, N, C = x.shape 45 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 46 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 47 | 48 | attn = (q @ k.transpose(-2, -1)) * self.scale 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | 52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 53 | x = self.proj(x) 54 | x = self.proj_drop(x) 55 | return x, attn 56 | 57 | 58 | class Block(nn.Module): 59 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 60 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 61 | super().__init__() 62 | self.norm1 = norm_layer(dim) 63 | 64 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 65 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 66 | self.norm2 = norm_layer(dim) 67 | mlp_hidden_dim = int(dim * mlp_ratio) 68 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 69 | 70 | self.attn = Attention( 71 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 72 | 73 | def forward(self, x): 74 | attn_out, attn_weights = self.attn(self.norm1(x)) 75 | x = x + self.drop_path(attn_out) 76 | x = x + self.drop_path(self.mlp(self.norm2(x))) 77 | return x, attn_weights 78 | 79 | 80 | class TransformerEncoder(nn.Module): 81 | """ Transformer Encoder without hierarchical structure 82 | """ 83 | 84 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 85 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 86 | super().__init__() 87 | 88 | self.blocks = nn.ModuleList([ 89 | Block( 90 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 91 | drop=drop_rate, attn_drop=attn_drop_rate, 92 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 93 | ) 94 | for i in range(depth)]) 95 | 96 | def forward(self, x, pos): 97 | for i, block in enumerate(self.blocks): 98 | x, attn_weights = block(x + pos) 99 | if i == len(self.blocks) - 1: 100 | final_attn_weights = attn_weights 101 | return x, final_attn_weights 102 | 103 | 104 | class PointTransformer(nn.Module): 105 | def __init__(self, config, use_max_pool=True): 106 | super().__init__() 107 | self.config = config 108 | 109 | self.use_max_pool = use_max_pool # * whethet to max pool the features of different tokens 110 | 111 | self.trans_dim = config['trans_dim'] 112 | self.depth = config['depth'] 113 | self.drop_path_rate = config['drop_path_rate'] 114 | self.cls_dim = config['cls_dim'] 115 | self.num_heads = config['num_heads'] 116 | self.point_dims = 6 117 | self.group_size = config['group_size'] 118 | self.num_group = config['num_group'] 119 | # grouper 120 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 121 | # define the encoder 122 | self.encoder_dims = config['encoder_dims'] 123 | self.encoder = Encoder(encoder_channel=self.encoder_dims, point_input_dims=self.point_dims) 124 | # bridge encoder and transformer 125 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) 126 | 127 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 128 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) 129 | 130 | self.pos_embed = nn.Sequential( 131 | nn.Linear(3, 128), 132 | nn.GELU(), 133 | nn.Linear(128, self.trans_dim) 134 | ) 135 | 136 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 137 | self.blocks = TransformerEncoder( 138 | embed_dim=self.trans_dim, 139 | depth=self.depth, 140 | drop_path_rate=dpr, 141 | num_heads=self.num_heads 142 | ) 143 | 144 | self.norm = nn.LayerNorm(self.trans_dim) 145 | 146 | def load_checkpoint(self, bert_ckpt_path): 147 | ckpt = torch.load(bert_ckpt_path, map_location='cpu') 148 | 149 | incompatible = self.load_state_dict(ckpt['base_model'], strict=False) 150 | 151 | if incompatible.missing_keys: 152 | print_log('missing_keys', logger='Transformer') 153 | print_log( 154 | get_missing_parameters_message(incompatible.missing_keys), 155 | logger='Transformer' 156 | ) 157 | if incompatible.unexpected_keys: 158 | print_log('unexpected_keys', logger='Transformer') 159 | print_log( 160 | get_unexpected_parameters_message(incompatible.unexpected_keys), 161 | logger='Transformer' 162 | ) 163 | if not incompatible.missing_keys and not incompatible.unexpected_keys: 164 | # * print successful loading 165 | print_log("PointBERT's weights are successfully loaded from {}".format(bert_ckpt_path), logger='Transformer') 166 | 167 | def forward(self, pts): 168 | # divide the point cloud in the same form. This is important 169 | neighborhood, center = self.group_divider(pts) 170 | # encoder the input cloud blocks 171 | group_input_tokens = self.encoder(neighborhood) # B G N 172 | group_input_tokens = self.reduce_dim(group_input_tokens) 173 | # prepare cls 174 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) 175 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 176 | # add pos embedding 177 | pos = self.pos_embed(center) 178 | # final input 179 | x = torch.cat((cls_tokens, group_input_tokens), dim=1) 180 | pos = torch.cat((cls_pos, pos), dim=1) 181 | # transformer 182 | x, final_attn_weights = self.blocks(x, pos) 183 | x = self.norm(x) # * B, G + 1(cls token)(513), C(384) 184 | if not self.use_max_pool: 185 | return x, final_attn_weights 186 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1).unsqueeze(1) # * concat the cls token and max pool the features of different tokens, make it B, 1, C 187 | return concat_f, final_attn_weights # * B, 1, C(384 + 384) -------------------------------------------------------------------------------- /n2m/module/N2Mmodule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch.distributions import MultivariateNormal 4 | 5 | from n2m.models.N2Mnet import N2Mnet 6 | from n2m.utils.point_cloud import fix_point_cloud_size 7 | from n2m.utils.sample_utils import CollisionChecker 8 | 9 | class N2Mmodule: 10 | """ 11 | N2M module class that encapsulates the N2M model. 12 | This includes collision checking point cloud pre-processing. 13 | """ 14 | def __init__(self, config): 15 | self.n2mnet_config = config['n2mnet'] 16 | self.preprocess_config = config['preprocess'] 17 | self.postprocess_config = config['postprocess'] 18 | 19 | self.model = N2Mnet(self.n2mnet_config) 20 | self.collision_checker = CollisionChecker(self.postprocess_config['collision_checker']) 21 | 22 | def predict(self, point_cloud): 23 | """ 24 | Predicts the preferable initial pose. 25 | 26 | Inputs: 27 | point_cloud (numpy.array): Input point cloud of shape (N, 3). The point cloud should be captured from robot centric camera and should have robot base at the origin. 28 | 29 | Outputs: 30 | preferable initial_pose (numpy.array): Predicted preferable initial pose of shape (3) for SE(2) (4) for SE(2) + z. 31 | prediction success (Boolean): There might not be a valid sample within the sample number provided. return the validity of the sample 32 | """ 33 | self.model.eval() 34 | 35 | # preprocess(downsmaple) point cloud to match the specified pointnum 36 | pointnum = self.preprocess_config['pointnum'] 37 | orig_point_cloud = point_cloud.copy() 38 | point_cloud = fix_point_cloud_size(point_cloud, pointnum) 39 | 40 | # load point cloud for collision checking 41 | self.collision_checker.set_pcd(orig_point_cloud) 42 | 43 | # get predictions 44 | with torch.no_grad(): 45 | point_cloud_tensor = torch.tensor(input_point_cloud, dtype=torch.float32).unsqueeze(0) # Add batch dimension (1, N, 3) 46 | num_samples = self.postprocess_config['num_samples'] 47 | samples, means, covs, weights = self.model.sample(point_cloud_tensor, num_samples=num_samples) 48 | 49 | # first check mean's collision. If it doesn't collide, return mean value 50 | if self.collision_checker.check_collision(means[0]): 51 | return means[0], True 52 | 53 | # sort predictions in the order of predicted probability 54 | num_modes = weights.shape[1] 55 | mvns = [MultivariateNormal(means[0, i], covs[0, i]) for i in range(num_modes)] 56 | log_probs = torch.stack([mvn.log_prob(samples[0]) for mvn in mvns]) # shape: [num_modes, num_samples] 57 | log_weights = torch.log(weights[0] + 1e-8).unsqueeze(1) # shape: [num_modes, 1] 58 | weighted_log_probs = log_probs + log_weights 59 | gaussian_probabilities = torch.logsumexp(weighted_log_probs, dim=0) # shape: [num_samples] 60 | sorted_indices = torch.argsort(gaussian_probabilities, descending=True) 61 | samples = samples[0, sorted_indices] 62 | 63 | # check collision for each samples and return the non-colliding sample with the highest probability 64 | for i in range(num_samples): 65 | sample = samples[0, i] 66 | if self.collision_checker.check_collision(sample): 67 | print("Valid initial pose found: ", sample) 68 | return sample, True 69 | else: 70 | print("Invalid pose, trying again: ", sample) 71 | 72 | # prediction fail. return False for validity 73 | return means[0], False -------------------------------------------------------------------------------- /n2m/utils/AverageMeter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | def __init__(self, items=None): 4 | self.items = items 5 | self.n_items = 1 if items is None else len(items) 6 | self.reset() 7 | 8 | def reset(self): 9 | self._val = [0] * self.n_items 10 | self._sum = [0] * self.n_items 11 | self._count = [0] * self.n_items 12 | 13 | def update(self, values): 14 | if type(values).__name__ == 'list': 15 | for idx, v in enumerate(values): 16 | self._val[idx] = v 17 | self._sum[idx] += v 18 | self._count[idx] += 1 19 | else: 20 | self._val[0] = values 21 | self._sum[0] += values 22 | self._count[0] += 1 23 | 24 | def val(self, idx=None): 25 | if idx is None: 26 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)] 27 | else: 28 | return self._val[idx] 29 | 30 | def count(self, idx=None): 31 | if idx is None: 32 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)] 33 | else: 34 | return self._count[idx] 35 | 36 | def avg(self, idx=None): 37 | if idx is None: 38 | return self._sum[0] / self._count[0] if self.items is None else [ 39 | self._sum[i] / self._count[i] for i in range(self.n_items) 40 | ] 41 | else: 42 | return self._sum[idx] / self._count[idx] -------------------------------------------------------------------------------- /n2m/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import copy 5 | import logging 6 | import os 7 | from collections import defaultdict 8 | import torch 9 | import torch.nn as nn 10 | 11 | from typing import Any 12 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable 13 | 14 | from termcolor import colored 15 | 16 | def get_missing_parameters_message(keys: List[str]) -> str: 17 | """ 18 | Get a logging-friendly message to report parameter names (keys) that are in 19 | the model but not found in a checkpoint. 20 | Args: 21 | keys (list[str]): List of keys that were not found in the checkpoint. 22 | Returns: 23 | str: message. 24 | """ 25 | groups = _group_checkpoint_keys(keys) 26 | msg = "Some model parameters or buffers are not found in the checkpoint:\n" 27 | msg += "\n".join( 28 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() 29 | ) 30 | return msg 31 | 32 | 33 | def get_unexpected_parameters_message(keys: List[str]) -> str: 34 | """ 35 | Get a logging-friendly message to report parameter names (keys) that are in 36 | the checkpoint but not found in the model. 37 | Args: 38 | keys (list[str]): List of keys that were not found in the model. 39 | Returns: 40 | str: message. 41 | """ 42 | groups = _group_checkpoint_keys(keys) 43 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n" 44 | msg += "\n".join( 45 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() 46 | ) 47 | return msg 48 | 49 | 50 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: 51 | """ 52 | Strip the prefix in metadata, if any. 53 | Args: 54 | state_dict (OrderedDict): a state-dict to be loaded to the model. 55 | prefix (str): prefix. 56 | """ 57 | keys = sorted(state_dict.keys()) 58 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys): 59 | return 60 | 61 | for key in keys: 62 | newkey = key[len(prefix):] 63 | state_dict[newkey] = state_dict.pop(key) 64 | 65 | # also strip the prefix in metadata, if any.. 66 | try: 67 | metadata = state_dict._metadata # pyre-ignore 68 | except AttributeError: 69 | pass 70 | else: 71 | for key in list(metadata.keys()): 72 | # for the metadata dict, the key can be: 73 | # '': for the DDP module, which we want to remove. 74 | # 'module': for the actual model. 75 | # 'module.xx.xx': for the rest. 76 | 77 | if len(key) == 0: 78 | continue 79 | newkey = key[len(prefix):] 80 | metadata[newkey] = metadata.pop(key) 81 | 82 | 83 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: 84 | """ 85 | Group keys based on common prefixes. A prefix is the string up to the final 86 | "." in each key. 87 | Args: 88 | keys (list[str]): list of parameter names, i.e. keys in the model 89 | checkpoint dict. 90 | Returns: 91 | dict[list]: keys with common prefixes are grouped into lists. 92 | """ 93 | groups = defaultdict(list) 94 | for key in keys: 95 | pos = key.rfind(".") 96 | if pos >= 0: 97 | head, tail = key[:pos], [key[pos + 1:]] 98 | else: 99 | head, tail = key, [] 100 | groups[head].extend(tail) 101 | return groups 102 | 103 | 104 | def _group_to_str(group: List[str]) -> str: 105 | """ 106 | Format a group of parameter name suffixes into a loggable string. 107 | Args: 108 | group (list[str]): list of parameter name suffixes. 109 | Returns: 110 | str: formated string. 111 | """ 112 | if len(group) == 0: 113 | return "" 114 | 115 | if len(group) == 1: 116 | return "." + group[0] 117 | 118 | return ".{" + ", ".join(group) + "}" 119 | 120 | 121 | def _named_modules_with_dup( 122 | model: nn.Module, prefix: str = "" 123 | ) -> Iterable[Tuple[str, nn.Module]]: 124 | """ 125 | The same as `model.named_modules()`, except that it includes 126 | duplicated modules that have more than one name. 127 | """ 128 | yield prefix, model 129 | for name, module in model._modules.items(): # pyre-ignore 130 | if module is None: 131 | continue 132 | submodule_prefix = prefix + ("." if prefix else "") + name 133 | yield from _named_modules_with_dup(module, submodule_prefix) -------------------------------------------------------------------------------- /n2m/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict 3 | import os 4 | from .logger import print_log 5 | 6 | def log_args_to_file(args, pre='args', logger=None): 7 | for key, val in args.__dict__.items(): 8 | print_log(f'{pre}.{key} : {val}', logger = logger) 9 | 10 | def log_config_to_file(cfg, pre='cfg', logger=None): 11 | for key, val in cfg.items(): 12 | if isinstance(cfg[key], EasyDict): 13 | print_log(f'{pre}.{key} = edict()', logger = logger) 14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 15 | continue 16 | print_log(f'{pre}.{key} : {val}', logger = logger) 17 | 18 | def merge_new_config(config, new_config): 19 | for key, val in new_config.items(): 20 | if not isinstance(val, dict): 21 | if key == '_base_': 22 | with open(new_config['_base_'], 'r') as f: 23 | try: 24 | val = yaml.load(f, Loader=yaml.FullLoader) 25 | except: 26 | val = yaml.load(f) 27 | config[key] = EasyDict() 28 | merge_new_config(config[key], val) 29 | else: 30 | config[key] = val 31 | continue 32 | if key not in config: 33 | config[key] = EasyDict() 34 | merge_new_config(config[key], val) 35 | return config 36 | 37 | def cfg_from_yaml_file(cfg_file): 38 | config = EasyDict() 39 | with open(cfg_file, 'r') as f: 40 | try: 41 | new_config = yaml.load(f, Loader=yaml.FullLoader) 42 | except: 43 | new_config = yaml.load(f) 44 | merge_new_config(config=config, new_config=new_config) 45 | return config 46 | 47 | def get_config(args, logger=None): 48 | # if args.resume: 49 | # cfg_path = os.path.join(args.experiment_path, 'config.yaml') 50 | # if not os.path.exists(cfg_path): 51 | # print_log("Failed to resume", logger = logger) 52 | # raise FileNotFoundError() 53 | # print_log(f'Resume yaml from {cfg_path}', logger = logger) 54 | # args.config = cfg_path 55 | config = cfg_from_yaml_file(args.config) 56 | # if not args.resume and args.local_rank == 0: 57 | # save_experiment_config(args, config, logger) 58 | return config 59 | 60 | def save_experiment_config(args, config, logger = None): 61 | config_path = os.path.join(args.experiment_path, 'config.yaml') 62 | os.system('cp %s %s' % (args.config, config_path)) 63 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger ) -------------------------------------------------------------------------------- /n2m/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.multiprocessing as mp 5 | from torch import distributed as dist 6 | 7 | 8 | 9 | def init_dist(launcher, backend='nccl', **kwargs): 10 | if mp.get_start_method(allow_none=True) is None: 11 | mp.set_start_method('spawn') 12 | if launcher == 'pytorch': 13 | _init_dist_pytorch(backend, **kwargs) 14 | else: 15 | raise ValueError(f'Invalid launcher type: {launcher}') 16 | 17 | 18 | def _init_dist_pytorch(backend, **kwargs): 19 | # TODO: use local_rank instead of rank % num_gpus 20 | rank = int(os.environ['RANK']) 21 | num_gpus = torch.cuda.device_count() 22 | torch.cuda.set_device(rank % num_gpus) 23 | dist.init_process_group(backend=backend, **kwargs) 24 | print(f'init distributed in rank {torch.distributed.get_rank()}') 25 | 26 | 27 | def get_dist_info(): 28 | if dist.is_available(): 29 | initialized = dist.is_initialized() 30 | else: 31 | initialized = False 32 | if initialized: 33 | rank = dist.get_rank() 34 | world_size = dist.get_world_size() 35 | else: 36 | rank = 0 37 | world_size = 1 38 | return rank, world_size 39 | 40 | 41 | def reduce_tensor(tensor, args): 42 | ''' 43 | for acc kind, get the mean in each gpu 44 | ''' 45 | rt = tensor.clone() 46 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 47 | rt /= args.world_size 48 | return rt 49 | 50 | def gather_tensor(tensor, args): 51 | output_tensors = [tensor.clone() for _ in range(args.world_size)] 52 | torch.distributed.all_gather(output_tensors, tensor) 53 | concat = torch.cat(output_tensors, dim=0) 54 | return concat 55 | -------------------------------------------------------------------------------- /n2m/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /n2m/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | import math 6 | 7 | class Loss(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | self.name = config['name'] 11 | self.neg_weight = config.get('neg_weight', 1.0) 12 | self.lam_weight = config.get('lam_weight', 0.0) 13 | self.lam_dist = config.get('lam_dist', 0.0) 14 | self.min_logprob = config.get('min_logprob', 0.0) 15 | 16 | def forward(self, means, covs, weights, target_point, label): 17 | if self.name == 'mle': 18 | return self._mle_loss(means, covs, weights, target_point, label) 19 | elif self.name == 'mle_max': 20 | return self._mle_max_loss(means, covs, weights, target_point, label) 21 | elif self.name == 'ce': 22 | return self._ce_loss(means, covs, weights, target_point, label) 23 | else: 24 | raise ValueError(f"Unknown loss name: {self.name}") 25 | 26 | def _mle_loss(self, means, covs, weights, target_point, label): 27 | B, K, D = means.size() 28 | target_point = target_point.unsqueeze(1) 29 | diff = target_point - means 30 | inv_covs = torch.inverse(covs).to(dtype=torch.float32) 31 | diff = diff.to(dtype=torch.float32) 32 | 33 | mahalanobis = torch.sum( 34 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1 35 | ).squeeze(-1) 36 | 37 | log_det = torch.logdet(covs) 38 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi) 39 | log_prob = log_prob + torch.log(weights + 1e-6) 40 | max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0] 41 | log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1) 42 | if self.min_logprob < 0.0: 43 | log_prob[(label == -1) & (log_prob < self.min_logprob)] = self.min_logprob 44 | 45 | label = label.float().to(log_prob.device) 46 | label[label == -1] = -self.neg_weight # optional handling of ignore index 47 | 48 | entropy_weight = -torch.sum(weights * torch.log(weights + 1e-6), dim=-1).mean() 49 | entropy_dist = -torch.sum(weights * (0.5 * (D * (1 + math.log(2 * math.pi)) + log_det)), dim=-1).mean() 50 | 51 | loss = - torch.mean(log_prob * label) - self.lam_weight * entropy_weight - self.lam_dist * entropy_dist 52 | return loss 53 | 54 | def _mle_max_loss(self, means, covs, weights, target_point, label): 55 | B, K, D = means.size() 56 | target_point = target_point.unsqueeze(1) 57 | diff = target_point - means 58 | inv_covs = torch.inverse(covs).to(dtype=torch.float32) 59 | diff = diff.to(dtype=torch.float32) 60 | 61 | mahalanobis = torch.sum( 62 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1 63 | ).squeeze(-1) 64 | 65 | log_det = torch.logdet(covs) 66 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi) 67 | # log_prob = log_prob + torch.log(weights + 1e-6) 68 | # max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0] 69 | # log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1) 70 | log_prob, _ = torch.max(log_prob, dim=1) 71 | if self.min_logprob < 0.0: 72 | log_prob[(label == -1) & (log_prob < self.min_logprob)] = self.min_logprob 73 | 74 | label = label.float().to(log_prob.device) 75 | label[label == -1] = -self.neg_weight # optional handling of ignore index 76 | 77 | entropy_weight = -torch.sum(weights * torch.log(weights + 1e-6), dim=-1).mean() 78 | entropy_dist = -torch.sum(weights * (0.5 * (D * (1 + math.log(2 * math.pi)) + log_det)), dim=-1).mean() 79 | 80 | loss = - torch.mean(log_prob * label) - self.lam_weight * entropy_weight - self.lam_dist * entropy_dist 81 | return loss 82 | 83 | def _ce_loss(self, means, covs, weights, target_point, label): 84 | B, K, D = means.size() 85 | target_point = target_point.unsqueeze(1) 86 | diff = target_point - means 87 | inv_covs = torch.inverse(covs).to(dtype=torch.float32) 88 | diff = diff.to(dtype=torch.float32) 89 | 90 | mahalanobis = torch.sum( 91 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1) 92 | 93 | log_det = torch.logdet(covs) 94 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi) 95 | log_prob = log_prob + torch.log(weights + 1e-6) 96 | 97 | max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0] 98 | log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1) 99 | pdf = torch.exp(log_prob) 100 | success_rate = torch.sigmoid(pdf) 101 | 102 | label = label.float().to(log_prob.device) 103 | label = (label + torch.abs(label)) / 2 104 | 105 | loss = -torch.mean( 106 | torch.log(success_rate + 1e-6) * label + 107 | torch.log(1 - success_rate + 1e-6) * (1 - label) 108 | ) 109 | return loss -------------------------------------------------------------------------------- /n2m/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-08 14:31:30 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-05-25 09:13:32 6 | # @Email: cshzxie@gmail.com 7 | 8 | import logging 9 | import open3d 10 | 11 | from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2 12 | 13 | 14 | class Metrics(object): 15 | ITEMS = [{ 16 | 'name': 'F-Score', 17 | 'enabled': True, 18 | 'eval_func': 'cls._get_f_score', 19 | 'is_greater_better': True, 20 | 'init_value': 0 21 | }, { 22 | 'name': 'CDL1', 23 | 'enabled': True, 24 | 'eval_func': 'cls._get_chamfer_distancel1', 25 | 'eval_object': ChamferDistanceL1(ignore_zeros=True), 26 | 'is_greater_better': False, 27 | 'init_value': 32767 28 | }, { 29 | 'name': 'CDL2', 30 | 'enabled': True, 31 | 'eval_func': 'cls._get_chamfer_distancel2', 32 | 'eval_object': ChamferDistanceL2(ignore_zeros=True), 33 | 'is_greater_better': False, 34 | 'init_value': 32767 35 | }] 36 | 37 | @classmethod 38 | def get(cls, pred, gt): 39 | _items = cls.items() 40 | _values = [0] * len(_items) 41 | for i, item in enumerate(_items): 42 | eval_func = eval(item['eval_func']) 43 | _values[i] = eval_func(pred, gt) 44 | 45 | return _values 46 | 47 | @classmethod 48 | def items(cls): 49 | return [i for i in cls.ITEMS if i['enabled']] 50 | 51 | @classmethod 52 | def names(cls): 53 | _items = cls.items() 54 | return [i['name'] for i in _items] 55 | 56 | @classmethod 57 | def _get_f_score(cls, pred, gt, th=0.01): 58 | 59 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py""" 60 | b = pred.size(0) 61 | assert pred.size(0) == gt.size(0) 62 | if b != 1: 63 | f_score_list = [] 64 | for idx in range(b): 65 | f_score_list.append(cls._get_f_score(pred[idx:idx+1], gt[idx:idx+1])) 66 | return sum(f_score_list)/len(f_score_list) 67 | else: 68 | pred = cls._get_open3d_ptcloud(pred) 69 | gt = cls._get_open3d_ptcloud(gt) 70 | 71 | dist1 = pred.compute_point_cloud_distance(gt) 72 | dist2 = gt.compute_point_cloud_distance(pred) 73 | 74 | recall = float(sum(d < th for d in dist2)) / float(len(dist2)) 75 | precision = float(sum(d < th for d in dist1)) / float(len(dist1)) 76 | return 2 * recall * precision / (recall + precision) if recall + precision else 0 77 | 78 | @classmethod 79 | def _get_open3d_ptcloud(cls, tensor): 80 | """pred and gt bs is 1""" 81 | tensor = tensor.squeeze().cpu().numpy() 82 | ptcloud = open3d.geometry.PointCloud() 83 | ptcloud.points = open3d.utility.Vector3dVector(tensor) 84 | 85 | return ptcloud 86 | 87 | @classmethod 88 | def _get_chamfer_distancel1(cls, pred, gt): 89 | chamfer_distance = cls.ITEMS[1]['eval_object'] 90 | return chamfer_distance(pred, gt).item() * 1000 91 | 92 | @classmethod 93 | def _get_chamfer_distancel2(cls, pred, gt): 94 | chamfer_distance = cls.ITEMS[2]['eval_object'] 95 | return chamfer_distance(pred, gt).item() * 1000 96 | 97 | def __init__(self, metric_name, values): 98 | self._items = Metrics.items() 99 | self._values = [item['init_value'] for item in self._items] 100 | self.metric_name = metric_name 101 | 102 | if type(values).__name__ == 'list': 103 | self._values = values 104 | elif type(values).__name__ == 'dict': 105 | metric_indexes = {} 106 | for idx, item in enumerate(self._items): 107 | item_name = item['name'] 108 | metric_indexes[item_name] = idx 109 | for k, v in values.items(): 110 | if k not in metric_indexes: 111 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k) 112 | continue 113 | self._values[metric_indexes[k]] = v 114 | else: 115 | raise Exception('Unsupported value type: %s' % type(values)) 116 | 117 | def state_dict(self): 118 | _dict = dict() 119 | for i in range(len(self._items)): 120 | item = self._items[i]['name'] 121 | value = self._values[i] 122 | _dict[item] = value 123 | 124 | return _dict 125 | 126 | def __repr__(self): 127 | return str(self.state_dict()) 128 | 129 | def better_than(self, other): 130 | if other is None: 131 | return True 132 | 133 | _index = -1 134 | for i, _item in enumerate(self._items): 135 | if _item['name'] == self.metric_name: 136 | _index = i 137 | break 138 | if _index == -1: 139 | raise Exception('Invalid metric name to compare.') 140 | 141 | _metric = self._items[i] 142 | _value = self._values[_index] 143 | other_value = other._values[_index] 144 | return _value > other_value if _metric['is_greater_better'] else _value < other_value 145 | -------------------------------------------------------------------------------- /n2m/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from collections import abc 10 | from pointnet2_ops import pointnet2_utils 11 | 12 | 13 | def fps(data, number): 14 | ''' 15 | data B N 3 16 | number int 17 | ''' 18 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 19 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 20 | return fps_data 21 | 22 | 23 | def worker_init_fn(worker_id): 24 | np.random.seed(np.random.get_state()[1][0] + worker_id) 25 | 26 | def build_lambda_sche(opti, config): 27 | if config.get('decay_step') is not None: 28 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) 29 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) 30 | else: 31 | raise NotImplementedError() 32 | return scheduler 33 | 34 | def build_lambda_bnsche(model, config): 35 | if config.get('decay_step') is not None: 36 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay) 37 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) 38 | else: 39 | raise NotImplementedError() 40 | return bnm_scheduler 41 | 42 | def set_random_seed(seed, deterministic=False): 43 | """Set random seed. 44 | Args: 45 | seed (int): Seed to be used. 46 | deterministic (bool): Whether to set the deterministic option for 47 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 48 | to True and `torch.backends.cudnn.benchmark` to False. 49 | Default: False. 50 | 51 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 52 | if cuda_deterministic: # slower, more reproducible 53 | cudnn.deterministic = True 54 | cudnn.benchmark = False 55 | else: # faster, less reproducible 56 | cudnn.deterministic = False 57 | cudnn.benchmark = True 58 | 59 | """ 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | if deterministic: 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False 67 | 68 | 69 | def is_seq_of(seq, expected_type, seq_type=None): 70 | """Check whether it is a sequence of some type. 71 | Args: 72 | seq (Sequence): The sequence to be checked. 73 | expected_type (type): Expected type of sequence items. 74 | seq_type (type, optional): Expected sequence type. 75 | Returns: 76 | bool: Whether the sequence is valid. 77 | """ 78 | if seq_type is None: 79 | exp_seq_type = abc.Sequence 80 | else: 81 | assert isinstance(seq_type, type) 82 | exp_seq_type = seq_type 83 | if not isinstance(seq, exp_seq_type): 84 | return False 85 | for item in seq: 86 | if not isinstance(item, expected_type): 87 | return False 88 | return True 89 | 90 | 91 | def set_bn_momentum_default(bn_momentum): 92 | def fn(m): 93 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 94 | m.momentum = bn_momentum 95 | return fn 96 | 97 | class BNMomentumScheduler(object): 98 | 99 | def __init__( 100 | self, model, bn_lambda, last_epoch=-1, 101 | setter=set_bn_momentum_default 102 | ): 103 | if not isinstance(model, nn.Module): 104 | raise RuntimeError( 105 | "Class '{}' is not a PyTorch nn Module".format( 106 | type(model).__name__ 107 | ) 108 | ) 109 | 110 | self.model = model 111 | self.setter = setter 112 | self.lmbd = bn_lambda 113 | 114 | self.step(last_epoch + 1) 115 | self.last_epoch = last_epoch 116 | 117 | def step(self, epoch=None): 118 | if epoch is None: 119 | epoch = self.last_epoch + 1 120 | 121 | self.last_epoch = epoch 122 | self.model.apply(self.setter(self.lmbd(epoch))) 123 | 124 | def get_momentum(self, epoch=None): 125 | if epoch is None: 126 | epoch = self.last_epoch + 1 127 | return self.lmbd(epoch) 128 | 129 | 130 | 131 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False): 132 | ''' 133 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. 134 | ''' 135 | _,n,c = xyz.shape 136 | 137 | assert n == num_points 138 | assert c == 3 139 | if crop == num_points: 140 | return xyz, None 141 | 142 | INPUT = [] 143 | CROP = [] 144 | for points in xyz: 145 | if isinstance(crop,list): 146 | num_crop = random.randint(crop[0],crop[1]) 147 | else: 148 | num_crop = crop 149 | 150 | points = points.unsqueeze(0) 151 | 152 | if fixed_points is None: 153 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda() 154 | else: 155 | if isinstance(fixed_points,list): 156 | fixed_point = random.sample(fixed_points,1)[0] 157 | else: 158 | fixed_point = fixed_points 159 | center = fixed_point.reshape(1,1,3).cuda() 160 | 161 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048 162 | 163 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048 164 | 165 | if padding_zeros: 166 | input_data = points.clone() 167 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0 168 | 169 | else: 170 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 171 | 172 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) 173 | 174 | if isinstance(crop,list): 175 | INPUT.append(fps(input_data,2048)) 176 | CROP.append(fps(crop_data,2048)) 177 | else: 178 | INPUT.append(input_data) 179 | CROP.append(crop_data) 180 | 181 | input_data = torch.cat(INPUT,dim=0)# B N 3 182 | crop_data = torch.cat(CROP,dim=0)# B M 3 183 | 184 | return input_data.contiguous(), crop_data.contiguous() 185 | 186 | def get_ptcloud_img(ptcloud): 187 | fig = plt.figure(figsize=(8, 8)) 188 | 189 | x, z, y = ptcloud.transpose(1, 0) 190 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 191 | ax.axis('off') 192 | # ax.axis('scaled') 193 | ax.view_init(30, 45) 194 | max, min = np.max(ptcloud), np.min(ptcloud) 195 | ax.set_xbound(min, max) 196 | ax.set_ybound(min, max) 197 | ax.set_zbound(min, max) 198 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet') 199 | 200 | fig.canvas.draw() 201 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 202 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 203 | return img 204 | 205 | 206 | 207 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y', 208 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ): 209 | fig = plt.figure(figsize=(6*len(data_list),6)) 210 | cmax = data_list[-1][:,0].max() 211 | 212 | for i in range(len(data_list)): 213 | data = data_list[i][:-2048] if i == 1 else data_list[i] 214 | color = data[:,0] /cmax 215 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d') 216 | ax.view_init(30, -120) 217 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black') 218 | ax.set_title(titles[i]) 219 | 220 | ax.set_axis_off() 221 | ax.set_xlim(xlim) 222 | ax.set_ylim(ylim) 223 | ax.set_zlim(zlim) 224 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) 225 | if not os.path.exists(path): 226 | os.makedirs(path) 227 | 228 | pic_path = path + '.png' 229 | fig.savefig(pic_path) 230 | 231 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy()) 232 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy()) 233 | plt.close(fig) 234 | 235 | 236 | def random_dropping(pc, e): 237 | up_num = max(64, 768 // (e//50 + 1)) 238 | pc = pc 239 | random_num = torch.randint(1, up_num, (1,1))[0,0] 240 | pc = fps(pc, random_num) 241 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) 242 | pc = torch.cat([pc, padding], dim = 1) 243 | return pc 244 | 245 | 246 | def random_scale(partial, scale_range=[0.8, 1.2]): 247 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] 248 | return partial * scale 249 | -------------------------------------------------------------------------------- /n2m/utils/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | '--config', 9 | type = str, 10 | help = 'yaml config file') 11 | parser.add_argument( 12 | '--launcher', 13 | choices=['none', 'pytorch'], 14 | default='none', 15 | help='job launcher') 16 | parser.add_argument('--local_rank', type=int, default=0) 17 | parser.add_argument('--num_workers', type=int, default=4) 18 | # seed 19 | parser.add_argument('--seed', type=int, default=0, help='random seed') 20 | parser.add_argument( 21 | '--deterministic', 22 | action='store_true', 23 | help='whether to set deterministic options for CUDNN backend.') 24 | # bn 25 | parser.add_argument( 26 | '--sync_bn', 27 | action='store_true', 28 | default=False, 29 | help='whether to use sync bn') 30 | # some args 31 | parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name') 32 | parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path') 33 | parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path') 34 | parser.add_argument('--val_freq', type = int, default=1, help = 'test freq') 35 | parser.add_argument( 36 | '--resume', 37 | action='store_true', 38 | default=False, 39 | help = 'autoresume training (interrupted by accident)') 40 | parser.add_argument( 41 | '--test', 42 | action='store_true', 43 | default=False, 44 | help = 'test mode for certain ckpt') 45 | parser.add_argument( 46 | '--finetune_model', 47 | action='store_true', 48 | default=False, 49 | help = 'finetune modelnet with pretrained weight') 50 | parser.add_argument( 51 | '--scratch_model', 52 | action='store_true', 53 | default=False, 54 | help = 'training modelnet from scratch') 55 | parser.add_argument( 56 | '--label_smoothing', 57 | action='store_true', 58 | default=False, 59 | help = 'use label smoothing loss trick') 60 | parser.add_argument( 61 | '--mode', 62 | choices=['easy', 'median', 'hard', None], 63 | default=None, 64 | help = 'difficulty mode for shapenet') 65 | parser.add_argument( 66 | '--way', type=int, default=-1) 67 | parser.add_argument( 68 | '--shot', type=int, default=-1) 69 | parser.add_argument( 70 | '--fold', type=int, default=-1) 71 | 72 | args = parser.parse_args() 73 | 74 | if args.test and args.resume: 75 | raise ValueError( 76 | '--test and --resume cannot be both activate') 77 | 78 | if args.resume and args.start_ckpts is not None: 79 | raise ValueError( 80 | '--resume and --start_ckpts cannot be both activate') 81 | 82 | if args.test and args.ckpts is None: 83 | raise ValueError( 84 | 'ckpts shouldnt be None while test mode') 85 | 86 | if args.finetune_model and args.ckpts is None: 87 | raise ValueError( 88 | 'ckpts shouldnt be None while finetune_model mode') 89 | 90 | if 'LOCAL_RANK' not in os.environ: 91 | os.environ['LOCAL_RANK'] = str(args.local_rank) 92 | 93 | if args.test: 94 | args.exp_name = 'test_' + args.exp_name 95 | if args.mode is not None: 96 | args.exp_name = args.exp_name + '_' +args.mode 97 | args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name) 98 | args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name) 99 | args.log_name = Path(args.config).stem 100 | create_experiment_dir(args) 101 | return args 102 | 103 | def create_experiment_dir(args): 104 | if not os.path.exists(args.experiment_path): 105 | os.makedirs(args.experiment_path) 106 | print('Create experiment path successfully at %s' % args.experiment_path) 107 | if not os.path.exists(args.tfboard_path): 108 | os.makedirs(args.tfboard_path) 109 | print('Create TFBoard path successfully at %s' % args.tfboard_path) 110 | 111 | -------------------------------------------------------------------------------- /n2m/utils/point_cloud.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def rotate_points_xy(points, angle): 4 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle), 0], 5 | [np.sin(angle), np.cos(angle), 0], 6 | [0, 0, 1]]) 7 | return points @ rotation_matrix 8 | 9 | def rotate_points_se2(points, angle): 10 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], 11 | [np.sin(angle), np.cos(angle)]]) 12 | points[:2] = points[:2] @ rotation_matrix 13 | points[2] = points[2] - angle 14 | return points 15 | 16 | def rotate_points_xythetaz(points, angle): 17 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], 18 | [np.sin(angle), np.cos(angle)]]) 19 | points[:2] = points[:2] @ rotation_matrix 20 | points[2] = points[2] - angle 21 | return points 22 | 23 | def translate_points_xy(points, radius, angle): 24 | translation_matrix = radius * np.array([np.cos(angle), np.sin(angle), 0]) 25 | return points + translation_matrix 26 | 27 | def translate_points_xythetaz(points, radius, angle): 28 | translation_matrix = radius * np.array([np.cos(angle), np.sin(angle), 0, 0]) 29 | return points + translation_matrix 30 | 31 | def apply_rotation_se2(point_cloud, target_point, rotation): 32 | min_angle = rotation['min_angle'] * np.pi / 180 33 | max_angle = rotation['max_angle'] * np.pi / 180 34 | 35 | angle = np.random.uniform(min_angle, max_angle) 36 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle) 37 | target_point = rotate_points_se2(target_point, angle) 38 | return point_cloud, target_point 39 | 40 | def apply_rotation_xy(point_cloud, target_point, rotation): 41 | min_angle = rotation['min_angle'] * np.pi / 180 42 | max_angle = rotation['max_angle'] * np.pi / 180 43 | 44 | angle = np.random.uniform(min_angle, max_angle) 45 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle) 46 | target_point = rotate_points_xy(target_point, angle) 47 | return point_cloud, target_point 48 | 49 | def apply_translation_xy(point_cloud, target_point, translation): 50 | radius = np.random.uniform(0, translation['radius']) 51 | angle = np.random.uniform(0, 2 * np.pi) 52 | point_cloud[:, :3] = translate_points_xy(point_cloud[:, :3], radius, angle) 53 | target_point = translate_points_xy(target_point, radius, angle) 54 | return point_cloud, target_point 55 | 56 | def apply_rotation_xythetaz(point_cloud, target_point, rotation): 57 | min_angle = rotation['min_angle'] * np.pi / 180 58 | max_angle = rotation['max_angle'] * np.pi / 180 59 | 60 | angle = np.random.uniform(min_angle, max_angle) 61 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle) 62 | target_point = rotate_points_xythetaz(target_point, angle) 63 | return point_cloud, target_point 64 | 65 | def apply_translation_xythetaz(point_cloud, target_point, translation): 66 | radius = np.random.uniform(0, translation['radius']) 67 | angle = np.random.uniform(0, 2 * np.pi) 68 | point_cloud[:, :3] = translate_points_xy(point_cloud[:, :3], radius, angle) 69 | target_point = translate_points_xythetaz(target_point, radius, angle) 70 | return point_cloud, target_point 71 | 72 | def apply_augmentations(point_cloud, target_point, augmentations): 73 | if 'rotation_xy' in augmentations: 74 | point_cloud, target_point = apply_rotation_xy(point_cloud, target_point, augmentations['rotation_xy']) 75 | if 'rotation_se2' in augmentations: 76 | point_cloud, target_point = apply_rotation_se2(point_cloud, target_point, augmentations['rotation_se2']) 77 | if 'rotation_xythetaz' in augmentations: 78 | point_cloud, target_point = apply_rotation_xythetaz(point_cloud, target_point, augmentations['rotation_xythetaz']) 79 | if 'translation_xy' in augmentations: 80 | point_cloud, target_point = apply_translation_xy(point_cloud, target_point, augmentations['translation_xy']) 81 | if 'translation_xythetaz' in augmentations: 82 | point_cloud, target_point = apply_translation_xythetaz(point_cloud, target_point, augmentations['translation_xythetaz']) 83 | return point_cloud, target_point 84 | 85 | def fix_point_cloud_size(point_cloud, pointnum): 86 | if point_cloud.shape[0] > pointnum: 87 | indices = np.random.choice(point_cloud.shape[0], pointnum, replace=False) 88 | point_cloud = point_cloud[indices] 89 | elif point_cloud.shape[0] < pointnum: 90 | padding = np.zeros((pointnum - point_cloud.shape[0], point_cloud.shape[1])) 91 | point_cloud = np.vstack([point_cloud, padding]) 92 | return point_cloud 93 | 94 | def translate_se2_point_cloud(point_cloud, se2_transform): 95 | point_cloud[:, :3] = point_cloud[:, :3] + np.array([se2_transform[0], se2_transform[1], 0]) 96 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], se2_transform[2]) 97 | return point_cloud 98 | 99 | def translate_se2_target(target_point, se2_transform): 100 | target_point = rotate_points_se2(target_point, se2_transform[2]) 101 | target_point = target_point + np.array([se2_transform[0], se2_transform[1], 0]) 102 | return target_point -------------------------------------------------------------------------------- /n2m/utils/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from torch.distributions import MultivariateNormal 5 | 6 | from n2m.utils.point_cloud import fix_point_cloud_size 7 | from n2m.utils.visualizer import save_gmm_visualization_se2 8 | from n2m.utils.point_cloud import translate_se2_point_cloud, translate_se2_target 9 | 10 | def predict_SIR_target_point( 11 | SIR_predictor, 12 | SIR_config, 13 | pc_numpy, 14 | target_helper, 15 | SIR_sample_num, 16 | robot_centric, 17 | abs_base_se2, 18 | task_name, 19 | max_trial=100, 20 | save_dir=None, 21 | id=None, 22 | ): 23 | pcl_input = fix_point_cloud_size(pc_numpy, SIR_config['dataset']['pointnum']) 24 | if robot_centric: 25 | print ("abs_base_se2", abs_base_se2) 26 | pcl_input = translate_se2_point_cloud(pcl_input, [-abs_base_se2[0], -abs_base_se2[1], abs_base_se2[2]]) 27 | pcl_input = torch.from_numpy(pcl_input)[None, ...].cuda().float() 28 | 29 | task_idx = None 30 | if 'settings' in SIR_config['dataset']: 31 | task_idx = torch.tensor([SIR_config['dataset']['settings'].index(task_name)]).cuda() 32 | 33 | target_SIR_prediction = None 34 | SIR_predictor.eval() 35 | with torch.no_grad(): 36 | finished = False 37 | trial_count = 0 38 | while not finished: 39 | if trial_count > max_trial: 40 | print(f"Failed to find a valid target after {max_trial} trials") 41 | break 42 | 43 | samples, means, covs, weights = SIR_predictor.sample(pcl_input, task_idx, SIR_sample_num) 44 | 45 | num_modes = means.shape[1] 46 | mvns = [MultivariateNormal(means[0, i], covs[0, i]) for i in range(num_modes)] 47 | log_probs = torch.stack([mvn.log_prob(samples[0]) for mvn in mvns]) # shape: [num_modes, SIR_sample_num] 48 | log_weights = torch.log(weights[0] + 1e-8).unsqueeze(1) # shape: [num_modes, 1] 49 | weighted_log_probs = log_probs + log_weights 50 | gaussian_probabilities = torch.logsumexp(weighted_log_probs, dim=0) # shape: [SIR_sample_num] 51 | sorted_indices = torch.argsort(gaussian_probabilities, descending=True) 52 | samples = samples[0, sorted_indices] 53 | 54 | # convert to numpy 55 | means = means[0].cpu().numpy() 56 | covs = covs[0].cpu().numpy() 57 | weights = weights[0].cpu().numpy() 58 | samples = samples.cpu().numpy() 59 | 60 | # First set target_SIR_prediction to the mean with highest probability 61 | target_SIR_prediction = means[0] 62 | max_prob = 0 63 | for i in range(num_modes): 64 | # Calculate probability density at the mean using the covariance and weight 65 | mvn = MultivariateNormal(torch.from_numpy(means[i]), torch.from_numpy(covs[i])) 66 | prob = mvn.log_prob(torch.from_numpy(means[i])).exp().item() * weights[i] 67 | if prob > max_prob: 68 | max_prob = prob 69 | target_SIR_prediction = means[i] 70 | 71 | if save_dir is not None: 72 | prediction_folder = os.path.join(save_dir, "prediction") 73 | os.makedirs(prediction_folder, exist_ok=True) 74 | save_gmm_visualization_se2( 75 | pcl_input[0].cpu().numpy(), 76 | target_SIR_prediction, 77 | 1, 78 | means, 79 | covs, 80 | weights, 81 | os.path.join(prediction_folder, f"{id}.ply"), 82 | ) 83 | 84 | if robot_centric: 85 | global_target_SIR_prediction = translate_se2_target(target_SIR_prediction.copy(), [abs_base_se2[0], abs_base_se2[1], -abs_base_se2[2]]) 86 | else: 87 | global_target_SIR_prediction = target_SIR_prediction 88 | 89 | if not target_helper.check_collision(global_target_SIR_prediction): # target_helper is based on the relative position 90 | print("target_SIR_prediction is collided with the furniture, try again") 91 | elif not target_helper.check_boundary(global_target_SIR_prediction): # target_helper is based on the relative position 92 | print("target_SIR_prediction is out of the boundary, try again") 93 | else: 94 | print("target_SIR_prediction is valid, using mean with highest probability") 95 | finished = True 96 | break 97 | 98 | 99 | # If the mean with highest probability is not valid, try with samples 100 | for i in range(SIR_sample_num): 101 | target_SIR_prediction = samples[i] 102 | 103 | if save_dir is not None: 104 | prediction_folder = os.path.join(save_dir, "prediction") 105 | os.makedirs(prediction_folder, exist_ok=True) 106 | save_gmm_visualization_se2( 107 | pcl_input[0].cpu().numpy(), 108 | target_SIR_prediction, 109 | 1, 110 | means, 111 | covs, 112 | weights, 113 | os.path.join(prediction_folder, f"{id}.ply"), 114 | ) 115 | 116 | if robot_centric: 117 | global_target_SIR_prediction = translate_se2_target(target_SIR_prediction.copy(), [abs_base_se2[0], abs_base_se2[1], -abs_base_se2[2]]) 118 | else: 119 | global_target_SIR_prediction = target_SIR_prediction 120 | 121 | if not target_helper.check_collision(global_target_SIR_prediction): # target_helper is based on the relative position 122 | print("target_SIR_prediction is collided with the furniture, try again") 123 | elif not target_helper.check_boundary(global_target_SIR_prediction): # target_helper is based on the relative position 124 | print("target_SIR_prediction is out of the boundary, try again") 125 | else: 126 | print("target_SIR_prediction is valid, using sample") 127 | finished = True 128 | break 129 | 130 | trial_count += 1 131 | 132 | return global_target_SIR_prediction, target_SIR_prediction, means, covs, weights, pcl_input[0].cpu().numpy(), finished -------------------------------------------------------------------------------- /n2m/utils/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from functools import partial 4 | from utils import config 5 | 6 | class Registry: 7 | """A registry to map strings to classes. 8 | Registered object could be built from registry. 9 | Example: 10 | >>> MODELS = Registry('models') 11 | >>> @MODELS.register_module() 12 | >>> class ResNet: 13 | >>> pass 14 | >>> resnet = MODELS.build(dict(NAME='ResNet')) 15 | Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for 16 | advanced useage. 17 | Args: 18 | name (str): Registry name. 19 | build_func(func, optional): Build function to construct instance from 20 | Registry, func:`build_from_cfg` is used if neither ``parent`` or 21 | ``build_func`` is specified. If ``parent`` is specified and 22 | ``build_func`` is not given, ``build_func`` will be inherited 23 | from ``parent``. Default: None. 24 | parent (Registry, optional): Parent registry. The class registered in 25 | children registry could be built from parent. Default: None. 26 | scope (str, optional): The scope of registry. It is the key to search 27 | for children registry. If not specified, scope will be the name of 28 | the package where class is defined, e.g. mmdet, mmcls, mmseg. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, name, build_func=None, parent=None, scope=None): 33 | self._name = name 34 | self._module_dict = dict() 35 | self._children = dict() 36 | self._scope = self.infer_scope() if scope is None else scope 37 | 38 | # self.build_func will be set with the following priority: 39 | # 1. build_func 40 | # 2. parent.build_func 41 | # 3. build_from_cfg 42 | if build_func is None: 43 | if parent is not None: 44 | self.build_func = parent.build_func 45 | else: 46 | self.build_func = build_from_cfg 47 | else: 48 | self.build_func = build_func 49 | if parent is not None: 50 | assert isinstance(parent, Registry) 51 | parent._add_children(self) 52 | self.parent = parent 53 | else: 54 | self.parent = None 55 | 56 | def __len__(self): 57 | return len(self._module_dict) 58 | 59 | def __contains__(self, key): 60 | return self.get(key) is not None 61 | 62 | def __repr__(self): 63 | format_str = self.__class__.__name__ + \ 64 | f'(name={self._name}, ' \ 65 | f'items={self._module_dict})' 66 | return format_str 67 | 68 | @staticmethod 69 | def infer_scope(): 70 | """Infer the scope of registry. 71 | The name of the package where registry is defined will be returned. 72 | Example: 73 | # in mmdet/models/backbone/resnet.py 74 | >>> MODELS = Registry('models') 75 | >>> @MODELS.register_module() 76 | >>> class ResNet: 77 | >>> pass 78 | The scope of ``ResNet`` will be ``mmdet``. 79 | Returns: 80 | scope (str): The inferred scope name. 81 | """ 82 | # inspect.stack() trace where this function is called, the index-2 83 | # indicates the frame where `infer_scope()` is called 84 | filename = inspect.getmodule(inspect.stack()[2][0]).__name__ 85 | split_filename = filename.split('.') 86 | return split_filename[0] 87 | 88 | @staticmethod 89 | def split_scope_key(key): 90 | """Split scope and key. 91 | The first scope will be split from key. 92 | Examples: 93 | >>> Registry.split_scope_key('mmdet.ResNet') 94 | 'mmdet', 'ResNet' 95 | >>> Registry.split_scope_key('ResNet') 96 | None, 'ResNet' 97 | Return: 98 | scope (str, None): The first scope. 99 | key (str): The remaining key. 100 | """ 101 | split_index = key.find('.') 102 | if split_index != -1: 103 | return key[:split_index], key[split_index + 1:] 104 | else: 105 | return None, key 106 | 107 | @property 108 | def name(self): 109 | return self._name 110 | 111 | @property 112 | def scope(self): 113 | return self._scope 114 | 115 | @property 116 | def module_dict(self): 117 | return self._module_dict 118 | 119 | @property 120 | def children(self): 121 | return self._children 122 | 123 | def get(self, key): 124 | """Get the registry record. 125 | Args: 126 | key (str): The class name in string format. 127 | Returns: 128 | class: The corresponding class. 129 | """ 130 | scope, real_key = self.split_scope_key(key) 131 | if scope is None or scope == self._scope: 132 | # get from self 133 | if real_key in self._module_dict: 134 | return self._module_dict[real_key] 135 | else: 136 | # get from self._children 137 | if scope in self._children: 138 | return self._children[scope].get(real_key) 139 | else: 140 | # goto root 141 | parent = self.parent 142 | while parent.parent is not None: 143 | parent = parent.parent 144 | return parent.get(key) 145 | 146 | def build(self, *args, **kwargs): 147 | return self.build_func(*args, **kwargs, registry=self) 148 | 149 | def _add_children(self, registry): 150 | """Add children for a registry. 151 | The ``registry`` will be added as children based on its scope. 152 | The parent registry could build objects from children registry. 153 | Example: 154 | >>> models = Registry('models') 155 | >>> mmdet_models = Registry('models', parent=models) 156 | >>> @mmdet_models.register_module() 157 | >>> class ResNet: 158 | >>> pass 159 | >>> resnet = models.build(dict(NAME='mmdet.ResNet')) 160 | """ 161 | 162 | assert isinstance(registry, Registry) 163 | assert registry.scope is not None 164 | assert registry.scope not in self.children, \ 165 | f'scope {registry.scope} exists in {self.name} registry' 166 | self.children[registry.scope] = registry 167 | 168 | def _register_module(self, module_class, module_name=None, force=False): 169 | if not inspect.isclass(module_class): 170 | raise TypeError('module must be a class, ' 171 | f'but got {type(module_class)}') 172 | 173 | if module_name is None: 174 | module_name = module_class.__name__ 175 | if isinstance(module_name, str): 176 | module_name = [module_name] 177 | for name in module_name: 178 | if not force and name in self._module_dict: 179 | raise KeyError(f'{name} is already registered ' 180 | f'in {self.name}') 181 | self._module_dict[name] = module_class 182 | 183 | def deprecated_register_module(self, cls=None, force=False): 184 | warnings.warn( 185 | 'The old API of register_module(module, force=False) ' 186 | 'is deprecated and will be removed, please use the new API ' 187 | 'register_module(name=None, force=False, module=None) instead.') 188 | if cls is None: 189 | return partial(self.deprecated_register_module, force=force) 190 | self._register_module(cls, force=force) 191 | return cls 192 | 193 | def register_module(self, name=None, force=False, module=None): 194 | """Register a module. 195 | A record will be added to `self._module_dict`, whose key is the class 196 | name or the specified name, and value is the class itself. 197 | It can be used as a decorator or a normal function. 198 | Example: 199 | >>> backbones = Registry('backbone') 200 | >>> @backbones.register_module() 201 | >>> class ResNet: 202 | >>> pass 203 | >>> backbones = Registry('backbone') 204 | >>> @backbones.register_module(name='mnet') 205 | >>> class MobileNet: 206 | >>> pass 207 | >>> backbones = Registry('backbone') 208 | >>> class ResNet: 209 | >>> pass 210 | >>> backbones.register_module(ResNet) 211 | Args: 212 | name (str | None): The module name to be registered. If not 213 | specified, the class name will be used. 214 | force (bool, optional): Whether to override an existing class with 215 | the same name. Default: False. 216 | module (type): Module class to be registered. 217 | """ 218 | if not isinstance(force, bool): 219 | raise TypeError(f'force must be a boolean, but got {type(force)}') 220 | # NOTE: This is a walkaround to be compatible with the old api, 221 | # while it may introduce unexpected bugs. 222 | if isinstance(name, type): 223 | return self.deprecated_register_module(name, force=force) 224 | 225 | # raise the error ahead of time 226 | if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)): 227 | raise TypeError( 228 | 'name must be either of None, an instance of str or a sequence' 229 | f' of str, but got {type(name)}') 230 | 231 | # use it as a normal method: x.register_module(module=SomeClass) 232 | if module is not None: 233 | self._register_module( 234 | module_class=module, module_name=name, force=force) 235 | return module 236 | 237 | # use it as a decorator: @x.register_module() 238 | def _register(cls): 239 | self._register_module( 240 | module_class=cls, module_name=name, force=force) 241 | return cls 242 | 243 | return _register 244 | 245 | 246 | def build_from_cfg(cfg, registry, default_args=None): 247 | """Build a module from config dict. 248 | Args: 249 | cfg (edict): Config dict. It should at least contain the key "NAME". 250 | registry (:obj:`Registry`): The registry to search the type from. 251 | Returns: 252 | object: The constructed object. 253 | """ 254 | if not isinstance(cfg, dict): 255 | raise TypeError(f'cfg must be a dict, but got {type(cfg)}') 256 | if 'NAME' not in cfg: 257 | if default_args is None or 'NAME' not in default_args: 258 | raise KeyError( 259 | '`cfg` or `default_args` must contain the key "NAME", ' 260 | f'but got {cfg}\n{default_args}') 261 | if not isinstance(registry, Registry): 262 | raise TypeError('registry must be an mmcv.Registry object, ' 263 | f'but got {type(registry)}') 264 | 265 | if not (isinstance(default_args, dict) or default_args is None): 266 | raise TypeError('default_args must be a dict or None, ' 267 | f'but got {type(default_args)}') 268 | 269 | if default_args is not None: 270 | cfg = config.merge_new_config(cfg, default_args) 271 | 272 | obj_type = cfg.get('NAME') 273 | 274 | if isinstance(obj_type, str): 275 | obj_cls = registry.get(obj_type) 276 | if obj_cls is None: 277 | raise KeyError( 278 | f'{obj_type} is not in the {registry.name} registry') 279 | elif inspect.isclass(obj_type): 280 | obj_cls = obj_type 281 | else: 282 | raise TypeError( 283 | f'type must be a str or valid type, but got {type(obj_type)}') 284 | try: 285 | return obj_cls(cfg) 286 | except Exception as e: 287 | # Normal TypeError does not print class name. 288 | raise type(e)(f'{obj_cls.__name__}: {e}') 289 | -------------------------------------------------------------------------------- /n2m/utils/sample_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.patches as patches 3 | import numpy as np 4 | import matplotlib.transforms as transforms 5 | import open3d as o3d 6 | 7 | 8 | class CollisionChecker: 9 | """Class for collision checking.""" 10 | def __init__( 11 | self, 12 | config, 13 | ): 14 | self.filter_noise = config.get("filter_noise", True) # Whether to filter noise from point cloud 15 | self.ground_z = config.get("ground_z", 0.05) # Points below this z-value are considered ground and ignored 16 | self.resolution = config.get("resolution", 0.02) # Grid resolution in meters 17 | self.robot_width = config.get("robot_width", 0.5) # Robot width in meters 18 | self.robot_length = config.get("robot_length", 0.63) # Robot length in meters 19 | 20 | def filter_point_cloud(self, pcd): 21 | """ 22 | Filter noisy points from the point cloud using Open3D filtering methods. 23 | 24 | Args: 25 | pcd: Open3D point cloud object 26 | 27 | Returns: 28 | Filtered Open3D point cloud object 29 | """ 30 | try: 31 | print(f"Original point cloud has {len(pcd.points)} points") 32 | 33 | # 1. Voxel downsampling to reduce noise and density 34 | voxel_size = 0.02 # 2cm voxel size 35 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size) 36 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points") 37 | 38 | # 2. Remove statistical outliers 39 | # This removes points that are too far from their neighbors 40 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier( 41 | nb_neighbors=20, # Number of neighbors to analyze 42 | std_ratio=2.0 # Standard deviation ratio threshold 43 | ) 44 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points") 45 | 46 | # 3. Remove radius outliers (optional, can be commented out if too aggressive) 47 | # This removes points that have too few neighbors within a radius 48 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier( 49 | nb_points=16, # Minimum number of points within radius 50 | radius=0.05 # Radius to search for neighbors (5cm) 51 | ) 52 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points") 53 | 54 | return pcd_cleaned 55 | 56 | except Exception as e: 57 | print(f"Error during point cloud filtering: {e}") 58 | print("Returning original point cloud without filtering") 59 | return pcd 60 | 61 | def set_occupancy_grid(self): 62 | """Create occupancy grid from point cloud, filtering out ground points below ground_z""" 63 | # Filter out ground points 64 | non_ground_points = self.pcd_points[self.pcd_points[:, 2] >= self.ground_z] 65 | 66 | # Get min and max coordinates 67 | min_coords = np.min(non_ground_points, axis=0) 68 | max_coords = np.max(non_ground_points, axis=0) 69 | 70 | # Calculate grid dimensions 71 | width = int(np.ceil((max_coords[0] - min_coords[0]) / self.resolution)) 72 | height = int(np.ceil((max_coords[1] - min_coords[1]) / self.resolution)) 73 | 74 | # Initialize occupancy grid 75 | occupancy_grid = np.zeros((height, width)) 76 | 77 | # Project points to grid 78 | for point in non_ground_points: 79 | x_idx = int((point[0] - min_coords[0]) / self.resolution) 80 | y_idx = int((point[1] - min_coords[1]) / self.resolution) 81 | 82 | if 0 <= x_idx < width and 0 <= y_idx < height: 83 | occupancy_grid[y_idx, x_idx] = 1 84 | 85 | self.occupancy_grid = occupancy_grid 86 | self.min_coords = min_coords 87 | self.max_coords = max_coords 88 | 89 | def set_pcd(self, pcd): 90 | if self.filter_noise: 91 | self.pcd = self.filter_point_cloud(pcd) 92 | else: 93 | self.pcd = pcd 94 | 95 | self.pcd_points = pcd.points 96 | self.pcd_colors = pcd.colors 97 | 98 | print("Generating occupancy grid map") 99 | self.set_occupancy_grid() 100 | 101 | def check_collision(self, pose): 102 | """Checks collision in the XY-plane for given SE(2) or SE(2)+z pose.""" 103 | if len(pose) == 3: 104 | x, y, theta = pose 105 | elif len(pose) == 4: 106 | x, y, theta, _ = pose 107 | else: 108 | raise ValueError("Pose must be length 3 or 4") 109 | 110 | if self.occupancy_grid is None or self.min_coords is None or self.max_coords is None: 111 | raise ValueError("Occupancy grid not initialized. Call set_pcd() first.") 112 | 113 | # Get rectangle corners in world frame 114 | corners = np.array([ 115 | [-self.robot_length*0.83, -self.robot_width/2], # Bottom left 116 | [self.robot_length*0.17, -self.robot_width/2], # Bottom right 117 | [self.robot_length*0.17, self.robot_width/2], # Top right 118 | [-self.robot_length*0.83, self.robot_width/2] # Top left 119 | ]) 120 | # Rotate corners 121 | rot_matrix = np.array([ 122 | [np.cos(theta), -np.sin(theta)], 123 | [np.sin(theta), np.cos(theta)] 124 | ]) 125 | rotated_corners = (rot_matrix @ corners.T).T 126 | 127 | # Translate corners 128 | world_corners = rotated_corners + np.array([x, y]) 129 | 130 | # Get min and max coordinates of the rectangle 131 | rect_min = np.min(world_corners, axis=0) 132 | rect_max = np.max(world_corners, axis=0) 133 | 134 | # Convert to grid indices 135 | min_x_idx = int((rect_min[0] - self.min_coords[0]) / self.resolution) 136 | min_y_idx = int((rect_min[1] - self.min_coords[1]) / self.resolution) 137 | max_x_idx = int((rect_max[0] - self.min_coords[0]) / self.resolution) + 1 138 | max_y_idx = int((rect_max[1] - self.min_coords[1]) / self.resolution) + 1 139 | 140 | # Ensure indices are within grid bounds 141 | min_x_idx = max(0, min_x_idx) 142 | min_y_idx = max(0, min_y_idx) 143 | max_x_idx = min(self.occupancy_grid.shape[1], max_x_idx) 144 | max_y_idx = min(self.occupancy_grid.shape[0], max_y_idx) 145 | 146 | # Check each grid cell in the bounding box 147 | for y_idx in range(min_y_idx, max_y_idx): 148 | for x_idx in range(min_x_idx, max_x_idx): 149 | if self.occupancy_grid[y_idx, x_idx] == 1: 150 | # Get grid cell corners in world coordinates 151 | grid_min_x = self.min_coords[0] + x_idx * self.resolution 152 | grid_min_y = self.min_coords[1] + y_idx * self.resolution 153 | grid_max_x = grid_min_x + self.resolution 154 | grid_max_y = grid_min_y + self.resolution 155 | 156 | # Check if grid cell intersects with rectangle 157 | grid_corners = np.array([ 158 | [grid_min_x, grid_min_y], 159 | [grid_max_x, grid_min_y], 160 | [grid_max_x, grid_max_y], 161 | [grid_min_x, grid_max_y] 162 | ]) 163 | translated_corners = grid_corners - np.array([x, y]) 164 | local_corners = (rot_matrix.T @ translated_corners.T).T 165 | 166 | rect_min_x = -self.robot_length*0.83 167 | rect_max_x = self.robot_length*0.17 168 | rect_min_y = -self.robot_width/2 169 | rect_max_y = self.robot_width/2 170 | 171 | for corner in local_corners: 172 | if (rect_min_x <= corner[0] <= rect_max_x and 173 | rect_min_y <= corner[1] <= rect_max_y): 174 | return False 175 | return True 176 | 177 | 178 | def get_target_helper_for_rollout_collection(inference_mode=False, all_pcd=None, se2_origin=None, vis=False, camera_intrinsic=None, filter_noise=True): 179 | if inference_mode: 180 | x_half_range = 0.5 181 | y_half_range = 0.5 182 | theta_half_range_deg = 30 183 | else: 184 | x_half_range = 0.2 185 | y_half_range = 0.2 186 | theta_half_range_deg = 15 187 | 188 | return TargetHelper( 189 | all_pcd, 190 | se2_origin, 191 | x_half_range, 192 | y_half_range, 193 | theta_half_range_deg, 194 | vis=vis, 195 | camera_intrinsic=camera_intrinsic, 196 | filter_noise=filter_noise 197 | ) 198 | 199 | class TargetHelper: 200 | def __init__(self, pcd, origin_se2, x_half_range, y_half_range, theta_half_range_deg, vis=False, camera_intrinsic=None, filter_noise=True): 201 | self.resolution = 0.02 202 | self.width = 0.5 203 | self.length = 0.63 204 | self.ground_z = 0.05 205 | # self.T_base_cam = np.array([ 206 | # [ 5.83445639e-03, 5.87238353e-01, -8.09393072e-01, 4.49474752e-02], 207 | # [-9.99982552e-01, 2.64855534e-03, -5.28670366e-03, -1.88012307e-02], 208 | # [-9.60832741e-04, 8.09409739e-01, 5.87243519e-01, 1.64493829e-00], 209 | # [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]) # this is just for DETECT mode and 210 | # self.T_base_cam = np.array( 211 | # [[-8.14932746e-02, -5.73355454e-01, 8.15243714e-01, 9.75590401e-04], 212 | # [-9.95079241e-01, 5.53813432e-04, -9.90804744e-02, -2.56451598e-02], 213 | # [ 5.63568407e-02, -8.19306535e-01, -5.70579275e-01, 1.64368076e-01], 214 | # [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]] 215 | # ) 216 | self.T_base_cam = np.array([ 217 | [-8.25269110e-02, -5.73057816e-01, 8.15348841e-01, 6.05364230e-04], 218 | [-9.95784041e-01, 1.45464862e-02, -9.05661474e-02, -3.94417736e-02], 219 | [ 4.00391906e-02, -8.19385767e-01, -5.71842485e-01, 1.64310488e-00], 220 | [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00] 221 | ]) 222 | self.arm_length = 1 223 | 224 | # Filter noise from point cloud if requested 225 | if filter_noise: 226 | self.pcd = self.filter_point_cloud(pcd) 227 | else: 228 | self.pcd = pcd 229 | 230 | self.origin_se2 = origin_se2 231 | self.x_half_range = x_half_range 232 | self.y_half_range = y_half_range 233 | self.theta_half_range_deg = theta_half_range_deg 234 | self.to_rad = lambda x: x/180*np.pi 235 | self.pcd_np = np.asarray(self.pcd.points) 236 | self.pcd_np = np.concatenate([self.pcd_np, np.zeros((self.pcd_np.shape[0], 1))], axis=1) 237 | self.pcd_color_np = np.asarray(self.pcd.colors) 238 | self.pcd_max_value = np.max(self.pcd_np[:, :3], axis=0) 239 | self.pcd_min_value = np.min(self.pcd_np[:, :3], axis=0) 240 | self.vis = vis 241 | 242 | if camera_intrinsic is not None: 243 | self.cam_intrinsic = np.array(camera_intrinsic) 244 | else: 245 | self.cam_intrinsic = np.array([100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240]) 246 | # Initialize occupancy grid in constructor 247 | self.get_occupancy_grid() 248 | 249 | def filter_point_cloud(self, pcd): 250 | """ 251 | Filter noisy points from the point cloud using Open3D filtering methods. 252 | 253 | Args: 254 | pcd: Open3D point cloud object 255 | 256 | Returns: 257 | Filtered Open3D point cloud object 258 | """ 259 | try: 260 | print(f"Original point cloud has {len(pcd.points)} points") 261 | 262 | # 1. Voxel downsampling to reduce noise and density 263 | voxel_size = 0.02 # 2cm voxel size 264 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size) 265 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points") 266 | 267 | # 2. Remove statistical outliers 268 | # This removes points that are too far from their neighbors 269 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier( 270 | nb_neighbors=20, # Number of neighbors to analyze 271 | std_ratio=2.0 # Standard deviation ratio threshold 272 | ) 273 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points") 274 | 275 | # 3. Remove radius outliers (optional, can be commented out if too aggressive) 276 | # This removes points that have too few neighbors within a radius 277 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier( 278 | nb_points=16, # Minimum number of points within radius 279 | radius=0.05 # Radius to search for neighbors (5cm) 280 | ) 281 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points") 282 | 283 | return pcd_cleaned 284 | 285 | except Exception as e: 286 | print(f"Error during point cloud filtering: {e}") 287 | print("Returning original point cloud without filtering") 288 | return pcd 289 | 290 | def filter_point_cloud_custom(self, pcd, voxel_size=0.02, nb_neighbors=20, std_ratio=2.0, 291 | nb_points=16, radius=0.05, use_radius_filter=True): 292 | """ 293 | Filter noisy points from the point cloud with custom parameters. 294 | 295 | Args: 296 | pcd: Open3D point cloud object 297 | voxel_size: Size of voxels for downsampling (in meters) 298 | nb_neighbors: Number of neighbors for statistical outlier removal 299 | std_ratio: Standard deviation ratio threshold for statistical outlier removal 300 | nb_points: Minimum number of points within radius for radius outlier removal 301 | radius: Radius to search for neighbors (in meters) 302 | use_radius_filter: Whether to apply radius outlier removal 303 | 304 | Returns: 305 | Filtered Open3D point cloud object 306 | """ 307 | try: 308 | print(f"Original point cloud has {len(pcd.points)} points") 309 | 310 | # 1. Voxel downsampling 311 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size) 312 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points") 313 | 314 | # 2. Remove statistical outliers 315 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier( 316 | nb_neighbors=nb_neighbors, 317 | std_ratio=std_ratio 318 | ) 319 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points") 320 | 321 | # 3. Remove radius outliers (optional) 322 | if use_radius_filter: 323 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier( 324 | nb_points=nb_points, 325 | radius=radius 326 | ) 327 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points") 328 | 329 | return pcd_cleaned 330 | 331 | except Exception as e: 332 | print(f"Error during point cloud filtering: {e}") 333 | print("Returning original point cloud without filtering") 334 | return pcd 335 | 336 | def visualize_filtering_comparison(self, original_pcd): 337 | """ 338 | Visualize the comparison between original and filtered point clouds. 339 | 340 | Args: 341 | original_pcd: Original Open3D point cloud object 342 | """ 343 | try: 344 | # Create visualization window 345 | vis = o3d.visualization.Visualizer() 346 | vis.create_window(window_name="Point Cloud Filtering Comparison", width=1200, height=800) 347 | 348 | # Add original point cloud (white) 349 | original_pcd.paint_uniform_color([1, 1, 1]) # White 350 | vis.add_geometry(original_pcd) 351 | 352 | # Add filtered point cloud (red) 353 | filtered_pcd = self.pcd 354 | filtered_pcd.paint_uniform_color([1, 0, 0]) # Red 355 | vis.add_geometry(filtered_pcd) 356 | 357 | # Add coordinate frame 358 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5) 359 | vis.add_geometry(coord_frame) 360 | 361 | # Set view options 362 | opt = vis.get_render_option() 363 | opt.background_color = np.asarray([0, 0, 0]) # Black background 364 | opt.point_size = 2.0 365 | 366 | print("Visualizing point cloud filtering comparison...") 367 | print("White points: Original point cloud") 368 | print("Red points: Filtered point cloud") 369 | print("Press 'Q' to close the visualization") 370 | 371 | # Run visualization 372 | vis.run() 373 | vis.destroy_window() 374 | 375 | except Exception as e: 376 | print(f"Error during visualization: {e}") 377 | 378 | def get_occupancy_grid(self): 379 | """Create occupancy grid from point cloud, filtering out ground points below ground_z""" 380 | # Filter out ground points 381 | non_ground_points = self.pcd_np[self.pcd_np[:, 2] >= self.ground_z] 382 | 383 | # Get min and max coordinates 384 | min_coords = np.min(non_ground_points, axis=0) 385 | max_coords = np.max(non_ground_points, axis=0) 386 | 387 | # Calculate grid dimensions 388 | width = int(np.ceil((max_coords[0] - min_coords[0]) / self.resolution)) 389 | height = int(np.ceil((max_coords[1] - min_coords[1]) / self.resolution)) 390 | 391 | # Initialize occupancy grid 392 | occupancy_grid = np.zeros((height, width)) 393 | 394 | # Project points to grid 395 | for point in non_ground_points: 396 | x_idx = int((point[0] - min_coords[0]) / self.resolution) 397 | y_idx = int((point[1] - min_coords[1]) / self.resolution) 398 | 399 | if 0 <= x_idx < width and 0 <= y_idx < height: 400 | occupancy_grid[y_idx, x_idx] = 1 401 | 402 | self.occupancy_grid = occupancy_grid 403 | self.min_coords = min_coords 404 | self.max_coords = max_coords 405 | 406 | def check_boundary(self, se2_pose): 407 | """Check if the target is within the boundary""" 408 | x, y, theta = se2_pose 409 | if x < self.pcd_min_value[0] or x > self.pcd_max_value[0] or y < self.pcd_min_value[1] or y > self.pcd_max_value[1]: 410 | return False 411 | return True 412 | 413 | def check_collision(self, se2_pose): 414 | """Check if all grids intersected by the rectangle are empty""" 415 | x, y, theta = se2_pose 416 | # Get rectangle corners in world frame 417 | corners = np.array([ 418 | [-self.length*0.83, -self.width/2], # Bottom left 419 | [self.length*0.17, -self.width/2], # Bottom right 420 | [self.length*0.17, self.width/2], # Top right 421 | [-self.length*0.83, self.width/2] # Top left 422 | ]) 423 | # Rotate corners 424 | rot_matrix = np.array([ 425 | [np.cos(theta), -np.sin(theta)], 426 | [np.sin(theta), np.cos(theta)] 427 | ]) 428 | rotated_corners = (rot_matrix @ corners.T).T 429 | 430 | # Translate corners 431 | world_corners = rotated_corners + np.array([x, y]) 432 | 433 | # Get min and max coordinates of the rectangle 434 | rect_min = np.min(world_corners, axis=0) 435 | rect_max = np.max(world_corners, axis=0) 436 | 437 | # Convert to grid indices 438 | min_x_idx = int((rect_min[0] - self.min_coords[0]) / self.resolution) 439 | min_y_idx = int((rect_min[1] - self.min_coords[1]) / self.resolution) 440 | max_x_idx = int((rect_max[0] - self.min_coords[0]) / self.resolution) + 1 441 | max_y_idx = int((rect_max[1] - self.min_coords[1]) / self.resolution) + 1 442 | 443 | # Ensure indices are within grid bounds 444 | min_x_idx = max(0, min_x_idx) 445 | min_y_idx = max(0, min_y_idx) 446 | max_x_idx = min(self.occupancy_grid.shape[1], max_x_idx) 447 | max_y_idx = min(self.occupancy_grid.shape[0], max_y_idx) 448 | 449 | # Check each grid cell in the bounding box 450 | for y_idx in range(min_y_idx, max_y_idx): 451 | for x_idx in range(min_x_idx, max_x_idx): 452 | if self.occupancy_grid[y_idx, x_idx] == 1: 453 | # Get grid cell corners in world coordinates 454 | grid_min_x = self.min_coords[0] + x_idx * self.resolution 455 | grid_min_y = self.min_coords[1] + y_idx * self.resolution 456 | grid_max_x = grid_min_x + self.resolution 457 | grid_max_y = grid_min_y + self.resolution 458 | 459 | # Check if grid cell intersects with rectangle 460 | # Convert grid corners to rectangle's local frame 461 | grid_corners = np.array([ 462 | [grid_min_x, grid_min_y], 463 | [grid_max_x, grid_min_y], 464 | [grid_max_x, grid_max_y], 465 | [grid_min_x, grid_max_y] 466 | ]) 467 | 468 | # Translate and rotate grid corners to rectangle's local frame 469 | translated_corners = grid_corners - np.array([x, y]) 470 | local_corners = (rot_matrix.T @ translated_corners.T).T 471 | 472 | # Check if any grid corner is inside rectangle 473 | # Rectangle bounds in local frame 474 | rect_min_x = -self.length*0.83 475 | rect_max_x = self.length*0.17 476 | rect_min_y = -self.width/2 477 | rect_max_y = self.width/2 478 | 479 | # Check if any grid corner is inside rectangle 480 | for corner in local_corners: 481 | if (rect_min_x <= corner[0] <= rect_max_x and 482 | rect_min_y <= corner[1] <= rect_max_y): 483 | return False 484 | 485 | return True 486 | 487 | def visualize_occupancy_and_rectangle(self, target_se2, object_pos=None): 488 | """Visualize occupancy grid and rectangle""" 489 | try: 490 | plt.figure(figsize=(10, 10)) 491 | ax = plt.gca() 492 | 493 | # Plot occupancy grid 494 | plt.imshow(self.occupancy_grid, cmap='binary', origin='lower', 495 | extent=[self.min_coords[0], self.max_coords[0], self.min_coords[1], self.max_coords[1]]) 496 | 497 | # Draw sampling region 498 | sampling_rect = patches.Rectangle( 499 | (self.origin_se2[0] - self.x_half_range, self.origin_se2[1] - self.y_half_range), 500 | 2 * self.x_half_range, 2 * self.y_half_range, 501 | linewidth=1, edgecolor='b', facecolor='none', linestyle='--', label='Random Region' 502 | ) 503 | ax.add_patch(sampling_rect) 504 | 505 | # Draw origin point 506 | ax.plot(self.origin_se2[0], self.origin_se2[1], 'bo', markersize=8, label='Origin') 507 | 508 | # Draw rectangle with SE2 pose 509 | x, y, theta = target_se2 510 | 511 | # Create rectangle with origin at center of width edge 512 | rect = patches.Rectangle((-self.length*0.83, -self.width/2), 513 | self.length, self.width, 514 | linewidth=2, edgecolor='r', facecolor='none') 515 | 516 | # Create transformation 517 | t = transforms.Affine2D().rotate(theta).translate(x, y) 518 | rect.set_transform(t + ax.transData) 519 | 520 | # Add rectangle and reachability circle to plot 521 | ax.add_patch(rect) 522 | 523 | # Create reachability circle centered at object_pos 524 | if object_pos is not None: 525 | reachability_circle = patches.Circle(object_pos, self.arm_length, linewidth=2, edgecolor='b', facecolor='none') 526 | ax.add_patch(reachability_circle) 527 | 528 | # Draw coordinate frame 529 | arrow_length = 0.2 530 | # X-axis (red) 531 | ax.arrow(x, y, 532 | arrow_length * np.cos(theta), 533 | arrow_length * np.sin(theta), 534 | head_width=0.05, head_length=0.05, fc='r', ec='r') 535 | # Y-axis (green) 536 | ax.arrow(x, y, 537 | arrow_length * np.cos(theta + np.pi/2), 538 | arrow_length * np.sin(theta + np.pi/2), 539 | head_width=0.05, head_length=0.05, fc='g', ec='g') 540 | 541 | # Add axis labels 542 | plt.xlabel('X (meters)') 543 | plt.ylabel('Y (meters)') 544 | 545 | # Add grid 546 | plt.grid(True, alpha=0.3) 547 | 548 | # Add title 549 | plt.title('Occupancy Grid Map with Rectangle') 550 | 551 | # Add legend 552 | plt.legend() 553 | 554 | # Show plot 555 | plt.show() 556 | plt.pause(0.5) 557 | plt.close() # Explicitly close the figure 558 | except Exception as e: 559 | print(f"Visualization error: {e}") 560 | 561 | def visualize_pcl_with_camera_and_object(self, se2_pose, object_pos): 562 | """Visualize point cloud, camera, and object using Open3D""" 563 | try: 564 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 565 | se3_pose = self.calculate_target_se3(se2_pose) 566 | 567 | # Create visualization window 568 | vis = o3d.visualization.Visualizer() 569 | vis.create_window() 570 | 571 | # Add point cloud 572 | pcd = o3d.geometry.PointCloud() 573 | pcd.points = o3d.utility.Vector3dVector(self.pcd_np[:,:3]) 574 | pcd.colors = o3d.utility.Vector3dVector(self.pcd_color_np) 575 | vis.add_geometry(pcd) 576 | 577 | # Add object as sphere 578 | if object_pos.shape == (2,): 579 | object_pos = np.array([object_pos[0], object_pos[1], 1.0]) 580 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.05) 581 | sphere.translate(object_pos[:3]) 582 | sphere.paint_uniform_color([1, 0, 0]) # Red color 583 | vis.add_geometry(sphere) 584 | 585 | # # Add camera frustum 586 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3) 587 | camera_frame.transform(camera_extrinsic) 588 | vis.add_geometry(camera_frame) 589 | 590 | # Add target as sphere 591 | target_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3) 592 | target_frame.transform(se3_pose) 593 | vis.add_geometry(target_frame) 594 | 595 | # Add origin coordinate frame 596 | origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3) 597 | vis.add_geometry(origin_frame) 598 | 599 | # Render and capture image 600 | vis.run() 601 | vis.destroy_window() 602 | 603 | except Exception as e: 604 | print(f"Visualization error: {e}") 605 | 606 | def calculate_target_se3(self, se2_pose): 607 | """Get the target SE3 pose""" 608 | x, y, theta = se2_pose 609 | se3_pose = np.array([ 610 | [np.cos(theta), -np.sin(theta), 0, x], 611 | [np.sin(theta), np.cos(theta), 0, y], 612 | [0, 0, 1, 0], 613 | [0, 0, 0, 1] 614 | ]) 615 | return se3_pose 616 | 617 | def calculate_camera_extrinsic(self, se2_pose): 618 | """Get the camera extrinsic matrix from the target SE2 pose""" 619 | se3_pose = self.calculate_target_se3(se2_pose) 620 | camera_extrinsic = se3_pose @ self.T_base_cam 621 | return camera_extrinsic 622 | 623 | def check_object_visibility(self, se2_pose, object_pos): 624 | """Check if the object is visible from the camera""" 625 | # convert target_se2 to camera_extrinsic 626 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 627 | fx, fy, cx, cy, w, h = self.cam_intrinsic 628 | intrinsic_matrix = np.array([ 629 | [fx, 0, cx], 630 | [0, fy, cy], 631 | [0, 0, 1] 632 | ]) 633 | 634 | if object_pos.shape == (3,): 635 | object_pos = np.concatenate([object_pos, np.array([1.0])]) 636 | elif object_pos.shape == (2,): # use default height 1.0 637 | object_pos = np.concatenate([object_pos, np.array([1.0, 1.0])]) 638 | else: 639 | raise ValueError("object_pos must be a 2D or 3D array") 640 | 641 | # convert object_pos to camera coordinates 642 | object_cam = np.linalg.inv(camera_extrinsic) @ object_pos 643 | object_cam = object_cam[:3] / object_cam[3] 644 | object_pix = intrinsic_matrix @ object_cam 645 | object_pix = object_pix[:2] / object_pix[2] 646 | 647 | if self.vis: 648 | global_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1) 649 | global_frame_points = global_frame.sample_points_uniformly(number_of_points=1000) 650 | 651 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3) 652 | camera_frame.transform(camera_extrinsic) 653 | camera_frame_points = camera_frame.sample_points_uniformly(number_of_points=1000) 654 | 655 | base_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6) 656 | base_frame.transform(self.calculate_target_se3(se2_pose)) 657 | base_frame_points = base_frame.sample_points_uniformly(number_of_points=1000) 658 | 659 | object_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.05) 660 | object_point.translate(object_pos[:3]) 661 | object_point.paint_uniform_color([1, 0, 0]) 662 | object_point_points = object_point.sample_points_uniformly(number_of_points=1000) 663 | 664 | transformed_object_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.05) 665 | transformed_object_point.translate((np.linalg.inv(camera_extrinsic) @ object_pos)[:3]) 666 | transformed_object_point.paint_uniform_color([0, 1, 0]) 667 | transformed_object_point_points = transformed_object_point.sample_points_uniformly(number_of_points=1000) 668 | 669 | o3d.io.write_point_cloud("check_transform.ply", global_frame_points + camera_frame_points + base_frame_points + object_point_points + transformed_object_point_points) 670 | 671 | # check if the object is in front of the camera 672 | if object_cam[2] <= 0: 673 | return False 674 | 675 | # check if the object is within the image bounds 676 | if object_pix[0] < 0 or object_pix[0] > w or object_pix[1] < 0 or object_pix[1] > h: 677 | return False 678 | 679 | return True 680 | 681 | def check_reachability(self, se2_pose, object_pos): 682 | """Check if the object is reachable from the target""" 683 | # convert target_se2 to camera_extrinsic 684 | robot_xy = se2_pose[:2] 685 | object_xy = object_pos[:2] 686 | distance = np.linalg.norm(robot_xy - object_xy) 687 | if distance > self.arm_length: 688 | return False 689 | return True 690 | 691 | def get_random_target_se2(self): 692 | """Get a random target SE2 pose""" 693 | while True: 694 | random_x = np.random.uniform(-self.x_half_range, self.x_half_range) 695 | random_y = np.random.uniform(-self.y_half_range, self.y_half_range) 696 | random_theta = np.random.uniform(self.to_rad(-self.theta_half_range_deg), self.to_rad(self.theta_half_range_deg)) 697 | # Calculate absolute position by adding to origin 698 | se2_delta = [random_x, random_y, random_theta] 699 | se2_pose = [ 700 | self.origin_se2[0] + se2_delta[0], 701 | self.origin_se2[1] + se2_delta[1], 702 | self.origin_se2[2] + se2_delta[2] 703 | ] 704 | if self.check_collision(se2_pose): 705 | # print("collision, break", se2_pose) 706 | break 707 | 708 | if self.vis: 709 | try: 710 | self.visualize_occupancy_and_rectangle(se2_pose) 711 | except Exception as e: 712 | print(f"Visualization failed: {e}") 713 | 714 | return se2_delta 715 | 716 | def get_random_target_se2_with_reachability_check(self, object_pos): 717 | """Get a random target SE2 pose with reachability check""" 718 | while True: 719 | random_x = np.random.uniform(-self.x_half_range, self.x_half_range) 720 | random_y = np.random.uniform(-self.y_half_range, self.y_half_range) 721 | random_theta = np.random.uniform(self.to_rad(-self.theta_half_range_deg), self.to_rad(self.theta_half_range_deg)) 722 | # Calculate absolute position by adding to origin 723 | se2_delta = [random_x, random_y, random_theta] 724 | se2_pose = [ 725 | self.origin_se2[0] + se2_delta[0], 726 | self.origin_se2[1] + se2_delta[1], 727 | self.origin_se2[2] + se2_delta[2] 728 | ] 729 | if self.check_collision(se2_pose) and self.check_reachability(se2_pose, object_pos): 730 | break 731 | 732 | if self.vis: 733 | try: 734 | self.visualize_occupancy_and_rectangle(se2_pose, object_pos) 735 | except Exception as e: 736 | print(f"Visualization failed: {e}") 737 | 738 | return se2_delta 739 | 740 | def get_random_target_se2_with_visibility_check(self, object_pos): 741 | """Get a random target SE2 pose with visibility check""" 742 | while True: 743 | se2_delta = self.get_random_target_se2() 744 | se2_pose = [ 745 | self.origin_se2[0] + se2_delta[0], 746 | self.origin_se2[1] + se2_delta[1], 747 | self.origin_se2[2] + se2_delta[2] 748 | ] 749 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 750 | if self.check_object_visibility(se2_pose, object_pos) and self.check_boundary(se2_pose): 751 | break 752 | 753 | if self.vis: 754 | try: 755 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos) 756 | except Exception as e: 757 | print(f"Visualization failed: {e}") 758 | 759 | return se2_delta, camera_extrinsic 760 | 761 | def get_random_target_se2_with_boundary_check(self, object_pos): 762 | """Get a random target SE2 pose with visibility check""" 763 | while True: 764 | se2_delta = self.get_random_target_se2() 765 | se2_pose = [ 766 | self.origin_se2[0] + se2_delta[0], 767 | self.origin_se2[1] + se2_delta[1], 768 | self.origin_se2[2] + se2_delta[2] 769 | ] 770 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 771 | if self.check_boundary(se2_pose): 772 | break 773 | 774 | if self.vis: 775 | try: 776 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos) 777 | except Exception as e: 778 | print(f"Visualization failed: {e}") 779 | 780 | return se2_delta, camera_extrinsic 781 | 782 | def get_random_target_se2_with_visibility_check_without_boundary(self, object_pos): 783 | """Get a random target SE2 pose with visibility check""" 784 | while True: 785 | se2_delta = self.get_random_target_se2() 786 | se2_pose = [ 787 | self.origin_se2[0] + se2_delta[0], 788 | self.origin_se2[1] + se2_delta[1], 789 | self.origin_se2[2] + se2_delta[2] 790 | ] 791 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 792 | if self.check_object_visibility(se2_pose, object_pos): 793 | break 794 | 795 | if self.vis: 796 | try: 797 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos) 798 | except Exception as e: 799 | print(f"Visualization failed: {e}") 800 | 801 | return se2_delta, camera_extrinsic 802 | 803 | def get_random_target_se2_with_visibility_reachability_check(self, object_pos): 804 | """Get a random target SE2 pose with visibility check""" 805 | while True: 806 | se2_delta = self.get_random_target_se2() 807 | se2_pose = [ 808 | self.origin_se2[0] + se2_delta[0], 809 | self.origin_se2[1] + se2_delta[1], 810 | self.origin_se2[2] + se2_delta[2] 811 | ] 812 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose) 813 | if self.check_object_visibility(se2_pose, object_pos) and self.check_reachability(se2_pose, object_pos): 814 | break 815 | 816 | if self.vis: 817 | try: 818 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos) 819 | except Exception as e: 820 | print(f"Visualization failed: {e}") 821 | 822 | return se2_delta, camera_extrinsic -------------------------------------------------------------------------------- /n2m/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import torch 4 | from scipy.stats import multivariate_normal 5 | from sklearn.neighbors import KernelDensity 6 | import yaml 7 | 8 | 9 | def _load_point_cloud(file_path, use_color=True): 10 | pcd = o3d.io.read_point_cloud(file_path) 11 | point_cloud = np.asarray(pcd.points) 12 | 13 | if use_color: 14 | colors = np.asarray(pcd.colors) 15 | point_cloud = np.concatenate([point_cloud, colors], axis=1) 16 | return point_cloud 17 | 18 | def gmm_pdf(points, means, covs, weights): 19 | pdf_vals = np.zeros(len(points)) 20 | for k in range(len(weights)): 21 | mvn = multivariate_normal(mean=means[k], cov=covs[k]) 22 | pdf_vals += weights[k] * mvn.pdf(points) 23 | return pdf_vals 24 | 25 | def create_ellipsoid(mean, cov, color, scale=1.0, num_points=100000): 26 | """ 27 | Create an ellipsoid from a mean and covariance matrix 28 | 29 | Args: 30 | mean: Mean vector (3,) 31 | cov: Covariance matrix (3, 3) 32 | num_points: Number of points to generate for the ellipsoid 33 | """ 34 | # Eigen-decomposition of covariance matrix 35 | eigvals, eigvecs = np.linalg.eigh(cov) 36 | 37 | # Create a sphere 38 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1.0) 39 | sphere.compute_vertex_normals() 40 | 41 | # Scale sphere to form ellipsoid 42 | scales = scale * np.sqrt(eigvals) 43 | sphere_vertices = np.asarray(sphere.vertices) 44 | sphere_vertices = sphere_vertices @ np.diag(scales) 45 | 46 | # Rotate according to eigenvectors 47 | sphere_vertices = sphere_vertices @ eigvecs.T 48 | sphere.vertices = o3d.utility.Vector3dVector(sphere_vertices) 49 | 50 | # Translate to mean 51 | sphere.translate(mean) 52 | 53 | # Color the ellipsoid 54 | sphere.paint_uniform_color(color) 55 | 56 | return sphere 57 | 58 | def visualize_gmm_distribution(point_cloud, target_point, label, means, covs, weights): 59 | """ 60 | Visualize GMM distribution as a 3D heatmap overlaid on point cloud 61 | 62 | Args: 63 | point_cloud: Input point cloud (N, 3) 64 | means: Mean vectors for each Gaussian component (K, 3) 65 | covs: Covariance matrices for each Gaussian component (K, 3, 3) 66 | weights: Mixing coefficients for each Gaussian component (K) 67 | num_points: Number of points to generate for the heatmap 68 | true_points: List of true points to visualize as spheres (optional) 69 | """ 70 | 71 | # Create point cloud object for input points 72 | pcd = o3d.geometry.PointCloud() 73 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3]) 74 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6]) 75 | 76 | # Create ellipsoids for each Gaussian component 77 | ellipsoids = [] 78 | for i in range(len(weights)): 79 | ellipsoid = create_ellipsoid(means[i], covs[i], color=[0, 1, 0], scale=2.0) 80 | ellipsoids.append(ellipsoid) 81 | 82 | # Create coordinate frame 83 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 84 | size=0.1, origin=[0, 0, 0]) 85 | 86 | # Combine point clouds 87 | geometries = [pcd, coord_frame] + ellipsoids 88 | 89 | # Add spheres for true points if provided 90 | if target_point is not None: 91 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.02) 92 | sphere.translate(target_point) 93 | 94 | if label == 1: 95 | sphere.paint_uniform_color([0, 0, 1]) # Green color for true points 96 | else: 97 | sphere.paint_uniform_color([1, 0, 0]) # Red color for false points 98 | geometries.append(sphere) 99 | 100 | # Return all geometries 101 | return geometries 102 | 103 | def save_gmm_visualization(point_cloud, target_point, label, means, covs, weights, output_path): 104 | """ 105 | Save GMM visualization to a file 106 | 107 | Args: 108 | point_cloud: Input point cloud (N, 3) 109 | means: Mean vectors for each Gaussian component (K, 3) 110 | covs: Covariance matrices for each Gaussian component (K, 3, 3) 111 | weights: Mixing coefficients for each Gaussian component (K) 112 | output_path: Path to save the visualization 113 | num_points: Number of points to generate for the heatmap 114 | true_points: List of true points to visualize as spheres (optional) 115 | """ 116 | geometries = visualize_gmm_distribution( 117 | point_cloud, target_point, label, means, covs, weights 118 | ) 119 | 120 | # Extract combined point cloud and convert other geometries to points 121 | combined_points = [] 122 | combined_colors = [] 123 | 124 | # Add the main point cloud points and colors 125 | combined_points.append(np.asarray(geometries[0].points)) 126 | combined_colors.append(np.asarray(geometries[0].colors)) 127 | 128 | # Convert coordinate frame to points 129 | coord_mesh = geometries[1] 130 | coord_pcd = coord_mesh.sample_points_uniformly(number_of_points=1000) 131 | combined_points.append(np.asarray(coord_pcd.points)) 132 | combined_colors.append(np.asarray(coord_pcd.colors)) 133 | 134 | # Convert ellipsoids to points if they exist 135 | for ellipsoid in geometries[2:]: 136 | ellipsoid_pcd = ellipsoid.sample_points_uniformly(number_of_points=10000) 137 | combined_points.append(np.asarray(ellipsoid_pcd.points)) 138 | combined_colors.append(np.asarray(ellipsoid_pcd.colors)) 139 | 140 | # Combine all points and colors 141 | all_points = np.vstack(combined_points) 142 | all_colors = np.vstack(combined_colors) 143 | 144 | all_colors = np.clip(all_colors, 0, 1) 145 | 146 | # Create final point cloud 147 | final_pcd = o3d.geometry.PointCloud() 148 | final_pcd.points = o3d.utility.Vector3dVector(all_points) 149 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors) 150 | 151 | # Save to file 152 | o3d.io.write_point_cloud(output_path, final_pcd) 153 | 154 | 155 | def save_gmm_visualization_se2(point_cloud, target_se2, label, means, covs, weights, output_path, interval=0.1, area=[-10, 10, -10, 10], threshold=0.5, z_value=0.8): 156 | """ 157 | Given a point cloud and mean and covariance of SE(2) points, visualize the point cloud and the SE(2) points. 158 | For each grid points in the xy area with interval, calculate the most probable theta 159 | And then with SE(2) point (x, y, theta), calculate the probalility of that SE(2) point 160 | Based on the probability, color the grid and visualize the arrow of theta. 161 | Repeat this process for all grid points and save this with the point cloud. 162 | 163 | Args: 164 | point_cloud: Input point cloud (N, 3) 165 | target_se2: Target SE(2) point to visualize 166 | label: Label of the target point (0 or 1) 167 | means: Mean vectors for each Gaussian component (K, 3) 168 | covs: Covariance matrices for each Gaussian component (K, 3, 3) 169 | weights: Mixing coefficients for each Gaussian component (K) 170 | interval: Interval of the grid to visualize the SE(2) points 171 | area: Area to visualize the SE(2) points 172 | threshold: Threshold of the probability to visualize the SE(2) points 173 | 174 | Returns: 175 | no returns, save the visualization to output_path. 176 | """ 177 | # Create a grid of points in the xy area 178 | x = np.arange(area[0], area[1] + interval, interval) 179 | y = np.arange(area[2], area[3] + interval, interval) 180 | X, Y = np.meshgrid(x, y) 181 | grid_points = np.column_stack((X.ravel(), Y.ravel())) 182 | 183 | # Initialize arrays to store results 184 | n_grid_points = len(grid_points) 185 | theta_samples = np.linspace(-np.pi, np.pi, 36) # Sample 36 different angles 186 | 187 | # Create all possible combinations of grid points and thetas 188 | # Shape: (n_grid_points * n_thetas, 3) 189 | grid_expanded = np.repeat(grid_points, len(theta_samples), axis=0) 190 | theta_expanded = np.tile(theta_samples, n_grid_points) 191 | all_se2_points = np.column_stack((grid_expanded, theta_expanded)) 192 | 193 | # Calculate probabilities for all points at once 194 | all_probs = gmm_pdf(all_se2_points, means, covs, weights) 195 | 196 | # Reshape probabilities to (n_grid_points, n_thetas) 197 | probs_matrix = all_probs.reshape(n_grid_points, len(theta_samples)) 198 | 199 | # Find the best theta and probability for each grid point 200 | best_probs = np.max(probs_matrix, axis=1) 201 | best_theta_indices = np.argmax(probs_matrix, axis=1) 202 | best_thetas = theta_samples[best_theta_indices] 203 | 204 | # Filter out probabilities below threshold and normalize probabilities to [0,1] 205 | prob_min, prob_max = best_probs.min(), best_probs.max() 206 | best_probs = (best_probs - prob_min) / (prob_max - prob_min) 207 | 208 | # Create visualization geometries 209 | geometries = [] 210 | 211 | # Add original point cloud 212 | pcd = o3d.geometry.PointCloud() 213 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3]) 214 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6]) 215 | geometries.append(pcd) 216 | 217 | # Add coordinate frame 218 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 219 | geometries.append(coord_frame) 220 | 221 | # Add target SE(2) point if provided 222 | if target_se2 is not None: 223 | # Create sphere for position 224 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/4) 225 | sphere.translate([target_se2[0], target_se2[1], z_value]) # Use x,y coordinates 226 | if label == 1: 227 | sphere.paint_uniform_color([0, 0, 1]) # Blue for positive label 228 | else: 229 | sphere.paint_uniform_color([1, 0, 0]) # Red for negative label 230 | geometries.append(sphere) 231 | 232 | # Create arrow for theta 233 | arrow_length = interval * 1 234 | arrow = o3d.geometry.TriangleMesh.create_arrow( 235 | cylinder_radius=interval/16, 236 | cone_radius=interval/12, 237 | cylinder_height=arrow_length*0.8, 238 | cone_height=arrow_length*0.2 239 | ) 240 | 241 | # Rotate and translate arrow to position 242 | R_x = np.array([ 243 | [0, 0, 1], 244 | [0, 1, 0], 245 | [1, 0, 0] 246 | ]) 247 | R = np.array([ 248 | [np.cos(target_se2[2]), -np.sin(target_se2[2]), 0], 249 | [np.sin(target_se2[2]), np.cos(target_se2[2]), 0], 250 | [0, 0, 1] 251 | ]) 252 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis 253 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta 254 | arrow.translate([target_se2[0], target_se2[1], z_value]) 255 | arrow.paint_uniform_color([0, 0, 1] if label == 1 else [1, 0, 0]) 256 | geometries.append(arrow) 257 | 258 | # Add grid points and arrows for points above threshold 259 | arrow_length = interval * 1 260 | for i in range(len(best_probs)): 261 | if best_probs[i] > threshold: 262 | # Create sphere for grid point 263 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/8) 264 | sphere.translate([grid_points[i,0], grid_points[i,1], z_value]) 265 | 266 | # Color based on probability (blue->red gradient) 267 | red_value = (best_probs[i] - threshold) / (1 - threshold) * 0.99 268 | color = [1-red_value, red_value, 0] 269 | sphere.paint_uniform_color(color) 270 | geometries.append(sphere) 271 | 272 | # Create arrow for theta 273 | arrow = o3d.geometry.TriangleMesh.create_arrow( 274 | cylinder_radius=interval/16, 275 | cone_radius=interval/12, 276 | cylinder_height=arrow_length*0.8, 277 | cone_height=arrow_length*0.2 278 | ) 279 | 280 | # Rotate and translate arrow to position 281 | R_x = np.array([ 282 | [0, 0, 1], 283 | [0, 1, 0], 284 | [1, 0, 0] 285 | ]) 286 | R = np.array([ 287 | [np.cos(best_thetas[i]), -np.sin(best_thetas[i]), 0], 288 | [np.sin(best_thetas[i]), np.cos(best_thetas[i]), 0], 289 | [0, 0, 1] 290 | ]) 291 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis 292 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta 293 | arrow.translate([grid_points[i,0], grid_points[i,1], z_value]) 294 | arrow.paint_uniform_color(color) 295 | geometries.append(arrow) 296 | 297 | # Convert all geometries to point clouds and combine 298 | combined_points = [] 299 | combined_colors = [] 300 | 301 | for geom in geometries: 302 | if isinstance(geom, o3d.geometry.PointCloud): 303 | combined_points.append(np.asarray(geom.points)) 304 | combined_colors.append(np.asarray(geom.colors)) 305 | else: 306 | # Sample points from mesh 307 | pcd = geom.sample_points_uniformly(number_of_points=500) 308 | combined_points.append(np.asarray(pcd.points)) 309 | combined_colors.append(np.asarray(pcd.colors)) 310 | 311 | # Combine all points and colors 312 | all_points = np.vstack(combined_points) 313 | all_colors = np.vstack(combined_colors) 314 | all_colors = np.clip(all_colors, 0, 1) 315 | 316 | # Create final point cloud 317 | final_pcd = o3d.geometry.PointCloud() 318 | final_pcd.points = o3d.utility.Vector3dVector(all_points) 319 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors) 320 | 321 | # Save to file 322 | o3d.io.write_point_cloud(output_path, final_pcd) 323 | 324 | def save_gmm_visualization_xythetaz(point_cloud, target_xythetaz, label, means, covs, weights, output_path, interval=0.1, area=[-10, 10, -10, 10, -10, 10], threshold=0.5): 325 | """ 326 | Given a point cloud and mean and covariance of SE(2)+Z points, visualize the point cloud and the SE(2)+Z points. 327 | For each grid points in the xyz area with interval, calculate the most probable theta 328 | And then with SE(2)+Z point (x, y, theta, z), calculate the probability of that point 329 | Based on the probability, color the grid and visualize the arrow of theta. 330 | Repeat this process for all grid points and save this with the point cloud. 331 | 332 | Args: 333 | point_cloud: Input point cloud (N, 3) 334 | target_xythetaz: Target SE(2)+Z point to visualize (x,y,theta,z) 335 | label: Label of the target point (0 or 1) 336 | means: Mean vectors for each Gaussian component (K, 4) 337 | covs: Covariance matrices for each Gaussian component (K, 4, 4) 338 | weights: Mixing coefficients for each Gaussian component (K) 339 | interval: Interval of the grid to visualize the points 340 | area: Area to visualize the points [xmin, xmax, ymin, ymax, zmin, zmax] 341 | threshold: Threshold of the probability to visualize the points 342 | """ 343 | # Create a grid of points in the xyz area 344 | x = np.arange(area[0], area[1] + interval, interval) 345 | y = np.arange(area[2], area[3] + interval, interval) 346 | z = np.arange(area[4], area[5] + interval, interval) 347 | X, Y, Z = np.meshgrid(x, y, z) 348 | grid_points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel())) 349 | 350 | # Initialize arrays to store results 351 | n_grid_points = len(grid_points) 352 | theta_samples = np.linspace(-np.pi, np.pi, 36) # Sample 36 different angles 353 | 354 | # Create all possible combinations of grid points and thetas 355 | # Shape: (n_grid_points * n_thetas, 4) 356 | grid_expanded = np.repeat(grid_points, len(theta_samples), axis=0) 357 | theta_expanded = np.tile(theta_samples, n_grid_points) 358 | # Rearrange to x,y,theta,z format 359 | all_xythetaz_points = np.column_stack(( 360 | grid_expanded[:, 0], # x 361 | grid_expanded[:, 1], # y 362 | theta_expanded, # theta 363 | grid_expanded[:, 2] # z 364 | )) 365 | 366 | # Calculate probabilities for all points at once 367 | all_probs = gmm_pdf(all_xythetaz_points, means, covs, weights) 368 | 369 | # Reshape probabilities to (n_grid_points, n_thetas) 370 | probs_matrix = all_probs.reshape(n_grid_points, len(theta_samples)) 371 | 372 | # Find the best theta and probability for each grid point 373 | best_probs = np.max(probs_matrix, axis=1) 374 | best_theta_indices = np.argmax(probs_matrix, axis=1) 375 | best_thetas = theta_samples[best_theta_indices] 376 | 377 | # Filter out probabilities below threshold and normalize probabilities to [0,1] 378 | prob_min, prob_max = best_probs.min(), best_probs.max() 379 | best_probs = (best_probs - prob_min) / (prob_max - prob_min) 380 | 381 | # Create visualization geometries 382 | geometries = [] 383 | 384 | # Add original point cloud 385 | pcd = o3d.geometry.PointCloud() 386 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3]) 387 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6]) 388 | geometries.append(pcd) 389 | 390 | # Add coordinate frame 391 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0]) 392 | geometries.append(coord_frame) 393 | 394 | # Add target xythetaz point if provided 395 | if target_xythetaz is not None: 396 | # Create sphere for position 397 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/4) 398 | sphere.translate([target_xythetaz[0], target_xythetaz[1], target_xythetaz[3]]) # Use x,y,z coordinates 399 | if label == 1: 400 | sphere.paint_uniform_color([0, 0, 1]) # Blue for positive label 401 | else: 402 | sphere.paint_uniform_color([1, 0, 0]) # Red for negative label 403 | geometries.append(sphere) 404 | 405 | # Create arrow for theta 406 | arrow_length = interval * 1 407 | arrow = o3d.geometry.TriangleMesh.create_arrow( 408 | cylinder_radius=interval/16, 409 | cone_radius=interval/12, 410 | cylinder_height=arrow_length*0.8, 411 | cone_height=arrow_length*0.2 412 | ) 413 | 414 | # Rotate and translate arrow to position 415 | R_x = np.array([ 416 | [0, 0, 1], 417 | [0, 1, 0], 418 | [1, 0, 0] 419 | ]) 420 | R = np.array([ 421 | [np.cos(target_xythetaz[2]), -np.sin(target_xythetaz[2]), 0], 422 | [np.sin(target_xythetaz[2]), np.cos(target_xythetaz[2]), 0], 423 | [0, 0, 1] 424 | ]) 425 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis 426 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta 427 | arrow.translate([target_xythetaz[0], target_xythetaz[1], target_xythetaz[3]]) 428 | arrow.paint_uniform_color([0, 0, 1] if label == 1 else [1, 0, 0]) 429 | geometries.append(arrow) 430 | 431 | # Add grid points and arrows for points above threshold 432 | arrow_length = interval * 1 433 | for i in range(len(best_probs)): 434 | if best_probs[i] > threshold: 435 | # Create sphere for grid point 436 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/8) 437 | sphere.translate([grid_points[i,0], grid_points[i,1], grid_points[i,2]]) 438 | 439 | # Color based on probability (blue->red gradient) 440 | red_value = (best_probs[i] - threshold) / (1 - threshold) * 0.99 441 | color = [1-red_value, red_value, 0] 442 | sphere.paint_uniform_color(color) 443 | geometries.append(sphere) 444 | 445 | # Create arrow for theta 446 | arrow = o3d.geometry.TriangleMesh.create_arrow( 447 | cylinder_radius=interval/16, 448 | cone_radius=interval/12, 449 | cylinder_height=arrow_length*0.8, 450 | cone_height=arrow_length*0.2 451 | ) 452 | 453 | # Rotate and translate arrow to position 454 | R_x = np.array([ 455 | [0, 0, 1], 456 | [0, 1, 0], 457 | [1, 0, 0] 458 | ]) 459 | R = np.array([ 460 | [np.cos(best_thetas[i]), -np.sin(best_thetas[i]), 0], 461 | [np.sin(best_thetas[i]), np.cos(best_thetas[i]), 0], 462 | [0, 0, 1] 463 | ]) 464 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis 465 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta 466 | arrow.translate([grid_points[i,0], grid_points[i,1], grid_points[i,2]]) 467 | arrow.paint_uniform_color(color) 468 | geometries.append(arrow) 469 | 470 | # Convert all geometries to point clouds and combine 471 | combined_points = [] 472 | combined_colors = [] 473 | 474 | for geom in geometries: 475 | if isinstance(geom, o3d.geometry.PointCloud): 476 | combined_points.append(np.asarray(geom.points)) 477 | combined_colors.append(np.asarray(geom.colors)) 478 | else: 479 | # Sample points from mesh 480 | pcd = geom.sample_points_uniformly(number_of_points=500) 481 | combined_points.append(np.asarray(pcd.points)) 482 | combined_colors.append(np.asarray(pcd.colors)) 483 | 484 | # Combine all points and colors 485 | all_points = np.vstack(combined_points) 486 | all_colors = np.vstack(combined_colors) 487 | all_colors = np.clip(all_colors, 0, 1) 488 | 489 | # Create final point cloud 490 | final_pcd = o3d.geometry.PointCloud() 491 | final_pcd.points = o3d.utility.Vector3dVector(all_points) 492 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors) 493 | 494 | # Save to file 495 | o3d.io.write_point_cloud(output_path, final_pcd) 496 | 497 | if __name__ == "__main__": 498 | # Load data 499 | point_cloud = _load_point_cloud("/home/hyunjun/projects/CoRL2025/nav2man/exp_pcl_train/data/point_clouds/0001.ply") 500 | means = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] 501 | covs = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]] 502 | weights = [0.1, 0.4, 0.5] 503 | save_gmm_visualization(point_cloud, means, covs, weights, "data/gmm_visualization") -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "n2m" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | dependencies = [ 8 | "easydict>=1.13", 9 | "numpy==1.23.5", 10 | "open3d>=0.19.0", 11 | "termcolor>=2.4.0", 12 | "timm>=1.0.17", 13 | "torch==2.2.1", 14 | "torchaudio==2.2.1", 15 | "torchvision==0.17.1", 16 | "wandb>=0.21.0", 17 | ] 18 | -------------------------------------------------------------------------------- /scripts/process_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python scripts/sample_camera_poses.py --dataset_path $1 --num_poses 300 --vis 4 | scripts/render/build/fpv_render --dataset_path $1 -------------------------------------------------------------------------------- /scripts/render/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(PointCloudRenderer) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 6 | set(CMAKE_BUILD_TYPE Release) 7 | 8 | # Find packages 9 | find_package(PCL 1.8 REQUIRED) 10 | find_package(Eigen3 REQUIRED) 11 | 12 | # Include PCL definitions and directories 13 | include_directories(${PCL_INCLUDE_DIRS}) 14 | link_directories(${PCL_LIBRARY_DIRS}) 15 | add_definitions(${PCL_DEFINITIONS}) 16 | 17 | # Add executable 18 | add_executable(fpv_render fpv_render.cpp) 19 | 20 | # Link libraries 21 | target_link_libraries(fpv_render 22 | ${PCL_LIBRARIES} 23 | Eigen3::Eigen 24 | stdc++fs # Add filesystem library 25 | ) 26 | 27 | # Include directories 28 | target_include_directories(fpv_render PRIVATE 29 | ${PCL_INCLUDE_DIRS} 30 | ${EIGEN3_INCLUDE_DIR} 31 | ${PROJECT_SOURCE_DIR}/include 32 | ) 33 | 34 | # Compiler flags 35 | if(CMAKE_COMPILER_IS_GNUCXX) 36 | target_compile_options(fpv_render PRIVATE -O3 -march=native) 37 | endif() -------------------------------------------------------------------------------- /scripts/render/README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud Renderer 2 | 3 | This is a C++ implementation of a point cloud renderer with occlusion handling, converted from a Python implementation for improved performance. It uses the Point Cloud Library (PCL) for point cloud processing and visualization. 4 | 5 | ## Performance Comparison 6 | 7 | The C++ implementation is expected to be significantly faster than the Python version due to: 8 | 9 | 1. Native code execution without Python interpreter overhead 10 | 2. More efficient memory management 11 | 3. Better optimization with compiler flags (`-O3 -march=native`) 12 | 4. Efficient point cloud operations using PCL's optimized algorithms 13 | 5. Reduced memory allocation overhead in the point processing loop 14 | 15 | ## Requirements 16 | 17 | - C++14 compatible compiler (GCC or Clang recommended) 18 | - CMake (version 3.10 or higher) 19 | - Point Cloud Library (PCL) version 1.8 or higher 20 | - librealsense2 SDK 21 | - Eigen3 library 22 | 23 | ## Installation 24 | 25 | ### Installing Dependencies 26 | 27 | #### Ubuntu/Debian: 28 | 29 | ```bash 30 | # Install build tools 31 | sudo apt-get update 32 | sudo apt-get install build-essential cmake 33 | 34 | # Install PCL and dependencies 35 | sudo apt-get install libpcl-dev 36 | 37 | # Install Eigen3 38 | sudo apt-get install libeigen3-dev 39 | 40 | # Install librealsense (following Intel's guide) 41 | # See: https://github.com/IntelRealSense/librealsense/blob/master/doc/installation.md 42 | sudo apt-key adv --keyserver keys.gnupg.net --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE || sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE 43 | sudo add-apt-repository "deb https://librealsense.intel.com/Debian/apt-repo $(lsb_release -cs) main" -u 44 | sudo apt-get install librealsense2-dkms librealsense2-utils librealsense2-dev 45 | ``` 46 | 47 | ## Building the Project 48 | 49 | ```bash 50 | # Clone this repository 51 | git clone https://github.com/yourusername/point-cloud-renderer 52 | cd point-cloud-renderer 53 | 54 | # Create build directory 55 | mkdir build && cd build 56 | 57 | # Configure and build 58 | cmake .. 59 | make -j$(nproc) 60 | ``` 61 | 62 | ## Usage 63 | 64 | Make sure you have a point cloud file named `7.pcd` in the current directory, or modify the source code to point to your PCD file. 65 | 66 | ```bash 67 | # Run the renderer 68 | ./fpv_render 69 | ``` 70 | 71 | The program will: 72 | 1. Load the specified point cloud file 73 | 2. Render it from the specified camera perspective 74 | 3. Save the result as `rendered_view.pcd` 75 | 4. Display the point cloud with coordinate frames in a visualization window 76 | 77 | ## Comparing with Python Version 78 | 79 | To compare performance: 80 | 81 | 1. Run the Python version and note the reported timing 82 | 2. Run the C++ version and note the reported timing 83 | 3. Compare the quality of the rendered point clouds 84 | 85 | The C++ implementation using PCL is expected to show significant performance improvements, especially for large point clouds, compared to the Python implementation using Open3D. PCL's native C++ implementation avoids Python's interpreter overhead and benefits from highly optimized algorithms specifically designed for point cloud processing. -------------------------------------------------------------------------------- /scripts/render/fpv_render.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | // PCL includes 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using json = nlohmann::json; 19 | namespace fs = std::filesystem; 20 | 21 | namespace nlohmann { 22 | // Convert Eigen::Vector3f to JSON 23 | void to_json(json& j, const Eigen::Vector3f& v) { 24 | j = json::array({v(0), v(1), v(2)}); 25 | } 26 | 27 | // Convert JSON to Eigen::Vector3f 28 | void from_json(const json& j, Eigen::Vector3f& v) { 29 | v(0) = j[0].get(); 30 | v(1) = j[1].get(); 31 | v(2) = j[2].get(); 32 | } 33 | 34 | // Convert Eigen::Vector4f to JSON 35 | void to_json(json& j, const Eigen::Vector4f& v) { 36 | j = json::array({v(0), v(1), v(2), v(3)}); 37 | } 38 | 39 | // Convert JSON to Eigen::Vector4f 40 | void from_json(const json& j, Eigen::Vector4f& v) { 41 | v(0) = j[0].get(); 42 | v(1) = j[1].get(); 43 | v(2) = j[2].get(); 44 | v(3) = j[3].get(); 45 | } 46 | } 47 | 48 | class CameraIntrinsics { 49 | public: 50 | float fx, fy, cx, cy; 51 | int width, height; 52 | 53 | CameraIntrinsics(float fx, float fy, float cx, float cy, int width, int height) 54 | : fx(fx), fy(fy), cx(cx), cy(cy), width(width), height(height) {} 55 | }; 56 | 57 | template 58 | void transformPoints(const pcl::PointCloud& source, 59 | pcl::PointCloud& target, 60 | const Eigen::Matrix4f& transform) { 61 | pcl::transformPointCloud(source, target, transform); 62 | } 63 | 64 | template 65 | typename pcl::PointCloud::Ptr translatePointCloud( 66 | const typename pcl::PointCloud::Ptr& pointCloud, 67 | const Eigen::Matrix4f& basePose 68 | ) { 69 | // Create output point cloud 70 | typename pcl::PointCloud::Ptr transformedCloud(new pcl::PointCloud()); 71 | 72 | // Create inverse transformation to move base pose to origin 73 | Eigen::Matrix4f inverseTransform = basePose.inverse(); 74 | 75 | // Transform the point cloud 76 | pcl::transformPointCloud(*pointCloud, *transformedCloud, inverseTransform); 77 | 78 | return transformedCloud; 79 | } 80 | 81 | 82 | 83 | template 84 | Eigen::Vector3f translateTarget( 85 | const Eigen::Vector3f& target_se2, 86 | const Eigen::Matrix4f& basePose 87 | ) { 88 | // SE(2) pose is [x, y, theta] 89 | float x = target_se2(0); 90 | float y = target_se2(1); 91 | float theta = target_se2(2); 92 | 93 | // Create SE(2) transformation matrix 94 | Eigen::Matrix3f se2_transform; 95 | se2_transform << cos(theta), -sin(theta), x, 96 | sin(theta), cos(theta), y, 97 | 0, 0, 1; 98 | 99 | // Get base SE(2) transformation 100 | float base_theta = atan2(basePose(1,0), basePose(0,0)); 101 | Eigen::Matrix3f base_se2; 102 | base_se2 << cos(base_theta), -sin(base_theta), basePose(0,3), 103 | sin(base_theta), cos(base_theta), basePose(1,3), 104 | 0, 0, 1; 105 | 106 | // Transform SE(2) pose 107 | Eigen::Matrix3f transformed_se2 = base_se2.inverse() * se2_transform; 108 | 109 | // Extract transformed SE(2) parameters 110 | Eigen::Vector3f translated_se2; 111 | translated_se2(0) = transformed_se2(0,2); 112 | translated_se2(1) = transformed_se2(1,2); 113 | translated_se2(2) = atan2(transformed_se2(1,0), transformed_se2(0,0)); 114 | 115 | return translated_se2; 116 | } 117 | 118 | template 119 | typename pcl::PointCloud::Ptr renderPointCloudWithOcclusion( 120 | const typename pcl::PointCloud::Ptr& pointCloud, 121 | const Eigen::Matrix4f& cameraPose, 122 | const CameraIntrinsics& intrinsics, 123 | float minDepth = 0.1, 124 | float maxDepth = 10.0 125 | ) { 126 | // Transform points to camera frame 127 | Eigen::Matrix4f cameraToWorld = cameraPose.inverse(); 128 | typename pcl::PointCloud::Ptr cloudCamera(new pcl::PointCloud()); 129 | transformPoints(*pointCloud, *cloudCamera, cameraToWorld); 130 | 131 | // Initialize depth buffer 132 | std::vector> depthBuffer(intrinsics.width, 133 | std::vector(intrinsics.height, 134 | std::numeric_limits::infinity())); 135 | std::vector> pointIndices(intrinsics.width, 136 | std::vector(intrinsics.height, -1)); 137 | 138 | auto timeStart = std::chrono::high_resolution_clock::now(); 139 | double tTotal = 0.0; 140 | 141 | // Process each point 142 | for (size_t i = 0; i < cloudCamera->points.size(); ++i) { 143 | auto pointStart = std::chrono::high_resolution_clock::now(); 144 | 145 | const auto& point = cloudCamera->points[i]; 146 | 147 | // Transform coordinates to match Python implementation 148 | // In Python, camera looks along Y axis, with Z up and X right 149 | float x = point.x; // X stays the same 150 | float y = point.y; // flip y 151 | float z = point.z; // flip z 152 | 153 | // Skip if behind camera or invalid 154 | if (z <= 0 || !std::isfinite(x) || !std::isfinite(y) || !std::isfinite(z)) { 155 | continue; 156 | } 157 | 158 | // Project point to image plane 159 | float u = (x * intrinsics.fx / z) + intrinsics.cx; 160 | float v = (y * intrinsics.fy / z) + intrinsics.cy; 161 | 162 | // Skip if outside image bounds 163 | if (u < 0 || u >= intrinsics.width || v < 0 || v >= intrinsics.height) { 164 | continue; 165 | } 166 | 167 | // Get pixel coordinates 168 | int pixelX = static_cast(u); 169 | int pixelY = static_cast(v); 170 | 171 | // Skip if outside depth range 172 | float depth = z; 173 | if (depth < minDepth || depth > maxDepth) { 174 | continue; 175 | } 176 | 177 | // Update depth buffer if point is closer 178 | if (depth < depthBuffer[pixelX][pixelY]) { 179 | depthBuffer[pixelX][pixelY] = depth; 180 | pointIndices[pixelX][pixelY] = i; 181 | } 182 | 183 | auto pointEnd = std::chrono::high_resolution_clock::now(); 184 | tTotal += std::chrono::duration(pointEnd - pointStart).count(); 185 | } 186 | 187 | std::cout << "Time taken for point processing: " << tTotal * 1000 << " ms" << std::endl; 188 | 189 | // Create output point cloud 190 | std::vector validIndices; 191 | for (const auto& row : pointIndices) { 192 | for (int idx : row) { 193 | if (idx >= 0) { 194 | validIndices.push_back(idx); 195 | } 196 | } 197 | } 198 | 199 | if (validIndices.empty()) { 200 | return typename pcl::PointCloud::Ptr(new pcl::PointCloud()); 201 | } 202 | 203 | typename pcl::PointCloud::Ptr renderedCloud(new pcl::PointCloud()); 204 | renderedCloud->reserve(validIndices.size()); 205 | 206 | // Extract visible points 207 | for (int idx : validIndices) { 208 | renderedCloud->points.push_back(pointCloud->points[idx]); 209 | } 210 | 211 | renderedCloud->width = renderedCloud->points.size(); 212 | renderedCloud->height = 1; 213 | renderedCloud->is_dense = false; 214 | 215 | return renderedCloud; 216 | } 217 | 218 | 219 | 220 | Eigen::Matrix4f createCameraPoseFromTranslationQuaternion( 221 | float tx, float ty, float tz, 222 | float qw, float qx, float qy, float qz 223 | ) { 224 | // Normalize quaternion 225 | float norm = std::sqrt(qw*qw + qx*qx + qy*qy + qz*qz); 226 | qw /= norm; 227 | qx /= norm; 228 | qy /= norm; 229 | qz /= norm; 230 | 231 | // Quaternion to rotation matrix conversion 232 | Eigen::Matrix3f R; 233 | R << 1 - 2*qy*qy - 2*qz*qz, 2*qx*qy - 2*qw*qz, 2*qx*qz + 2*qw*qy, 234 | 2*qx*qy + 2*qw*qz, 1 - 2*qx*qx - 2*qz*qz, 2*qy*qz - 2*qw*qx, 235 | 2*qx*qz - 2*qw*qy, 2*qy*qz + 2*qw*qx, 1 - 2*qx*qx - 2*qy*qy; 236 | 237 | // Create 4x4 transformation matrix 238 | Eigen::Matrix4f pose = Eigen::Matrix4f::Identity(); 239 | pose.block<3,3>(0,0) = R; 240 | pose(0,3) = tx; 241 | pose(1,3) = ty; 242 | pose(2,3) = tz; 243 | 244 | return pose; 245 | } 246 | 247 | int main(int argc, char** argv) { 248 | if (argc != 2) { 249 | std::cerr << "Usage: " << argv[0] << " " << std::endl; 250 | return 1; 251 | } 252 | 253 | std::string dataset_path = argv[1]; 254 | std::string pcl_dir = dataset_path + "/pcl"; 255 | std::string meta_path = dataset_path + "/meta.json"; 256 | std::string output_pcl_dir = dataset_path + "/pcl_aug"; 257 | std::string output_meta_path = dataset_path + "/meta_aug.json"; 258 | std::string camera_poses_path = dataset_path + "/camera_poses/camera_poses.json"; 259 | std::string base_poses_path = dataset_path + "/camera_poses/base_poses.json"; 260 | 261 | // Create output directory if it doesn't exist 262 | fs::create_directories(output_pcl_dir); 263 | 264 | std::ifstream camera_poses_file(camera_poses_path); 265 | json camera_poses; 266 | camera_poses_file >> camera_poses; 267 | 268 | std::ifstream meta_file(meta_path); 269 | json meta; 270 | meta_file >> meta; 271 | 272 | std::ifstream base_poses_file(base_poses_path); 273 | json base_poses; 274 | base_poses_file >> base_poses; 275 | 276 | // Read camera intrinsics from meta.json 277 | std::vector cam_intr; 278 | if (meta["meta"].contains("camera_intrinsic")) { 279 | for (const auto& v : meta["meta"]["camera_intrinsic"]) { 280 | cam_intr.push_back(v.get()); 281 | } 282 | } else { 283 | std::cerr << "camera_intrinsic not found in meta.json, using default values." << std::endl; 284 | cam_intr = {100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240}; 285 | } 286 | if (cam_intr.size() != 6) { 287 | std::cerr << "camera_intrinsic in meta.json must have 6 values (fx, fy, cx, cy, width, height)." << std::endl; 288 | return 1; 289 | } 290 | CameraIntrinsics intrinsics( 291 | cam_intr[0], cam_intr[1], cam_intr[2], cam_intr[3], (int)cam_intr[4], (int)cam_intr[5] 292 | ); 293 | 294 | json meta_aug; 295 | meta_aug["meta"] = meta["meta"]; 296 | meta_aug["episodes"] = json::array(); 297 | 298 | auto timeStart = std::chrono::high_resolution_clock::now(); 299 | for (const auto& episode : meta["episodes"]) { 300 | // Load point cloud 301 | std::string pcl_path = fs::path(dataset_path) / episode["file_path"]; 302 | pcl::PointCloud::Ptr pointCloud(new pcl::PointCloud); 303 | if (pcl::io::loadPCDFile(pcl_path, *pointCloud) == -1) { 304 | std::cerr << "Failed to load point cloud file: " << pcl_path << std::endl; 305 | continue; 306 | } 307 | 308 | int episode_id = episode["id"]; 309 | for (int i = 0; i < camera_poses[episode_id].size(); i++) { 310 | json episode_aug = episode; 311 | 312 | std::string file_name = std::to_string(episode_id) + "_" + std::to_string(i) + ".pcd"; 313 | std::string output_path = (fs::path(output_pcl_dir) / file_name).string(); 314 | if (fs::exists(output_path)) { 315 | std::cout << "Skipping: " << output_path << std::endl; 316 | continue; 317 | } 318 | 319 | // define camera pose 320 | Eigen::Matrix4f cameraPose = Eigen::Matrix4f::Identity(); 321 | for (int j = 0; j < 4; j++) { 322 | for (int k = 0; k < 4; k++) { 323 | cameraPose(j, k) = camera_poses[episode_id][i][j][k]; 324 | } 325 | }; 326 | 327 | // translate point cloud 328 | Eigen::Matrix4f basePose = Eigen::Matrix4f::Identity(); 329 | for (int j = 0; j < 4; j++) { 330 | for (int k = 0; k < 4; k++) { 331 | basePose(j, k) = base_poses[episode_id][i][j][k]; 332 | } 333 | }; 334 | 335 | // Render point cloud 336 | auto renderedCloud = renderPointCloudWithOcclusion( 337 | pointCloud, 338 | cameraPose, 339 | intrinsics, 340 | 0.1f, 341 | 10.0f 342 | ); 343 | 344 | // translate point cloud 345 | auto translatedCloud = translatePointCloud(renderedCloud, basePose); 346 | 347 | // translate target 348 | Eigen::Vector3f pose_se2; 349 | for (int idx = 0; idx < 3; ++idx) { 350 | pose_se2(idx) = episode_aug["pose"][idx].get(); 351 | } 352 | auto translated_se2 = translateTarget(pose_se2, basePose); 353 | episode_aug["pose"] = translated_se2; 354 | 355 | // Save rendered point cloud 356 | pcl::io::savePCDFile(output_path, *translatedCloud); 357 | 358 | episode_aug["file_path"] = "pcl_aug/" + file_name; 359 | meta_aug["episodes"].push_back(episode_aug); 360 | 361 | auto timeEnd = std::chrono::high_resolution_clock::now(); 362 | auto duration = std::chrono::duration(timeEnd - timeStart).count(); 363 | 364 | std::cout << "Rendered point cloud saved to: " << output_path << std::endl; 365 | std::cout << "Number of points: " << renderedCloud->size() << std::endl; 366 | std::cout << "Total time: " << duration << " s" << std::endl; 367 | std::cout << "----------------------------------------" << std::endl; 368 | } 369 | } 370 | 371 | // Save augmented meta file 372 | std::ofstream output_meta_file(output_meta_path); 373 | output_meta_file << meta_aug.dump(4); 374 | std::cout << "Augmented meta file saved to: " << output_meta_path << std::endl; 375 | 376 | return 0; 377 | } -------------------------------------------------------------------------------- /scripts/sample_camera_poses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import open3d as o3d 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from n2m.utils.sample_utils import TargetHelper 9 | 10 | def save_pose_visualization(pcl, poses, furniture_pos, save_path): 11 | # Create point cloud for the scene 12 | pcd = o3d.geometry.PointCloud() 13 | pcd.points = o3d.utility.Vector3dVector(pcl.points) 14 | pcd.colors = o3d.utility.Vector3dVector(pcl.colors) 15 | 16 | # Add furniture position as red points 17 | furniture_pcd = o3d.geometry.PointCloud() 18 | furniture_pcd.points = o3d.utility.Vector3dVector([furniture_pos[:3]]) 19 | furniture_pcd.colors = o3d.utility.Vector3dVector([[1, 0, 0]]) # Red 20 | pcd = pcd + furniture_pcd 21 | 22 | for pose in poses: 23 | # Get axis endpoints in world coordinates 24 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1) 25 | camera_frame.transform(pose) 26 | camera_frame_points = camera_frame.sample_points_uniformly(number_of_points=300) 27 | pcd = pcd + camera_frame_points 28 | 29 | # Save combined point cloud 30 | o3d.io.write_point_cloud(save_path, pcd) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--dataset_path", type=str, default="datasets/rollouts/PnPCounterToCab_BCtransformer_rollouts/PnPCounterToCab_BCtransformer_rollout_scene1/20250430230003") 36 | parser.add_argument("--num_poses", type=int, default=10) 37 | parser.add_argument("--vis", action="store_true") 38 | args = parser.parse_args() 39 | 40 | print("loading meta...") 41 | meta_path = os.path.join(args.dataset_path, "meta.json") 42 | with open(meta_path, "r") as f: 43 | meta = json.load(f) 44 | episodes = meta['episodes'] 45 | print("meta loaded") 46 | 47 | T_base_to_cam = meta["meta"]["T_base_to_cam"] 48 | camera_intrinsic = meta["meta"]["camera_intrinsic"] 49 | 50 | camera_pose_save_dir = os.path.join(args.dataset_path, "camera_poses") 51 | os.makedirs(camera_pose_save_dir, exist_ok=True) 52 | 53 | if args.vis: 54 | camera_pose_vis_dir = os.path.join(args.dataset_path, "camera_poses_vis") 55 | os.makedirs(camera_pose_vis_dir, exist_ok=True) 56 | 57 | camera_poses = [] 58 | base_poses = [] 59 | for episode in tqdm(episodes): 60 | pcl_path = os.path.join(args.dataset_path, episode['file_path']) 61 | pcl = o3d.io.read_point_cloud(pcl_path) 62 | pose = episode['pose'] 63 | se2_origin = pose 64 | 65 | # load object position if exists, otherwise set it in front of the robot 66 | if "object_position" in episode: 67 | object_position = episode['object_position'] 68 | else: 69 | object_position = pose + [0.5 * np.cos(pose[2]), 0.5 * np.sin(pose[2]), 0] 70 | object_position = np.array(object_position) 71 | 72 | target_helper = TargetHelper(pcl, origin_se2=se2_origin, x_half_range=2, y_half_range=2, theta_half_range_deg=60, vis=False, camera_intrinsic=camera_intrinsic) 73 | target_helper.T_base_cam = np.array(T_base_to_cam) 74 | episode_camera_poses = [] 75 | episode_base_poses = [] 76 | for _ in range(args.num_poses): 77 | rel_base_se2, camera_extrinsic = target_helper.get_random_target_se2_with_visibility_check(object_position[:2]) 78 | abs_base_se2 = [x + y for x, y in zip(rel_base_se2, pose)] 79 | matrix_base_se3 = target_helper.calculate_target_se3(abs_base_se2) 80 | episode_base_poses.append(matrix_base_se3.tolist()) 81 | episode_camera_poses.append(camera_extrinsic.tolist()) 82 | 83 | if args.vis: 84 | save_pose_visualization(pcl, episode_camera_poses, object_position, os.path.join(camera_pose_vis_dir, f"{episode['id']}.pcd")) 85 | 86 | camera_poses.append(episode_camera_poses) 87 | base_poses.append(episode_base_poses) 88 | 89 | with open(os.path.join(camera_pose_save_dir, "camera_poses.json"), "w") as f: 90 | json.dump(camera_poses, f, indent=4) 91 | with open(os.path.join(camera_pose_save_dir, "base_poses.json"), "w") as f: 92 | json.dump(base_poses, f, indent=4) 93 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import json 6 | from tqdm import tqdm 7 | import argparse 8 | import wandb 9 | import time 10 | import datetime 11 | import random 12 | import numpy as np 13 | 14 | from n2m.data.dataset import make_data_module 15 | from n2m.model.N2Mnet import N2Mnet 16 | from n2m.utils.config import * 17 | from n2m.utils.visualizer import save_gmm_visualization, save_gmm_visualization_se2, save_gmm_visualization_xythetaz 18 | from n2m.utils.loss import Loss 19 | 20 | def set_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | def load_config(config_path): 29 | """Load configuration from YAML file""" 30 | with open(config_path, 'r') as f: 31 | config = json.load(f) 32 | return config 33 | 34 | def train_one_epoch(model, train_loader, loss_fn, optimizer, epoch, device, train_config): 35 | model.train() 36 | total_loss = 0 37 | pbar = tqdm(train_loader, desc=f'Epoch {epoch}') 38 | 39 | for batch in pbar: 40 | # Move data to device 41 | point_cloud = batch['point_cloud'].to(device) 42 | target_point = batch['target_point'].to(device) 43 | label = batch['label'].to(device) 44 | 45 | # Forward pass 46 | means, covs, weights = model(point_cloud) 47 | loss = loss_fn(means, covs, weights, target_point, label) 48 | 49 | # Backward pass 50 | optimizer.zero_grad() 51 | loss.backward() 52 | optimizer.step() 53 | 54 | # Update metrics 55 | total_loss += loss.item() 56 | pbar.set_postfix({'loss': loss.item()}) 57 | 58 | 59 | avg_loss = total_loss / len(train_loader) 60 | wandb.log({"train/loss": avg_loss, "epoch": epoch}) 61 | return avg_loss 62 | 63 | def validate(model, val_loader, loss_fn, epoch, device, val_dir,train_config): 64 | model.eval() 65 | total_loss = 0 66 | 67 | # Create visualization directory 68 | os.makedirs(val_dir, exist_ok=True) 69 | 70 | with torch.no_grad(): 71 | for batch_idx, batch in enumerate(tqdm(val_loader, desc='Validation')): 72 | # Move data to device 73 | point_cloud = batch['point_cloud'].to(device) 74 | target_point = batch['target_point'].to(device) 75 | label = batch['label'].to(device) 76 | 77 | # Forward pass 78 | means, covs, weights = model(point_cloud) 79 | loss = loss_fn(means, covs, weights, target_point, label) 80 | 81 | # Update metrics 82 | total_loss += loss.item() 83 | 84 | # Generate visualization for first item in batch 85 | for i in range(len(batch['point_cloud'])): 86 | if target_point[i].shape[0] == 3: 87 | save_gmm_visualization_se2( 88 | point_cloud[i].cpu().numpy(), 89 | target_point[i].cpu().numpy(), 90 | label[i].cpu().numpy(), 91 | means[i].cpu().numpy(), 92 | covs[i].cpu().numpy(), 93 | weights[i].cpu().numpy(), 94 | os.path.join(val_dir, f'batch_{batch_idx}_{i}.pcd') 95 | ) 96 | elif target_point[i].shape[0] == 4: 97 | save_gmm_visualization_xythetaz( 98 | point_cloud[i].cpu().numpy(), 99 | target_point[i].cpu().numpy(), 100 | label[i].cpu().numpy(), 101 | means[i].cpu().numpy(), 102 | covs[i].cpu().numpy(), 103 | weights[i].cpu().numpy(), 104 | os.path.join(val_dir, f'batch_{batch_idx}_{i}.pcd') 105 | ) 106 | 107 | 108 | avg_loss = total_loss / len(val_loader) 109 | 110 | # Log to wandb 111 | wandb.log({"val/loss": avg_loss, "epoch": epoch}) 112 | 113 | return avg_loss 114 | 115 | def save_checkpoint(model, optimizer, epoch, loss, save_path): 116 | checkpoint = { 117 | 'epoch': epoch, 118 | 'model_state_dict': model.state_dict(), 119 | 'optimizer_state_dict': optimizer.state_dict(), 120 | 'loss': loss, 121 | } 122 | torch.save(checkpoint, save_path) 123 | 124 | def get_exp_dir(train_config): 125 | t_now = time.time() 126 | time_str = datetime.datetime.fromtimestamp(t_now).strftime('%Y%m%d%H%M%S') 127 | output_dir = os.path.join(train_config['output_dir'], time_str) 128 | 129 | ckpt_dir = os.path.join(output_dir, 'ckpts') 130 | val_dir = os.path.join(output_dir, 'val') 131 | log_dir = os.path.join(output_dir, 'logs') 132 | 133 | # create output directory 134 | os.makedirs(output_dir, exist_ok=True) 135 | os.makedirs(ckpt_dir, exist_ok=True) 136 | os.makedirs(val_dir, exist_ok=True) 137 | os.makedirs(log_dir, exist_ok=True) 138 | 139 | return output_dir, ckpt_dir, val_dir, log_dir 140 | 141 | def main(): 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--config', type=str, default='configs/training/config.json', help='Path to the config file.') 144 | args = parser.parse_args() 145 | 146 | # Load configuration 147 | config = load_config(args.config) 148 | train_config = config['train'] 149 | model_config = config['model'] 150 | dataset_config = config['dataset'] 151 | 152 | # Create output directory 153 | output_dir, ckpt_dir, val_dir, log_dir = get_exp_dir(train_config) 154 | 155 | # Initialize wandb 156 | wandb_config = train_config['wandb'] 157 | wandb.init( 158 | project=wandb_config['project'], 159 | entity=wandb_config['entity'], 160 | name=wandb_config['name'], 161 | config=config, 162 | mode="online" if wandb_config['entity'] is not None else "disabled" 163 | ) 164 | 165 | # save config 166 | config_save_path = os.path.join(output_dir, 'config.json') 167 | with open(config_save_path, 'w') as f: 168 | json.dump(config, f, indent=4) 169 | 170 | # Set device 171 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 172 | print(f'Using device: {device}') 173 | 174 | # Create model 175 | model = N2Mnet( 176 | config=model_config 177 | ).to(device) 178 | 179 | # Watch model with wandb 180 | wandb.watch(model, log="all", log_freq=100) 181 | 182 | # Create data loaders 183 | train_dataset, val_dataset = make_data_module(dataset_config) 184 | 185 | train_loader = DataLoader( 186 | train_dataset, 187 | batch_size=train_config['batch_size'], 188 | shuffle=True, 189 | num_workers=train_config['num_workers'], 190 | pin_memory=True 191 | ) 192 | 193 | val_loader = DataLoader( 194 | val_dataset, 195 | batch_size=train_config['batch_size'], 196 | shuffle=False, 197 | num_workers=train_config['num_workers'], 198 | pin_memory=True 199 | ) 200 | if len(val_loader) == 0: 201 | val_loader = train_loader 202 | 203 | # Create optimizer and scheduler 204 | optimizer = torch.optim.Adam(model.parameters(), lr=float(train_config['learning_rate'])) 205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 206 | optimizer, 207 | T_max=int(train_config['num_epochs']) 208 | ) 209 | 210 | # create loss function 211 | loss_config = train_config['loss'] 212 | loss_fn = Loss(loss_config) 213 | 214 | # Training loop 215 | best_val_loss = float('inf') 216 | for epoch in range(train_config['num_epochs']): 217 | # Train 218 | train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch, device, train_config) 219 | print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}') 220 | 221 | # Validate 222 | if (epoch + 1) % train_config['val_freq'] == 0: 223 | val_loss = validate(model, val_loader, loss_fn, epoch, device, val_dir, train_config) 224 | print(f'Epoch {epoch}: Val Loss = {val_loss:.4f}') 225 | 226 | # Save best model 227 | if val_loss < best_val_loss: 228 | best_val_loss = val_loss 229 | save_checkpoint( 230 | model, 231 | optimizer, 232 | epoch, 233 | val_loss, 234 | os.path.join(ckpt_dir, 'best_model.pth') 235 | ) 236 | save_checkpoint( 237 | model, 238 | optimizer, 239 | epoch, 240 | train_loss, 241 | os.path.join(ckpt_dir, f'model_{epoch}.pth') 242 | ) 243 | 244 | # Step scheduler 245 | scheduler.step() 246 | 247 | # Save final model 248 | save_checkpoint( 249 | model, 250 | optimizer, 251 | train_config['num_epochs'], 252 | train_loss, 253 | os.path.join(ckpt_dir, 'final_model.pth') 254 | ) 255 | 256 | wandb.finish() 257 | 258 | if __name__ == '__main__': 259 | main() 260 | -------------------------------------------------------------------------------- /scripts/visualize_attention.py: -------------------------------------------------------------------------------- 1 | def visualize_attention(pts, center, attn_weights): 2 | """ 3 | Visualizes attention weights on the point cloud. 4 | 5 | Args: 6 | pts (Tensor): Original input point cloud (B, N_pts, 3). 7 | center (Tensor): Center points of the point groups (B, N_groups, 3). 8 | attn_weights (Tensor): Attention matrix from the final block (B, num_heads, N, N). 9 | """ 10 | # Assuming batch size is 1 for simplicity 11 | attn = attn_weights[0] # Shape: (num_heads, N, N) 12 | 13 | # Average attention across all heads for simplicity, or pick a specific head 14 | attn_avg = attn.mean(dim=0) # Shape: (N, N) 15 | 16 | # Get the attention scores of the class token (index 0) to all other tokens 17 | # Note: The first token is the cls token, subsequent tokens correspond to point groups 18 | cls_attn = attn_avg[0, 1:] # Shape: (N_groups,) 19 | 20 | # Normalize scores for better visualization (e.g., between 0 and 1) 21 | cls_attn_normalized = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min()) 22 | 23 | # You can now use a library like Matplotlib, Mayavi, or Open3D to visualize 24 | # the points. For each point in `center`, use `cls_attn_normalized` to 25 | # determine its color (e.g., a colormap from blue to red) or size. 26 | 27 | # Here's a conceptual example using a library for visualization 28 | # Let's say we have a function `plot_points_with_colors` 29 | # You would need to implement this part based on your chosen library. 30 | 31 | # For visualization, you'd plot `center` points. The `cls_attn_normalized` 32 | # provides the scalar value for the heatmap. 33 | 34 | # Example for coloring points: 35 | # colors = plt.cm.jet(cls_attn_normalized.numpy())[:, :3] # Get RGB from colormap 36 | # plot_points_with_colors(center.numpy()[0], colors) 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | 41 | # Assuming you have loaded your model and data 42 | model = PointTransformer(config) 43 | model.load_checkpoint(bert_ckpt_path) 44 | 45 | # Prepare a single ego-centric point cloud for visualization 46 | # This should be a point cloud from your dataset. 47 | input_pts = ... # Your point cloud data, shape (1, N_pts, 3) 48 | 49 | # Run the forward pass to get the output and attention weights 50 | output_features, attn_weights = model(input_pts) 51 | 52 | # Get the center points from the group divider (you might need to run it again) 53 | _, center = model.group_divider(input_pts) 54 | 55 | # Call the visualization function 56 | visualize_attention(input_pts, center, attn_weights) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="n2m", 5 | packages=[ 6 | package for package in find_packages() if package.startswith("n2m") 7 | ], 8 | install_requires=[], 9 | version="0.0.1", 10 | ) --------------------------------------------------------------------------------