├── .gitignore
├── README.md
├── configs
├── inference
│ └── config.json
└── training
│ └── config.json
├── doc
├── Data_Preparation.png
└── System_Overview.png
├── n2m
├── __init__.py
├── data
│ └── dataset.py
├── model
│ ├── N2Mnet.py
│ └── pointbert
│ │ ├── checkpoint.py
│ │ ├── dvae.py
│ │ ├── logger.py
│ │ ├── misc.py
│ │ └── point_encoder.py
├── module
│ └── N2Mmodule.py
└── utils
│ ├── AverageMeter.py
│ ├── checkpoint.py
│ ├── config.py
│ ├── dist_utils.py
│ ├── logger.py
│ ├── loss.py
│ ├── metrics.py
│ ├── misc.py
│ ├── parser.py
│ ├── point_cloud.py
│ ├── prediction.py
│ ├── registry.py
│ ├── sample_utils.py
│ └── visualizer.py
├── pyproject.toml
├── scripts
├── process_dataset.sh
├── render
│ ├── CMakeLists.txt
│ ├── README.md
│ ├── fpv_render.cpp
│ └── include
│ │ └── json.hpp
├── sample_camera_poses.py
├── train.py
└── visualize_attention.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | models
2 | datasets
3 | /training/
4 | build
5 | **.egg-info
6 | **__pycache__**
7 | wandb
8 | **.DS_Store
9 | test_codes
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # N2M: Bridging Navigation and Manipulation by Learning Pose Preference from Rollout
2 |
Kaixin Chai*, Hyunjun Lee*, Joseph J. Lim
3 |
4 | 
5 |
6 | This is an official implementation of N2M. We provided detailed instructions to train and inference N2M. For examples in simulation and real world, please refer to `sim` and `real` branches.
7 |
8 | ## TODOs
9 | - Organize N2M training / inference code ✅
10 | - Organize N2M real world code ✅
11 | - Organize N2M simulation code
12 |
13 | ## Installation
14 | Clone and install necessary packages
15 | ```bash
16 | git clone --single-branch --branch main https://github.com/clvrai/N2M.git
17 | cd N2M
18 |
19 | # install mamba environment
20 | mamba create -n n2m python==3.11
21 | mamba activate n2m
22 | pip install -r requirements.txt
23 | pip install -e .
24 |
25 | # compile c++ file
26 | cd scripts/render
27 | mkdir build && cd build
28 | cmake ..
29 | make
30 | ```
31 |
32 | ## 📊 Data preparation
33 |
34 | 
35 |
36 | You should first prepare raw data with pairs of local scene and preferable initial pose. Local scene is a point cloud of a scene and you may stitch point clouds using multiple calibrated cameras. In this repo, we do not provide code for capturing the local scene.
37 |
38 | The format of raw data should be placed under `datasets` folder in the format below
39 | ```
40 | datasets/
41 | └── {dataset name}/
42 | ├── pcl/
43 | │ ├── 0.pcd
44 | │ ├── 1.pcd
45 | │ └── ...
46 | └── meta.json
47 |
48 | ```
49 | Replace `{dataset name}` with your own dataset name. `pcl/` folder should include point clouds of your local scene and `meta.json` should include the information of each local scene and the label of the preferable initial pose of each scene. `meta.json` should be in the format as below.
50 | ```json
51 | {
52 | "meta": {
53 | "T_base_to_cam": [
54 | [-8.25269110e-02, -5.73057816e-01, 8.15348841e-01, 6.05364230e-04],
55 | [-9.95784041e-01, 1.45464862e-02, -9.05661474e-02, -3.94417736e-02],
56 | [ 4.00391906e-02, -8.19385767e-01, -5.71842485e-01, 1.64310488e-00],
57 | [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]
58 | ],
59 | "camera_intrinsic": [
60 | 100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240
61 | ]
62 | },
63 | "episodes": [
64 | {
65 | "id": 0,
66 | "file_path": "pcl/0.pcd",
67 | "pose": [
68 | 3.07151444879416,
69 | -0.9298766226100992,
70 | 1.5782995419534618
71 | ],
72 | "object_position": [
73 | 3.2,
74 | -0.2,
75 | 1.85
76 | ]
77 | },
78 | ...
79 | ]
80 | }
81 | ```
82 | `T_base_to_cam`: A 4 x 4 transformation matrix that transforms SE(3) pose of base to SE(3) pose of camera. SE(3) pose of base is an SE(2) pose transformed into SE(3) format. This should be pre-calculated and provided to the dataset.
83 |
84 | `camera_intrinsic`: The intrinsic of the camera used to capture the point cloud. This should also be provided in the dataset.
85 |
86 | `file_path`: Relative path of the corresponding local scene point cloud.
87 |
88 | `pose`: Preferable initial pose where you mark the rollout success. Above example is SE(2) pose but you can change it to (x, y, theta, z) information depending on your setting.
89 |
90 | `object position`: Position of the object of interest. We later used this to check the visibility of the object when we sample viewpoints. This element is not mandatory as we estimate it as a position 0.5m infront of `pose` with height 1m if it is empty.
91 |
92 | Example dataset can be downloaded from this link
93 |
94 | ## 🛠️ Data Processing
95 | Now we are ready to process the data. Run the following command to process the data.
96 | ```bash
97 | sh scripts/process_dataset.sh "path/to/dataset"
98 | ```
99 | This will apply viewpoint augmentation and generate augmented point clouds with new transformed labels corresponding to them. The file structure of the dataset will now look like this:
100 | ```
101 | datasets/
102 | └── {dataset name}/
103 | ├── camera_poses/
104 | ├── camera_poses_vis/
105 | ├── pcl/
106 | ├── pcl_aug/
107 | ├── meta_aug.json
108 | └── meta.json
109 | ```
110 | You will find the visualization of sampled camera poses per each scene in `camera_poses_vis/`. Augmented point cloud and corresponding labels will be saved in `pcl_aug/` and `meta_aug.json` respectively.
111 |
112 | ## 🚀 Training
113 | You will be using `configs/training/config.json` as training configuration. Change `dataset_path: "datasets/{dataset name}"` and additional training settings related to your taste
114 |
115 | Before running training, download pretrained PointBERT weight from this link and save it under `models/PointBERT` folder.
116 | ```bash
117 | python scripts/train.py --config configs/training/config.json
118 | ```
119 | Your training log will saved under `training/{dataset name}`
120 |
121 | ## 🏃🏻♂️➡️ Inference
122 | To use N2M, you have to import N2Mmodule from `n2m/module/N2Mmodule.py`. This is a wrapper for N2Mnet with data pre-processing and post-processing. This also contains collision checking for sampling valid initial pose from predicted distribution. Example code is as follows
123 | ```python
124 | import json
125 | import o3d
126 | import numpy as np
127 |
128 | from n2m.module import N2Mmodule
129 |
130 | # initialize n2m module
131 | config = json.load("configs/inference/config.json")
132 | n2m = N2Mmodule(config)
133 |
134 | # load pcd
135 | pcd = o3d.io.read_point_cloud("example.pcd")
136 | pcd_numpy = np.concatenate([np.asarray(pcd.points), np.asarray(pcd.colors)], axis=1)
137 |
138 | # predict initial pose. If it fails to sample valid points within certain number of trial, is_valid will return False. Otherwise, is_valid will be True
139 | initial_pose, is_valid = n2m.predict(pcd_numpy)
140 | ```
--------------------------------------------------------------------------------
/configs/inference/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "n2mnet": {
3 | "encoder": {
4 | "name": "PointBERT",
5 | "config": {
6 | "NAME": "PointTransformer",
7 | "trans_dim": 384,
8 | "depth": 12,
9 | "drop_path_rate": 0.1,
10 | "cls_dim": 40,
11 | "num_heads": 6,
12 | "group_size": 32,
13 | "num_group": 512,
14 | "encoder_dims": 256
15 | }
16 | },
17 | "decoder": {
18 | "name": "mlp",
19 | "config": {
20 | "dropout": 0.1
21 | },
22 | "layers": [
23 | 512,
24 | 512,
25 | 512,
26 | 512,
27 | 512,
28 | 512,
29 | 512,
30 | 512
31 | ],
32 | "num_gaussians": 2,
33 | "output_dim": 3
34 | },
35 | "ckpt": "models/N2M/N2Mnet.pth"
36 | },
37 | "preprocess": {
38 | "pointnum": 8192
39 | },
40 | "postprocess": {
41 | "num_samples": 100,
42 | "collision_checker": {
43 | "ground_z": 0.05,
44 | "filter_noise": true,
45 | "resolution": 0.02,
46 | "robot_width": 0.5,
47 | "robot_length": 0.63
48 | }
49 | }
50 | }
--------------------------------------------------------------------------------
/configs/training/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "exp_name": "N2M_training",
3 | "model": {
4 | "encoder": {
5 | "name": "PointBERT",
6 | "config": {
7 | "NAME": "PointTransformer",
8 | "trans_dim": 384,
9 | "depth": 12,
10 | "drop_path_rate": 0.1,
11 | "cls_dim": 40,
12 | "num_heads": 6,
13 | "group_size": 32,
14 | "num_group": 512,
15 | "encoder_dims": 256
16 | },
17 | "ckpt": "models/PointBERT/PointTransformer_ModelNet8192points.pth",
18 | "freeze": false
19 | },
20 | "decoder": {
21 | "name": "mlp",
22 | "config": {
23 | "dropout": 0.1
24 | },
25 | "num_gaussians": 1,
26 | "output_dim": 3
27 | }
28 | },
29 | "dataset": {
30 | "dataset_path": "datasets/test",
31 | "anno_path": "meta_aug.json",
32 | "pointnum": 8192,
33 | "train_val_ratio": 0.9,
34 | "augmentations": {
35 | "rotation_se2": {
36 | "min_angle": -180,
37 | "max_angle": 180
38 | },
39 | "translation_xy": {
40 | "radius": 1
41 | }
42 | }
43 | },
44 | "train": {
45 | "batch_size": 32,
46 | "num_epochs": 300,
47 | "learning_rate": 1e-4,
48 | "num_workers": 4,
49 | "val_freq": 50,
50 | "loss": {
51 | "name": "mle",
52 | "config": {
53 | "lam_weight": 0.1,
54 | "lam_dist": 0.1
55 | }
56 | },
57 | "wandb": {
58 | "name": null,
59 | "project": null,
60 | "entity": null
61 | },
62 | "output_dir": "training/test"
63 | }
64 | }
--------------------------------------------------------------------------------
/doc/Data_Preparation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/N2M/76ab68b8e75debd3c190bb6db6399b2cee179093/doc/Data_Preparation.png
--------------------------------------------------------------------------------
/doc/System_Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/N2M/76ab68b8e75debd3c190bb6db6399b2cee179093/doc/System_Overview.png
--------------------------------------------------------------------------------
/n2m/__init__.py:
--------------------------------------------------------------------------------
1 | from .model.N2Mnet import N2Mnet
2 | from .data.dataset import N2MDataset
3 |
--------------------------------------------------------------------------------
/n2m/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 | import random
6 | from torch.utils.data import Dataset
7 | import open3d as o3d
8 |
9 | from n2m.utils.point_cloud import apply_augmentations, fix_point_cloud_size
10 |
11 | def make_data_module(config):
12 | """
13 | Make training dataset for SIR predictor.
14 |
15 | input(config):
16 | data_path: str
17 | anno_path: str
18 | use_color: bool
19 | pointnum: int
20 |
21 | return:
22 | train_dataset: Dataset
23 | val_dataset: Dataset
24 | """
25 | if isinstance(config['anno_path'], str):
26 | anno_path = os.path.join(config['dataset_path'], config['anno_path'])
27 | with open(anno_path, 'r') as f:
28 | anno = json.load(f)['episodes']
29 | train_num = int(len(anno) * config['train_val_ratio'])
30 | random.shuffle(anno)
31 | anno_train = anno[:train_num]
32 | anno_val = anno[train_num:]
33 |
34 | print(f"Train set size: {len(anno_train)}")
35 | print(f"Val set size: {len(anno_val)}")
36 |
37 | train_dataset = N2MDataset(
38 | config=config,
39 | anno=anno_train,
40 | )
41 | val_dataset = N2MDataset(
42 | config=config,
43 | anno=anno_val,
44 | )
45 |
46 | return train_dataset, val_dataset
47 | else:
48 | raise ValueError(f"Invalid config: {config}")
49 |
50 | class N2MDataset(Dataset):
51 | """Dataset for SIR predictor."""
52 | def __init__(self, config, anno):
53 | self.anno = anno
54 | self.pointnum = config['pointnum']
55 | self.config = config
56 |
57 | self.settings = config['settings'] if 'settings' in config else None
58 |
59 | if 'dataset_path' in config:
60 | self.dataset_path = config['dataset_path']
61 | else:
62 | raise ValueError(f"Invalid config: {config}")
63 |
64 | def _load_point_cloud(self, file_path):
65 | if os.path.exists(file_path):
66 | pcd = o3d.io.read_point_cloud(file_path)
67 | if 'augmentations' in self.config and 'hpr' in self.config['augmentations']:
68 | pcd = self._apply_hpr(pcd, self.config['augmentations']['hpr'])
69 | point_cloud = np.asarray(pcd.points)
70 | colors = np.asarray(pcd.colors)
71 | point_cloud = np.concatenate([point_cloud, colors], axis=1)
72 | return point_cloud
73 |
74 | raise FileNotFoundError(f"No point cloud file found for object {file_path}")
75 |
76 | def __len__(self):
77 | """
78 | Return number of samples in the dataset
79 | """
80 | return len(self.anno)
81 |
82 | def __getitem__(self, index):
83 | """
84 | Get a sample from the dataset
85 | """
86 | data = self.anno[index]
87 | file_path = os.path.join(self.dataset_path, data['file_path'])
88 |
89 | # Load point cloud and target point
90 | point_cloud = self._load_point_cloud(file_path)
91 | target_point = np.array(data['pose'], dtype=np.float32)
92 | label = 1
93 |
94 | # Ensure point cloud has consistent size
95 | point_cloud = fix_point_cloud_size(point_cloud, self.pointnum)
96 |
97 | # Apply augmentations
98 | if 'augmentations' in self.config:
99 | point_cloud, target_point = apply_augmentations(point_cloud, target_point, self.config['augmentations'])
100 |
101 | # Convert to torch tensors
102 | point_cloud = torch.from_numpy(point_cloud.astype(np.float32))
103 | target_point = torch.from_numpy(target_point)
104 | label = torch.tensor(label, dtype=torch.long)
105 |
106 | return {
107 | 'point_cloud': point_cloud,
108 | 'target_point': target_point,
109 | 'label': label,
110 | }
111 |
--------------------------------------------------------------------------------
/n2m/model/N2Mnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from n2m.model.pointbert.point_encoder import PointTransformer
5 | import os
6 | from typing import Tuple, Dict
7 |
8 |
9 | class N2Mnet(nn.Module):
10 | """
11 | N2Mnet predicts a Gaussian Mixture Model (GMM) from an input point cloud.
12 | """
13 |
14 | def __init__(self, config: Dict):
15 | """
16 | Initializes the N2Mnet model.
17 |
18 | Args:
19 | config (Dict): A configuration dictionary containing model, encoder,
20 | and decoder settings.
21 | """
22 | super().__init__()
23 | self.encoder_config = config['encoder']
24 | self.decoder_config = config['decoder']
25 |
26 | self.num_gaussians = self.decoder_config['num_gaussians']
27 | self.output_dim = self.decoder_config['output_dim']
28 |
29 | encoder_output_dim = self._build_encoder()
30 | self._build_decoder(encoder_output_dim)
31 |
32 | if 'ckpt' in config and config['ckpt']:
33 | if not os.path.exists(config['ckpt']):
34 | raise FileNotFoundError(f"Checkpoint file not found: {config['ckpt']}")
35 | print(f"Loading model weights from {config['ckpt']}")
36 | self.load_state_dict(torch.load(config['ckpt'])['model_state_dict'])
37 |
38 | def _build_encoder(self) -> int:
39 | """
40 | Builds the point cloud encoder based on the configuration.
41 |
42 | Returns:
43 | int: The output dimension of the encoder.
44 | """
45 | if self.encoder_config['name'] == 'PointBERT':
46 | self.encoder = PointTransformer(self.encoder_config['config'])
47 | # Output dimension is doubled due to max and mean pooling in PointTransformer
48 | encoder_output_dim = self.encoder_config['config']['trans_dim'] * 2
49 | else:
50 | raise ValueError(f"Unsupported encoder: {self.encoder_config['name']}")
51 |
52 | # Load pretrained encoder weights if specified
53 | if 'ckpt' in self.encoder_config and self.encoder_config['ckpt']:
54 | ckpt_path = self.encoder_config['ckpt']
55 | if not os.path.exists(ckpt_path):
56 | raise FileNotFoundError(f"Encoder checkpoint not found: {ckpt_path}")
57 | print(f"Loading encoder weights from {ckpt_path}")
58 | self.encoder.load_checkpoint(ckpt_path)
59 |
60 | # Freeze encoder weights if specified
61 | if self.encoder_config.get('freeze', False):
62 | for param in self.encoder.parameters():
63 | param.requires_grad = False
64 |
65 | return encoder_output_dim
66 |
67 | def _build_decoder(self, decoder_input_dim: int):
68 | """
69 | Builds the GMM parameter decoder based on the configuration.
70 |
71 | Args:
72 | decoder_input_dim (int): The input dimension for the decoder.
73 | """
74 | # Each Gaussian component requires parameters for mean, covariance, and mixing weight
75 | gmm_params_per_gaussian = self.output_dim + (self.output_dim ** 2) + 1
76 | decoder_output_dim = self.num_gaussians * gmm_params_per_gaussian
77 |
78 | if self.decoder_config['name'] == 'mlp':
79 | layers = self.decoder_config.get('layers', [512, 256])
80 |
81 | decoder_layers = []
82 | prev_dim = decoder_input_dim
83 | for layer_dim in layers:
84 | decoder_layers.extend([
85 | nn.Linear(prev_dim, layer_dim),
86 | nn.ReLU(),
87 | nn.Dropout(self.decoder_config['config']['dropout'])
88 | ])
89 | prev_dim = layer_dim
90 |
91 | decoder_layers.append(nn.Linear(prev_dim, decoder_output_dim))
92 | self.decoder = nn.Sequential(*decoder_layers)
93 | else:
94 | raise ValueError(f"Unsupported decoder: {self.decoder_config['name']}")
95 |
96 | def _construct_covariance_matrices(self, sigma_params: torch.Tensor) -> torch.Tensor:
97 | """
98 | Constructs positive semidefinite covariance matrices from raw network outputs.
99 |
100 | This method ensures symmetry and positive definiteness using the matrix
101 | exponential. A small identity matrix is added for numerical stability before
102 | the exponential.
103 |
104 | Args:
105 | sigma_params (torch.Tensor): Raw covariance parameters from the decoder,
106 | with shape (B, K, D, D), where B is batch
107 | size, K is the number of Gaussians, and D
108 | is the output dimension.
109 |
110 | Returns:
111 | torch.Tensor: Valid, positive semidefinite covariance matrices of shape
112 | (B, K, D, D).
113 | """
114 | B, K, D, _ = sigma_params.shape
115 |
116 | # Enforce symmetry for the input to matrix exponential
117 | sigma_params = 0.5 * (sigma_params + sigma_params.transpose(-2, -1))
118 |
119 | # Add a small diagonal epsilon for numerical stability
120 | eye = torch.eye(D, device=sigma_params.device).unsqueeze(0).unsqueeze(0)
121 | sigma_params = sigma_params + 1e-6 * eye
122 |
123 | # The matrix exponential of a symmetric matrix is symmetric positive definite
124 | covs = torch.matrix_exp(sigma_params)
125 |
126 | return covs
127 |
128 | def forward(self, point_cloud: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
129 | """
130 | Performs the forward pass to predict GMM parameters from a point cloud.
131 |
132 | Args:
133 | point_cloud (torch.Tensor): Input point cloud of shape (B, N, C), where B
134 | is batch size, N is the number of points,
135 | and C is feature dimension.
136 |
137 | Returns:
138 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
139 | - means (torch.Tensor): GMM means (B, K, D).
140 | - covs (torch.Tensor): GMM covariance matrices (B, K, D, D).
141 | - weights (torch.Tensor): GMM mixing weights (B, K).
142 | """
143 | # Encode point cloud to a global feature vector
144 | features, _ = self.encoder(point_cloud)
145 | features = features.squeeze(1) # (B, encoder_output_dim)
146 |
147 | # Decode features into a flat tensor of GMM parameters
148 | gmm_params = self.decoder(features)
149 |
150 | # Reshape the flat tensor to extract means, covariance parameters, and weights
151 | batch_size = gmm_params.size(0)
152 |
153 | # Extract means
154 | means_end = self.num_gaussians * self.output_dim
155 | means = gmm_params[:, :means_end].view(
156 | batch_size, self.num_gaussians, self.output_dim
157 | )
158 |
159 | # Extract raw covariance parameters
160 | cov_end = means_end + self.num_gaussians * (self.output_dim ** 2)
161 | sigma_params = gmm_params[:, means_end:cov_end].view(
162 | batch_size, self.num_gaussians, self.output_dim, self.output_dim
163 | )
164 | covs = self._construct_covariance_matrices(sigma_params)
165 |
166 | # Extract and normalize mixing weights
167 | weights = gmm_params[:, -self.num_gaussians:].view(batch_size, self.num_gaussians)
168 | weights = torch.softmax(weights, dim=-1)
169 |
170 | return means, covs, weights
171 |
172 | def sample(self, point_cloud: torch.Tensor, num_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
173 | """
174 | Samples points from the predicted GMM for a given point cloud.
175 |
176 | Args:
177 | point_cloud (torch.Tensor): Input point cloud of shape (B, N, C).
178 | num_samples (int): Number of points to sample for each item in the batch.
179 |
180 | Returns:
181 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple of:
182 | - samples (torch.Tensor): Sampled points (B, num_samples, D).
183 | - means (torch.Tensor): GMM means (B, K, D).
184 | - covs (torch.Tensor): GMM covariances (B, K, D, D).
185 | - weights (torch.Tensor): GMM weights (B, K).
186 | """
187 | means, covs, weights = self.forward(point_cloud)
188 | batch_size = means.size(0)
189 |
190 | samples = torch.zeros(batch_size, num_samples, self.output_dim, device=means.device)
191 |
192 | # Process each item in the batch independently
193 | for b in range(batch_size):
194 | # 1. Sample component indices based on the mixture weights
195 | component_indices = torch.multinomial(weights[b], num_samples, replacement=True)
196 |
197 | # 2. Sample from the corresponding Gaussian for each chosen component
198 | for i in range(self.num_gaussians):
199 | # Find which samples belong to the current component
200 | mask = (component_indices == i)
201 | num_component_samples = mask.sum().item()
202 |
203 | if num_component_samples > 0:
204 | dist = torch.distributions.MultivariateNormal(
205 | loc=means[b, i],
206 | covariance_matrix=covs[b, i]
207 | )
208 | samples[b, mask] = dist.sample((num_component_samples,))
209 |
210 | return samples, means, covs, weights
--------------------------------------------------------------------------------
/n2m/model/pointbert/checkpoint.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch.nn as nn
3 |
4 | from typing import Any
5 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
6 |
7 | from termcolor import colored
8 |
9 | def get_missing_parameters_message(keys: List[str]) -> str:
10 | """
11 | Get a logging-friendly message to report parameter names (keys) that are in
12 | the model but not found in a checkpoint.
13 | Args:
14 | keys (list[str]): List of keys that were not found in the checkpoint.
15 | Returns:
16 | str: message.
17 | """
18 | groups = _group_checkpoint_keys(keys)
19 | msg = "Some model parameters or buffers are not found in the checkpoint:\n"
20 | msg += "\n".join(
21 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
22 | )
23 | return msg
24 |
25 |
26 | def get_unexpected_parameters_message(keys: List[str]) -> str:
27 | """
28 | Get a logging-friendly message to report parameter names (keys) that are in
29 | the checkpoint but not found in the model.
30 | Args:
31 | keys (list[str]): List of keys that were not found in the model.
32 | Returns:
33 | str: message.
34 | """
35 | groups = _group_checkpoint_keys(keys)
36 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
37 | msg += "\n".join(
38 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
39 | )
40 | return msg
41 |
42 |
43 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
44 | """
45 | Strip the prefix in metadata, if any.
46 | Args:
47 | state_dict (OrderedDict): a state-dict to be loaded to the model.
48 | prefix (str): prefix.
49 | """
50 | keys = sorted(state_dict.keys())
51 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
52 | return
53 |
54 | for key in keys:
55 | newkey = key[len(prefix):]
56 | state_dict[newkey] = state_dict.pop(key)
57 |
58 | # also strip the prefix in metadata, if any..
59 | try:
60 | metadata = state_dict._metadata # pyre-ignore
61 | except AttributeError:
62 | pass
63 | else:
64 | for key in list(metadata.keys()):
65 | # for the metadata dict, the key can be:
66 | # '': for the DDP module, which we want to remove.
67 | # 'module': for the actual model.
68 | # 'module.xx.xx': for the rest.
69 |
70 | if len(key) == 0:
71 | continue
72 | newkey = key[len(prefix):]
73 | metadata[newkey] = metadata.pop(key)
74 |
75 |
76 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
77 | """
78 | Group keys based on common prefixes. A prefix is the string up to the final
79 | "." in each key.
80 | Args:
81 | keys (list[str]): list of parameter names, i.e. keys in the model
82 | checkpoint dict.
83 | Returns:
84 | dict[list]: keys with common prefixes are grouped into lists.
85 | """
86 | groups = defaultdict(list)
87 | for key in keys:
88 | pos = key.rfind(".")
89 | if pos >= 0:
90 | head, tail = key[:pos], [key[pos + 1:]]
91 | else:
92 | head, tail = key, []
93 | groups[head].extend(tail)
94 | return groups
95 |
96 |
97 | def _group_to_str(group: List[str]) -> str:
98 | """
99 | Format a group of parameter name suffixes into a loggable string.
100 | Args:
101 | group (list[str]): list of parameter name suffixes.
102 | Returns:
103 | str: formated string.
104 | """
105 | if len(group) == 0:
106 | return ""
107 |
108 | if len(group) == 1:
109 | return "." + group[0]
110 |
111 | return ".{" + ", ".join(group) + "}"
112 |
113 |
114 | def _named_modules_with_dup(
115 | model: nn.Module, prefix: str = ""
116 | ) -> Iterable[Tuple[str, nn.Module]]:
117 | """
118 | The same as `model.named_modules()`, except that it includes
119 | duplicated modules that have more than one name.
120 | """
121 | yield prefix, model
122 | for name, module in model._modules.items(): # pyre-ignore
123 | if module is None:
124 | continue
125 | submodule_prefix = prefix + ("." if prefix else "") + name
126 | yield from _named_modules_with_dup(module, submodule_prefix)
--------------------------------------------------------------------------------
/n2m/model/pointbert/dvae.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from . import misc
5 |
6 | # from knn_cuda import KNN
7 |
8 | # knn = KNN(k=4, transpose_mode=False)
9 |
10 |
11 | class DGCNN(nn.Module):
12 | def __init__(self, encoder_channel, output_channel):
13 | super().__init__()
14 | '''
15 | K has to be 16
16 | '''
17 | self.input_trans = nn.Conv1d(encoder_channel, 128, 1)
18 |
19 | self.layer1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
20 | nn.GroupNorm(4, 256),
21 | nn.LeakyReLU(negative_slope=0.2)
22 | )
23 |
24 | self.layer2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, bias=False),
25 | nn.GroupNorm(4, 512),
26 | nn.LeakyReLU(negative_slope=0.2)
27 | )
28 |
29 | self.layer3 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=1, bias=False),
30 | nn.GroupNorm(4, 512),
31 | nn.LeakyReLU(negative_slope=0.2)
32 | )
33 |
34 | self.layer4 = nn.Sequential(nn.Conv2d(1024, 1024, kernel_size=1, bias=False),
35 | nn.GroupNorm(4, 1024),
36 | nn.LeakyReLU(negative_slope=0.2)
37 | )
38 |
39 | self.layer5 = nn.Sequential(nn.Conv1d(2304, output_channel, kernel_size=1, bias=False),
40 | nn.GroupNorm(4, output_channel),
41 | nn.LeakyReLU(negative_slope=0.2)
42 | )
43 |
44 | @staticmethod
45 | def get_graph_feature(coor_q, x_q, coor_k, x_k):
46 | # coor: bs, 3, np, x: bs, c, np
47 |
48 | k = 4
49 | batch_size = x_k.size(0)
50 | num_points_k = x_k.size(2)
51 | num_points_q = x_q.size(2)
52 |
53 | with torch.no_grad():
54 | _, idx = knn(coor_k, coor_q) # bs k np
55 | assert idx.shape[1] == k
56 | idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
57 | idx = idx + idx_base
58 | idx = idx.view(-1)
59 | num_dims = x_k.size(1)
60 | x_k = x_k.transpose(2, 1).contiguous()
61 | feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
62 | feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
63 | x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
64 | feature = torch.cat((feature - x_q, x_q), dim=1)
65 | return feature
66 |
67 | def forward(self, f, coor):
68 | # f: B G C
69 | # coor: B G 3
70 |
71 | # bs 3 N bs C N
72 | feature_list = []
73 | coor = coor.transpose(1, 2).contiguous() # B 3 N
74 | f = f.transpose(1, 2).contiguous() # B C N
75 | f = self.input_trans(f) # B 128 N
76 |
77 | f = self.get_graph_feature(coor, f, coor, f) # B 256 N k
78 | f = self.layer1(f) # B 256 N k
79 | f = f.max(dim=-1, keepdim=False)[0] # B 256 N
80 | feature_list.append(f)
81 |
82 | f = self.get_graph_feature(coor, f, coor, f) # B 512 N k
83 | f = self.layer2(f) # B 512 N k
84 | f = f.max(dim=-1, keepdim=False)[0] # B 512 N
85 | feature_list.append(f)
86 |
87 | f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
88 | f = self.layer3(f) # B 512 N k
89 | f = f.max(dim=-1, keepdim=False)[0] # B 512 N
90 | feature_list.append(f)
91 |
92 | f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
93 | f = self.layer4(f) # B 1024 N k
94 | f = f.max(dim=-1, keepdim=False)[0] # B 1024 N
95 | feature_list.append(f)
96 |
97 | f = torch.cat(feature_list, dim=1) # B 2304 N
98 |
99 | f = self.layer5(f) # B C' N
100 |
101 | f = f.transpose(-1, -2)
102 |
103 | return f
104 |
105 |
106 | ### ref https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py ###
107 | def knn_point(nsample, xyz, new_xyz):
108 | """
109 | Input:
110 | nsample: max sample number in local region
111 | xyz: all points, [B, N, C]
112 | new_xyz: query points, [B, S, C]
113 | Return:
114 | group_idx: grouped points index, [B, S, nsample]
115 | """
116 | sqrdists = square_distance(new_xyz, xyz)
117 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
118 | return group_idx
119 |
120 |
121 | def square_distance(src, dst):
122 | """
123 | Calculate Euclid distance between each two points.
124 | src^T * dst = xn * xm + yn * ym + zn * zm;
125 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
126 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
127 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
128 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
129 | Input:
130 | src: source points, [B, N, C]
131 | dst: target points, [B, M, C]
132 | Output:
133 | dist: per-point square distance, [B, N, M]
134 | """
135 | B, N, _ = src.shape
136 | _, M, _ = dst.shape
137 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
138 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
139 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
140 | return dist
141 |
142 |
143 | class Group(nn.Module):
144 | def __init__(self, num_group, group_size):
145 | super().__init__()
146 | self.num_group = num_group
147 | self.group_size = group_size
148 | # self.knn = KNN(k=self.group_size, transpose_mode=True)
149 |
150 | def forward(self, xyz):
151 | '''
152 | input: B N 3
153 | ---------------------------
154 | output: B G M 3
155 | center : B G 3
156 | '''
157 | B, N, C = xyz.shape
158 | if C > 3:
159 | data = xyz
160 | xyz = data[:,:,:3]
161 | rgb = data[:, :, 3:]
162 | batch_size, num_points, _ = xyz.shape
163 | # fps the centers out
164 | center = misc.fps(xyz, self.num_group) # B G 3
165 |
166 | # knn to get the neighborhood
167 | # _, idx = self.knn(xyz, center) # B G M
168 | idx = knn_point(self.group_size, xyz, center) # B G M
169 | assert idx.size(1) == self.num_group
170 | assert idx.size(2) == self.group_size
171 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
172 | idx = idx + idx_base
173 | idx = idx.view(-1)
174 |
175 | neighborhood_xyz = xyz.view(batch_size * num_points, -1)[idx, :]
176 | neighborhood_xyz = neighborhood_xyz.view(batch_size, self.num_group, self.group_size, 3).contiguous()
177 | if C > 3:
178 | neighborhood_rgb = rgb.view(batch_size * num_points, -1)[idx, :]
179 | neighborhood_rgb = neighborhood_rgb.view(batch_size, self.num_group, self.group_size, -1).contiguous()
180 |
181 | # normalize xyz
182 | neighborhood_xyz = neighborhood_xyz - center.unsqueeze(2)
183 | if C > 3:
184 | neighborhood = torch.cat((neighborhood_xyz, neighborhood_rgb), dim=-1)
185 | else:
186 | neighborhood = neighborhood_xyz
187 | return neighborhood, center
188 |
189 | class Encoder(nn.Module):
190 | def __init__(self, encoder_channel, point_input_dims=3):
191 | super().__init__()
192 | self.encoder_channel = encoder_channel
193 | self.point_input_dims = point_input_dims
194 | self.first_conv = nn.Sequential(
195 | nn.Conv1d(self.point_input_dims, 128, 1),
196 | nn.BatchNorm1d(128),
197 | nn.ReLU(inplace=True),
198 | nn.Conv1d(128, 256, 1)
199 | )
200 | self.second_conv = nn.Sequential(
201 | nn.Conv1d(512, 512, 1),
202 | nn.BatchNorm1d(512),
203 | nn.ReLU(inplace=True),
204 | nn.Conv1d(512, self.encoder_channel, 1)
205 | )
206 |
207 | def forward(self, point_groups):
208 | '''
209 | point_groups : B G N 3
210 | -----------------
211 | feature_global : B G C
212 | '''
213 | bs, g, n, c = point_groups.shape
214 | point_groups = point_groups.reshape(bs * g, n, c)
215 | # encoder
216 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
217 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
218 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
219 | feature = self.second_conv(feature) # BG 1024 n
220 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
221 | return feature_global.reshape(bs, g, self.encoder_channel)
222 |
223 |
224 | class Decoder(nn.Module):
225 | def __init__(self, encoder_channel, num_fine):
226 | super().__init__()
227 | self.num_fine = num_fine
228 | self.grid_size = 2
229 | self.num_coarse = self.num_fine // 4
230 | assert num_fine % 4 == 0
231 |
232 | self.mlp = nn.Sequential(
233 | nn.Linear(encoder_channel, 1024),
234 | nn.ReLU(inplace=True),
235 | nn.Linear(1024, 1024),
236 | nn.ReLU(inplace=True),
237 | nn.Linear(1024, 3 * self.num_coarse)
238 | )
239 | self.final_conv = nn.Sequential(
240 | nn.Conv1d(encoder_channel + 3 + 2, 512, 1),
241 | nn.BatchNorm1d(512),
242 | nn.ReLU(inplace=True),
243 | nn.Conv1d(512, 512, 1),
244 | nn.BatchNorm1d(512),
245 | nn.ReLU(inplace=True),
246 | nn.Conv1d(512, 3, 1)
247 | )
248 | a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
249 | self.grid_size, self.grid_size).reshape(1, -1)
250 | b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
251 | self.grid_size, self.grid_size).reshape(1, -1)
252 | self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2) # 1 2 S
253 |
254 | def forward(self, feature_global):
255 | '''
256 | feature_global : B G C
257 | -------
258 | coarse : B G M 3
259 | fine : B G N 3
260 |
261 | '''
262 | bs, g, c = feature_global.shape
263 | feature_global = feature_global.reshape(bs * g, c)
264 |
265 | coarse = self.mlp(feature_global).reshape(bs * g, self.num_coarse, 3) # BG M 3
266 |
267 | point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
268 | point_feat = point_feat.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
269 |
270 | seed = self.folding_seed.unsqueeze(2).expand(bs * g, -1, self.num_coarse, -1) # BG 2 M (S)
271 | seed = seed.reshape(bs * g, -1, self.num_fine).to(feature_global.device) # BG 2 N
272 |
273 | feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_fine) # BG 1024 N
274 | feat = torch.cat([feature_global, seed, point_feat], dim=1) # BG C N
275 |
276 | center = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
277 | center = center.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
278 |
279 | fine = self.final_conv(feat) + center # BG 3 N
280 | fine = fine.reshape(bs, g, 3, self.num_fine).transpose(-1, -2)
281 | coarse = coarse.reshape(bs, g, self.num_coarse, 3)
282 | return coarse, fine
283 |
284 |
285 | class DiscreteVAE(nn.Module):
286 | def __init__(self, config, **kwargs):
287 | super().__init__()
288 | self.group_size = config.group_size
289 | self.num_group = config.num_group
290 | self.encoder_dims = config.encoder_dims
291 | self.tokens_dims = config.tokens_dims
292 |
293 | self.decoder_dims = config.decoder_dims
294 | self.num_tokens = config.num_tokens
295 |
296 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
297 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
298 | self.dgcnn_1 = DGCNN(encoder_channel=self.encoder_dims, output_channel=self.num_tokens)
299 | self.codebook = nn.Parameter(torch.randn(self.num_tokens, self.tokens_dims))
300 |
301 | self.dgcnn_2 = DGCNN(encoder_channel=self.tokens_dims, output_channel=self.decoder_dims)
302 | self.decoder = Decoder(encoder_channel=self.decoder_dims, num_fine=self.group_size)
303 | # self.build_loss_func()
304 |
305 | # def build_loss_func(self):
306 | # self.loss_func_cdl1 = ChamferDistanceL1().cuda()
307 | # self.loss_func_cdl2 = ChamferDistanceL2().cuda()
308 | # self.loss_func_emd = emd().cuda()
309 |
310 | def recon_loss(self, ret, gt):
311 | whole_coarse, whole_fine, coarse, fine, group_gt, _ = ret
312 |
313 | bs, g, _, _ = coarse.shape
314 |
315 | coarse = coarse.reshape(bs * g, -1, 3).contiguous()
316 | fine = fine.reshape(bs * g, -1, 3).contiguous()
317 | group_gt = group_gt.reshape(bs * g, -1, 3).contiguous()
318 |
319 | loss_coarse_block = self.loss_func_cdl1(coarse, group_gt)
320 | loss_fine_block = self.loss_func_cdl1(fine, group_gt)
321 |
322 | loss_recon = loss_coarse_block + loss_fine_block
323 |
324 | return loss_recon
325 |
326 | def get_loss(self, ret, gt):
327 | # reconstruction loss
328 | loss_recon = self.recon_loss(ret, gt)
329 | # kl divergence
330 | logits = ret[-1] # B G N
331 | softmax = F.softmax(logits, dim=-1)
332 | mean_softmax = softmax.mean(dim=1)
333 | log_qy = torch.log(mean_softmax)
334 | log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=gt.device))
335 | loss_klv = F.kl_div(log_qy, log_uniform.expand(log_qy.size(0), log_qy.size(1)), None, None, 'batchmean',
336 | log_target=True)
337 |
338 | return loss_recon, loss_klv
339 |
340 | def forward(self, inp, temperature=1., hard=False, **kwargs):
341 | neighborhood, center = self.group_divider(inp)
342 | logits = self.encoder(neighborhood) # B G C
343 | logits = self.dgcnn_1(logits, center) # B G N
344 | soft_one_hot = F.gumbel_softmax(logits, tau=temperature, dim=2, hard=hard) # B G N
345 | sampled = torch.einsum('b g n, n c -> b g c', soft_one_hot, self.codebook) # B G C
346 | feature = self.dgcnn_2(sampled, center)
347 | coarse, fine = self.decoder(feature)
348 |
349 | with torch.no_grad():
350 | whole_fine = (fine + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
351 | whole_coarse = (coarse + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
352 |
353 | assert fine.size(2) == self.group_size
354 | ret = (whole_coarse, whole_fine, coarse, fine, neighborhood, logits)
355 | return ret
--------------------------------------------------------------------------------
/n2m/model/pointbert/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/n2m/model/pointbert/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from collections import abc
10 | # from pointnet2_ops import pointnet2_utils
11 |
12 |
13 | # def fps(data, number):
14 | # '''
15 | # data B N 3
16 | # number int
17 | # '''
18 | # fps_idx = pointnet2_utils.furthest_point_sample(data, number)
19 | # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
20 | # return fps_data
21 |
22 | def index_points(points, idx):
23 | """
24 | Input:
25 | points: input points data, [B, N, C]
26 | idx: sample index data, [B, S]
27 | Return:
28 | new_points:, indexed points data, [B, S, C]
29 | """
30 | device = points.device
31 | B = points.shape[0]
32 | view_shape = list(idx.shape)
33 | view_shape[1:] = [1] * (len(view_shape) - 1)
34 | repeat_shape = list(idx.shape)
35 | repeat_shape[0] = 1
36 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
37 | new_points = points[batch_indices, idx, :]
38 | return new_points
39 |
40 | def fps(xyz, npoint):
41 | """
42 | Input:
43 | xyz: pointcloud data, [B, N, 3]
44 | npoint: number of samples
45 | Return:
46 | centroids: sampled pointcloud index, [B, npoint]
47 | """
48 | device = xyz.device
49 | B, N, C = xyz.shape
50 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
51 | distance = torch.ones(B, N).to(device) * 1e10
52 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
53 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
54 | for i in range(npoint):
55 | centroids[:, i] = farthest
56 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
57 | dist = torch.sum((xyz - centroid) ** 2, -1)
58 | distance = torch.min(distance, dist)
59 | farthest = torch.max(distance, -1)[1]
60 | return index_points(xyz, centroids)
61 |
62 | def worker_init_fn(worker_id):
63 | np.random.seed(np.random.get_state()[1][0] + worker_id)
64 |
65 | def build_lambda_sche(opti, config):
66 | if config.get('decay_step') is not None:
67 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
68 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
69 | else:
70 | raise NotImplementedError()
71 | return scheduler
72 |
73 | def build_lambda_bnsche(model, config):
74 | if config.get('decay_step') is not None:
75 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
76 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
77 | else:
78 | raise NotImplementedError()
79 | return bnm_scheduler
80 |
81 | def set_random_seed(seed, deterministic=False):
82 | """Set random seed.
83 | Args:
84 | seed (int): Seed to be used.
85 | deterministic (bool): Whether to set the deterministic option for
86 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
87 | to True and `torch.backends.cudnn.benchmark` to False.
88 | Default: False.
89 |
90 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
91 | if cuda_deterministic: # slower, more reproducible
92 | cudnn.deterministic = True
93 | cudnn.benchmark = False
94 | else: # faster, less reproducible
95 | cudnn.deterministic = False
96 | cudnn.benchmark = True
97 |
98 | """
99 | random.seed(seed)
100 | np.random.seed(seed)
101 | torch.manual_seed(seed)
102 | torch.cuda.manual_seed_all(seed)
103 | if deterministic:
104 | torch.backends.cudnn.deterministic = True
105 | torch.backends.cudnn.benchmark = False
106 |
107 |
108 | def is_seq_of(seq, expected_type, seq_type=None):
109 | """Check whether it is a sequence of some type.
110 | Args:
111 | seq (Sequence): The sequence to be checked.
112 | expected_type (type): Expected type of sequence items.
113 | seq_type (type, optional): Expected sequence type.
114 | Returns:
115 | bool: Whether the sequence is valid.
116 | """
117 | if seq_type is None:
118 | exp_seq_type = abc.Sequence
119 | else:
120 | assert isinstance(seq_type, type)
121 | exp_seq_type = seq_type
122 | if not isinstance(seq, exp_seq_type):
123 | return False
124 | for item in seq:
125 | if not isinstance(item, expected_type):
126 | return False
127 | return True
128 |
129 |
130 | def set_bn_momentum_default(bn_momentum):
131 | def fn(m):
132 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
133 | m.momentum = bn_momentum
134 | return fn
135 |
136 | class BNMomentumScheduler(object):
137 |
138 | def __init__(
139 | self, model, bn_lambda, last_epoch=-1,
140 | setter=set_bn_momentum_default
141 | ):
142 | if not isinstance(model, nn.Module):
143 | raise RuntimeError(
144 | "Class '{}' is not a PyTorch nn Module".format(
145 | type(model).__name__
146 | )
147 | )
148 |
149 | self.model = model
150 | self.setter = setter
151 | self.lmbd = bn_lambda
152 |
153 | self.step(last_epoch + 1)
154 | self.last_epoch = last_epoch
155 |
156 | def step(self, epoch=None):
157 | if epoch is None:
158 | epoch = self.last_epoch + 1
159 |
160 | self.last_epoch = epoch
161 | self.model.apply(self.setter(self.lmbd(epoch)))
162 |
163 | def get_momentum(self, epoch=None):
164 | if epoch is None:
165 | epoch = self.last_epoch + 1
166 | return self.lmbd(epoch)
167 |
168 |
169 |
170 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
171 | '''
172 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
173 | '''
174 | _,n,c = xyz.shape
175 |
176 | assert n == num_points
177 | assert c == 3
178 | if crop == num_points:
179 | return xyz, None
180 |
181 | INPUT = []
182 | CROP = []
183 | for points in xyz:
184 | if isinstance(crop,list):
185 | num_crop = random.randint(crop[0],crop[1])
186 | else:
187 | num_crop = crop
188 |
189 | points = points.unsqueeze(0)
190 |
191 | if fixed_points is None:
192 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
193 | else:
194 | if isinstance(fixed_points,list):
195 | fixed_point = random.sample(fixed_points,1)[0]
196 | else:
197 | fixed_point = fixed_points
198 | center = fixed_point.reshape(1,1,3).cuda()
199 |
200 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
201 |
202 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
203 |
204 | if padding_zeros:
205 | input_data = points.clone()
206 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
207 |
208 | else:
209 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
210 |
211 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
212 |
213 | if isinstance(crop,list):
214 | INPUT.append(fps(input_data,2048))
215 | CROP.append(fps(crop_data,2048))
216 | else:
217 | INPUT.append(input_data)
218 | CROP.append(crop_data)
219 |
220 | input_data = torch.cat(INPUT,dim=0)# B N 3
221 | crop_data = torch.cat(CROP,dim=0)# B M 3
222 |
223 | return input_data.contiguous(), crop_data.contiguous()
224 |
225 | def get_ptcloud_img(ptcloud):
226 | fig = plt.figure(figsize=(8, 8))
227 |
228 | x, z, y = ptcloud.transpose(1, 0)
229 | ax = fig.gca(projection=Axes3D.name, adjustable='box')
230 | ax.axis('off')
231 | # ax.axis('scaled')
232 | ax.view_init(30, 45)
233 | max, min = np.max(ptcloud), np.min(ptcloud)
234 | ax.set_xbound(min, max)
235 | ax.set_ybound(min, max)
236 | ax.set_zbound(min, max)
237 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet')
238 |
239 | fig.canvas.draw()
240 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
241 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
242 | return img
243 |
244 |
245 |
246 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
247 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
248 | fig = plt.figure(figsize=(6*len(data_list),6))
249 | cmax = data_list[-1][:,0].max()
250 |
251 | for i in range(len(data_list)):
252 | data = data_list[i][:-2048] if i == 1 else data_list[i]
253 | color = data[:,0] /cmax
254 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
255 | ax.view_init(30, -120)
256 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
257 | ax.set_title(titles[i])
258 |
259 | ax.set_axis_off()
260 | ax.set_xlim(xlim)
261 | ax.set_ylim(ylim)
262 | ax.set_zlim(zlim)
263 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
264 | if not os.path.exists(path):
265 | os.makedirs(path)
266 |
267 | pic_path = path + '.png'
268 | fig.savefig(pic_path)
269 |
270 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
271 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
272 | plt.close(fig)
273 |
274 |
275 | def random_dropping(pc, e):
276 | up_num = max(64, 768 // (e//50 + 1))
277 | pc = pc
278 | random_num = torch.randint(1, up_num, (1,1))[0,0]
279 | pc = fps(pc, random_num)
280 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
281 | pc = torch.cat([pc, padding], dim = 1)
282 | return pc
283 |
284 |
285 | def random_scale(partial, scale_range=[0.8, 1.2]):
286 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
287 | return partial * scale
288 |
--------------------------------------------------------------------------------
/n2m/model/pointbert/point_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath
4 | from .dvae import Group
5 | from .dvae import Encoder
6 | from .logger import print_log
7 | from collections import OrderedDict
8 |
9 | from .checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
10 |
11 | class Mlp(nn.Module):
12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13 | super().__init__()
14 | out_features = out_features or in_features
15 | hidden_features = hidden_features or in_features
16 | self.fc1 = nn.Linear(in_features, hidden_features)
17 | self.act = act_layer()
18 | self.fc2 = nn.Linear(hidden_features, out_features)
19 | self.drop = nn.Dropout(drop)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.act(x)
24 | x = self.drop(x)
25 | x = self.fc2(x)
26 | x = self.drop(x)
27 | return x
28 |
29 |
30 | class Attention(nn.Module):
31 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
32 | super().__init__()
33 | self.num_heads = num_heads
34 | head_dim = dim // num_heads
35 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
36 | self.scale = qk_scale or head_dim ** -0.5
37 |
38 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
39 | self.attn_drop = nn.Dropout(attn_drop)
40 | self.proj = nn.Linear(dim, dim)
41 | self.proj_drop = nn.Dropout(proj_drop)
42 |
43 | def forward(self, x):
44 | B, N, C = x.shape
45 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
46 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
47 |
48 | attn = (q @ k.transpose(-2, -1)) * self.scale
49 | attn = attn.softmax(dim=-1)
50 | attn = self.attn_drop(attn)
51 |
52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
53 | x = self.proj(x)
54 | x = self.proj_drop(x)
55 | return x, attn
56 |
57 |
58 | class Block(nn.Module):
59 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
60 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
61 | super().__init__()
62 | self.norm1 = norm_layer(dim)
63 |
64 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
65 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
66 | self.norm2 = norm_layer(dim)
67 | mlp_hidden_dim = int(dim * mlp_ratio)
68 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
69 |
70 | self.attn = Attention(
71 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
72 |
73 | def forward(self, x):
74 | attn_out, attn_weights = self.attn(self.norm1(x))
75 | x = x + self.drop_path(attn_out)
76 | x = x + self.drop_path(self.mlp(self.norm2(x)))
77 | return x, attn_weights
78 |
79 |
80 | class TransformerEncoder(nn.Module):
81 | """ Transformer Encoder without hierarchical structure
82 | """
83 |
84 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
85 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
86 | super().__init__()
87 |
88 | self.blocks = nn.ModuleList([
89 | Block(
90 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
91 | drop=drop_rate, attn_drop=attn_drop_rate,
92 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
93 | )
94 | for i in range(depth)])
95 |
96 | def forward(self, x, pos):
97 | for i, block in enumerate(self.blocks):
98 | x, attn_weights = block(x + pos)
99 | if i == len(self.blocks) - 1:
100 | final_attn_weights = attn_weights
101 | return x, final_attn_weights
102 |
103 |
104 | class PointTransformer(nn.Module):
105 | def __init__(self, config, use_max_pool=True):
106 | super().__init__()
107 | self.config = config
108 |
109 | self.use_max_pool = use_max_pool # * whethet to max pool the features of different tokens
110 |
111 | self.trans_dim = config['trans_dim']
112 | self.depth = config['depth']
113 | self.drop_path_rate = config['drop_path_rate']
114 | self.cls_dim = config['cls_dim']
115 | self.num_heads = config['num_heads']
116 | self.point_dims = 6
117 | self.group_size = config['group_size']
118 | self.num_group = config['num_group']
119 | # grouper
120 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
121 | # define the encoder
122 | self.encoder_dims = config['encoder_dims']
123 | self.encoder = Encoder(encoder_channel=self.encoder_dims, point_input_dims=self.point_dims)
124 | # bridge encoder and transformer
125 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
126 |
127 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
128 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
129 |
130 | self.pos_embed = nn.Sequential(
131 | nn.Linear(3, 128),
132 | nn.GELU(),
133 | nn.Linear(128, self.trans_dim)
134 | )
135 |
136 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
137 | self.blocks = TransformerEncoder(
138 | embed_dim=self.trans_dim,
139 | depth=self.depth,
140 | drop_path_rate=dpr,
141 | num_heads=self.num_heads
142 | )
143 |
144 | self.norm = nn.LayerNorm(self.trans_dim)
145 |
146 | def load_checkpoint(self, bert_ckpt_path):
147 | ckpt = torch.load(bert_ckpt_path, map_location='cpu')
148 |
149 | incompatible = self.load_state_dict(ckpt['base_model'], strict=False)
150 |
151 | if incompatible.missing_keys:
152 | print_log('missing_keys', logger='Transformer')
153 | print_log(
154 | get_missing_parameters_message(incompatible.missing_keys),
155 | logger='Transformer'
156 | )
157 | if incompatible.unexpected_keys:
158 | print_log('unexpected_keys', logger='Transformer')
159 | print_log(
160 | get_unexpected_parameters_message(incompatible.unexpected_keys),
161 | logger='Transformer'
162 | )
163 | if not incompatible.missing_keys and not incompatible.unexpected_keys:
164 | # * print successful loading
165 | print_log("PointBERT's weights are successfully loaded from {}".format(bert_ckpt_path), logger='Transformer')
166 |
167 | def forward(self, pts):
168 | # divide the point cloud in the same form. This is important
169 | neighborhood, center = self.group_divider(pts)
170 | # encoder the input cloud blocks
171 | group_input_tokens = self.encoder(neighborhood) # B G N
172 | group_input_tokens = self.reduce_dim(group_input_tokens)
173 | # prepare cls
174 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
175 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
176 | # add pos embedding
177 | pos = self.pos_embed(center)
178 | # final input
179 | x = torch.cat((cls_tokens, group_input_tokens), dim=1)
180 | pos = torch.cat((cls_pos, pos), dim=1)
181 | # transformer
182 | x, final_attn_weights = self.blocks(x, pos)
183 | x = self.norm(x) # * B, G + 1(cls token)(513), C(384)
184 | if not self.use_max_pool:
185 | return x, final_attn_weights
186 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1).unsqueeze(1) # * concat the cls token and max pool the features of different tokens, make it B, 1, C
187 | return concat_f, final_attn_weights # * B, 1, C(384 + 384)
--------------------------------------------------------------------------------
/n2m/module/N2Mmodule.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torch.distributions import MultivariateNormal
4 |
5 | from n2m.models.N2Mnet import N2Mnet
6 | from n2m.utils.point_cloud import fix_point_cloud_size
7 | from n2m.utils.sample_utils import CollisionChecker
8 |
9 | class N2Mmodule:
10 | """
11 | N2M module class that encapsulates the N2M model.
12 | This includes collision checking point cloud pre-processing.
13 | """
14 | def __init__(self, config):
15 | self.n2mnet_config = config['n2mnet']
16 | self.preprocess_config = config['preprocess']
17 | self.postprocess_config = config['postprocess']
18 |
19 | self.model = N2Mnet(self.n2mnet_config)
20 | self.collision_checker = CollisionChecker(self.postprocess_config['collision_checker'])
21 |
22 | def predict(self, point_cloud):
23 | """
24 | Predicts the preferable initial pose.
25 |
26 | Inputs:
27 | point_cloud (numpy.array): Input point cloud of shape (N, 3). The point cloud should be captured from robot centric camera and should have robot base at the origin.
28 |
29 | Outputs:
30 | preferable initial_pose (numpy.array): Predicted preferable initial pose of shape (3) for SE(2) (4) for SE(2) + z.
31 | prediction success (Boolean): There might not be a valid sample within the sample number provided. return the validity of the sample
32 | """
33 | self.model.eval()
34 |
35 | # preprocess(downsmaple) point cloud to match the specified pointnum
36 | pointnum = self.preprocess_config['pointnum']
37 | orig_point_cloud = point_cloud.copy()
38 | point_cloud = fix_point_cloud_size(point_cloud, pointnum)
39 |
40 | # load point cloud for collision checking
41 | self.collision_checker.set_pcd(orig_point_cloud)
42 |
43 | # get predictions
44 | with torch.no_grad():
45 | point_cloud_tensor = torch.tensor(input_point_cloud, dtype=torch.float32).unsqueeze(0) # Add batch dimension (1, N, 3)
46 | num_samples = self.postprocess_config['num_samples']
47 | samples, means, covs, weights = self.model.sample(point_cloud_tensor, num_samples=num_samples)
48 |
49 | # first check mean's collision. If it doesn't collide, return mean value
50 | if self.collision_checker.check_collision(means[0]):
51 | return means[0], True
52 |
53 | # sort predictions in the order of predicted probability
54 | num_modes = weights.shape[1]
55 | mvns = [MultivariateNormal(means[0, i], covs[0, i]) for i in range(num_modes)]
56 | log_probs = torch.stack([mvn.log_prob(samples[0]) for mvn in mvns]) # shape: [num_modes, num_samples]
57 | log_weights = torch.log(weights[0] + 1e-8).unsqueeze(1) # shape: [num_modes, 1]
58 | weighted_log_probs = log_probs + log_weights
59 | gaussian_probabilities = torch.logsumexp(weighted_log_probs, dim=0) # shape: [num_samples]
60 | sorted_indices = torch.argsort(gaussian_probabilities, descending=True)
61 | samples = samples[0, sorted_indices]
62 |
63 | # check collision for each samples and return the non-colliding sample with the highest probability
64 | for i in range(num_samples):
65 | sample = samples[0, i]
66 | if self.collision_checker.check_collision(sample):
67 | print("Valid initial pose found: ", sample)
68 | return sample, True
69 | else:
70 | print("Invalid pose, trying again: ", sample)
71 |
72 | # prediction fail. return False for validity
73 | return means[0], False
--------------------------------------------------------------------------------
/n2m/utils/AverageMeter.py:
--------------------------------------------------------------------------------
1 |
2 | class AverageMeter(object):
3 | def __init__(self, items=None):
4 | self.items = items
5 | self.n_items = 1 if items is None else len(items)
6 | self.reset()
7 |
8 | def reset(self):
9 | self._val = [0] * self.n_items
10 | self._sum = [0] * self.n_items
11 | self._count = [0] * self.n_items
12 |
13 | def update(self, values):
14 | if type(values).__name__ == 'list':
15 | for idx, v in enumerate(values):
16 | self._val[idx] = v
17 | self._sum[idx] += v
18 | self._count[idx] += 1
19 | else:
20 | self._val[0] = values
21 | self._sum[0] += values
22 | self._count[0] += 1
23 |
24 | def val(self, idx=None):
25 | if idx is None:
26 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]
27 | else:
28 | return self._val[idx]
29 |
30 | def count(self, idx=None):
31 | if idx is None:
32 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]
33 | else:
34 | return self._count[idx]
35 |
36 | def avg(self, idx=None):
37 | if idx is None:
38 | return self._sum[0] / self._count[0] if self.items is None else [
39 | self._sum[i] / self._count[i] for i in range(self.n_items)
40 | ]
41 | else:
42 | return self._sum[idx] / self._count[idx]
--------------------------------------------------------------------------------
/n2m/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | import copy
5 | import logging
6 | import os
7 | from collections import defaultdict
8 | import torch
9 | import torch.nn as nn
10 |
11 | from typing import Any
12 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
13 |
14 | from termcolor import colored
15 |
16 | def get_missing_parameters_message(keys: List[str]) -> str:
17 | """
18 | Get a logging-friendly message to report parameter names (keys) that are in
19 | the model but not found in a checkpoint.
20 | Args:
21 | keys (list[str]): List of keys that were not found in the checkpoint.
22 | Returns:
23 | str: message.
24 | """
25 | groups = _group_checkpoint_keys(keys)
26 | msg = "Some model parameters or buffers are not found in the checkpoint:\n"
27 | msg += "\n".join(
28 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
29 | )
30 | return msg
31 |
32 |
33 | def get_unexpected_parameters_message(keys: List[str]) -> str:
34 | """
35 | Get a logging-friendly message to report parameter names (keys) that are in
36 | the checkpoint but not found in the model.
37 | Args:
38 | keys (list[str]): List of keys that were not found in the model.
39 | Returns:
40 | str: message.
41 | """
42 | groups = _group_checkpoint_keys(keys)
43 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
44 | msg += "\n".join(
45 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
46 | )
47 | return msg
48 |
49 |
50 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
51 | """
52 | Strip the prefix in metadata, if any.
53 | Args:
54 | state_dict (OrderedDict): a state-dict to be loaded to the model.
55 | prefix (str): prefix.
56 | """
57 | keys = sorted(state_dict.keys())
58 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
59 | return
60 |
61 | for key in keys:
62 | newkey = key[len(prefix):]
63 | state_dict[newkey] = state_dict.pop(key)
64 |
65 | # also strip the prefix in metadata, if any..
66 | try:
67 | metadata = state_dict._metadata # pyre-ignore
68 | except AttributeError:
69 | pass
70 | else:
71 | for key in list(metadata.keys()):
72 | # for the metadata dict, the key can be:
73 | # '': for the DDP module, which we want to remove.
74 | # 'module': for the actual model.
75 | # 'module.xx.xx': for the rest.
76 |
77 | if len(key) == 0:
78 | continue
79 | newkey = key[len(prefix):]
80 | metadata[newkey] = metadata.pop(key)
81 |
82 |
83 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
84 | """
85 | Group keys based on common prefixes. A prefix is the string up to the final
86 | "." in each key.
87 | Args:
88 | keys (list[str]): list of parameter names, i.e. keys in the model
89 | checkpoint dict.
90 | Returns:
91 | dict[list]: keys with common prefixes are grouped into lists.
92 | """
93 | groups = defaultdict(list)
94 | for key in keys:
95 | pos = key.rfind(".")
96 | if pos >= 0:
97 | head, tail = key[:pos], [key[pos + 1:]]
98 | else:
99 | head, tail = key, []
100 | groups[head].extend(tail)
101 | return groups
102 |
103 |
104 | def _group_to_str(group: List[str]) -> str:
105 | """
106 | Format a group of parameter name suffixes into a loggable string.
107 | Args:
108 | group (list[str]): list of parameter name suffixes.
109 | Returns:
110 | str: formated string.
111 | """
112 | if len(group) == 0:
113 | return ""
114 |
115 | if len(group) == 1:
116 | return "." + group[0]
117 |
118 | return ".{" + ", ".join(group) + "}"
119 |
120 |
121 | def _named_modules_with_dup(
122 | model: nn.Module, prefix: str = ""
123 | ) -> Iterable[Tuple[str, nn.Module]]:
124 | """
125 | The same as `model.named_modules()`, except that it includes
126 | duplicated modules that have more than one name.
127 | """
128 | yield prefix, model
129 | for name, module in model._modules.items(): # pyre-ignore
130 | if module is None:
131 | continue
132 | submodule_prefix = prefix + ("." if prefix else "") + name
133 | yield from _named_modules_with_dup(module, submodule_prefix)
--------------------------------------------------------------------------------
/n2m/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from easydict import EasyDict
3 | import os
4 | from .logger import print_log
5 |
6 | def log_args_to_file(args, pre='args', logger=None):
7 | for key, val in args.__dict__.items():
8 | print_log(f'{pre}.{key} : {val}', logger = logger)
9 |
10 | def log_config_to_file(cfg, pre='cfg', logger=None):
11 | for key, val in cfg.items():
12 | if isinstance(cfg[key], EasyDict):
13 | print_log(f'{pre}.{key} = edict()', logger = logger)
14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
15 | continue
16 | print_log(f'{pre}.{key} : {val}', logger = logger)
17 |
18 | def merge_new_config(config, new_config):
19 | for key, val in new_config.items():
20 | if not isinstance(val, dict):
21 | if key == '_base_':
22 | with open(new_config['_base_'], 'r') as f:
23 | try:
24 | val = yaml.load(f, Loader=yaml.FullLoader)
25 | except:
26 | val = yaml.load(f)
27 | config[key] = EasyDict()
28 | merge_new_config(config[key], val)
29 | else:
30 | config[key] = val
31 | continue
32 | if key not in config:
33 | config[key] = EasyDict()
34 | merge_new_config(config[key], val)
35 | return config
36 |
37 | def cfg_from_yaml_file(cfg_file):
38 | config = EasyDict()
39 | with open(cfg_file, 'r') as f:
40 | try:
41 | new_config = yaml.load(f, Loader=yaml.FullLoader)
42 | except:
43 | new_config = yaml.load(f)
44 | merge_new_config(config=config, new_config=new_config)
45 | return config
46 |
47 | def get_config(args, logger=None):
48 | # if args.resume:
49 | # cfg_path = os.path.join(args.experiment_path, 'config.yaml')
50 | # if not os.path.exists(cfg_path):
51 | # print_log("Failed to resume", logger = logger)
52 | # raise FileNotFoundError()
53 | # print_log(f'Resume yaml from {cfg_path}', logger = logger)
54 | # args.config = cfg_path
55 | config = cfg_from_yaml_file(args.config)
56 | # if not args.resume and args.local_rank == 0:
57 | # save_experiment_config(args, config, logger)
58 | return config
59 |
60 | def save_experiment_config(args, config, logger = None):
61 | config_path = os.path.join(args.experiment_path, 'config.yaml')
62 | os.system('cp %s %s' % (args.config, config_path))
63 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger )
--------------------------------------------------------------------------------
/n2m/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.multiprocessing as mp
5 | from torch import distributed as dist
6 |
7 |
8 |
9 | def init_dist(launcher, backend='nccl', **kwargs):
10 | if mp.get_start_method(allow_none=True) is None:
11 | mp.set_start_method('spawn')
12 | if launcher == 'pytorch':
13 | _init_dist_pytorch(backend, **kwargs)
14 | else:
15 | raise ValueError(f'Invalid launcher type: {launcher}')
16 |
17 |
18 | def _init_dist_pytorch(backend, **kwargs):
19 | # TODO: use local_rank instead of rank % num_gpus
20 | rank = int(os.environ['RANK'])
21 | num_gpus = torch.cuda.device_count()
22 | torch.cuda.set_device(rank % num_gpus)
23 | dist.init_process_group(backend=backend, **kwargs)
24 | print(f'init distributed in rank {torch.distributed.get_rank()}')
25 |
26 |
27 | def get_dist_info():
28 | if dist.is_available():
29 | initialized = dist.is_initialized()
30 | else:
31 | initialized = False
32 | if initialized:
33 | rank = dist.get_rank()
34 | world_size = dist.get_world_size()
35 | else:
36 | rank = 0
37 | world_size = 1
38 | return rank, world_size
39 |
40 |
41 | def reduce_tensor(tensor, args):
42 | '''
43 | for acc kind, get the mean in each gpu
44 | '''
45 | rt = tensor.clone()
46 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
47 | rt /= args.world_size
48 | return rt
49 |
50 | def gather_tensor(tensor, args):
51 | output_tensors = [tensor.clone() for _ in range(args.world_size)]
52 | torch.distributed.all_gather(output_tensors, tensor)
53 | concat = torch.cat(output_tensors, dim=0)
54 | return concat
55 |
--------------------------------------------------------------------------------
/n2m/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/n2m/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import numpy as np
5 | import math
6 |
7 | class Loss(nn.Module):
8 | def __init__(self, config):
9 | super().__init__()
10 | self.name = config['name']
11 | self.neg_weight = config.get('neg_weight', 1.0)
12 | self.lam_weight = config.get('lam_weight', 0.0)
13 | self.lam_dist = config.get('lam_dist', 0.0)
14 | self.min_logprob = config.get('min_logprob', 0.0)
15 |
16 | def forward(self, means, covs, weights, target_point, label):
17 | if self.name == 'mle':
18 | return self._mle_loss(means, covs, weights, target_point, label)
19 | elif self.name == 'mle_max':
20 | return self._mle_max_loss(means, covs, weights, target_point, label)
21 | elif self.name == 'ce':
22 | return self._ce_loss(means, covs, weights, target_point, label)
23 | else:
24 | raise ValueError(f"Unknown loss name: {self.name}")
25 |
26 | def _mle_loss(self, means, covs, weights, target_point, label):
27 | B, K, D = means.size()
28 | target_point = target_point.unsqueeze(1)
29 | diff = target_point - means
30 | inv_covs = torch.inverse(covs).to(dtype=torch.float32)
31 | diff = diff.to(dtype=torch.float32)
32 |
33 | mahalanobis = torch.sum(
34 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1
35 | ).squeeze(-1)
36 |
37 | log_det = torch.logdet(covs)
38 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi)
39 | log_prob = log_prob + torch.log(weights + 1e-6)
40 | max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0]
41 | log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1)
42 | if self.min_logprob < 0.0:
43 | log_prob[(label == -1) & (log_prob < self.min_logprob)] = self.min_logprob
44 |
45 | label = label.float().to(log_prob.device)
46 | label[label == -1] = -self.neg_weight # optional handling of ignore index
47 |
48 | entropy_weight = -torch.sum(weights * torch.log(weights + 1e-6), dim=-1).mean()
49 | entropy_dist = -torch.sum(weights * (0.5 * (D * (1 + math.log(2 * math.pi)) + log_det)), dim=-1).mean()
50 |
51 | loss = - torch.mean(log_prob * label) - self.lam_weight * entropy_weight - self.lam_dist * entropy_dist
52 | return loss
53 |
54 | def _mle_max_loss(self, means, covs, weights, target_point, label):
55 | B, K, D = means.size()
56 | target_point = target_point.unsqueeze(1)
57 | diff = target_point - means
58 | inv_covs = torch.inverse(covs).to(dtype=torch.float32)
59 | diff = diff.to(dtype=torch.float32)
60 |
61 | mahalanobis = torch.sum(
62 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1
63 | ).squeeze(-1)
64 |
65 | log_det = torch.logdet(covs)
66 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi)
67 | # log_prob = log_prob + torch.log(weights + 1e-6)
68 | # max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0]
69 | # log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1)
70 | log_prob, _ = torch.max(log_prob, dim=1)
71 | if self.min_logprob < 0.0:
72 | log_prob[(label == -1) & (log_prob < self.min_logprob)] = self.min_logprob
73 |
74 | label = label.float().to(log_prob.device)
75 | label[label == -1] = -self.neg_weight # optional handling of ignore index
76 |
77 | entropy_weight = -torch.sum(weights * torch.log(weights + 1e-6), dim=-1).mean()
78 | entropy_dist = -torch.sum(weights * (0.5 * (D * (1 + math.log(2 * math.pi)) + log_det)), dim=-1).mean()
79 |
80 | loss = - torch.mean(log_prob * label) - self.lam_weight * entropy_weight - self.lam_dist * entropy_dist
81 | return loss
82 |
83 | def _ce_loss(self, means, covs, weights, target_point, label):
84 | B, K, D = means.size()
85 | target_point = target_point.unsqueeze(1)
86 | diff = target_point - means
87 | inv_covs = torch.inverse(covs).to(dtype=torch.float32)
88 | diff = diff.to(dtype=torch.float32)
89 |
90 | mahalanobis = torch.sum(
91 | torch.matmul(diff.unsqueeze(2), inv_covs) * diff.unsqueeze(2), dim=-1)
92 |
93 | log_det = torch.logdet(covs)
94 | log_prob = -0.5 * mahalanobis - 0.5 * log_det - 0.5 * D * np.log(2 * np.pi)
95 | log_prob = log_prob + torch.log(weights + 1e-6)
96 |
97 | max_log_prob = torch.max(log_prob, dim=1, keepdim=True)[0]
98 | log_prob = torch.logsumexp(log_prob - max_log_prob, dim=1) + max_log_prob.squeeze(1)
99 | pdf = torch.exp(log_prob)
100 | success_rate = torch.sigmoid(pdf)
101 |
102 | label = label.float().to(log_prob.device)
103 | label = (label + torch.abs(label)) / 2
104 |
105 | loss = -torch.mean(
106 | torch.log(success_rate + 1e-6) * label +
107 | torch.log(1 - success_rate + 1e-6) * (1 - label)
108 | )
109 | return loss
--------------------------------------------------------------------------------
/n2m/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Haozhe Xie
3 | # @Date: 2019-08-08 14:31:30
4 | # @Last Modified by: Haozhe Xie
5 | # @Last Modified time: 2020-05-25 09:13:32
6 | # @Email: cshzxie@gmail.com
7 |
8 | import logging
9 | import open3d
10 |
11 | from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
12 |
13 |
14 | class Metrics(object):
15 | ITEMS = [{
16 | 'name': 'F-Score',
17 | 'enabled': True,
18 | 'eval_func': 'cls._get_f_score',
19 | 'is_greater_better': True,
20 | 'init_value': 0
21 | }, {
22 | 'name': 'CDL1',
23 | 'enabled': True,
24 | 'eval_func': 'cls._get_chamfer_distancel1',
25 | 'eval_object': ChamferDistanceL1(ignore_zeros=True),
26 | 'is_greater_better': False,
27 | 'init_value': 32767
28 | }, {
29 | 'name': 'CDL2',
30 | 'enabled': True,
31 | 'eval_func': 'cls._get_chamfer_distancel2',
32 | 'eval_object': ChamferDistanceL2(ignore_zeros=True),
33 | 'is_greater_better': False,
34 | 'init_value': 32767
35 | }]
36 |
37 | @classmethod
38 | def get(cls, pred, gt):
39 | _items = cls.items()
40 | _values = [0] * len(_items)
41 | for i, item in enumerate(_items):
42 | eval_func = eval(item['eval_func'])
43 | _values[i] = eval_func(pred, gt)
44 |
45 | return _values
46 |
47 | @classmethod
48 | def items(cls):
49 | return [i for i in cls.ITEMS if i['enabled']]
50 |
51 | @classmethod
52 | def names(cls):
53 | _items = cls.items()
54 | return [i['name'] for i in _items]
55 |
56 | @classmethod
57 | def _get_f_score(cls, pred, gt, th=0.01):
58 |
59 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py"""
60 | b = pred.size(0)
61 | assert pred.size(0) == gt.size(0)
62 | if b != 1:
63 | f_score_list = []
64 | for idx in range(b):
65 | f_score_list.append(cls._get_f_score(pred[idx:idx+1], gt[idx:idx+1]))
66 | return sum(f_score_list)/len(f_score_list)
67 | else:
68 | pred = cls._get_open3d_ptcloud(pred)
69 | gt = cls._get_open3d_ptcloud(gt)
70 |
71 | dist1 = pred.compute_point_cloud_distance(gt)
72 | dist2 = gt.compute_point_cloud_distance(pred)
73 |
74 | recall = float(sum(d < th for d in dist2)) / float(len(dist2))
75 | precision = float(sum(d < th for d in dist1)) / float(len(dist1))
76 | return 2 * recall * precision / (recall + precision) if recall + precision else 0
77 |
78 | @classmethod
79 | def _get_open3d_ptcloud(cls, tensor):
80 | """pred and gt bs is 1"""
81 | tensor = tensor.squeeze().cpu().numpy()
82 | ptcloud = open3d.geometry.PointCloud()
83 | ptcloud.points = open3d.utility.Vector3dVector(tensor)
84 |
85 | return ptcloud
86 |
87 | @classmethod
88 | def _get_chamfer_distancel1(cls, pred, gt):
89 | chamfer_distance = cls.ITEMS[1]['eval_object']
90 | return chamfer_distance(pred, gt).item() * 1000
91 |
92 | @classmethod
93 | def _get_chamfer_distancel2(cls, pred, gt):
94 | chamfer_distance = cls.ITEMS[2]['eval_object']
95 | return chamfer_distance(pred, gt).item() * 1000
96 |
97 | def __init__(self, metric_name, values):
98 | self._items = Metrics.items()
99 | self._values = [item['init_value'] for item in self._items]
100 | self.metric_name = metric_name
101 |
102 | if type(values).__name__ == 'list':
103 | self._values = values
104 | elif type(values).__name__ == 'dict':
105 | metric_indexes = {}
106 | for idx, item in enumerate(self._items):
107 | item_name = item['name']
108 | metric_indexes[item_name] = idx
109 | for k, v in values.items():
110 | if k not in metric_indexes:
111 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k)
112 | continue
113 | self._values[metric_indexes[k]] = v
114 | else:
115 | raise Exception('Unsupported value type: %s' % type(values))
116 |
117 | def state_dict(self):
118 | _dict = dict()
119 | for i in range(len(self._items)):
120 | item = self._items[i]['name']
121 | value = self._values[i]
122 | _dict[item] = value
123 |
124 | return _dict
125 |
126 | def __repr__(self):
127 | return str(self.state_dict())
128 |
129 | def better_than(self, other):
130 | if other is None:
131 | return True
132 |
133 | _index = -1
134 | for i, _item in enumerate(self._items):
135 | if _item['name'] == self.metric_name:
136 | _index = i
137 | break
138 | if _index == -1:
139 | raise Exception('Invalid metric name to compare.')
140 |
141 | _metric = self._items[i]
142 | _value = self._values[_index]
143 | other_value = other._values[_index]
144 | return _value > other_value if _metric['is_greater_better'] else _value < other_value
145 |
--------------------------------------------------------------------------------
/n2m/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from collections import abc
10 | from pointnet2_ops import pointnet2_utils
11 |
12 |
13 | def fps(data, number):
14 | '''
15 | data B N 3
16 | number int
17 | '''
18 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
19 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
20 | return fps_data
21 |
22 |
23 | def worker_init_fn(worker_id):
24 | np.random.seed(np.random.get_state()[1][0] + worker_id)
25 |
26 | def build_lambda_sche(opti, config):
27 | if config.get('decay_step') is not None:
28 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
29 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
30 | else:
31 | raise NotImplementedError()
32 | return scheduler
33 |
34 | def build_lambda_bnsche(model, config):
35 | if config.get('decay_step') is not None:
36 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
37 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
38 | else:
39 | raise NotImplementedError()
40 | return bnm_scheduler
41 |
42 | def set_random_seed(seed, deterministic=False):
43 | """Set random seed.
44 | Args:
45 | seed (int): Seed to be used.
46 | deterministic (bool): Whether to set the deterministic option for
47 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
48 | to True and `torch.backends.cudnn.benchmark` to False.
49 | Default: False.
50 |
51 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
52 | if cuda_deterministic: # slower, more reproducible
53 | cudnn.deterministic = True
54 | cudnn.benchmark = False
55 | else: # faster, less reproducible
56 | cudnn.deterministic = False
57 | cudnn.benchmark = True
58 |
59 | """
60 | random.seed(seed)
61 | np.random.seed(seed)
62 | torch.manual_seed(seed)
63 | torch.cuda.manual_seed_all(seed)
64 | if deterministic:
65 | torch.backends.cudnn.deterministic = True
66 | torch.backends.cudnn.benchmark = False
67 |
68 |
69 | def is_seq_of(seq, expected_type, seq_type=None):
70 | """Check whether it is a sequence of some type.
71 | Args:
72 | seq (Sequence): The sequence to be checked.
73 | expected_type (type): Expected type of sequence items.
74 | seq_type (type, optional): Expected sequence type.
75 | Returns:
76 | bool: Whether the sequence is valid.
77 | """
78 | if seq_type is None:
79 | exp_seq_type = abc.Sequence
80 | else:
81 | assert isinstance(seq_type, type)
82 | exp_seq_type = seq_type
83 | if not isinstance(seq, exp_seq_type):
84 | return False
85 | for item in seq:
86 | if not isinstance(item, expected_type):
87 | return False
88 | return True
89 |
90 |
91 | def set_bn_momentum_default(bn_momentum):
92 | def fn(m):
93 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
94 | m.momentum = bn_momentum
95 | return fn
96 |
97 | class BNMomentumScheduler(object):
98 |
99 | def __init__(
100 | self, model, bn_lambda, last_epoch=-1,
101 | setter=set_bn_momentum_default
102 | ):
103 | if not isinstance(model, nn.Module):
104 | raise RuntimeError(
105 | "Class '{}' is not a PyTorch nn Module".format(
106 | type(model).__name__
107 | )
108 | )
109 |
110 | self.model = model
111 | self.setter = setter
112 | self.lmbd = bn_lambda
113 |
114 | self.step(last_epoch + 1)
115 | self.last_epoch = last_epoch
116 |
117 | def step(self, epoch=None):
118 | if epoch is None:
119 | epoch = self.last_epoch + 1
120 |
121 | self.last_epoch = epoch
122 | self.model.apply(self.setter(self.lmbd(epoch)))
123 |
124 | def get_momentum(self, epoch=None):
125 | if epoch is None:
126 | epoch = self.last_epoch + 1
127 | return self.lmbd(epoch)
128 |
129 |
130 |
131 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
132 | '''
133 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
134 | '''
135 | _,n,c = xyz.shape
136 |
137 | assert n == num_points
138 | assert c == 3
139 | if crop == num_points:
140 | return xyz, None
141 |
142 | INPUT = []
143 | CROP = []
144 | for points in xyz:
145 | if isinstance(crop,list):
146 | num_crop = random.randint(crop[0],crop[1])
147 | else:
148 | num_crop = crop
149 |
150 | points = points.unsqueeze(0)
151 |
152 | if fixed_points is None:
153 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
154 | else:
155 | if isinstance(fixed_points,list):
156 | fixed_point = random.sample(fixed_points,1)[0]
157 | else:
158 | fixed_point = fixed_points
159 | center = fixed_point.reshape(1,1,3).cuda()
160 |
161 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
162 |
163 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
164 |
165 | if padding_zeros:
166 | input_data = points.clone()
167 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
168 |
169 | else:
170 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
171 |
172 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
173 |
174 | if isinstance(crop,list):
175 | INPUT.append(fps(input_data,2048))
176 | CROP.append(fps(crop_data,2048))
177 | else:
178 | INPUT.append(input_data)
179 | CROP.append(crop_data)
180 |
181 | input_data = torch.cat(INPUT,dim=0)# B N 3
182 | crop_data = torch.cat(CROP,dim=0)# B M 3
183 |
184 | return input_data.contiguous(), crop_data.contiguous()
185 |
186 | def get_ptcloud_img(ptcloud):
187 | fig = plt.figure(figsize=(8, 8))
188 |
189 | x, z, y = ptcloud.transpose(1, 0)
190 | ax = fig.gca(projection=Axes3D.name, adjustable='box')
191 | ax.axis('off')
192 | # ax.axis('scaled')
193 | ax.view_init(30, 45)
194 | max, min = np.max(ptcloud), np.min(ptcloud)
195 | ax.set_xbound(min, max)
196 | ax.set_ybound(min, max)
197 | ax.set_zbound(min, max)
198 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet')
199 |
200 | fig.canvas.draw()
201 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
202 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
203 | return img
204 |
205 |
206 |
207 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
208 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
209 | fig = plt.figure(figsize=(6*len(data_list),6))
210 | cmax = data_list[-1][:,0].max()
211 |
212 | for i in range(len(data_list)):
213 | data = data_list[i][:-2048] if i == 1 else data_list[i]
214 | color = data[:,0] /cmax
215 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
216 | ax.view_init(30, -120)
217 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
218 | ax.set_title(titles[i])
219 |
220 | ax.set_axis_off()
221 | ax.set_xlim(xlim)
222 | ax.set_ylim(ylim)
223 | ax.set_zlim(zlim)
224 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
225 | if not os.path.exists(path):
226 | os.makedirs(path)
227 |
228 | pic_path = path + '.png'
229 | fig.savefig(pic_path)
230 |
231 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
232 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
233 | plt.close(fig)
234 |
235 |
236 | def random_dropping(pc, e):
237 | up_num = max(64, 768 // (e//50 + 1))
238 | pc = pc
239 | random_num = torch.randint(1, up_num, (1,1))[0,0]
240 | pc = fps(pc, random_num)
241 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
242 | pc = torch.cat([pc, padding], dim = 1)
243 | return pc
244 |
245 |
246 | def random_scale(partial, scale_range=[0.8, 1.2]):
247 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
248 | return partial * scale
249 |
--------------------------------------------------------------------------------
/n2m/utils/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from pathlib import Path
4 |
5 | def get_args():
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | '--config',
9 | type = str,
10 | help = 'yaml config file')
11 | parser.add_argument(
12 | '--launcher',
13 | choices=['none', 'pytorch'],
14 | default='none',
15 | help='job launcher')
16 | parser.add_argument('--local_rank', type=int, default=0)
17 | parser.add_argument('--num_workers', type=int, default=4)
18 | # seed
19 | parser.add_argument('--seed', type=int, default=0, help='random seed')
20 | parser.add_argument(
21 | '--deterministic',
22 | action='store_true',
23 | help='whether to set deterministic options for CUDNN backend.')
24 | # bn
25 | parser.add_argument(
26 | '--sync_bn',
27 | action='store_true',
28 | default=False,
29 | help='whether to use sync bn')
30 | # some args
31 | parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name')
32 | parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')
33 | parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path')
34 | parser.add_argument('--val_freq', type = int, default=1, help = 'test freq')
35 | parser.add_argument(
36 | '--resume',
37 | action='store_true',
38 | default=False,
39 | help = 'autoresume training (interrupted by accident)')
40 | parser.add_argument(
41 | '--test',
42 | action='store_true',
43 | default=False,
44 | help = 'test mode for certain ckpt')
45 | parser.add_argument(
46 | '--finetune_model',
47 | action='store_true',
48 | default=False,
49 | help = 'finetune modelnet with pretrained weight')
50 | parser.add_argument(
51 | '--scratch_model',
52 | action='store_true',
53 | default=False,
54 | help = 'training modelnet from scratch')
55 | parser.add_argument(
56 | '--label_smoothing',
57 | action='store_true',
58 | default=False,
59 | help = 'use label smoothing loss trick')
60 | parser.add_argument(
61 | '--mode',
62 | choices=['easy', 'median', 'hard', None],
63 | default=None,
64 | help = 'difficulty mode for shapenet')
65 | parser.add_argument(
66 | '--way', type=int, default=-1)
67 | parser.add_argument(
68 | '--shot', type=int, default=-1)
69 | parser.add_argument(
70 | '--fold', type=int, default=-1)
71 |
72 | args = parser.parse_args()
73 |
74 | if args.test and args.resume:
75 | raise ValueError(
76 | '--test and --resume cannot be both activate')
77 |
78 | if args.resume and args.start_ckpts is not None:
79 | raise ValueError(
80 | '--resume and --start_ckpts cannot be both activate')
81 |
82 | if args.test and args.ckpts is None:
83 | raise ValueError(
84 | 'ckpts shouldnt be None while test mode')
85 |
86 | if args.finetune_model and args.ckpts is None:
87 | raise ValueError(
88 | 'ckpts shouldnt be None while finetune_model mode')
89 |
90 | if 'LOCAL_RANK' not in os.environ:
91 | os.environ['LOCAL_RANK'] = str(args.local_rank)
92 |
93 | if args.test:
94 | args.exp_name = 'test_' + args.exp_name
95 | if args.mode is not None:
96 | args.exp_name = args.exp_name + '_' +args.mode
97 | args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)
98 | args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)
99 | args.log_name = Path(args.config).stem
100 | create_experiment_dir(args)
101 | return args
102 |
103 | def create_experiment_dir(args):
104 | if not os.path.exists(args.experiment_path):
105 | os.makedirs(args.experiment_path)
106 | print('Create experiment path successfully at %s' % args.experiment_path)
107 | if not os.path.exists(args.tfboard_path):
108 | os.makedirs(args.tfboard_path)
109 | print('Create TFBoard path successfully at %s' % args.tfboard_path)
110 |
111 |
--------------------------------------------------------------------------------
/n2m/utils/point_cloud.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def rotate_points_xy(points, angle):
4 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle), 0],
5 | [np.sin(angle), np.cos(angle), 0],
6 | [0, 0, 1]])
7 | return points @ rotation_matrix
8 |
9 | def rotate_points_se2(points, angle):
10 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],
11 | [np.sin(angle), np.cos(angle)]])
12 | points[:2] = points[:2] @ rotation_matrix
13 | points[2] = points[2] - angle
14 | return points
15 |
16 | def rotate_points_xythetaz(points, angle):
17 | rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],
18 | [np.sin(angle), np.cos(angle)]])
19 | points[:2] = points[:2] @ rotation_matrix
20 | points[2] = points[2] - angle
21 | return points
22 |
23 | def translate_points_xy(points, radius, angle):
24 | translation_matrix = radius * np.array([np.cos(angle), np.sin(angle), 0])
25 | return points + translation_matrix
26 |
27 | def translate_points_xythetaz(points, radius, angle):
28 | translation_matrix = radius * np.array([np.cos(angle), np.sin(angle), 0, 0])
29 | return points + translation_matrix
30 |
31 | def apply_rotation_se2(point_cloud, target_point, rotation):
32 | min_angle = rotation['min_angle'] * np.pi / 180
33 | max_angle = rotation['max_angle'] * np.pi / 180
34 |
35 | angle = np.random.uniform(min_angle, max_angle)
36 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle)
37 | target_point = rotate_points_se2(target_point, angle)
38 | return point_cloud, target_point
39 |
40 | def apply_rotation_xy(point_cloud, target_point, rotation):
41 | min_angle = rotation['min_angle'] * np.pi / 180
42 | max_angle = rotation['max_angle'] * np.pi / 180
43 |
44 | angle = np.random.uniform(min_angle, max_angle)
45 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle)
46 | target_point = rotate_points_xy(target_point, angle)
47 | return point_cloud, target_point
48 |
49 | def apply_translation_xy(point_cloud, target_point, translation):
50 | radius = np.random.uniform(0, translation['radius'])
51 | angle = np.random.uniform(0, 2 * np.pi)
52 | point_cloud[:, :3] = translate_points_xy(point_cloud[:, :3], radius, angle)
53 | target_point = translate_points_xy(target_point, radius, angle)
54 | return point_cloud, target_point
55 |
56 | def apply_rotation_xythetaz(point_cloud, target_point, rotation):
57 | min_angle = rotation['min_angle'] * np.pi / 180
58 | max_angle = rotation['max_angle'] * np.pi / 180
59 |
60 | angle = np.random.uniform(min_angle, max_angle)
61 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], angle)
62 | target_point = rotate_points_xythetaz(target_point, angle)
63 | return point_cloud, target_point
64 |
65 | def apply_translation_xythetaz(point_cloud, target_point, translation):
66 | radius = np.random.uniform(0, translation['radius'])
67 | angle = np.random.uniform(0, 2 * np.pi)
68 | point_cloud[:, :3] = translate_points_xy(point_cloud[:, :3], radius, angle)
69 | target_point = translate_points_xythetaz(target_point, radius, angle)
70 | return point_cloud, target_point
71 |
72 | def apply_augmentations(point_cloud, target_point, augmentations):
73 | if 'rotation_xy' in augmentations:
74 | point_cloud, target_point = apply_rotation_xy(point_cloud, target_point, augmentations['rotation_xy'])
75 | if 'rotation_se2' in augmentations:
76 | point_cloud, target_point = apply_rotation_se2(point_cloud, target_point, augmentations['rotation_se2'])
77 | if 'rotation_xythetaz' in augmentations:
78 | point_cloud, target_point = apply_rotation_xythetaz(point_cloud, target_point, augmentations['rotation_xythetaz'])
79 | if 'translation_xy' in augmentations:
80 | point_cloud, target_point = apply_translation_xy(point_cloud, target_point, augmentations['translation_xy'])
81 | if 'translation_xythetaz' in augmentations:
82 | point_cloud, target_point = apply_translation_xythetaz(point_cloud, target_point, augmentations['translation_xythetaz'])
83 | return point_cloud, target_point
84 |
85 | def fix_point_cloud_size(point_cloud, pointnum):
86 | if point_cloud.shape[0] > pointnum:
87 | indices = np.random.choice(point_cloud.shape[0], pointnum, replace=False)
88 | point_cloud = point_cloud[indices]
89 | elif point_cloud.shape[0] < pointnum:
90 | padding = np.zeros((pointnum - point_cloud.shape[0], point_cloud.shape[1]))
91 | point_cloud = np.vstack([point_cloud, padding])
92 | return point_cloud
93 |
94 | def translate_se2_point_cloud(point_cloud, se2_transform):
95 | point_cloud[:, :3] = point_cloud[:, :3] + np.array([se2_transform[0], se2_transform[1], 0])
96 | point_cloud[:, :3] = rotate_points_xy(point_cloud[:, :3], se2_transform[2])
97 | return point_cloud
98 |
99 | def translate_se2_target(target_point, se2_transform):
100 | target_point = rotate_points_se2(target_point, se2_transform[2])
101 | target_point = target_point + np.array([se2_transform[0], se2_transform[1], 0])
102 | return target_point
--------------------------------------------------------------------------------
/n2m/utils/prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from torch.distributions import MultivariateNormal
5 |
6 | from n2m.utils.point_cloud import fix_point_cloud_size
7 | from n2m.utils.visualizer import save_gmm_visualization_se2
8 | from n2m.utils.point_cloud import translate_se2_point_cloud, translate_se2_target
9 |
10 | def predict_SIR_target_point(
11 | SIR_predictor,
12 | SIR_config,
13 | pc_numpy,
14 | target_helper,
15 | SIR_sample_num,
16 | robot_centric,
17 | abs_base_se2,
18 | task_name,
19 | max_trial=100,
20 | save_dir=None,
21 | id=None,
22 | ):
23 | pcl_input = fix_point_cloud_size(pc_numpy, SIR_config['dataset']['pointnum'])
24 | if robot_centric:
25 | print ("abs_base_se2", abs_base_se2)
26 | pcl_input = translate_se2_point_cloud(pcl_input, [-abs_base_se2[0], -abs_base_se2[1], abs_base_se2[2]])
27 | pcl_input = torch.from_numpy(pcl_input)[None, ...].cuda().float()
28 |
29 | task_idx = None
30 | if 'settings' in SIR_config['dataset']:
31 | task_idx = torch.tensor([SIR_config['dataset']['settings'].index(task_name)]).cuda()
32 |
33 | target_SIR_prediction = None
34 | SIR_predictor.eval()
35 | with torch.no_grad():
36 | finished = False
37 | trial_count = 0
38 | while not finished:
39 | if trial_count > max_trial:
40 | print(f"Failed to find a valid target after {max_trial} trials")
41 | break
42 |
43 | samples, means, covs, weights = SIR_predictor.sample(pcl_input, task_idx, SIR_sample_num)
44 |
45 | num_modes = means.shape[1]
46 | mvns = [MultivariateNormal(means[0, i], covs[0, i]) for i in range(num_modes)]
47 | log_probs = torch.stack([mvn.log_prob(samples[0]) for mvn in mvns]) # shape: [num_modes, SIR_sample_num]
48 | log_weights = torch.log(weights[0] + 1e-8).unsqueeze(1) # shape: [num_modes, 1]
49 | weighted_log_probs = log_probs + log_weights
50 | gaussian_probabilities = torch.logsumexp(weighted_log_probs, dim=0) # shape: [SIR_sample_num]
51 | sorted_indices = torch.argsort(gaussian_probabilities, descending=True)
52 | samples = samples[0, sorted_indices]
53 |
54 | # convert to numpy
55 | means = means[0].cpu().numpy()
56 | covs = covs[0].cpu().numpy()
57 | weights = weights[0].cpu().numpy()
58 | samples = samples.cpu().numpy()
59 |
60 | # First set target_SIR_prediction to the mean with highest probability
61 | target_SIR_prediction = means[0]
62 | max_prob = 0
63 | for i in range(num_modes):
64 | # Calculate probability density at the mean using the covariance and weight
65 | mvn = MultivariateNormal(torch.from_numpy(means[i]), torch.from_numpy(covs[i]))
66 | prob = mvn.log_prob(torch.from_numpy(means[i])).exp().item() * weights[i]
67 | if prob > max_prob:
68 | max_prob = prob
69 | target_SIR_prediction = means[i]
70 |
71 | if save_dir is not None:
72 | prediction_folder = os.path.join(save_dir, "prediction")
73 | os.makedirs(prediction_folder, exist_ok=True)
74 | save_gmm_visualization_se2(
75 | pcl_input[0].cpu().numpy(),
76 | target_SIR_prediction,
77 | 1,
78 | means,
79 | covs,
80 | weights,
81 | os.path.join(prediction_folder, f"{id}.ply"),
82 | )
83 |
84 | if robot_centric:
85 | global_target_SIR_prediction = translate_se2_target(target_SIR_prediction.copy(), [abs_base_se2[0], abs_base_se2[1], -abs_base_se2[2]])
86 | else:
87 | global_target_SIR_prediction = target_SIR_prediction
88 |
89 | if not target_helper.check_collision(global_target_SIR_prediction): # target_helper is based on the relative position
90 | print("target_SIR_prediction is collided with the furniture, try again")
91 | elif not target_helper.check_boundary(global_target_SIR_prediction): # target_helper is based on the relative position
92 | print("target_SIR_prediction is out of the boundary, try again")
93 | else:
94 | print("target_SIR_prediction is valid, using mean with highest probability")
95 | finished = True
96 | break
97 |
98 |
99 | # If the mean with highest probability is not valid, try with samples
100 | for i in range(SIR_sample_num):
101 | target_SIR_prediction = samples[i]
102 |
103 | if save_dir is not None:
104 | prediction_folder = os.path.join(save_dir, "prediction")
105 | os.makedirs(prediction_folder, exist_ok=True)
106 | save_gmm_visualization_se2(
107 | pcl_input[0].cpu().numpy(),
108 | target_SIR_prediction,
109 | 1,
110 | means,
111 | covs,
112 | weights,
113 | os.path.join(prediction_folder, f"{id}.ply"),
114 | )
115 |
116 | if robot_centric:
117 | global_target_SIR_prediction = translate_se2_target(target_SIR_prediction.copy(), [abs_base_se2[0], abs_base_se2[1], -abs_base_se2[2]])
118 | else:
119 | global_target_SIR_prediction = target_SIR_prediction
120 |
121 | if not target_helper.check_collision(global_target_SIR_prediction): # target_helper is based on the relative position
122 | print("target_SIR_prediction is collided with the furniture, try again")
123 | elif not target_helper.check_boundary(global_target_SIR_prediction): # target_helper is based on the relative position
124 | print("target_SIR_prediction is out of the boundary, try again")
125 | else:
126 | print("target_SIR_prediction is valid, using sample")
127 | finished = True
128 | break
129 |
130 | trial_count += 1
131 |
132 | return global_target_SIR_prediction, target_SIR_prediction, means, covs, weights, pcl_input[0].cpu().numpy(), finished
--------------------------------------------------------------------------------
/n2m/utils/registry.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from functools import partial
4 | from utils import config
5 |
6 | class Registry:
7 | """A registry to map strings to classes.
8 | Registered object could be built from registry.
9 | Example:
10 | >>> MODELS = Registry('models')
11 | >>> @MODELS.register_module()
12 | >>> class ResNet:
13 | >>> pass
14 | >>> resnet = MODELS.build(dict(NAME='ResNet'))
15 | Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
16 | advanced useage.
17 | Args:
18 | name (str): Registry name.
19 | build_func(func, optional): Build function to construct instance from
20 | Registry, func:`build_from_cfg` is used if neither ``parent`` or
21 | ``build_func`` is specified. If ``parent`` is specified and
22 | ``build_func`` is not given, ``build_func`` will be inherited
23 | from ``parent``. Default: None.
24 | parent (Registry, optional): Parent registry. The class registered in
25 | children registry could be built from parent. Default: None.
26 | scope (str, optional): The scope of registry. It is the key to search
27 | for children registry. If not specified, scope will be the name of
28 | the package where class is defined, e.g. mmdet, mmcls, mmseg.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, name, build_func=None, parent=None, scope=None):
33 | self._name = name
34 | self._module_dict = dict()
35 | self._children = dict()
36 | self._scope = self.infer_scope() if scope is None else scope
37 |
38 | # self.build_func will be set with the following priority:
39 | # 1. build_func
40 | # 2. parent.build_func
41 | # 3. build_from_cfg
42 | if build_func is None:
43 | if parent is not None:
44 | self.build_func = parent.build_func
45 | else:
46 | self.build_func = build_from_cfg
47 | else:
48 | self.build_func = build_func
49 | if parent is not None:
50 | assert isinstance(parent, Registry)
51 | parent._add_children(self)
52 | self.parent = parent
53 | else:
54 | self.parent = None
55 |
56 | def __len__(self):
57 | return len(self._module_dict)
58 |
59 | def __contains__(self, key):
60 | return self.get(key) is not None
61 |
62 | def __repr__(self):
63 | format_str = self.__class__.__name__ + \
64 | f'(name={self._name}, ' \
65 | f'items={self._module_dict})'
66 | return format_str
67 |
68 | @staticmethod
69 | def infer_scope():
70 | """Infer the scope of registry.
71 | The name of the package where registry is defined will be returned.
72 | Example:
73 | # in mmdet/models/backbone/resnet.py
74 | >>> MODELS = Registry('models')
75 | >>> @MODELS.register_module()
76 | >>> class ResNet:
77 | >>> pass
78 | The scope of ``ResNet`` will be ``mmdet``.
79 | Returns:
80 | scope (str): The inferred scope name.
81 | """
82 | # inspect.stack() trace where this function is called, the index-2
83 | # indicates the frame where `infer_scope()` is called
84 | filename = inspect.getmodule(inspect.stack()[2][0]).__name__
85 | split_filename = filename.split('.')
86 | return split_filename[0]
87 |
88 | @staticmethod
89 | def split_scope_key(key):
90 | """Split scope and key.
91 | The first scope will be split from key.
92 | Examples:
93 | >>> Registry.split_scope_key('mmdet.ResNet')
94 | 'mmdet', 'ResNet'
95 | >>> Registry.split_scope_key('ResNet')
96 | None, 'ResNet'
97 | Return:
98 | scope (str, None): The first scope.
99 | key (str): The remaining key.
100 | """
101 | split_index = key.find('.')
102 | if split_index != -1:
103 | return key[:split_index], key[split_index + 1:]
104 | else:
105 | return None, key
106 |
107 | @property
108 | def name(self):
109 | return self._name
110 |
111 | @property
112 | def scope(self):
113 | return self._scope
114 |
115 | @property
116 | def module_dict(self):
117 | return self._module_dict
118 |
119 | @property
120 | def children(self):
121 | return self._children
122 |
123 | def get(self, key):
124 | """Get the registry record.
125 | Args:
126 | key (str): The class name in string format.
127 | Returns:
128 | class: The corresponding class.
129 | """
130 | scope, real_key = self.split_scope_key(key)
131 | if scope is None or scope == self._scope:
132 | # get from self
133 | if real_key in self._module_dict:
134 | return self._module_dict[real_key]
135 | else:
136 | # get from self._children
137 | if scope in self._children:
138 | return self._children[scope].get(real_key)
139 | else:
140 | # goto root
141 | parent = self.parent
142 | while parent.parent is not None:
143 | parent = parent.parent
144 | return parent.get(key)
145 |
146 | def build(self, *args, **kwargs):
147 | return self.build_func(*args, **kwargs, registry=self)
148 |
149 | def _add_children(self, registry):
150 | """Add children for a registry.
151 | The ``registry`` will be added as children based on its scope.
152 | The parent registry could build objects from children registry.
153 | Example:
154 | >>> models = Registry('models')
155 | >>> mmdet_models = Registry('models', parent=models)
156 | >>> @mmdet_models.register_module()
157 | >>> class ResNet:
158 | >>> pass
159 | >>> resnet = models.build(dict(NAME='mmdet.ResNet'))
160 | """
161 |
162 | assert isinstance(registry, Registry)
163 | assert registry.scope is not None
164 | assert registry.scope not in self.children, \
165 | f'scope {registry.scope} exists in {self.name} registry'
166 | self.children[registry.scope] = registry
167 |
168 | def _register_module(self, module_class, module_name=None, force=False):
169 | if not inspect.isclass(module_class):
170 | raise TypeError('module must be a class, '
171 | f'but got {type(module_class)}')
172 |
173 | if module_name is None:
174 | module_name = module_class.__name__
175 | if isinstance(module_name, str):
176 | module_name = [module_name]
177 | for name in module_name:
178 | if not force and name in self._module_dict:
179 | raise KeyError(f'{name} is already registered '
180 | f'in {self.name}')
181 | self._module_dict[name] = module_class
182 |
183 | def deprecated_register_module(self, cls=None, force=False):
184 | warnings.warn(
185 | 'The old API of register_module(module, force=False) '
186 | 'is deprecated and will be removed, please use the new API '
187 | 'register_module(name=None, force=False, module=None) instead.')
188 | if cls is None:
189 | return partial(self.deprecated_register_module, force=force)
190 | self._register_module(cls, force=force)
191 | return cls
192 |
193 | def register_module(self, name=None, force=False, module=None):
194 | """Register a module.
195 | A record will be added to `self._module_dict`, whose key is the class
196 | name or the specified name, and value is the class itself.
197 | It can be used as a decorator or a normal function.
198 | Example:
199 | >>> backbones = Registry('backbone')
200 | >>> @backbones.register_module()
201 | >>> class ResNet:
202 | >>> pass
203 | >>> backbones = Registry('backbone')
204 | >>> @backbones.register_module(name='mnet')
205 | >>> class MobileNet:
206 | >>> pass
207 | >>> backbones = Registry('backbone')
208 | >>> class ResNet:
209 | >>> pass
210 | >>> backbones.register_module(ResNet)
211 | Args:
212 | name (str | None): The module name to be registered. If not
213 | specified, the class name will be used.
214 | force (bool, optional): Whether to override an existing class with
215 | the same name. Default: False.
216 | module (type): Module class to be registered.
217 | """
218 | if not isinstance(force, bool):
219 | raise TypeError(f'force must be a boolean, but got {type(force)}')
220 | # NOTE: This is a walkaround to be compatible with the old api,
221 | # while it may introduce unexpected bugs.
222 | if isinstance(name, type):
223 | return self.deprecated_register_module(name, force=force)
224 |
225 | # raise the error ahead of time
226 | if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
227 | raise TypeError(
228 | 'name must be either of None, an instance of str or a sequence'
229 | f' of str, but got {type(name)}')
230 |
231 | # use it as a normal method: x.register_module(module=SomeClass)
232 | if module is not None:
233 | self._register_module(
234 | module_class=module, module_name=name, force=force)
235 | return module
236 |
237 | # use it as a decorator: @x.register_module()
238 | def _register(cls):
239 | self._register_module(
240 | module_class=cls, module_name=name, force=force)
241 | return cls
242 |
243 | return _register
244 |
245 |
246 | def build_from_cfg(cfg, registry, default_args=None):
247 | """Build a module from config dict.
248 | Args:
249 | cfg (edict): Config dict. It should at least contain the key "NAME".
250 | registry (:obj:`Registry`): The registry to search the type from.
251 | Returns:
252 | object: The constructed object.
253 | """
254 | if not isinstance(cfg, dict):
255 | raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
256 | if 'NAME' not in cfg:
257 | if default_args is None or 'NAME' not in default_args:
258 | raise KeyError(
259 | '`cfg` or `default_args` must contain the key "NAME", '
260 | f'but got {cfg}\n{default_args}')
261 | if not isinstance(registry, Registry):
262 | raise TypeError('registry must be an mmcv.Registry object, '
263 | f'but got {type(registry)}')
264 |
265 | if not (isinstance(default_args, dict) or default_args is None):
266 | raise TypeError('default_args must be a dict or None, '
267 | f'but got {type(default_args)}')
268 |
269 | if default_args is not None:
270 | cfg = config.merge_new_config(cfg, default_args)
271 |
272 | obj_type = cfg.get('NAME')
273 |
274 | if isinstance(obj_type, str):
275 | obj_cls = registry.get(obj_type)
276 | if obj_cls is None:
277 | raise KeyError(
278 | f'{obj_type} is not in the {registry.name} registry')
279 | elif inspect.isclass(obj_type):
280 | obj_cls = obj_type
281 | else:
282 | raise TypeError(
283 | f'type must be a str or valid type, but got {type(obj_type)}')
284 | try:
285 | return obj_cls(cfg)
286 | except Exception as e:
287 | # Normal TypeError does not print class name.
288 | raise type(e)(f'{obj_cls.__name__}: {e}')
289 |
--------------------------------------------------------------------------------
/n2m/utils/sample_utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotlib.patches as patches
3 | import numpy as np
4 | import matplotlib.transforms as transforms
5 | import open3d as o3d
6 |
7 |
8 | class CollisionChecker:
9 | """Class for collision checking."""
10 | def __init__(
11 | self,
12 | config,
13 | ):
14 | self.filter_noise = config.get("filter_noise", True) # Whether to filter noise from point cloud
15 | self.ground_z = config.get("ground_z", 0.05) # Points below this z-value are considered ground and ignored
16 | self.resolution = config.get("resolution", 0.02) # Grid resolution in meters
17 | self.robot_width = config.get("robot_width", 0.5) # Robot width in meters
18 | self.robot_length = config.get("robot_length", 0.63) # Robot length in meters
19 |
20 | def filter_point_cloud(self, pcd):
21 | """
22 | Filter noisy points from the point cloud using Open3D filtering methods.
23 |
24 | Args:
25 | pcd: Open3D point cloud object
26 |
27 | Returns:
28 | Filtered Open3D point cloud object
29 | """
30 | try:
31 | print(f"Original point cloud has {len(pcd.points)} points")
32 |
33 | # 1. Voxel downsampling to reduce noise and density
34 | voxel_size = 0.02 # 2cm voxel size
35 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size)
36 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points")
37 |
38 | # 2. Remove statistical outliers
39 | # This removes points that are too far from their neighbors
40 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier(
41 | nb_neighbors=20, # Number of neighbors to analyze
42 | std_ratio=2.0 # Standard deviation ratio threshold
43 | )
44 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points")
45 |
46 | # 3. Remove radius outliers (optional, can be commented out if too aggressive)
47 | # This removes points that have too few neighbors within a radius
48 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier(
49 | nb_points=16, # Minimum number of points within radius
50 | radius=0.05 # Radius to search for neighbors (5cm)
51 | )
52 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points")
53 |
54 | return pcd_cleaned
55 |
56 | except Exception as e:
57 | print(f"Error during point cloud filtering: {e}")
58 | print("Returning original point cloud without filtering")
59 | return pcd
60 |
61 | def set_occupancy_grid(self):
62 | """Create occupancy grid from point cloud, filtering out ground points below ground_z"""
63 | # Filter out ground points
64 | non_ground_points = self.pcd_points[self.pcd_points[:, 2] >= self.ground_z]
65 |
66 | # Get min and max coordinates
67 | min_coords = np.min(non_ground_points, axis=0)
68 | max_coords = np.max(non_ground_points, axis=0)
69 |
70 | # Calculate grid dimensions
71 | width = int(np.ceil((max_coords[0] - min_coords[0]) / self.resolution))
72 | height = int(np.ceil((max_coords[1] - min_coords[1]) / self.resolution))
73 |
74 | # Initialize occupancy grid
75 | occupancy_grid = np.zeros((height, width))
76 |
77 | # Project points to grid
78 | for point in non_ground_points:
79 | x_idx = int((point[0] - min_coords[0]) / self.resolution)
80 | y_idx = int((point[1] - min_coords[1]) / self.resolution)
81 |
82 | if 0 <= x_idx < width and 0 <= y_idx < height:
83 | occupancy_grid[y_idx, x_idx] = 1
84 |
85 | self.occupancy_grid = occupancy_grid
86 | self.min_coords = min_coords
87 | self.max_coords = max_coords
88 |
89 | def set_pcd(self, pcd):
90 | if self.filter_noise:
91 | self.pcd = self.filter_point_cloud(pcd)
92 | else:
93 | self.pcd = pcd
94 |
95 | self.pcd_points = pcd.points
96 | self.pcd_colors = pcd.colors
97 |
98 | print("Generating occupancy grid map")
99 | self.set_occupancy_grid()
100 |
101 | def check_collision(self, pose):
102 | """Checks collision in the XY-plane for given SE(2) or SE(2)+z pose."""
103 | if len(pose) == 3:
104 | x, y, theta = pose
105 | elif len(pose) == 4:
106 | x, y, theta, _ = pose
107 | else:
108 | raise ValueError("Pose must be length 3 or 4")
109 |
110 | if self.occupancy_grid is None or self.min_coords is None or self.max_coords is None:
111 | raise ValueError("Occupancy grid not initialized. Call set_pcd() first.")
112 |
113 | # Get rectangle corners in world frame
114 | corners = np.array([
115 | [-self.robot_length*0.83, -self.robot_width/2], # Bottom left
116 | [self.robot_length*0.17, -self.robot_width/2], # Bottom right
117 | [self.robot_length*0.17, self.robot_width/2], # Top right
118 | [-self.robot_length*0.83, self.robot_width/2] # Top left
119 | ])
120 | # Rotate corners
121 | rot_matrix = np.array([
122 | [np.cos(theta), -np.sin(theta)],
123 | [np.sin(theta), np.cos(theta)]
124 | ])
125 | rotated_corners = (rot_matrix @ corners.T).T
126 |
127 | # Translate corners
128 | world_corners = rotated_corners + np.array([x, y])
129 |
130 | # Get min and max coordinates of the rectangle
131 | rect_min = np.min(world_corners, axis=0)
132 | rect_max = np.max(world_corners, axis=0)
133 |
134 | # Convert to grid indices
135 | min_x_idx = int((rect_min[0] - self.min_coords[0]) / self.resolution)
136 | min_y_idx = int((rect_min[1] - self.min_coords[1]) / self.resolution)
137 | max_x_idx = int((rect_max[0] - self.min_coords[0]) / self.resolution) + 1
138 | max_y_idx = int((rect_max[1] - self.min_coords[1]) / self.resolution) + 1
139 |
140 | # Ensure indices are within grid bounds
141 | min_x_idx = max(0, min_x_idx)
142 | min_y_idx = max(0, min_y_idx)
143 | max_x_idx = min(self.occupancy_grid.shape[1], max_x_idx)
144 | max_y_idx = min(self.occupancy_grid.shape[0], max_y_idx)
145 |
146 | # Check each grid cell in the bounding box
147 | for y_idx in range(min_y_idx, max_y_idx):
148 | for x_idx in range(min_x_idx, max_x_idx):
149 | if self.occupancy_grid[y_idx, x_idx] == 1:
150 | # Get grid cell corners in world coordinates
151 | grid_min_x = self.min_coords[0] + x_idx * self.resolution
152 | grid_min_y = self.min_coords[1] + y_idx * self.resolution
153 | grid_max_x = grid_min_x + self.resolution
154 | grid_max_y = grid_min_y + self.resolution
155 |
156 | # Check if grid cell intersects with rectangle
157 | grid_corners = np.array([
158 | [grid_min_x, grid_min_y],
159 | [grid_max_x, grid_min_y],
160 | [grid_max_x, grid_max_y],
161 | [grid_min_x, grid_max_y]
162 | ])
163 | translated_corners = grid_corners - np.array([x, y])
164 | local_corners = (rot_matrix.T @ translated_corners.T).T
165 |
166 | rect_min_x = -self.robot_length*0.83
167 | rect_max_x = self.robot_length*0.17
168 | rect_min_y = -self.robot_width/2
169 | rect_max_y = self.robot_width/2
170 |
171 | for corner in local_corners:
172 | if (rect_min_x <= corner[0] <= rect_max_x and
173 | rect_min_y <= corner[1] <= rect_max_y):
174 | return False
175 | return True
176 |
177 |
178 | def get_target_helper_for_rollout_collection(inference_mode=False, all_pcd=None, se2_origin=None, vis=False, camera_intrinsic=None, filter_noise=True):
179 | if inference_mode:
180 | x_half_range = 0.5
181 | y_half_range = 0.5
182 | theta_half_range_deg = 30
183 | else:
184 | x_half_range = 0.2
185 | y_half_range = 0.2
186 | theta_half_range_deg = 15
187 |
188 | return TargetHelper(
189 | all_pcd,
190 | se2_origin,
191 | x_half_range,
192 | y_half_range,
193 | theta_half_range_deg,
194 | vis=vis,
195 | camera_intrinsic=camera_intrinsic,
196 | filter_noise=filter_noise
197 | )
198 |
199 | class TargetHelper:
200 | def __init__(self, pcd, origin_se2, x_half_range, y_half_range, theta_half_range_deg, vis=False, camera_intrinsic=None, filter_noise=True):
201 | self.resolution = 0.02
202 | self.width = 0.5
203 | self.length = 0.63
204 | self.ground_z = 0.05
205 | # self.T_base_cam = np.array([
206 | # [ 5.83445639e-03, 5.87238353e-01, -8.09393072e-01, 4.49474752e-02],
207 | # [-9.99982552e-01, 2.64855534e-03, -5.28670366e-03, -1.88012307e-02],
208 | # [-9.60832741e-04, 8.09409739e-01, 5.87243519e-01, 1.64493829e-00],
209 | # [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]) # this is just for DETECT mode and
210 | # self.T_base_cam = np.array(
211 | # [[-8.14932746e-02, -5.73355454e-01, 8.15243714e-01, 9.75590401e-04],
212 | # [-9.95079241e-01, 5.53813432e-04, -9.90804744e-02, -2.56451598e-02],
213 | # [ 5.63568407e-02, -8.19306535e-01, -5.70579275e-01, 1.64368076e-01],
214 | # [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]
215 | # )
216 | self.T_base_cam = np.array([
217 | [-8.25269110e-02, -5.73057816e-01, 8.15348841e-01, 6.05364230e-04],
218 | [-9.95784041e-01, 1.45464862e-02, -9.05661474e-02, -3.94417736e-02],
219 | [ 4.00391906e-02, -8.19385767e-01, -5.71842485e-01, 1.64310488e-00],
220 | [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]
221 | ])
222 | self.arm_length = 1
223 |
224 | # Filter noise from point cloud if requested
225 | if filter_noise:
226 | self.pcd = self.filter_point_cloud(pcd)
227 | else:
228 | self.pcd = pcd
229 |
230 | self.origin_se2 = origin_se2
231 | self.x_half_range = x_half_range
232 | self.y_half_range = y_half_range
233 | self.theta_half_range_deg = theta_half_range_deg
234 | self.to_rad = lambda x: x/180*np.pi
235 | self.pcd_np = np.asarray(self.pcd.points)
236 | self.pcd_np = np.concatenate([self.pcd_np, np.zeros((self.pcd_np.shape[0], 1))], axis=1)
237 | self.pcd_color_np = np.asarray(self.pcd.colors)
238 | self.pcd_max_value = np.max(self.pcd_np[:, :3], axis=0)
239 | self.pcd_min_value = np.min(self.pcd_np[:, :3], axis=0)
240 | self.vis = vis
241 |
242 | if camera_intrinsic is not None:
243 | self.cam_intrinsic = np.array(camera_intrinsic)
244 | else:
245 | self.cam_intrinsic = np.array([100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240])
246 | # Initialize occupancy grid in constructor
247 | self.get_occupancy_grid()
248 |
249 | def filter_point_cloud(self, pcd):
250 | """
251 | Filter noisy points from the point cloud using Open3D filtering methods.
252 |
253 | Args:
254 | pcd: Open3D point cloud object
255 |
256 | Returns:
257 | Filtered Open3D point cloud object
258 | """
259 | try:
260 | print(f"Original point cloud has {len(pcd.points)} points")
261 |
262 | # 1. Voxel downsampling to reduce noise and density
263 | voxel_size = 0.02 # 2cm voxel size
264 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size)
265 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points")
266 |
267 | # 2. Remove statistical outliers
268 | # This removes points that are too far from their neighbors
269 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier(
270 | nb_neighbors=20, # Number of neighbors to analyze
271 | std_ratio=2.0 # Standard deviation ratio threshold
272 | )
273 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points")
274 |
275 | # 3. Remove radius outliers (optional, can be commented out if too aggressive)
276 | # This removes points that have too few neighbors within a radius
277 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier(
278 | nb_points=16, # Minimum number of points within radius
279 | radius=0.05 # Radius to search for neighbors (5cm)
280 | )
281 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points")
282 |
283 | return pcd_cleaned
284 |
285 | except Exception as e:
286 | print(f"Error during point cloud filtering: {e}")
287 | print("Returning original point cloud without filtering")
288 | return pcd
289 |
290 | def filter_point_cloud_custom(self, pcd, voxel_size=0.02, nb_neighbors=20, std_ratio=2.0,
291 | nb_points=16, radius=0.05, use_radius_filter=True):
292 | """
293 | Filter noisy points from the point cloud with custom parameters.
294 |
295 | Args:
296 | pcd: Open3D point cloud object
297 | voxel_size: Size of voxels for downsampling (in meters)
298 | nb_neighbors: Number of neighbors for statistical outlier removal
299 | std_ratio: Standard deviation ratio threshold for statistical outlier removal
300 | nb_points: Minimum number of points within radius for radius outlier removal
301 | radius: Radius to search for neighbors (in meters)
302 | use_radius_filter: Whether to apply radius outlier removal
303 |
304 | Returns:
305 | Filtered Open3D point cloud object
306 | """
307 | try:
308 | print(f"Original point cloud has {len(pcd.points)} points")
309 |
310 | # 1. Voxel downsampling
311 | pcd_downsampled = pcd.voxel_down_sample(voxel_size=voxel_size)
312 | print(f"After voxel downsampling: {len(pcd_downsampled.points)} points")
313 |
314 | # 2. Remove statistical outliers
315 | pcd_cleaned, _ = pcd_downsampled.remove_statistical_outlier(
316 | nb_neighbors=nb_neighbors,
317 | std_ratio=std_ratio
318 | )
319 | print(f"After statistical outlier removal: {len(pcd_cleaned.points)} points")
320 |
321 | # 3. Remove radius outliers (optional)
322 | if use_radius_filter:
323 | pcd_cleaned, _ = pcd_cleaned.remove_radius_outlier(
324 | nb_points=nb_points,
325 | radius=radius
326 | )
327 | print(f"After radius outlier removal: {len(pcd_cleaned.points)} points")
328 |
329 | return pcd_cleaned
330 |
331 | except Exception as e:
332 | print(f"Error during point cloud filtering: {e}")
333 | print("Returning original point cloud without filtering")
334 | return pcd
335 |
336 | def visualize_filtering_comparison(self, original_pcd):
337 | """
338 | Visualize the comparison between original and filtered point clouds.
339 |
340 | Args:
341 | original_pcd: Original Open3D point cloud object
342 | """
343 | try:
344 | # Create visualization window
345 | vis = o3d.visualization.Visualizer()
346 | vis.create_window(window_name="Point Cloud Filtering Comparison", width=1200, height=800)
347 |
348 | # Add original point cloud (white)
349 | original_pcd.paint_uniform_color([1, 1, 1]) # White
350 | vis.add_geometry(original_pcd)
351 |
352 | # Add filtered point cloud (red)
353 | filtered_pcd = self.pcd
354 | filtered_pcd.paint_uniform_color([1, 0, 0]) # Red
355 | vis.add_geometry(filtered_pcd)
356 |
357 | # Add coordinate frame
358 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5)
359 | vis.add_geometry(coord_frame)
360 |
361 | # Set view options
362 | opt = vis.get_render_option()
363 | opt.background_color = np.asarray([0, 0, 0]) # Black background
364 | opt.point_size = 2.0
365 |
366 | print("Visualizing point cloud filtering comparison...")
367 | print("White points: Original point cloud")
368 | print("Red points: Filtered point cloud")
369 | print("Press 'Q' to close the visualization")
370 |
371 | # Run visualization
372 | vis.run()
373 | vis.destroy_window()
374 |
375 | except Exception as e:
376 | print(f"Error during visualization: {e}")
377 |
378 | def get_occupancy_grid(self):
379 | """Create occupancy grid from point cloud, filtering out ground points below ground_z"""
380 | # Filter out ground points
381 | non_ground_points = self.pcd_np[self.pcd_np[:, 2] >= self.ground_z]
382 |
383 | # Get min and max coordinates
384 | min_coords = np.min(non_ground_points, axis=0)
385 | max_coords = np.max(non_ground_points, axis=0)
386 |
387 | # Calculate grid dimensions
388 | width = int(np.ceil((max_coords[0] - min_coords[0]) / self.resolution))
389 | height = int(np.ceil((max_coords[1] - min_coords[1]) / self.resolution))
390 |
391 | # Initialize occupancy grid
392 | occupancy_grid = np.zeros((height, width))
393 |
394 | # Project points to grid
395 | for point in non_ground_points:
396 | x_idx = int((point[0] - min_coords[0]) / self.resolution)
397 | y_idx = int((point[1] - min_coords[1]) / self.resolution)
398 |
399 | if 0 <= x_idx < width and 0 <= y_idx < height:
400 | occupancy_grid[y_idx, x_idx] = 1
401 |
402 | self.occupancy_grid = occupancy_grid
403 | self.min_coords = min_coords
404 | self.max_coords = max_coords
405 |
406 | def check_boundary(self, se2_pose):
407 | """Check if the target is within the boundary"""
408 | x, y, theta = se2_pose
409 | if x < self.pcd_min_value[0] or x > self.pcd_max_value[0] or y < self.pcd_min_value[1] or y > self.pcd_max_value[1]:
410 | return False
411 | return True
412 |
413 | def check_collision(self, se2_pose):
414 | """Check if all grids intersected by the rectangle are empty"""
415 | x, y, theta = se2_pose
416 | # Get rectangle corners in world frame
417 | corners = np.array([
418 | [-self.length*0.83, -self.width/2], # Bottom left
419 | [self.length*0.17, -self.width/2], # Bottom right
420 | [self.length*0.17, self.width/2], # Top right
421 | [-self.length*0.83, self.width/2] # Top left
422 | ])
423 | # Rotate corners
424 | rot_matrix = np.array([
425 | [np.cos(theta), -np.sin(theta)],
426 | [np.sin(theta), np.cos(theta)]
427 | ])
428 | rotated_corners = (rot_matrix @ corners.T).T
429 |
430 | # Translate corners
431 | world_corners = rotated_corners + np.array([x, y])
432 |
433 | # Get min and max coordinates of the rectangle
434 | rect_min = np.min(world_corners, axis=0)
435 | rect_max = np.max(world_corners, axis=0)
436 |
437 | # Convert to grid indices
438 | min_x_idx = int((rect_min[0] - self.min_coords[0]) / self.resolution)
439 | min_y_idx = int((rect_min[1] - self.min_coords[1]) / self.resolution)
440 | max_x_idx = int((rect_max[0] - self.min_coords[0]) / self.resolution) + 1
441 | max_y_idx = int((rect_max[1] - self.min_coords[1]) / self.resolution) + 1
442 |
443 | # Ensure indices are within grid bounds
444 | min_x_idx = max(0, min_x_idx)
445 | min_y_idx = max(0, min_y_idx)
446 | max_x_idx = min(self.occupancy_grid.shape[1], max_x_idx)
447 | max_y_idx = min(self.occupancy_grid.shape[0], max_y_idx)
448 |
449 | # Check each grid cell in the bounding box
450 | for y_idx in range(min_y_idx, max_y_idx):
451 | for x_idx in range(min_x_idx, max_x_idx):
452 | if self.occupancy_grid[y_idx, x_idx] == 1:
453 | # Get grid cell corners in world coordinates
454 | grid_min_x = self.min_coords[0] + x_idx * self.resolution
455 | grid_min_y = self.min_coords[1] + y_idx * self.resolution
456 | grid_max_x = grid_min_x + self.resolution
457 | grid_max_y = grid_min_y + self.resolution
458 |
459 | # Check if grid cell intersects with rectangle
460 | # Convert grid corners to rectangle's local frame
461 | grid_corners = np.array([
462 | [grid_min_x, grid_min_y],
463 | [grid_max_x, grid_min_y],
464 | [grid_max_x, grid_max_y],
465 | [grid_min_x, grid_max_y]
466 | ])
467 |
468 | # Translate and rotate grid corners to rectangle's local frame
469 | translated_corners = grid_corners - np.array([x, y])
470 | local_corners = (rot_matrix.T @ translated_corners.T).T
471 |
472 | # Check if any grid corner is inside rectangle
473 | # Rectangle bounds in local frame
474 | rect_min_x = -self.length*0.83
475 | rect_max_x = self.length*0.17
476 | rect_min_y = -self.width/2
477 | rect_max_y = self.width/2
478 |
479 | # Check if any grid corner is inside rectangle
480 | for corner in local_corners:
481 | if (rect_min_x <= corner[0] <= rect_max_x and
482 | rect_min_y <= corner[1] <= rect_max_y):
483 | return False
484 |
485 | return True
486 |
487 | def visualize_occupancy_and_rectangle(self, target_se2, object_pos=None):
488 | """Visualize occupancy grid and rectangle"""
489 | try:
490 | plt.figure(figsize=(10, 10))
491 | ax = plt.gca()
492 |
493 | # Plot occupancy grid
494 | plt.imshow(self.occupancy_grid, cmap='binary', origin='lower',
495 | extent=[self.min_coords[0], self.max_coords[0], self.min_coords[1], self.max_coords[1]])
496 |
497 | # Draw sampling region
498 | sampling_rect = patches.Rectangle(
499 | (self.origin_se2[0] - self.x_half_range, self.origin_se2[1] - self.y_half_range),
500 | 2 * self.x_half_range, 2 * self.y_half_range,
501 | linewidth=1, edgecolor='b', facecolor='none', linestyle='--', label='Random Region'
502 | )
503 | ax.add_patch(sampling_rect)
504 |
505 | # Draw origin point
506 | ax.plot(self.origin_se2[0], self.origin_se2[1], 'bo', markersize=8, label='Origin')
507 |
508 | # Draw rectangle with SE2 pose
509 | x, y, theta = target_se2
510 |
511 | # Create rectangle with origin at center of width edge
512 | rect = patches.Rectangle((-self.length*0.83, -self.width/2),
513 | self.length, self.width,
514 | linewidth=2, edgecolor='r', facecolor='none')
515 |
516 | # Create transformation
517 | t = transforms.Affine2D().rotate(theta).translate(x, y)
518 | rect.set_transform(t + ax.transData)
519 |
520 | # Add rectangle and reachability circle to plot
521 | ax.add_patch(rect)
522 |
523 | # Create reachability circle centered at object_pos
524 | if object_pos is not None:
525 | reachability_circle = patches.Circle(object_pos, self.arm_length, linewidth=2, edgecolor='b', facecolor='none')
526 | ax.add_patch(reachability_circle)
527 |
528 | # Draw coordinate frame
529 | arrow_length = 0.2
530 | # X-axis (red)
531 | ax.arrow(x, y,
532 | arrow_length * np.cos(theta),
533 | arrow_length * np.sin(theta),
534 | head_width=0.05, head_length=0.05, fc='r', ec='r')
535 | # Y-axis (green)
536 | ax.arrow(x, y,
537 | arrow_length * np.cos(theta + np.pi/2),
538 | arrow_length * np.sin(theta + np.pi/2),
539 | head_width=0.05, head_length=0.05, fc='g', ec='g')
540 |
541 | # Add axis labels
542 | plt.xlabel('X (meters)')
543 | plt.ylabel('Y (meters)')
544 |
545 | # Add grid
546 | plt.grid(True, alpha=0.3)
547 |
548 | # Add title
549 | plt.title('Occupancy Grid Map with Rectangle')
550 |
551 | # Add legend
552 | plt.legend()
553 |
554 | # Show plot
555 | plt.show()
556 | plt.pause(0.5)
557 | plt.close() # Explicitly close the figure
558 | except Exception as e:
559 | print(f"Visualization error: {e}")
560 |
561 | def visualize_pcl_with_camera_and_object(self, se2_pose, object_pos):
562 | """Visualize point cloud, camera, and object using Open3D"""
563 | try:
564 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
565 | se3_pose = self.calculate_target_se3(se2_pose)
566 |
567 | # Create visualization window
568 | vis = o3d.visualization.Visualizer()
569 | vis.create_window()
570 |
571 | # Add point cloud
572 | pcd = o3d.geometry.PointCloud()
573 | pcd.points = o3d.utility.Vector3dVector(self.pcd_np[:,:3])
574 | pcd.colors = o3d.utility.Vector3dVector(self.pcd_color_np)
575 | vis.add_geometry(pcd)
576 |
577 | # Add object as sphere
578 | if object_pos.shape == (2,):
579 | object_pos = np.array([object_pos[0], object_pos[1], 1.0])
580 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
581 | sphere.translate(object_pos[:3])
582 | sphere.paint_uniform_color([1, 0, 0]) # Red color
583 | vis.add_geometry(sphere)
584 |
585 | # # Add camera frustum
586 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3)
587 | camera_frame.transform(camera_extrinsic)
588 | vis.add_geometry(camera_frame)
589 |
590 | # Add target as sphere
591 | target_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3)
592 | target_frame.transform(se3_pose)
593 | vis.add_geometry(target_frame)
594 |
595 | # Add origin coordinate frame
596 | origin_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3)
597 | vis.add_geometry(origin_frame)
598 |
599 | # Render and capture image
600 | vis.run()
601 | vis.destroy_window()
602 |
603 | except Exception as e:
604 | print(f"Visualization error: {e}")
605 |
606 | def calculate_target_se3(self, se2_pose):
607 | """Get the target SE3 pose"""
608 | x, y, theta = se2_pose
609 | se3_pose = np.array([
610 | [np.cos(theta), -np.sin(theta), 0, x],
611 | [np.sin(theta), np.cos(theta), 0, y],
612 | [0, 0, 1, 0],
613 | [0, 0, 0, 1]
614 | ])
615 | return se3_pose
616 |
617 | def calculate_camera_extrinsic(self, se2_pose):
618 | """Get the camera extrinsic matrix from the target SE2 pose"""
619 | se3_pose = self.calculate_target_se3(se2_pose)
620 | camera_extrinsic = se3_pose @ self.T_base_cam
621 | return camera_extrinsic
622 |
623 | def check_object_visibility(self, se2_pose, object_pos):
624 | """Check if the object is visible from the camera"""
625 | # convert target_se2 to camera_extrinsic
626 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
627 | fx, fy, cx, cy, w, h = self.cam_intrinsic
628 | intrinsic_matrix = np.array([
629 | [fx, 0, cx],
630 | [0, fy, cy],
631 | [0, 0, 1]
632 | ])
633 |
634 | if object_pos.shape == (3,):
635 | object_pos = np.concatenate([object_pos, np.array([1.0])])
636 | elif object_pos.shape == (2,): # use default height 1.0
637 | object_pos = np.concatenate([object_pos, np.array([1.0, 1.0])])
638 | else:
639 | raise ValueError("object_pos must be a 2D or 3D array")
640 |
641 | # convert object_pos to camera coordinates
642 | object_cam = np.linalg.inv(camera_extrinsic) @ object_pos
643 | object_cam = object_cam[:3] / object_cam[3]
644 | object_pix = intrinsic_matrix @ object_cam
645 | object_pix = object_pix[:2] / object_pix[2]
646 |
647 | if self.vis:
648 | global_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1)
649 | global_frame_points = global_frame.sample_points_uniformly(number_of_points=1000)
650 |
651 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.3)
652 | camera_frame.transform(camera_extrinsic)
653 | camera_frame_points = camera_frame.sample_points_uniformly(number_of_points=1000)
654 |
655 | base_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6)
656 | base_frame.transform(self.calculate_target_se3(se2_pose))
657 | base_frame_points = base_frame.sample_points_uniformly(number_of_points=1000)
658 |
659 | object_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
660 | object_point.translate(object_pos[:3])
661 | object_point.paint_uniform_color([1, 0, 0])
662 | object_point_points = object_point.sample_points_uniformly(number_of_points=1000)
663 |
664 | transformed_object_point = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
665 | transformed_object_point.translate((np.linalg.inv(camera_extrinsic) @ object_pos)[:3])
666 | transformed_object_point.paint_uniform_color([0, 1, 0])
667 | transformed_object_point_points = transformed_object_point.sample_points_uniformly(number_of_points=1000)
668 |
669 | o3d.io.write_point_cloud("check_transform.ply", global_frame_points + camera_frame_points + base_frame_points + object_point_points + transformed_object_point_points)
670 |
671 | # check if the object is in front of the camera
672 | if object_cam[2] <= 0:
673 | return False
674 |
675 | # check if the object is within the image bounds
676 | if object_pix[0] < 0 or object_pix[0] > w or object_pix[1] < 0 or object_pix[1] > h:
677 | return False
678 |
679 | return True
680 |
681 | def check_reachability(self, se2_pose, object_pos):
682 | """Check if the object is reachable from the target"""
683 | # convert target_se2 to camera_extrinsic
684 | robot_xy = se2_pose[:2]
685 | object_xy = object_pos[:2]
686 | distance = np.linalg.norm(robot_xy - object_xy)
687 | if distance > self.arm_length:
688 | return False
689 | return True
690 |
691 | def get_random_target_se2(self):
692 | """Get a random target SE2 pose"""
693 | while True:
694 | random_x = np.random.uniform(-self.x_half_range, self.x_half_range)
695 | random_y = np.random.uniform(-self.y_half_range, self.y_half_range)
696 | random_theta = np.random.uniform(self.to_rad(-self.theta_half_range_deg), self.to_rad(self.theta_half_range_deg))
697 | # Calculate absolute position by adding to origin
698 | se2_delta = [random_x, random_y, random_theta]
699 | se2_pose = [
700 | self.origin_se2[0] + se2_delta[0],
701 | self.origin_se2[1] + se2_delta[1],
702 | self.origin_se2[2] + se2_delta[2]
703 | ]
704 | if self.check_collision(se2_pose):
705 | # print("collision, break", se2_pose)
706 | break
707 |
708 | if self.vis:
709 | try:
710 | self.visualize_occupancy_and_rectangle(se2_pose)
711 | except Exception as e:
712 | print(f"Visualization failed: {e}")
713 |
714 | return se2_delta
715 |
716 | def get_random_target_se2_with_reachability_check(self, object_pos):
717 | """Get a random target SE2 pose with reachability check"""
718 | while True:
719 | random_x = np.random.uniform(-self.x_half_range, self.x_half_range)
720 | random_y = np.random.uniform(-self.y_half_range, self.y_half_range)
721 | random_theta = np.random.uniform(self.to_rad(-self.theta_half_range_deg), self.to_rad(self.theta_half_range_deg))
722 | # Calculate absolute position by adding to origin
723 | se2_delta = [random_x, random_y, random_theta]
724 | se2_pose = [
725 | self.origin_se2[0] + se2_delta[0],
726 | self.origin_se2[1] + se2_delta[1],
727 | self.origin_se2[2] + se2_delta[2]
728 | ]
729 | if self.check_collision(se2_pose) and self.check_reachability(se2_pose, object_pos):
730 | break
731 |
732 | if self.vis:
733 | try:
734 | self.visualize_occupancy_and_rectangle(se2_pose, object_pos)
735 | except Exception as e:
736 | print(f"Visualization failed: {e}")
737 |
738 | return se2_delta
739 |
740 | def get_random_target_se2_with_visibility_check(self, object_pos):
741 | """Get a random target SE2 pose with visibility check"""
742 | while True:
743 | se2_delta = self.get_random_target_se2()
744 | se2_pose = [
745 | self.origin_se2[0] + se2_delta[0],
746 | self.origin_se2[1] + se2_delta[1],
747 | self.origin_se2[2] + se2_delta[2]
748 | ]
749 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
750 | if self.check_object_visibility(se2_pose, object_pos) and self.check_boundary(se2_pose):
751 | break
752 |
753 | if self.vis:
754 | try:
755 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos)
756 | except Exception as e:
757 | print(f"Visualization failed: {e}")
758 |
759 | return se2_delta, camera_extrinsic
760 |
761 | def get_random_target_se2_with_boundary_check(self, object_pos):
762 | """Get a random target SE2 pose with visibility check"""
763 | while True:
764 | se2_delta = self.get_random_target_se2()
765 | se2_pose = [
766 | self.origin_se2[0] + se2_delta[0],
767 | self.origin_se2[1] + se2_delta[1],
768 | self.origin_se2[2] + se2_delta[2]
769 | ]
770 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
771 | if self.check_boundary(se2_pose):
772 | break
773 |
774 | if self.vis:
775 | try:
776 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos)
777 | except Exception as e:
778 | print(f"Visualization failed: {e}")
779 |
780 | return se2_delta, camera_extrinsic
781 |
782 | def get_random_target_se2_with_visibility_check_without_boundary(self, object_pos):
783 | """Get a random target SE2 pose with visibility check"""
784 | while True:
785 | se2_delta = self.get_random_target_se2()
786 | se2_pose = [
787 | self.origin_se2[0] + se2_delta[0],
788 | self.origin_se2[1] + se2_delta[1],
789 | self.origin_se2[2] + se2_delta[2]
790 | ]
791 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
792 | if self.check_object_visibility(se2_pose, object_pos):
793 | break
794 |
795 | if self.vis:
796 | try:
797 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos)
798 | except Exception as e:
799 | print(f"Visualization failed: {e}")
800 |
801 | return se2_delta, camera_extrinsic
802 |
803 | def get_random_target_se2_with_visibility_reachability_check(self, object_pos):
804 | """Get a random target SE2 pose with visibility check"""
805 | while True:
806 | se2_delta = self.get_random_target_se2()
807 | se2_pose = [
808 | self.origin_se2[0] + se2_delta[0],
809 | self.origin_se2[1] + se2_delta[1],
810 | self.origin_se2[2] + se2_delta[2]
811 | ]
812 | camera_extrinsic = self.calculate_camera_extrinsic(se2_pose)
813 | if self.check_object_visibility(se2_pose, object_pos) and self.check_reachability(se2_pose, object_pos):
814 | break
815 |
816 | if self.vis:
817 | try:
818 | self.visualize_pcl_with_camera_and_object(se2_pose, object_pos)
819 | except Exception as e:
820 | print(f"Visualization failed: {e}")
821 |
822 | return se2_delta, camera_extrinsic
--------------------------------------------------------------------------------
/n2m/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import torch
4 | from scipy.stats import multivariate_normal
5 | from sklearn.neighbors import KernelDensity
6 | import yaml
7 |
8 |
9 | def _load_point_cloud(file_path, use_color=True):
10 | pcd = o3d.io.read_point_cloud(file_path)
11 | point_cloud = np.asarray(pcd.points)
12 |
13 | if use_color:
14 | colors = np.asarray(pcd.colors)
15 | point_cloud = np.concatenate([point_cloud, colors], axis=1)
16 | return point_cloud
17 |
18 | def gmm_pdf(points, means, covs, weights):
19 | pdf_vals = np.zeros(len(points))
20 | for k in range(len(weights)):
21 | mvn = multivariate_normal(mean=means[k], cov=covs[k])
22 | pdf_vals += weights[k] * mvn.pdf(points)
23 | return pdf_vals
24 |
25 | def create_ellipsoid(mean, cov, color, scale=1.0, num_points=100000):
26 | """
27 | Create an ellipsoid from a mean and covariance matrix
28 |
29 | Args:
30 | mean: Mean vector (3,)
31 | cov: Covariance matrix (3, 3)
32 | num_points: Number of points to generate for the ellipsoid
33 | """
34 | # Eigen-decomposition of covariance matrix
35 | eigvals, eigvecs = np.linalg.eigh(cov)
36 |
37 | # Create a sphere
38 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1.0)
39 | sphere.compute_vertex_normals()
40 |
41 | # Scale sphere to form ellipsoid
42 | scales = scale * np.sqrt(eigvals)
43 | sphere_vertices = np.asarray(sphere.vertices)
44 | sphere_vertices = sphere_vertices @ np.diag(scales)
45 |
46 | # Rotate according to eigenvectors
47 | sphere_vertices = sphere_vertices @ eigvecs.T
48 | sphere.vertices = o3d.utility.Vector3dVector(sphere_vertices)
49 |
50 | # Translate to mean
51 | sphere.translate(mean)
52 |
53 | # Color the ellipsoid
54 | sphere.paint_uniform_color(color)
55 |
56 | return sphere
57 |
58 | def visualize_gmm_distribution(point_cloud, target_point, label, means, covs, weights):
59 | """
60 | Visualize GMM distribution as a 3D heatmap overlaid on point cloud
61 |
62 | Args:
63 | point_cloud: Input point cloud (N, 3)
64 | means: Mean vectors for each Gaussian component (K, 3)
65 | covs: Covariance matrices for each Gaussian component (K, 3, 3)
66 | weights: Mixing coefficients for each Gaussian component (K)
67 | num_points: Number of points to generate for the heatmap
68 | true_points: List of true points to visualize as spheres (optional)
69 | """
70 |
71 | # Create point cloud object for input points
72 | pcd = o3d.geometry.PointCloud()
73 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3])
74 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6])
75 |
76 | # Create ellipsoids for each Gaussian component
77 | ellipsoids = []
78 | for i in range(len(weights)):
79 | ellipsoid = create_ellipsoid(means[i], covs[i], color=[0, 1, 0], scale=2.0)
80 | ellipsoids.append(ellipsoid)
81 |
82 | # Create coordinate frame
83 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
84 | size=0.1, origin=[0, 0, 0])
85 |
86 | # Combine point clouds
87 | geometries = [pcd, coord_frame] + ellipsoids
88 |
89 | # Add spheres for true points if provided
90 | if target_point is not None:
91 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.02)
92 | sphere.translate(target_point)
93 |
94 | if label == 1:
95 | sphere.paint_uniform_color([0, 0, 1]) # Green color for true points
96 | else:
97 | sphere.paint_uniform_color([1, 0, 0]) # Red color for false points
98 | geometries.append(sphere)
99 |
100 | # Return all geometries
101 | return geometries
102 |
103 | def save_gmm_visualization(point_cloud, target_point, label, means, covs, weights, output_path):
104 | """
105 | Save GMM visualization to a file
106 |
107 | Args:
108 | point_cloud: Input point cloud (N, 3)
109 | means: Mean vectors for each Gaussian component (K, 3)
110 | covs: Covariance matrices for each Gaussian component (K, 3, 3)
111 | weights: Mixing coefficients for each Gaussian component (K)
112 | output_path: Path to save the visualization
113 | num_points: Number of points to generate for the heatmap
114 | true_points: List of true points to visualize as spheres (optional)
115 | """
116 | geometries = visualize_gmm_distribution(
117 | point_cloud, target_point, label, means, covs, weights
118 | )
119 |
120 | # Extract combined point cloud and convert other geometries to points
121 | combined_points = []
122 | combined_colors = []
123 |
124 | # Add the main point cloud points and colors
125 | combined_points.append(np.asarray(geometries[0].points))
126 | combined_colors.append(np.asarray(geometries[0].colors))
127 |
128 | # Convert coordinate frame to points
129 | coord_mesh = geometries[1]
130 | coord_pcd = coord_mesh.sample_points_uniformly(number_of_points=1000)
131 | combined_points.append(np.asarray(coord_pcd.points))
132 | combined_colors.append(np.asarray(coord_pcd.colors))
133 |
134 | # Convert ellipsoids to points if they exist
135 | for ellipsoid in geometries[2:]:
136 | ellipsoid_pcd = ellipsoid.sample_points_uniformly(number_of_points=10000)
137 | combined_points.append(np.asarray(ellipsoid_pcd.points))
138 | combined_colors.append(np.asarray(ellipsoid_pcd.colors))
139 |
140 | # Combine all points and colors
141 | all_points = np.vstack(combined_points)
142 | all_colors = np.vstack(combined_colors)
143 |
144 | all_colors = np.clip(all_colors, 0, 1)
145 |
146 | # Create final point cloud
147 | final_pcd = o3d.geometry.PointCloud()
148 | final_pcd.points = o3d.utility.Vector3dVector(all_points)
149 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors)
150 |
151 | # Save to file
152 | o3d.io.write_point_cloud(output_path, final_pcd)
153 |
154 |
155 | def save_gmm_visualization_se2(point_cloud, target_se2, label, means, covs, weights, output_path, interval=0.1, area=[-10, 10, -10, 10], threshold=0.5, z_value=0.8):
156 | """
157 | Given a point cloud and mean and covariance of SE(2) points, visualize the point cloud and the SE(2) points.
158 | For each grid points in the xy area with interval, calculate the most probable theta
159 | And then with SE(2) point (x, y, theta), calculate the probalility of that SE(2) point
160 | Based on the probability, color the grid and visualize the arrow of theta.
161 | Repeat this process for all grid points and save this with the point cloud.
162 |
163 | Args:
164 | point_cloud: Input point cloud (N, 3)
165 | target_se2: Target SE(2) point to visualize
166 | label: Label of the target point (0 or 1)
167 | means: Mean vectors for each Gaussian component (K, 3)
168 | covs: Covariance matrices for each Gaussian component (K, 3, 3)
169 | weights: Mixing coefficients for each Gaussian component (K)
170 | interval: Interval of the grid to visualize the SE(2) points
171 | area: Area to visualize the SE(2) points
172 | threshold: Threshold of the probability to visualize the SE(2) points
173 |
174 | Returns:
175 | no returns, save the visualization to output_path.
176 | """
177 | # Create a grid of points in the xy area
178 | x = np.arange(area[0], area[1] + interval, interval)
179 | y = np.arange(area[2], area[3] + interval, interval)
180 | X, Y = np.meshgrid(x, y)
181 | grid_points = np.column_stack((X.ravel(), Y.ravel()))
182 |
183 | # Initialize arrays to store results
184 | n_grid_points = len(grid_points)
185 | theta_samples = np.linspace(-np.pi, np.pi, 36) # Sample 36 different angles
186 |
187 | # Create all possible combinations of grid points and thetas
188 | # Shape: (n_grid_points * n_thetas, 3)
189 | grid_expanded = np.repeat(grid_points, len(theta_samples), axis=0)
190 | theta_expanded = np.tile(theta_samples, n_grid_points)
191 | all_se2_points = np.column_stack((grid_expanded, theta_expanded))
192 |
193 | # Calculate probabilities for all points at once
194 | all_probs = gmm_pdf(all_se2_points, means, covs, weights)
195 |
196 | # Reshape probabilities to (n_grid_points, n_thetas)
197 | probs_matrix = all_probs.reshape(n_grid_points, len(theta_samples))
198 |
199 | # Find the best theta and probability for each grid point
200 | best_probs = np.max(probs_matrix, axis=1)
201 | best_theta_indices = np.argmax(probs_matrix, axis=1)
202 | best_thetas = theta_samples[best_theta_indices]
203 |
204 | # Filter out probabilities below threshold and normalize probabilities to [0,1]
205 | prob_min, prob_max = best_probs.min(), best_probs.max()
206 | best_probs = (best_probs - prob_min) / (prob_max - prob_min)
207 |
208 | # Create visualization geometries
209 | geometries = []
210 |
211 | # Add original point cloud
212 | pcd = o3d.geometry.PointCloud()
213 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3])
214 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6])
215 | geometries.append(pcd)
216 |
217 | # Add coordinate frame
218 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
219 | geometries.append(coord_frame)
220 |
221 | # Add target SE(2) point if provided
222 | if target_se2 is not None:
223 | # Create sphere for position
224 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/4)
225 | sphere.translate([target_se2[0], target_se2[1], z_value]) # Use x,y coordinates
226 | if label == 1:
227 | sphere.paint_uniform_color([0, 0, 1]) # Blue for positive label
228 | else:
229 | sphere.paint_uniform_color([1, 0, 0]) # Red for negative label
230 | geometries.append(sphere)
231 |
232 | # Create arrow for theta
233 | arrow_length = interval * 1
234 | arrow = o3d.geometry.TriangleMesh.create_arrow(
235 | cylinder_radius=interval/16,
236 | cone_radius=interval/12,
237 | cylinder_height=arrow_length*0.8,
238 | cone_height=arrow_length*0.2
239 | )
240 |
241 | # Rotate and translate arrow to position
242 | R_x = np.array([
243 | [0, 0, 1],
244 | [0, 1, 0],
245 | [1, 0, 0]
246 | ])
247 | R = np.array([
248 | [np.cos(target_se2[2]), -np.sin(target_se2[2]), 0],
249 | [np.sin(target_se2[2]), np.cos(target_se2[2]), 0],
250 | [0, 0, 1]
251 | ])
252 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis
253 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta
254 | arrow.translate([target_se2[0], target_se2[1], z_value])
255 | arrow.paint_uniform_color([0, 0, 1] if label == 1 else [1, 0, 0])
256 | geometries.append(arrow)
257 |
258 | # Add grid points and arrows for points above threshold
259 | arrow_length = interval * 1
260 | for i in range(len(best_probs)):
261 | if best_probs[i] > threshold:
262 | # Create sphere for grid point
263 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/8)
264 | sphere.translate([grid_points[i,0], grid_points[i,1], z_value])
265 |
266 | # Color based on probability (blue->red gradient)
267 | red_value = (best_probs[i] - threshold) / (1 - threshold) * 0.99
268 | color = [1-red_value, red_value, 0]
269 | sphere.paint_uniform_color(color)
270 | geometries.append(sphere)
271 |
272 | # Create arrow for theta
273 | arrow = o3d.geometry.TriangleMesh.create_arrow(
274 | cylinder_radius=interval/16,
275 | cone_radius=interval/12,
276 | cylinder_height=arrow_length*0.8,
277 | cone_height=arrow_length*0.2
278 | )
279 |
280 | # Rotate and translate arrow to position
281 | R_x = np.array([
282 | [0, 0, 1],
283 | [0, 1, 0],
284 | [1, 0, 0]
285 | ])
286 | R = np.array([
287 | [np.cos(best_thetas[i]), -np.sin(best_thetas[i]), 0],
288 | [np.sin(best_thetas[i]), np.cos(best_thetas[i]), 0],
289 | [0, 0, 1]
290 | ])
291 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis
292 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta
293 | arrow.translate([grid_points[i,0], grid_points[i,1], z_value])
294 | arrow.paint_uniform_color(color)
295 | geometries.append(arrow)
296 |
297 | # Convert all geometries to point clouds and combine
298 | combined_points = []
299 | combined_colors = []
300 |
301 | for geom in geometries:
302 | if isinstance(geom, o3d.geometry.PointCloud):
303 | combined_points.append(np.asarray(geom.points))
304 | combined_colors.append(np.asarray(geom.colors))
305 | else:
306 | # Sample points from mesh
307 | pcd = geom.sample_points_uniformly(number_of_points=500)
308 | combined_points.append(np.asarray(pcd.points))
309 | combined_colors.append(np.asarray(pcd.colors))
310 |
311 | # Combine all points and colors
312 | all_points = np.vstack(combined_points)
313 | all_colors = np.vstack(combined_colors)
314 | all_colors = np.clip(all_colors, 0, 1)
315 |
316 | # Create final point cloud
317 | final_pcd = o3d.geometry.PointCloud()
318 | final_pcd.points = o3d.utility.Vector3dVector(all_points)
319 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors)
320 |
321 | # Save to file
322 | o3d.io.write_point_cloud(output_path, final_pcd)
323 |
324 | def save_gmm_visualization_xythetaz(point_cloud, target_xythetaz, label, means, covs, weights, output_path, interval=0.1, area=[-10, 10, -10, 10, -10, 10], threshold=0.5):
325 | """
326 | Given a point cloud and mean and covariance of SE(2)+Z points, visualize the point cloud and the SE(2)+Z points.
327 | For each grid points in the xyz area with interval, calculate the most probable theta
328 | And then with SE(2)+Z point (x, y, theta, z), calculate the probability of that point
329 | Based on the probability, color the grid and visualize the arrow of theta.
330 | Repeat this process for all grid points and save this with the point cloud.
331 |
332 | Args:
333 | point_cloud: Input point cloud (N, 3)
334 | target_xythetaz: Target SE(2)+Z point to visualize (x,y,theta,z)
335 | label: Label of the target point (0 or 1)
336 | means: Mean vectors for each Gaussian component (K, 4)
337 | covs: Covariance matrices for each Gaussian component (K, 4, 4)
338 | weights: Mixing coefficients for each Gaussian component (K)
339 | interval: Interval of the grid to visualize the points
340 | area: Area to visualize the points [xmin, xmax, ymin, ymax, zmin, zmax]
341 | threshold: Threshold of the probability to visualize the points
342 | """
343 | # Create a grid of points in the xyz area
344 | x = np.arange(area[0], area[1] + interval, interval)
345 | y = np.arange(area[2], area[3] + interval, interval)
346 | z = np.arange(area[4], area[5] + interval, interval)
347 | X, Y, Z = np.meshgrid(x, y, z)
348 | grid_points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
349 |
350 | # Initialize arrays to store results
351 | n_grid_points = len(grid_points)
352 | theta_samples = np.linspace(-np.pi, np.pi, 36) # Sample 36 different angles
353 |
354 | # Create all possible combinations of grid points and thetas
355 | # Shape: (n_grid_points * n_thetas, 4)
356 | grid_expanded = np.repeat(grid_points, len(theta_samples), axis=0)
357 | theta_expanded = np.tile(theta_samples, n_grid_points)
358 | # Rearrange to x,y,theta,z format
359 | all_xythetaz_points = np.column_stack((
360 | grid_expanded[:, 0], # x
361 | grid_expanded[:, 1], # y
362 | theta_expanded, # theta
363 | grid_expanded[:, 2] # z
364 | ))
365 |
366 | # Calculate probabilities for all points at once
367 | all_probs = gmm_pdf(all_xythetaz_points, means, covs, weights)
368 |
369 | # Reshape probabilities to (n_grid_points, n_thetas)
370 | probs_matrix = all_probs.reshape(n_grid_points, len(theta_samples))
371 |
372 | # Find the best theta and probability for each grid point
373 | best_probs = np.max(probs_matrix, axis=1)
374 | best_theta_indices = np.argmax(probs_matrix, axis=1)
375 | best_thetas = theta_samples[best_theta_indices]
376 |
377 | # Filter out probabilities below threshold and normalize probabilities to [0,1]
378 | prob_min, prob_max = best_probs.min(), best_probs.max()
379 | best_probs = (best_probs - prob_min) / (prob_max - prob_min)
380 |
381 | # Create visualization geometries
382 | geometries = []
383 |
384 | # Add original point cloud
385 | pcd = o3d.geometry.PointCloud()
386 | pcd.points = o3d.utility.Vector3dVector(point_cloud[:, 0:3])
387 | pcd.colors = o3d.utility.Vector3dVector(point_cloud[:, 3:6])
388 | geometries.append(pcd)
389 |
390 | # Add coordinate frame
391 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0, 0, 0])
392 | geometries.append(coord_frame)
393 |
394 | # Add target xythetaz point if provided
395 | if target_xythetaz is not None:
396 | # Create sphere for position
397 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/4)
398 | sphere.translate([target_xythetaz[0], target_xythetaz[1], target_xythetaz[3]]) # Use x,y,z coordinates
399 | if label == 1:
400 | sphere.paint_uniform_color([0, 0, 1]) # Blue for positive label
401 | else:
402 | sphere.paint_uniform_color([1, 0, 0]) # Red for negative label
403 | geometries.append(sphere)
404 |
405 | # Create arrow for theta
406 | arrow_length = interval * 1
407 | arrow = o3d.geometry.TriangleMesh.create_arrow(
408 | cylinder_radius=interval/16,
409 | cone_radius=interval/12,
410 | cylinder_height=arrow_length*0.8,
411 | cone_height=arrow_length*0.2
412 | )
413 |
414 | # Rotate and translate arrow to position
415 | R_x = np.array([
416 | [0, 0, 1],
417 | [0, 1, 0],
418 | [1, 0, 0]
419 | ])
420 | R = np.array([
421 | [np.cos(target_xythetaz[2]), -np.sin(target_xythetaz[2]), 0],
422 | [np.sin(target_xythetaz[2]), np.cos(target_xythetaz[2]), 0],
423 | [0, 0, 1]
424 | ])
425 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis
426 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta
427 | arrow.translate([target_xythetaz[0], target_xythetaz[1], target_xythetaz[3]])
428 | arrow.paint_uniform_color([0, 0, 1] if label == 1 else [1, 0, 0])
429 | geometries.append(arrow)
430 |
431 | # Add grid points and arrows for points above threshold
432 | arrow_length = interval * 1
433 | for i in range(len(best_probs)):
434 | if best_probs[i] > threshold:
435 | # Create sphere for grid point
436 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=interval/8)
437 | sphere.translate([grid_points[i,0], grid_points[i,1], grid_points[i,2]])
438 |
439 | # Color based on probability (blue->red gradient)
440 | red_value = (best_probs[i] - threshold) / (1 - threshold) * 0.99
441 | color = [1-red_value, red_value, 0]
442 | sphere.paint_uniform_color(color)
443 | geometries.append(sphere)
444 |
445 | # Create arrow for theta
446 | arrow = o3d.geometry.TriangleMesh.create_arrow(
447 | cylinder_radius=interval/16,
448 | cone_radius=interval/12,
449 | cylinder_height=arrow_length*0.8,
450 | cone_height=arrow_length*0.2
451 | )
452 |
453 | # Rotate and translate arrow to position
454 | R_x = np.array([
455 | [0, 0, 1],
456 | [0, 1, 0],
457 | [1, 0, 0]
458 | ])
459 | R = np.array([
460 | [np.cos(best_thetas[i]), -np.sin(best_thetas[i]), 0],
461 | [np.sin(best_thetas[i]), np.cos(best_thetas[i]), 0],
462 | [0, 0, 1]
463 | ])
464 | arrow.rotate(R_x, center=[0, 0, 0]) # rotate to x-axis
465 | arrow.rotate(R, center=[0, 0, 0]) # rotate to theta
466 | arrow.translate([grid_points[i,0], grid_points[i,1], grid_points[i,2]])
467 | arrow.paint_uniform_color(color)
468 | geometries.append(arrow)
469 |
470 | # Convert all geometries to point clouds and combine
471 | combined_points = []
472 | combined_colors = []
473 |
474 | for geom in geometries:
475 | if isinstance(geom, o3d.geometry.PointCloud):
476 | combined_points.append(np.asarray(geom.points))
477 | combined_colors.append(np.asarray(geom.colors))
478 | else:
479 | # Sample points from mesh
480 | pcd = geom.sample_points_uniformly(number_of_points=500)
481 | combined_points.append(np.asarray(pcd.points))
482 | combined_colors.append(np.asarray(pcd.colors))
483 |
484 | # Combine all points and colors
485 | all_points = np.vstack(combined_points)
486 | all_colors = np.vstack(combined_colors)
487 | all_colors = np.clip(all_colors, 0, 1)
488 |
489 | # Create final point cloud
490 | final_pcd = o3d.geometry.PointCloud()
491 | final_pcd.points = o3d.utility.Vector3dVector(all_points)
492 | final_pcd.colors = o3d.utility.Vector3dVector(all_colors)
493 |
494 | # Save to file
495 | o3d.io.write_point_cloud(output_path, final_pcd)
496 |
497 | if __name__ == "__main__":
498 | # Load data
499 | point_cloud = _load_point_cloud("/home/hyunjun/projects/CoRL2025/nav2man/exp_pcl_train/data/point_clouds/0001.ply")
500 | means = [[0, 0, 0], [1, 0, 0], [0, 1, 0]]
501 | covs = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]]
502 | weights = [0.1, 0.4, 0.5]
503 | save_gmm_visualization(point_cloud, means, covs, weights, "data/gmm_visualization")
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "n2m"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.8"
7 | dependencies = [
8 | "easydict>=1.13",
9 | "numpy==1.23.5",
10 | "open3d>=0.19.0",
11 | "termcolor>=2.4.0",
12 | "timm>=1.0.17",
13 | "torch==2.2.1",
14 | "torchaudio==2.2.1",
15 | "torchvision==0.17.1",
16 | "wandb>=0.21.0",
17 | ]
18 |
--------------------------------------------------------------------------------
/scripts/process_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python scripts/sample_camera_poses.py --dataset_path $1 --num_poses 300 --vis
4 | scripts/render/build/fpv_render --dataset_path $1
--------------------------------------------------------------------------------
/scripts/render/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10)
2 | project(PointCloudRenderer)
3 |
4 | set(CMAKE_CXX_STANDARD 17)
5 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
6 | set(CMAKE_BUILD_TYPE Release)
7 |
8 | # Find packages
9 | find_package(PCL 1.8 REQUIRED)
10 | find_package(Eigen3 REQUIRED)
11 |
12 | # Include PCL definitions and directories
13 | include_directories(${PCL_INCLUDE_DIRS})
14 | link_directories(${PCL_LIBRARY_DIRS})
15 | add_definitions(${PCL_DEFINITIONS})
16 |
17 | # Add executable
18 | add_executable(fpv_render fpv_render.cpp)
19 |
20 | # Link libraries
21 | target_link_libraries(fpv_render
22 | ${PCL_LIBRARIES}
23 | Eigen3::Eigen
24 | stdc++fs # Add filesystem library
25 | )
26 |
27 | # Include directories
28 | target_include_directories(fpv_render PRIVATE
29 | ${PCL_INCLUDE_DIRS}
30 | ${EIGEN3_INCLUDE_DIR}
31 | ${PROJECT_SOURCE_DIR}/include
32 | )
33 |
34 | # Compiler flags
35 | if(CMAKE_COMPILER_IS_GNUCXX)
36 | target_compile_options(fpv_render PRIVATE -O3 -march=native)
37 | endif()
--------------------------------------------------------------------------------
/scripts/render/README.md:
--------------------------------------------------------------------------------
1 | # Point Cloud Renderer
2 |
3 | This is a C++ implementation of a point cloud renderer with occlusion handling, converted from a Python implementation for improved performance. It uses the Point Cloud Library (PCL) for point cloud processing and visualization.
4 |
5 | ## Performance Comparison
6 |
7 | The C++ implementation is expected to be significantly faster than the Python version due to:
8 |
9 | 1. Native code execution without Python interpreter overhead
10 | 2. More efficient memory management
11 | 3. Better optimization with compiler flags (`-O3 -march=native`)
12 | 4. Efficient point cloud operations using PCL's optimized algorithms
13 | 5. Reduced memory allocation overhead in the point processing loop
14 |
15 | ## Requirements
16 |
17 | - C++14 compatible compiler (GCC or Clang recommended)
18 | - CMake (version 3.10 or higher)
19 | - Point Cloud Library (PCL) version 1.8 or higher
20 | - librealsense2 SDK
21 | - Eigen3 library
22 |
23 | ## Installation
24 |
25 | ### Installing Dependencies
26 |
27 | #### Ubuntu/Debian:
28 |
29 | ```bash
30 | # Install build tools
31 | sudo apt-get update
32 | sudo apt-get install build-essential cmake
33 |
34 | # Install PCL and dependencies
35 | sudo apt-get install libpcl-dev
36 |
37 | # Install Eigen3
38 | sudo apt-get install libeigen3-dev
39 |
40 | # Install librealsense (following Intel's guide)
41 | # See: https://github.com/IntelRealSense/librealsense/blob/master/doc/installation.md
42 | sudo apt-key adv --keyserver keys.gnupg.net --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE || sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv-key F6E65AC044F831AC80A06380C8B3A55A6F3EFCDE
43 | sudo add-apt-repository "deb https://librealsense.intel.com/Debian/apt-repo $(lsb_release -cs) main" -u
44 | sudo apt-get install librealsense2-dkms librealsense2-utils librealsense2-dev
45 | ```
46 |
47 | ## Building the Project
48 |
49 | ```bash
50 | # Clone this repository
51 | git clone https://github.com/yourusername/point-cloud-renderer
52 | cd point-cloud-renderer
53 |
54 | # Create build directory
55 | mkdir build && cd build
56 |
57 | # Configure and build
58 | cmake ..
59 | make -j$(nproc)
60 | ```
61 |
62 | ## Usage
63 |
64 | Make sure you have a point cloud file named `7.pcd` in the current directory, or modify the source code to point to your PCD file.
65 |
66 | ```bash
67 | # Run the renderer
68 | ./fpv_render
69 | ```
70 |
71 | The program will:
72 | 1. Load the specified point cloud file
73 | 2. Render it from the specified camera perspective
74 | 3. Save the result as `rendered_view.pcd`
75 | 4. Display the point cloud with coordinate frames in a visualization window
76 |
77 | ## Comparing with Python Version
78 |
79 | To compare performance:
80 |
81 | 1. Run the Python version and note the reported timing
82 | 2. Run the C++ version and note the reported timing
83 | 3. Compare the quality of the rendered point clouds
84 |
85 | The C++ implementation using PCL is expected to show significant performance improvements, especially for large point clouds, compared to the Python implementation using Open3D. PCL's native C++ implementation avoids Python's interpreter overhead and benefits from highly optimized algorithms specifically designed for point cloud processing.
--------------------------------------------------------------------------------
/scripts/render/fpv_render.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | #include
10 |
11 | // PCL includes
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | using json = nlohmann::json;
19 | namespace fs = std::filesystem;
20 |
21 | namespace nlohmann {
22 | // Convert Eigen::Vector3f to JSON
23 | void to_json(json& j, const Eigen::Vector3f& v) {
24 | j = json::array({v(0), v(1), v(2)});
25 | }
26 |
27 | // Convert JSON to Eigen::Vector3f
28 | void from_json(const json& j, Eigen::Vector3f& v) {
29 | v(0) = j[0].get();
30 | v(1) = j[1].get();
31 | v(2) = j[2].get();
32 | }
33 |
34 | // Convert Eigen::Vector4f to JSON
35 | void to_json(json& j, const Eigen::Vector4f& v) {
36 | j = json::array({v(0), v(1), v(2), v(3)});
37 | }
38 |
39 | // Convert JSON to Eigen::Vector4f
40 | void from_json(const json& j, Eigen::Vector4f& v) {
41 | v(0) = j[0].get();
42 | v(1) = j[1].get();
43 | v(2) = j[2].get();
44 | v(3) = j[3].get();
45 | }
46 | }
47 |
48 | class CameraIntrinsics {
49 | public:
50 | float fx, fy, cx, cy;
51 | int width, height;
52 |
53 | CameraIntrinsics(float fx, float fy, float cx, float cy, int width, int height)
54 | : fx(fx), fy(fy), cx(cx), cy(cy), width(width), height(height) {}
55 | };
56 |
57 | template
58 | void transformPoints(const pcl::PointCloud& source,
59 | pcl::PointCloud& target,
60 | const Eigen::Matrix4f& transform) {
61 | pcl::transformPointCloud(source, target, transform);
62 | }
63 |
64 | template
65 | typename pcl::PointCloud::Ptr translatePointCloud(
66 | const typename pcl::PointCloud::Ptr& pointCloud,
67 | const Eigen::Matrix4f& basePose
68 | ) {
69 | // Create output point cloud
70 | typename pcl::PointCloud::Ptr transformedCloud(new pcl::PointCloud());
71 |
72 | // Create inverse transformation to move base pose to origin
73 | Eigen::Matrix4f inverseTransform = basePose.inverse();
74 |
75 | // Transform the point cloud
76 | pcl::transformPointCloud(*pointCloud, *transformedCloud, inverseTransform);
77 |
78 | return transformedCloud;
79 | }
80 |
81 |
82 |
83 | template
84 | Eigen::Vector3f translateTarget(
85 | const Eigen::Vector3f& target_se2,
86 | const Eigen::Matrix4f& basePose
87 | ) {
88 | // SE(2) pose is [x, y, theta]
89 | float x = target_se2(0);
90 | float y = target_se2(1);
91 | float theta = target_se2(2);
92 |
93 | // Create SE(2) transformation matrix
94 | Eigen::Matrix3f se2_transform;
95 | se2_transform << cos(theta), -sin(theta), x,
96 | sin(theta), cos(theta), y,
97 | 0, 0, 1;
98 |
99 | // Get base SE(2) transformation
100 | float base_theta = atan2(basePose(1,0), basePose(0,0));
101 | Eigen::Matrix3f base_se2;
102 | base_se2 << cos(base_theta), -sin(base_theta), basePose(0,3),
103 | sin(base_theta), cos(base_theta), basePose(1,3),
104 | 0, 0, 1;
105 |
106 | // Transform SE(2) pose
107 | Eigen::Matrix3f transformed_se2 = base_se2.inverse() * se2_transform;
108 |
109 | // Extract transformed SE(2) parameters
110 | Eigen::Vector3f translated_se2;
111 | translated_se2(0) = transformed_se2(0,2);
112 | translated_se2(1) = transformed_se2(1,2);
113 | translated_se2(2) = atan2(transformed_se2(1,0), transformed_se2(0,0));
114 |
115 | return translated_se2;
116 | }
117 |
118 | template
119 | typename pcl::PointCloud::Ptr renderPointCloudWithOcclusion(
120 | const typename pcl::PointCloud::Ptr& pointCloud,
121 | const Eigen::Matrix4f& cameraPose,
122 | const CameraIntrinsics& intrinsics,
123 | float minDepth = 0.1,
124 | float maxDepth = 10.0
125 | ) {
126 | // Transform points to camera frame
127 | Eigen::Matrix4f cameraToWorld = cameraPose.inverse();
128 | typename pcl::PointCloud::Ptr cloudCamera(new pcl::PointCloud());
129 | transformPoints(*pointCloud, *cloudCamera, cameraToWorld);
130 |
131 | // Initialize depth buffer
132 | std::vector> depthBuffer(intrinsics.width,
133 | std::vector(intrinsics.height,
134 | std::numeric_limits::infinity()));
135 | std::vector> pointIndices(intrinsics.width,
136 | std::vector(intrinsics.height, -1));
137 |
138 | auto timeStart = std::chrono::high_resolution_clock::now();
139 | double tTotal = 0.0;
140 |
141 | // Process each point
142 | for (size_t i = 0; i < cloudCamera->points.size(); ++i) {
143 | auto pointStart = std::chrono::high_resolution_clock::now();
144 |
145 | const auto& point = cloudCamera->points[i];
146 |
147 | // Transform coordinates to match Python implementation
148 | // In Python, camera looks along Y axis, with Z up and X right
149 | float x = point.x; // X stays the same
150 | float y = point.y; // flip y
151 | float z = point.z; // flip z
152 |
153 | // Skip if behind camera or invalid
154 | if (z <= 0 || !std::isfinite(x) || !std::isfinite(y) || !std::isfinite(z)) {
155 | continue;
156 | }
157 |
158 | // Project point to image plane
159 | float u = (x * intrinsics.fx / z) + intrinsics.cx;
160 | float v = (y * intrinsics.fy / z) + intrinsics.cy;
161 |
162 | // Skip if outside image bounds
163 | if (u < 0 || u >= intrinsics.width || v < 0 || v >= intrinsics.height) {
164 | continue;
165 | }
166 |
167 | // Get pixel coordinates
168 | int pixelX = static_cast(u);
169 | int pixelY = static_cast(v);
170 |
171 | // Skip if outside depth range
172 | float depth = z;
173 | if (depth < minDepth || depth > maxDepth) {
174 | continue;
175 | }
176 |
177 | // Update depth buffer if point is closer
178 | if (depth < depthBuffer[pixelX][pixelY]) {
179 | depthBuffer[pixelX][pixelY] = depth;
180 | pointIndices[pixelX][pixelY] = i;
181 | }
182 |
183 | auto pointEnd = std::chrono::high_resolution_clock::now();
184 | tTotal += std::chrono::duration(pointEnd - pointStart).count();
185 | }
186 |
187 | std::cout << "Time taken for point processing: " << tTotal * 1000 << " ms" << std::endl;
188 |
189 | // Create output point cloud
190 | std::vector validIndices;
191 | for (const auto& row : pointIndices) {
192 | for (int idx : row) {
193 | if (idx >= 0) {
194 | validIndices.push_back(idx);
195 | }
196 | }
197 | }
198 |
199 | if (validIndices.empty()) {
200 | return typename pcl::PointCloud::Ptr(new pcl::PointCloud());
201 | }
202 |
203 | typename pcl::PointCloud::Ptr renderedCloud(new pcl::PointCloud());
204 | renderedCloud->reserve(validIndices.size());
205 |
206 | // Extract visible points
207 | for (int idx : validIndices) {
208 | renderedCloud->points.push_back(pointCloud->points[idx]);
209 | }
210 |
211 | renderedCloud->width = renderedCloud->points.size();
212 | renderedCloud->height = 1;
213 | renderedCloud->is_dense = false;
214 |
215 | return renderedCloud;
216 | }
217 |
218 |
219 |
220 | Eigen::Matrix4f createCameraPoseFromTranslationQuaternion(
221 | float tx, float ty, float tz,
222 | float qw, float qx, float qy, float qz
223 | ) {
224 | // Normalize quaternion
225 | float norm = std::sqrt(qw*qw + qx*qx + qy*qy + qz*qz);
226 | qw /= norm;
227 | qx /= norm;
228 | qy /= norm;
229 | qz /= norm;
230 |
231 | // Quaternion to rotation matrix conversion
232 | Eigen::Matrix3f R;
233 | R << 1 - 2*qy*qy - 2*qz*qz, 2*qx*qy - 2*qw*qz, 2*qx*qz + 2*qw*qy,
234 | 2*qx*qy + 2*qw*qz, 1 - 2*qx*qx - 2*qz*qz, 2*qy*qz - 2*qw*qx,
235 | 2*qx*qz - 2*qw*qy, 2*qy*qz + 2*qw*qx, 1 - 2*qx*qx - 2*qy*qy;
236 |
237 | // Create 4x4 transformation matrix
238 | Eigen::Matrix4f pose = Eigen::Matrix4f::Identity();
239 | pose.block<3,3>(0,0) = R;
240 | pose(0,3) = tx;
241 | pose(1,3) = ty;
242 | pose(2,3) = tz;
243 |
244 | return pose;
245 | }
246 |
247 | int main(int argc, char** argv) {
248 | if (argc != 2) {
249 | std::cerr << "Usage: " << argv[0] << " " << std::endl;
250 | return 1;
251 | }
252 |
253 | std::string dataset_path = argv[1];
254 | std::string pcl_dir = dataset_path + "/pcl";
255 | std::string meta_path = dataset_path + "/meta.json";
256 | std::string output_pcl_dir = dataset_path + "/pcl_aug";
257 | std::string output_meta_path = dataset_path + "/meta_aug.json";
258 | std::string camera_poses_path = dataset_path + "/camera_poses/camera_poses.json";
259 | std::string base_poses_path = dataset_path + "/camera_poses/base_poses.json";
260 |
261 | // Create output directory if it doesn't exist
262 | fs::create_directories(output_pcl_dir);
263 |
264 | std::ifstream camera_poses_file(camera_poses_path);
265 | json camera_poses;
266 | camera_poses_file >> camera_poses;
267 |
268 | std::ifstream meta_file(meta_path);
269 | json meta;
270 | meta_file >> meta;
271 |
272 | std::ifstream base_poses_file(base_poses_path);
273 | json base_poses;
274 | base_poses_file >> base_poses;
275 |
276 | // Read camera intrinsics from meta.json
277 | std::vector cam_intr;
278 | if (meta["meta"].contains("camera_intrinsic")) {
279 | for (const auto& v : meta["meta"]["camera_intrinsic"]) {
280 | cam_intr.push_back(v.get());
281 | }
282 | } else {
283 | std::cerr << "camera_intrinsic not found in meta.json, using default values." << std::endl;
284 | cam_intr = {100.6919557412736, 100.6919557412736, 160.0, 120.0, 320, 240};
285 | }
286 | if (cam_intr.size() != 6) {
287 | std::cerr << "camera_intrinsic in meta.json must have 6 values (fx, fy, cx, cy, width, height)." << std::endl;
288 | return 1;
289 | }
290 | CameraIntrinsics intrinsics(
291 | cam_intr[0], cam_intr[1], cam_intr[2], cam_intr[3], (int)cam_intr[4], (int)cam_intr[5]
292 | );
293 |
294 | json meta_aug;
295 | meta_aug["meta"] = meta["meta"];
296 | meta_aug["episodes"] = json::array();
297 |
298 | auto timeStart = std::chrono::high_resolution_clock::now();
299 | for (const auto& episode : meta["episodes"]) {
300 | // Load point cloud
301 | std::string pcl_path = fs::path(dataset_path) / episode["file_path"];
302 | pcl::PointCloud::Ptr pointCloud(new pcl::PointCloud);
303 | if (pcl::io::loadPCDFile(pcl_path, *pointCloud) == -1) {
304 | std::cerr << "Failed to load point cloud file: " << pcl_path << std::endl;
305 | continue;
306 | }
307 |
308 | int episode_id = episode["id"];
309 | for (int i = 0; i < camera_poses[episode_id].size(); i++) {
310 | json episode_aug = episode;
311 |
312 | std::string file_name = std::to_string(episode_id) + "_" + std::to_string(i) + ".pcd";
313 | std::string output_path = (fs::path(output_pcl_dir) / file_name).string();
314 | if (fs::exists(output_path)) {
315 | std::cout << "Skipping: " << output_path << std::endl;
316 | continue;
317 | }
318 |
319 | // define camera pose
320 | Eigen::Matrix4f cameraPose = Eigen::Matrix4f::Identity();
321 | for (int j = 0; j < 4; j++) {
322 | for (int k = 0; k < 4; k++) {
323 | cameraPose(j, k) = camera_poses[episode_id][i][j][k];
324 | }
325 | };
326 |
327 | // translate point cloud
328 | Eigen::Matrix4f basePose = Eigen::Matrix4f::Identity();
329 | for (int j = 0; j < 4; j++) {
330 | for (int k = 0; k < 4; k++) {
331 | basePose(j, k) = base_poses[episode_id][i][j][k];
332 | }
333 | };
334 |
335 | // Render point cloud
336 | auto renderedCloud = renderPointCloudWithOcclusion(
337 | pointCloud,
338 | cameraPose,
339 | intrinsics,
340 | 0.1f,
341 | 10.0f
342 | );
343 |
344 | // translate point cloud
345 | auto translatedCloud = translatePointCloud(renderedCloud, basePose);
346 |
347 | // translate target
348 | Eigen::Vector3f pose_se2;
349 | for (int idx = 0; idx < 3; ++idx) {
350 | pose_se2(idx) = episode_aug["pose"][idx].get();
351 | }
352 | auto translated_se2 = translateTarget(pose_se2, basePose);
353 | episode_aug["pose"] = translated_se2;
354 |
355 | // Save rendered point cloud
356 | pcl::io::savePCDFile(output_path, *translatedCloud);
357 |
358 | episode_aug["file_path"] = "pcl_aug/" + file_name;
359 | meta_aug["episodes"].push_back(episode_aug);
360 |
361 | auto timeEnd = std::chrono::high_resolution_clock::now();
362 | auto duration = std::chrono::duration(timeEnd - timeStart).count();
363 |
364 | std::cout << "Rendered point cloud saved to: " << output_path << std::endl;
365 | std::cout << "Number of points: " << renderedCloud->size() << std::endl;
366 | std::cout << "Total time: " << duration << " s" << std::endl;
367 | std::cout << "----------------------------------------" << std::endl;
368 | }
369 | }
370 |
371 | // Save augmented meta file
372 | std::ofstream output_meta_file(output_meta_path);
373 | output_meta_file << meta_aug.dump(4);
374 | std::cout << "Augmented meta file saved to: " << output_meta_path << std::endl;
375 |
376 | return 0;
377 | }
--------------------------------------------------------------------------------
/scripts/sample_camera_poses.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import open3d as o3d
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from n2m.utils.sample_utils import TargetHelper
9 |
10 | def save_pose_visualization(pcl, poses, furniture_pos, save_path):
11 | # Create point cloud for the scene
12 | pcd = o3d.geometry.PointCloud()
13 | pcd.points = o3d.utility.Vector3dVector(pcl.points)
14 | pcd.colors = o3d.utility.Vector3dVector(pcl.colors)
15 |
16 | # Add furniture position as red points
17 | furniture_pcd = o3d.geometry.PointCloud()
18 | furniture_pcd.points = o3d.utility.Vector3dVector([furniture_pos[:3]])
19 | furniture_pcd.colors = o3d.utility.Vector3dVector([[1, 0, 0]]) # Red
20 | pcd = pcd + furniture_pcd
21 |
22 | for pose in poses:
23 | # Get axis endpoints in world coordinates
24 | camera_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1)
25 | camera_frame.transform(pose)
26 | camera_frame_points = camera_frame.sample_points_uniformly(number_of_points=300)
27 | pcd = pcd + camera_frame_points
28 |
29 | # Save combined point cloud
30 | o3d.io.write_point_cloud(save_path, pcd)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--dataset_path", type=str, default="datasets/rollouts/PnPCounterToCab_BCtransformer_rollouts/PnPCounterToCab_BCtransformer_rollout_scene1/20250430230003")
36 | parser.add_argument("--num_poses", type=int, default=10)
37 | parser.add_argument("--vis", action="store_true")
38 | args = parser.parse_args()
39 |
40 | print("loading meta...")
41 | meta_path = os.path.join(args.dataset_path, "meta.json")
42 | with open(meta_path, "r") as f:
43 | meta = json.load(f)
44 | episodes = meta['episodes']
45 | print("meta loaded")
46 |
47 | T_base_to_cam = meta["meta"]["T_base_to_cam"]
48 | camera_intrinsic = meta["meta"]["camera_intrinsic"]
49 |
50 | camera_pose_save_dir = os.path.join(args.dataset_path, "camera_poses")
51 | os.makedirs(camera_pose_save_dir, exist_ok=True)
52 |
53 | if args.vis:
54 | camera_pose_vis_dir = os.path.join(args.dataset_path, "camera_poses_vis")
55 | os.makedirs(camera_pose_vis_dir, exist_ok=True)
56 |
57 | camera_poses = []
58 | base_poses = []
59 | for episode in tqdm(episodes):
60 | pcl_path = os.path.join(args.dataset_path, episode['file_path'])
61 | pcl = o3d.io.read_point_cloud(pcl_path)
62 | pose = episode['pose']
63 | se2_origin = pose
64 |
65 | # load object position if exists, otherwise set it in front of the robot
66 | if "object_position" in episode:
67 | object_position = episode['object_position']
68 | else:
69 | object_position = pose + [0.5 * np.cos(pose[2]), 0.5 * np.sin(pose[2]), 0]
70 | object_position = np.array(object_position)
71 |
72 | target_helper = TargetHelper(pcl, origin_se2=se2_origin, x_half_range=2, y_half_range=2, theta_half_range_deg=60, vis=False, camera_intrinsic=camera_intrinsic)
73 | target_helper.T_base_cam = np.array(T_base_to_cam)
74 | episode_camera_poses = []
75 | episode_base_poses = []
76 | for _ in range(args.num_poses):
77 | rel_base_se2, camera_extrinsic = target_helper.get_random_target_se2_with_visibility_check(object_position[:2])
78 | abs_base_se2 = [x + y for x, y in zip(rel_base_se2, pose)]
79 | matrix_base_se3 = target_helper.calculate_target_se3(abs_base_se2)
80 | episode_base_poses.append(matrix_base_se3.tolist())
81 | episode_camera_poses.append(camera_extrinsic.tolist())
82 |
83 | if args.vis:
84 | save_pose_visualization(pcl, episode_camera_poses, object_position, os.path.join(camera_pose_vis_dir, f"{episode['id']}.pcd"))
85 |
86 | camera_poses.append(episode_camera_poses)
87 | base_poses.append(episode_base_poses)
88 |
89 | with open(os.path.join(camera_pose_save_dir, "camera_poses.json"), "w") as f:
90 | json.dump(camera_poses, f, indent=4)
91 | with open(os.path.join(camera_pose_save_dir, "base_poses.json"), "w") as f:
92 | json.dump(base_poses, f, indent=4)
93 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.data import DataLoader
5 | import json
6 | from tqdm import tqdm
7 | import argparse
8 | import wandb
9 | import time
10 | import datetime
11 | import random
12 | import numpy as np
13 |
14 | from n2m.data.dataset import make_data_module
15 | from n2m.model.N2Mnet import N2Mnet
16 | from n2m.utils.config import *
17 | from n2m.utils.visualizer import save_gmm_visualization, save_gmm_visualization_se2, save_gmm_visualization_xythetaz
18 | from n2m.utils.loss import Loss
19 |
20 | def set_seed(seed):
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed)
24 | random.seed(seed)
25 | torch.backends.cudnn.deterministic = True
26 | torch.backends.cudnn.benchmark = False
27 |
28 | def load_config(config_path):
29 | """Load configuration from YAML file"""
30 | with open(config_path, 'r') as f:
31 | config = json.load(f)
32 | return config
33 |
34 | def train_one_epoch(model, train_loader, loss_fn, optimizer, epoch, device, train_config):
35 | model.train()
36 | total_loss = 0
37 | pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
38 |
39 | for batch in pbar:
40 | # Move data to device
41 | point_cloud = batch['point_cloud'].to(device)
42 | target_point = batch['target_point'].to(device)
43 | label = batch['label'].to(device)
44 |
45 | # Forward pass
46 | means, covs, weights = model(point_cloud)
47 | loss = loss_fn(means, covs, weights, target_point, label)
48 |
49 | # Backward pass
50 | optimizer.zero_grad()
51 | loss.backward()
52 | optimizer.step()
53 |
54 | # Update metrics
55 | total_loss += loss.item()
56 | pbar.set_postfix({'loss': loss.item()})
57 |
58 |
59 | avg_loss = total_loss / len(train_loader)
60 | wandb.log({"train/loss": avg_loss, "epoch": epoch})
61 | return avg_loss
62 |
63 | def validate(model, val_loader, loss_fn, epoch, device, val_dir,train_config):
64 | model.eval()
65 | total_loss = 0
66 |
67 | # Create visualization directory
68 | os.makedirs(val_dir, exist_ok=True)
69 |
70 | with torch.no_grad():
71 | for batch_idx, batch in enumerate(tqdm(val_loader, desc='Validation')):
72 | # Move data to device
73 | point_cloud = batch['point_cloud'].to(device)
74 | target_point = batch['target_point'].to(device)
75 | label = batch['label'].to(device)
76 |
77 | # Forward pass
78 | means, covs, weights = model(point_cloud)
79 | loss = loss_fn(means, covs, weights, target_point, label)
80 |
81 | # Update metrics
82 | total_loss += loss.item()
83 |
84 | # Generate visualization for first item in batch
85 | for i in range(len(batch['point_cloud'])):
86 | if target_point[i].shape[0] == 3:
87 | save_gmm_visualization_se2(
88 | point_cloud[i].cpu().numpy(),
89 | target_point[i].cpu().numpy(),
90 | label[i].cpu().numpy(),
91 | means[i].cpu().numpy(),
92 | covs[i].cpu().numpy(),
93 | weights[i].cpu().numpy(),
94 | os.path.join(val_dir, f'batch_{batch_idx}_{i}.pcd')
95 | )
96 | elif target_point[i].shape[0] == 4:
97 | save_gmm_visualization_xythetaz(
98 | point_cloud[i].cpu().numpy(),
99 | target_point[i].cpu().numpy(),
100 | label[i].cpu().numpy(),
101 | means[i].cpu().numpy(),
102 | covs[i].cpu().numpy(),
103 | weights[i].cpu().numpy(),
104 | os.path.join(val_dir, f'batch_{batch_idx}_{i}.pcd')
105 | )
106 |
107 |
108 | avg_loss = total_loss / len(val_loader)
109 |
110 | # Log to wandb
111 | wandb.log({"val/loss": avg_loss, "epoch": epoch})
112 |
113 | return avg_loss
114 |
115 | def save_checkpoint(model, optimizer, epoch, loss, save_path):
116 | checkpoint = {
117 | 'epoch': epoch,
118 | 'model_state_dict': model.state_dict(),
119 | 'optimizer_state_dict': optimizer.state_dict(),
120 | 'loss': loss,
121 | }
122 | torch.save(checkpoint, save_path)
123 |
124 | def get_exp_dir(train_config):
125 | t_now = time.time()
126 | time_str = datetime.datetime.fromtimestamp(t_now).strftime('%Y%m%d%H%M%S')
127 | output_dir = os.path.join(train_config['output_dir'], time_str)
128 |
129 | ckpt_dir = os.path.join(output_dir, 'ckpts')
130 | val_dir = os.path.join(output_dir, 'val')
131 | log_dir = os.path.join(output_dir, 'logs')
132 |
133 | # create output directory
134 | os.makedirs(output_dir, exist_ok=True)
135 | os.makedirs(ckpt_dir, exist_ok=True)
136 | os.makedirs(val_dir, exist_ok=True)
137 | os.makedirs(log_dir, exist_ok=True)
138 |
139 | return output_dir, ckpt_dir, val_dir, log_dir
140 |
141 | def main():
142 | parser = argparse.ArgumentParser()
143 | parser.add_argument('--config', type=str, default='configs/training/config.json', help='Path to the config file.')
144 | args = parser.parse_args()
145 |
146 | # Load configuration
147 | config = load_config(args.config)
148 | train_config = config['train']
149 | model_config = config['model']
150 | dataset_config = config['dataset']
151 |
152 | # Create output directory
153 | output_dir, ckpt_dir, val_dir, log_dir = get_exp_dir(train_config)
154 |
155 | # Initialize wandb
156 | wandb_config = train_config['wandb']
157 | wandb.init(
158 | project=wandb_config['project'],
159 | entity=wandb_config['entity'],
160 | name=wandb_config['name'],
161 | config=config,
162 | mode="online" if wandb_config['entity'] is not None else "disabled"
163 | )
164 |
165 | # save config
166 | config_save_path = os.path.join(output_dir, 'config.json')
167 | with open(config_save_path, 'w') as f:
168 | json.dump(config, f, indent=4)
169 |
170 | # Set device
171 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
172 | print(f'Using device: {device}')
173 |
174 | # Create model
175 | model = N2Mnet(
176 | config=model_config
177 | ).to(device)
178 |
179 | # Watch model with wandb
180 | wandb.watch(model, log="all", log_freq=100)
181 |
182 | # Create data loaders
183 | train_dataset, val_dataset = make_data_module(dataset_config)
184 |
185 | train_loader = DataLoader(
186 | train_dataset,
187 | batch_size=train_config['batch_size'],
188 | shuffle=True,
189 | num_workers=train_config['num_workers'],
190 | pin_memory=True
191 | )
192 |
193 | val_loader = DataLoader(
194 | val_dataset,
195 | batch_size=train_config['batch_size'],
196 | shuffle=False,
197 | num_workers=train_config['num_workers'],
198 | pin_memory=True
199 | )
200 | if len(val_loader) == 0:
201 | val_loader = train_loader
202 |
203 | # Create optimizer and scheduler
204 | optimizer = torch.optim.Adam(model.parameters(), lr=float(train_config['learning_rate']))
205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
206 | optimizer,
207 | T_max=int(train_config['num_epochs'])
208 | )
209 |
210 | # create loss function
211 | loss_config = train_config['loss']
212 | loss_fn = Loss(loss_config)
213 |
214 | # Training loop
215 | best_val_loss = float('inf')
216 | for epoch in range(train_config['num_epochs']):
217 | # Train
218 | train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch, device, train_config)
219 | print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}')
220 |
221 | # Validate
222 | if (epoch + 1) % train_config['val_freq'] == 0:
223 | val_loss = validate(model, val_loader, loss_fn, epoch, device, val_dir, train_config)
224 | print(f'Epoch {epoch}: Val Loss = {val_loss:.4f}')
225 |
226 | # Save best model
227 | if val_loss < best_val_loss:
228 | best_val_loss = val_loss
229 | save_checkpoint(
230 | model,
231 | optimizer,
232 | epoch,
233 | val_loss,
234 | os.path.join(ckpt_dir, 'best_model.pth')
235 | )
236 | save_checkpoint(
237 | model,
238 | optimizer,
239 | epoch,
240 | train_loss,
241 | os.path.join(ckpt_dir, f'model_{epoch}.pth')
242 | )
243 |
244 | # Step scheduler
245 | scheduler.step()
246 |
247 | # Save final model
248 | save_checkpoint(
249 | model,
250 | optimizer,
251 | train_config['num_epochs'],
252 | train_loss,
253 | os.path.join(ckpt_dir, 'final_model.pth')
254 | )
255 |
256 | wandb.finish()
257 |
258 | if __name__ == '__main__':
259 | main()
260 |
--------------------------------------------------------------------------------
/scripts/visualize_attention.py:
--------------------------------------------------------------------------------
1 | def visualize_attention(pts, center, attn_weights):
2 | """
3 | Visualizes attention weights on the point cloud.
4 |
5 | Args:
6 | pts (Tensor): Original input point cloud (B, N_pts, 3).
7 | center (Tensor): Center points of the point groups (B, N_groups, 3).
8 | attn_weights (Tensor): Attention matrix from the final block (B, num_heads, N, N).
9 | """
10 | # Assuming batch size is 1 for simplicity
11 | attn = attn_weights[0] # Shape: (num_heads, N, N)
12 |
13 | # Average attention across all heads for simplicity, or pick a specific head
14 | attn_avg = attn.mean(dim=0) # Shape: (N, N)
15 |
16 | # Get the attention scores of the class token (index 0) to all other tokens
17 | # Note: The first token is the cls token, subsequent tokens correspond to point groups
18 | cls_attn = attn_avg[0, 1:] # Shape: (N_groups,)
19 |
20 | # Normalize scores for better visualization (e.g., between 0 and 1)
21 | cls_attn_normalized = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
22 |
23 | # You can now use a library like Matplotlib, Mayavi, or Open3D to visualize
24 | # the points. For each point in `center`, use `cls_attn_normalized` to
25 | # determine its color (e.g., a colormap from blue to red) or size.
26 |
27 | # Here's a conceptual example using a library for visualization
28 | # Let's say we have a function `plot_points_with_colors`
29 | # You would need to implement this part based on your chosen library.
30 |
31 | # For visualization, you'd plot `center` points. The `cls_attn_normalized`
32 | # provides the scalar value for the heatmap.
33 |
34 | # Example for coloring points:
35 | # colors = plt.cm.jet(cls_attn_normalized.numpy())[:, :3] # Get RGB from colormap
36 | # plot_points_with_colors(center.numpy()[0], colors)
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser()
40 |
41 | # Assuming you have loaded your model and data
42 | model = PointTransformer(config)
43 | model.load_checkpoint(bert_ckpt_path)
44 |
45 | # Prepare a single ego-centric point cloud for visualization
46 | # This should be a point cloud from your dataset.
47 | input_pts = ... # Your point cloud data, shape (1, N_pts, 3)
48 |
49 | # Run the forward pass to get the output and attention weights
50 | output_features, attn_weights = model(input_pts)
51 |
52 | # Get the center points from the group divider (you might need to run it again)
53 | _, center = model.group_divider(input_pts)
54 |
55 | # Call the visualization function
56 | visualize_attention(input_pts, center, attn_weights)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="n2m",
5 | packages=[
6 | package for package in find_packages() if package.startswith("n2m")
7 | ],
8 | install_requires=[],
9 | version="0.0.1",
10 | )
--------------------------------------------------------------------------------