├── models
├── __init__.py
└── minkunet.py
├── trains
├── __init__.py
└── DistillationDPO.py
├── utils
├── __init__.py
├── ema.py
├── scheduling.py
├── render.py
├── EMD.py
├── histogram_metrics.py
├── data_map.py
├── collations.py
├── pcd_transforms.py
├── eval_path_get_pics.py
├── pcd_preprocess.py
├── eval_path.py
├── metrics.py
└── diff_completion_pipeline.py
├── datasets
├── SemanticKITTI_dataloader
│ ├── __init__.py
│ ├── SemanticKITTITemporalAggr.py
│ └── SemanticKITTITemporal.py
└── SemanticKITTI_dataset.py
├── pics
└── teaser.png
├── .gitignore
├── setup.py
├── map_from_scans.py
└── README.md
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/trains/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/SemanticKITTI_dataloader/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pics/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/happyw1nd/DistillationDPO/HEAD/pics/teaser.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # experiment results
2 | exp/
3 |
4 | # model weights
5 | checkpoints/
6 |
7 | *.egg-info/
8 | *.egg/
9 |
10 | **/__pycache__/
11 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | pkg_name = 'DistillationDPO'
4 | setup(name=pkg_name, version='1.0', packages=find_packages())
5 |
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class EMA:
4 | def __init__(self, model, decay=0.999):
5 | self.model = model
6 | self.decay = decay
7 | self.shadow = {}
8 | self._initialize()
9 |
10 | def _initialize(self):
11 | for name, param in self.model.named_parameters():
12 | if param.requires_grad:
13 | self.shadow[name] = param.clone().detach().to(param.device)
14 |
15 | def update(self):
16 | """Update the shadow weights using Exponential Moving Average (EMA)."""
17 | with torch.no_grad():
18 | for name, param in self.model.named_parameters():
19 | if param.requires_grad:
20 | self.shadow[name] = (self.decay * self.shadow[name].to(param.device) + (1.0 - self.decay) * param)
21 |
22 | def apply(self):
23 | """Apply the EMA weights to the model."""
24 | for name, param in self.model.named_parameters():
25 | if param.requires_grad:
26 | param.data = self.shadow[name].to(param.device)
27 |
--------------------------------------------------------------------------------
/utils/scheduling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from numpy import pi
3 |
4 | def cosine_beta_schedule(timesteps, s=0.008):
5 | """
6 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672
7 | """
8 | steps = timesteps + 1
9 | x = torch.linspace(0, timesteps, steps)
10 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * pi * 0.5) ** 2
11 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
12 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
13 | return torch.clip(betas, 0.0001, 0.9999)
14 |
15 | def linear_beta_schedule(timesteps, beta_start, beta_end):
16 | return torch.linspace(beta_start, beta_end, timesteps)
17 |
18 | def quadratic_beta_schedule(timesteps, beta_start, beta_end):
19 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
20 |
21 | def sigmoid_beta_schedule(timesteps, beta_start, beta_end):
22 | betas = torch.linspace(-6, 6, timesteps)
23 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
24 |
25 | beta_func = {
26 | 'cosine': cosine_beta_schedule,
27 | 'linear': linear_beta_schedule,
28 | 'quadratic': quadratic_beta_schedule,
29 | 'sigmoid': sigmoid_beta_schedule,
30 | }
31 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["PYOPENGL_PLATFORM"] = "osmesa"
3 |
4 | import numpy as np
5 | import pyrender
6 | import trimesh
7 | from PIL import Image
8 | import open3d as o3d
9 | import matplotlib.pyplot as plt
10 | import matplotlib.colors as mcolors
11 |
12 | def offscreen_render(pcd, output_name):
13 |
14 | # load point cloud
15 | points = np.asarray(pcd.points)
16 | z_values = points[:, 2]
17 |
18 | lower_bound = np.percentile(z_values, 1)
19 | upper_bound = np.percentile(z_values, 99)
20 | z_values = np.clip(z_values, lower_bound, upper_bound)
21 |
22 | norm = mcolors.Normalize(vmin=z_values.min(), vmax=z_values.max())
23 | cmap = plt.get_cmap("gist_rainbow_r")
24 | colors = cmap(norm(z_values))
25 | colors = colors * 0.5
26 |
27 | # object
28 | m = pyrender.Mesh.from_points(points, colors=colors)
29 |
30 | # light
31 | dl = pyrender.SpotLight(color=[1.0, 1.0, 1.0], intensity=2.0,innerConeAngle=0.05, outerConeAngle=0.5)
32 | light_theta = np.radians(-40)
33 | light_pose = np.array([
34 | [1.0, 0.0, 0.0, 0.0], # x axis ←
35 | [0.0, np.cos(light_theta), -np.sin(light_theta), 0.0], # y axis ↓
36 | [0.0, np.sin(light_theta), np.cos(light_theta), 0.0], # z axis ·
37 | [0.0, 0.0, 0.0, 1.0]
38 | ])
39 |
40 | # camera
41 | pc = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
42 | camera_theta = np.radians(20)
43 | camera_pose_rot = np.array([
44 | [1.0, 0.0, 0.0, 0.0], # x axis ←
45 | [0.0, np.cos(camera_theta), -np.sin(camera_theta), 0.0], # y axis ↓
46 | [0.0, np.sin(camera_theta), np.cos(camera_theta), 0.0], # z axis ·
47 | [0.0, 0.0, 0.0, 1.0]
48 | ])
49 | camera_pose_trans = np.array([
50 | [1.0, 0.0, 0.0, 0.0],
51 | [0.0, 1.0, 0.0, 0.0],
52 | [0.0, 0.0, 1.0, 50.0],
53 | [0.0, 0.0, 0.0, 1.0]
54 | ])
55 | camera_pose = np.matmul(camera_pose_rot, camera_pose_trans)
56 |
57 | # scene
58 | scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0], bg_color=[1.0, 1.0, 1.0])
59 | scene.add(m)
60 | scene.add(dl, pose=light_pose)
61 | scene.add(pc, pose=camera_pose)
62 |
63 | # renderer
64 | r = pyrender.OffscreenRenderer(viewport_width=1000, viewport_height=700, point_size=3.0)
65 | flags = pyrender.RenderFlags.SHADOWS_ALL
66 | color, depth = r.render(scene, flags=flags)
67 | r.delete()
68 |
69 | # save pics
70 | Image.fromarray(color).save(output_name)
71 |
--------------------------------------------------------------------------------
/utils/EMD.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 | from tqdm import tqdm
4 |
5 | def voxelize_point_cloud(points, voxel_size):
6 | if points.size == 0:
7 | raise ValueError()
8 |
9 | voxel_indices = np.floor(points / voxel_size).astype(int)
10 |
11 | unique_voxels, voxel_counts = np.unique(voxel_indices, axis=0, return_counts=True)
12 |
13 | weights = voxel_counts / voxel_counts.sum()
14 |
15 | voxel_centers = (unique_voxels + 0.5) * voxel_size
16 |
17 | return voxel_centers, weights
18 |
19 | def sinkhorn_knopp_emd(source_points, target_points, source_weights, target_weights, epsilon=0.01, max_iter=3000, tol=1e-3):
20 |
21 | if source_points.shape[1] != target_points.shape[1]:
22 | raise ValueError()
23 | if not (np.isclose(source_weights.sum(), 1) and np.isclose(target_weights.sum(), 1)):
24 | raise ValueError()
25 |
26 | cost_matrix = cdist(source_points, target_points, metric='euclidean')
27 |
28 | # Sinkhorn-Knopp
29 | K = np.exp(-cost_matrix / epsilon)
30 | K += 1e-9
31 |
32 | u = np.ones_like(source_weights)
33 | v = np.ones_like(target_weights)
34 |
35 | with tqdm(total=max_iter, desc="Running Sinkhorn", unit="iter") as pbar:
36 | for _ in range(max_iter):
37 | pbar.update(1)
38 | u_prev = u.copy()
39 | u = source_weights / (K @ v)
40 | v = target_weights / (K.T @ u)
41 |
42 | diff = np.linalg.norm(u - u_prev, 1)
43 | pbar.set_postfix({'diff': f"{diff:.5e}"})
44 | if diff < tol:
45 | break
46 |
47 | transport_matrix = np.outer(u, v) * K
48 | emd_approx = np.sum(transport_matrix * cost_matrix)
49 |
50 | return emd_approx
51 |
52 | def calc_EMD_with_sinkhorn_knopp(point_cloud_1, point_cloud_2, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4):
53 |
54 | voxel_centers_1, weights_1 = voxelize_point_cloud(point_cloud_1, voxel_size)
55 | voxel_centers_2, weights_2 = voxelize_point_cloud(point_cloud_2, voxel_size)
56 |
57 | # clip
58 | if voxel_centers_1.shape[0] + voxel_centers_2.shape[0] > 130000:
59 | return None
60 |
61 | return sinkhorn_knopp_emd(voxel_centers_1, voxel_centers_2, weights_1, weights_2, epsilon, max_iter, tol)
62 |
63 | if __name__ == "__main__":
64 |
65 | num_points_1 = 10000
66 | num_points_2 = 120000
67 | point_cloud_1 = np.random.rand(num_points_1, 3) * 15
68 | point_cloud_2 = (np.random.rand(num_points_2, 3)-0.5) * 15
69 |
70 | emd_result = calc_EMD_with_sinkhorn_knopp(point_cloud_1, point_cloud_2, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4)
71 |
72 | print(f"Approximate Earth Mover's Distance (Sinkhorn): {emd_result:.4f}")
73 |
--------------------------------------------------------------------------------
/datasets/SemanticKITTI_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | from pytorch_lightning import LightningDataModule
4 | from datasets.SemanticKITTI_dataloader.SemanticKITTITemporal import TemporalKITTISet
5 | from utils.collations import SparseSegmentCollation
6 | import warnings
7 |
8 | warnings.filterwarnings('ignore')
9 |
10 | __all__ = ['TemporalKittiDataModule']
11 |
12 | class TemporalKittiDataModule(LightningDataModule):
13 | def __init__(self, args):
14 | super().__init__()
15 | self.args = args
16 |
17 | def prepare_data(self):
18 | # Augmentations
19 | pass
20 |
21 | def setup(self, stage=None):
22 | # Create datasets
23 | pass
24 |
25 | def train_dataloader(self):
26 | collate = SparseSegmentCollation()
27 |
28 | data_set = TemporalKITTISet(
29 | data_dir=self.args.SemanticKITTI_path,
30 | seqs=[ '00', '01', '02', '03', '04', '05', '06', '07', '09', '10' ],
31 | split='train',
32 | resolution=0.05,
33 | num_points=180000,
34 | max_range=50.0,
35 | dataset_norm=False,
36 | std_axis_norm=False)
37 | loader = DataLoader(data_set, batch_size=self.args.batch_size, shuffle=True,
38 | num_workers=4, collate_fn=collate)
39 | return loader
40 |
41 | def val_dataloader(self, pre_training=True):
42 | collate = SparseSegmentCollation()
43 |
44 | data_set = TemporalKITTISet(
45 | data_dir=self.args.SemanticKITTI_path,
46 | seqs=[ '08' ],
47 | split='validation',
48 | resolution=0.05,
49 | num_points=180000,
50 | max_range=50.0,
51 | dataset_norm=False,
52 | std_axis_norm=False)
53 | loader = DataLoader(data_set, batch_size=1,#self.cfg['train']['batch_size'],
54 | shuffle=False,
55 | num_workers=4,
56 | collate_fn=collate)
57 | return loader
58 |
59 | def test_dataloader(self):
60 | collate = SparseSegmentCollation()
61 |
62 | data_set = TemporalKITTISet(
63 | data_dir=self.args.SemanticKITTI_path,
64 | seqs=[ '08' ],
65 | split='validation',
66 | resolution=0.05,
67 | num_points=180000,
68 | max_range=50.0,
69 | dataset_norm=False,
70 | std_axis_norm=False)
71 | loader = DataLoader(data_set, batch_size=self.args.batch_size, shuffle=False,
72 | num_workers=4, collate_fn=collate)
73 | return loader
74 |
75 | dataloaders = {
76 | 'KITTI': TemporalKittiDataModule,
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/utils/histogram_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | from scipy.spatial.distance import jensenshannon
4 | from utils.metrics import ChamferDistance, PrecisionRecall
5 | import matplotlib.pyplot as plt
6 |
7 | def histogram_point_cloud(pcd, resolution, max_range, bev=False):
8 | # get bins size by the number of voxels in the pcd
9 | bins = int(2 * max_range / resolution)
10 |
11 | hist = np.histogramdd(pcd, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range]))
12 |
13 | return np.clip(hist[0], a_min=0., a_max=1.) if bev else hist[0]
14 |
15 | def compute_jsd(hist_gt, hist_pred, bev=False, visualize=False):
16 | bev_gt = hist_gt.sum(-1) if bev else hist_gt
17 | norm_bev_gt = bev_gt / bev_gt.sum()
18 | norm_bev_gt = norm_bev_gt.flatten()
19 |
20 | bev_pred = hist_pred.sum(-1) if bev else hist_pred
21 | norm_bev_pred = bev_pred / bev_pred.sum()
22 | norm_bev_pred = norm_bev_pred.flatten()
23 |
24 | if visualize:
25 | # for visualization purposes
26 | grid = np.meshgrid(np.arange(len(hist_gt)), np.arange(len(hist_gt)))
27 | points = np.concatenate((grid[0].flatten()[:,None], grid[1].flatten()[:,None]), axis=-1)
28 | points = np.concatenate((points, np.zeros((len(points),1))),axis=-1)
29 |
30 | # build bev histogram gt view
31 | norm_hist_gt = bev_gt / bev_gt.max()
32 | colors_gt = plt.get_cmap('viridis')(norm_hist_gt)
33 | pcd_gt = o3d.geometry.PointCloud()
34 | pcd_gt.points = o3d.utility.Vector3dVector(points)
35 | pcd_gt.colors = o3d.utility.Vector3dVector(colors_gt.reshape(-1,4)[:,:3])
36 |
37 | # build bev histogram pred view
38 | norm_hist_pred = bev_pred / bev_pred.max()
39 | colors_pred = plt.get_cmap('viridis')(norm_hist_pred)
40 | pcd_pred = o3d.geometry.PointCloud()
41 | pcd_pred.points = o3d.utility.Vector3dVector(points)
42 | pcd_pred.colors = o3d.utility.Vector3dVector(colors_pred.reshape(-1,4)[:,:3])
43 |
44 | return jensenshannon(norm_bev_gt, norm_bev_pred)
45 |
46 |
47 | def compute_hist_metrics(pcd_gt, pcd_pred, bev=False):
48 | hist_pred = histogram_point_cloud(np.array(pcd_pred.points), 0.5, 50., bev)
49 | hist_gt = histogram_point_cloud(np.array(pcd_gt.points), 0.5, 50., bev)
50 |
51 | return compute_jsd(hist_gt, hist_pred, bev)
52 |
53 | def compute_chamfer(pcd_pred, pcd_gt):
54 | chamfer_distance = ChamferDistance()
55 | chamfer_distance.update(pcd_gt, pcd_pred)
56 | cd_pred_mean, cd_pred_std = chamfer_distance.compute()
57 |
58 | return cd_pred_mean
59 |
60 | def compute_precision_recall(pcd_pred, pcd_gt):
61 | precision_recall = PrecisionRecall(0.05,2*0.05,100)
62 | precision_recall.update(pcd_gt, pcd_pred)
63 | pr, re, f1 = precision_recall.compute_auc()
64 |
65 | return pr, re, f1
66 |
67 | def preprocess_pcd(pcd):
68 | points = np.array(pcd.points)
69 | dist = np.sqrt(np.sum(points**2, axis=-1))
70 | pcd.points = o3d.utility.Vector3dVector(points[dist < 30.])
71 |
72 | return pcd
73 |
74 | def compute_metrics(pred_path, gt_path):
75 | pcd_pred = preprocess_pcd(o3d.io.read_point_cloud(pred_path))
76 | points_pred = np.array(pcd_pred.points)
77 | pcd_gt = preprocess_pcd(o3d.io.read_point_cloud(gt_path))
78 | points_gt = np.array(pcd_gt.points)
79 |
80 | jsd_pred = compute_hist_metrics(points_pred, points_gt)
81 |
82 | cd_pred = compute_chamfer(pcd_pred, pcd_gt)
83 |
84 | pr_pred, re_pred, f1_pred = compute_precision_recall(pcd_pred, pcd_gt)
85 |
86 | return cd_pred, pr_pred, re_pred, f1_pred
87 |
88 |
--------------------------------------------------------------------------------
/utils/data_map.py:
--------------------------------------------------------------------------------
1 | learning_map = {
2 | 0 : 0, # "unlabeled"
3 | 1 : 0, # "outlier" mapped to "unlabeled" --------------------------mapped
4 | 10: 1, # "car"
5 | 11: 2, # "bicycle"
6 | 13: 5, # "bus" mapped to "other-vehicle" --------------------------mapped
7 | 15: 3, # "motorcycle"
8 | 16: 5, # "on-rails" mapped to "other-vehicle" ---------------------mapped
9 | 18: 4, # "truck"
10 | 20: 5, # "other-vehicle"
11 | 30: 6, # "person"
12 | 31: 7, # "bicyclist"
13 | 32: 8, # "motorcyclist"
14 | 40: 9, # "road"
15 | 44: 10, # "parking"
16 | 48: 11, # "sidewalk"
17 | 49: 12, # "other-ground"
18 | 50: 13, # "building"
19 | 51: 14, # "fence"
20 | 52: 0, # "other-structure" mapped to "unlabeled" ------------------mapped
21 | 60: 9, # "lane-marking" to "road" ---------------------------------mapped
22 | 70: 15, # "vegetation"
23 | 71: 16, # "trunk"
24 | 72: 17, # "terrain"
25 | 80: 18, # "pole"
26 | 81: 19, # "traffic-sign"
27 | 99: 0, # "other-object" to "unlabeled" ----------------------------mapped
28 | 252: 1, # "moving-car"
29 | 253: 7, # "moving-bicyclist"
30 | 254: 6, # "moving-person"
31 | 255: 8, # "moving-motorcyclist"
32 | 256: 5, # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped
33 | 257: 5, # "moving-bus" mapped to "moving-other-vehicle" -----------mapped
34 | 258: 4, # "moving-truck"
35 | 259: 5, # "moving-other-vehicle"
36 | }
37 |
38 | content = { # as a ratio with the total number of points
39 | 0: 0.03150183342534689,
40 | 1: 0.042607828674502385,
41 | 2: 0.00016609538710764618,
42 | 3: 0.00039838616015114444,
43 | 4: 0.0021649398241338114,
44 | 5: 0.0018070552978863615,
45 | 6: 0.0003375832743104974,
46 | 7: 0.00012711105887399155,
47 | 8: 3.746106399997359e-05,
48 | 9: 0.19879647126983288,
49 | 10: 0.014717169549888214,
50 | 11: 0.14392298360372,
51 | 12: 0.0039048553037472045,
52 | 13: 0.1326861944777486,
53 | 14: 0.0723592229456223,
54 | 15: 0.26681502148037506,
55 | 16: 0.006035012012626033,
56 | 17: 0.07814222006271769,
57 | 18: 0.002855498193863172,
58 | 19: 0.0006155958086189918,
59 | }
60 |
61 | content_indoor = {
62 | 0: 0.18111755628849344,
63 | 1: 0.15350115272756307,
64 | 2: 0.264323444618407,
65 | 3: 0.017095487624768667,
66 | 4: 0.02018415055214108,
67 | 5: 0.025684283218171625,
68 | 6: 0.05237503359636922,
69 | 7: 0.03495118545614923,
70 | 8: 0.04252921527371275,
71 | 9: 0.004767541066020183,
72 | 10: 0.06899976905686542,
73 | 11: 0.012345517150886037,
74 | 12: 0.12212566337045223,
75 | }
76 |
77 | labels = {
78 | 0: "unlabeled",
79 | 1: "car",
80 | 2: "bicycle",
81 | 3: "motorcycle",
82 | 4: "truck",
83 | 5: "other-vehicle",
84 | 6: "person",
85 | 7: "bicyclist",
86 | 8: "motorcyclist",
87 | 9: "road",
88 | 10: "parking",
89 | 11: "sidewalk",
90 | 12: "other-ground",
91 | 13: "building",
92 | 14: "fence",
93 | 15: "vegetation",
94 | 16: "trunk",
95 | 17: "terrain",
96 | 18: "pole",
97 | 19: "traffic-sign",
98 | }
99 |
100 | color_map = {
101 | 0: [0, 0, 0],
102 | 1: [245, 150, 100],
103 | 2: [245, 230, 100],
104 | 3: [150, 60, 30],
105 | 4: [180, 30, 80],
106 | 5: [255, 0, 0],
107 | 6: [30, 30, 255],
108 | 7: [200, 40, 255],
109 | 8: [90, 30, 150],
110 | 9: [255, 0, 255],
111 | 10: [255, 150, 255],
112 | 11: [75, 0, 75],
113 | 12: [75, 0, 175],
114 | 13: [0, 200, 255],
115 | 14: [50, 120, 255],
116 | 15: [0, 175, 0],
117 | 16: [0, 60, 135],
118 | 17: [80, 240, 150],
119 | 18: [150, 240, 255],
120 | 19: [0, 0, 255],
121 | }
122 |
--------------------------------------------------------------------------------
/map_from_scans.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import open3d as o3d
5 | from natsort import natsorted
6 | import click
7 | import tqdm
8 | import MinkowskiEngine as ME
9 |
10 | def parse_calibration(filename):
11 | calib = {}
12 |
13 | calib_file = open(filename)
14 | for line in calib_file:
15 | key, content = line.strip().split(":")
16 | values = [float(v) for v in content.strip().split()]
17 |
18 | pose = np.zeros((4, 4))
19 | pose[0, 0:4] = values[0:4]
20 | pose[1, 0:4] = values[4:8]
21 | pose[2, 0:4] = values[8:12]
22 | pose[3, 3] = 1.0
23 |
24 | calib[key] = pose
25 |
26 | calib_file.close()
27 |
28 | return calib
29 |
30 | def load_poses(calib_fname, poses_fname):
31 | if os.path.exists(calib_fname):
32 | calibration = parse_calibration(calib_fname)
33 | Tr = calibration["Tr"]
34 | Tr_inv = np.linalg.inv(Tr)
35 |
36 | poses_file = open(poses_fname)
37 | poses = []
38 |
39 | for line in poses_file:
40 | values = [float(v) for v in line.strip().split()]
41 |
42 | pose = np.zeros((4, 4))
43 | pose[0, 0:4] = values[0:4]
44 | pose[1, 0:4] = values[4:8]
45 | pose[2, 0:4] = values[8:12]
46 | pose[3, 3] = 1.0
47 |
48 | if os.path.exists(calib_fname):
49 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
50 | else:
51 | poses.append(pose)
52 |
53 | return poses
54 |
55 | @click.command()
56 | @click.option('--path', '-p', type=str, help='path to the scan sequence')
57 | @click.option('--voxel_size', '-v', type=float, default=0.1, help='voxel size')
58 | @click.option('--cpu', '-c', is_flag=True, help='Use CPU')
59 | def main(path, voxel_size, cpu):
60 | device_label = 'cuda' if torch.cuda.is_available() else 'cpu'
61 | device_label = 'cpu' if cpu else device_label
62 | device = torch.device(device_label)
63 | for seq in ['00','01','02','03','04','05','06','07','08','09','10']:
64 | map_points = torch.empty((0,3)).to(device)
65 |
66 | poses = load_poses(os.path.join(path, seq, 'calib.txt'), os.path.join(path, seq, 'poses.txt'))
67 | for pose, pcd_path in tqdm.tqdm(list(zip(poses, natsorted(os.listdir(os.path.join(path, seq, 'velodyne')))))):
68 | pose = torch.from_numpy(pose).float().to(device)
69 | pcd_file = os.path.join(path, seq, 'velodyne', pcd_path)
70 | points = torch.from_numpy(np.fromfile(pcd_file, dtype=np.float32)).to(device)
71 | points = points.reshape(-1,4)
72 |
73 | label_file = pcd_file.replace('velodyne', 'labels').replace('.bin', '.label')
74 | l_set = np.fromfile(label_file, dtype=np.uint32)
75 | l_set = l_set.reshape((-1))
76 | l_set = l_set & 0xFFFF
77 |
78 | # remove moving points
79 | static_idx = (l_set < 252) & (l_set > 1)
80 | points = points[static_idx]
81 |
82 | # remove flying artifacts
83 | dist = torch.pow(points, 2)
84 | dist = torch.sqrt(dist.sum(-1))
85 | points = points[dist > 3.5]
86 |
87 | points[:,-1] = 1.
88 | points = points @ pose.T
89 |
90 | map_points = torch.cat((map_points, points[:,:3]), axis=0)
91 | _, mapping = ME.utils.sparse_quantize(coordinates=map_points / voxel_size, return_index=True, device=device_label)
92 | map_points = map_points[mapping]
93 |
94 |
95 | print(f'saving map for sequence {seq}')
96 | np.save(os.path.join(path, seq, 'map_clean.npy'), map_points.cpu().numpy())
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
--------------------------------------------------------------------------------
/utils/collations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import MinkowskiEngine as ME
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import open3d as o3d
7 |
8 | def feats_to_coord(p_feats, resolution, batch_size):
9 | p_feats = p_feats.reshape(batch_size,-1,3)
10 | p_coord = torch.round(p_feats / resolution)
11 |
12 | return p_coord.reshape(-1,3)
13 |
14 | def normalize_pcd(points, mean, std):
15 | return (points - mean[:,None,:]) / std[:,None,:] if len(mean.shape) == 2 else (points - mean) / std
16 |
17 | def unormalize_pcd(points, mean, std):
18 | return (points * std[:,None,:]) + mean[:,None,:] if len(mean.shape) == 2 else (points * std) + mean
19 |
20 | def point_set_to_sparse_refine(p_full, p_part, n_full, n_part, resolution, filename):
21 | concat_full = np.ceil(n_full / p_full.shape[0])
22 | concat_part = np.ceil(n_part / p_part.shape[0])
23 |
24 | #if mode == 'diffusion':
25 | #p_full = p_full[torch.randperm(p_full.shape[0])]
26 | #p_part = p_part[torch.randperm(p_part.shape[0])]
27 | #elif mode == 'refine':
28 | p_full = p_full[torch.randperm(p_full.shape[0])]
29 | p_full = torch.tensor(p_full.repeat(concat_full, 0)[:n_full])
30 |
31 | p_part = p_part[torch.randperm(p_part.shape[0])]
32 | p_part = torch.tensor(p_part.repeat(concat_part, 0)[:n_part])
33 |
34 | #p_feats = ME.utils.batched_coordinates([p_feats], dtype=torch.float32)[:2000]
35 |
36 | # after creating the voxel coordinates we normalize the floating coordinates towards mean=0 and std=1
37 | p_mean, p_std = p_full.mean(axis=0), p_full.std(axis=0)
38 |
39 | return [p_full, p_mean, p_std, p_part, filename]
40 |
41 | def point_set_to_sparse(p_full, p_part, n_full, n_part, resolution, filename, p_mean=None, p_std=None):
42 | concat_part = np.ceil(n_part / p_part.shape[0])
43 | p_part = p_part.repeat(concat_part, 0)
44 | pcd_part = o3d.geometry.PointCloud()
45 | pcd_part.points = o3d.utility.Vector3dVector(p_part)
46 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd_part, voxel_size=10.)
47 | pcd_part = pcd_part.farthest_point_down_sample(n_part)
48 | p_part = torch.tensor(np.array(pcd_part.points))
49 |
50 | in_viewpoint = viewpoint_grid.check_if_included(o3d.utility.Vector3dVector(p_full))
51 | p_full = p_full[in_viewpoint]
52 | concat_full = np.ceil(n_full / p_full.shape[0])
53 |
54 | p_full = p_full[torch.randperm(p_full.shape[0])]
55 | p_full = p_full.repeat(concat_full, 0)[:n_full]
56 |
57 | p_full = torch.tensor(p_full)
58 |
59 | # after creating the voxel coordinates we normalize the floating coordinates towards mean=0 and std=1
60 | p_mean = p_full.mean(axis=0) if p_mean is None else p_mean
61 | p_std = p_full.std(axis=0) if p_std is None else p_std
62 |
63 | return [p_full, p_mean, p_std, p_part, filename]
64 |
65 | def numpy_to_sparse_tensor(p_coord, p_feats, p_label=None):
66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67 | p_coord = ME.utils.batched_coordinates(p_coord, dtype=torch.float32)
68 | p_feats = torch.vstack(p_feats).float()
69 |
70 | if p_label is not None:
71 | p_label = ME.utils.batched_coordinates(p_label, device=torch.device('cpu')).numpy()
72 |
73 | return ME.SparseTensor(
74 | features=p_feats,
75 | coordinates=p_coord,
76 | device=device,
77 | ), p_label
78 |
79 | return ME.SparseTensor(
80 | features=p_feats,
81 | coordinates=p_coord,
82 | device=device,
83 | )
84 |
85 | class SparseSegmentCollation:
86 | def __init__(self, mode='diffusion'):
87 | self.mode = mode
88 | return
89 |
90 | def __call__(self, data):
91 | # "transpose" the batch(pt, ptn) to batch(pt), batch(ptn)
92 | batch = list(zip(*data))
93 |
94 | return {'pcd_full': torch.stack(batch[0]).float(),
95 | 'mean': torch.stack(batch[1]).float(),
96 | 'std': torch.stack(batch[2]).float(),
97 | 'pcd_part' if self.mode == 'diffusion' else 'pcd_noise': torch.stack(batch[3]).float(),
98 | 'filename': batch[4],
99 | }
100 |
--------------------------------------------------------------------------------
/datasets/SemanticKITTI_dataloader/SemanticKITTITemporalAggr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from utils.pcd_preprocess import clusterize_pcd, visualize_pcd_clusters, point_set_to_coord_feats, overlap_clusters, aggregate_pcds
4 | from utils.pcd_transforms import *
5 | from utils.data_map import learning_map
6 | from utils.collations import point_set_to_sparse_refine
7 | import os
8 | import numpy as np
9 | import MinkowskiEngine as ME
10 |
11 | import warnings
12 |
13 | warnings.filterwarnings('ignore')
14 |
15 | #################################################
16 | ################## Data loader ##################
17 | #################################################
18 |
19 | class TemporalKITTISet(Dataset):
20 | def __init__(self, data_dir, scan_window, seqs, split, resolution, num_points, mode):
21 | super().__init__()
22 | self.data_dir = data_dir
23 | self.augmented_dir = 'segments_views'
24 |
25 | self.n_clusters = 50
26 | self.resolution = resolution
27 | self.scan_window = scan_window
28 | self.num_points = num_points
29 | self.seg_batch = True
30 |
31 | self.split = split
32 | self.seqs = seqs
33 | self.mode = mode
34 |
35 | # list of (shape_name, shape_txt_file_path) tuple
36 | self.datapath_list()
37 |
38 | self.nr_data = len(self.points_datapath)
39 |
40 | print('The size of %s data is %d'%(self.split,len(self.points_datapath)))
41 |
42 | def datapath_list(self):
43 | self.points_datapath = []
44 |
45 | for seq in self.seqs:
46 | point_seq_path = os.path.join(self.data_dir, 'dataset', 'sequences', seq, 'velodyne')
47 | point_seq_bin = os.listdir(point_seq_path)
48 | point_seq_bin.sort()
49 |
50 | for file_num in range(0, len(point_seq_bin)):
51 | # we guarantee that the end of sequence will not generate single scans as aggregated pcds
52 | end_file = file_num + self.scan_window if len(point_seq_bin) - file_num > 1.5 * self.scan_window else len(point_seq_bin)
53 | self.points_datapath.append([os.path.join(point_seq_path, point_file) for point_file in point_seq_bin[file_num:end_file] ])
54 | if end_file == len(point_seq_bin):
55 | break
56 |
57 | #self.points_datapath = self.points_datapath[:200]
58 |
59 | def transforms(self, points):
60 | points = points[None,...]
61 |
62 | points[:,:,:3] = rotate_point_cloud(points[:,:,:3])
63 | points[:,:,:3] = rotate_perturbation_point_cloud(points[:,:,:3])
64 | points[:,:,:3] = random_scale_point_cloud(points[:,:,:3])
65 | points[:,:,:3] = random_flip_point_cloud(points[:,:,:3])
66 |
67 | return points[0]
68 |
69 | def __getitem__(self, index):
70 | #index = 500
71 | seq_num = self.points_datapath[index][0].split('/')[-3]
72 | fname = self.points_datapath[index][0].split('/')[-1].split('.')[0]
73 |
74 | #t_frame = np.random.randint(len(self.points_datapath[index]))
75 | t_frame = int(len(self.points_datapath[index]) / 2)
76 | p_full, p_part = aggregate_pcds(self.points_datapath[index], self.data_dir, t_frame)
77 |
78 | p_concat = np.concatenate((p_full, p_part), axis=0)
79 | p_gt = p_concat.copy()
80 | p_concat = self.transforms(p_concat) if self.split == 'train' else p_concat
81 |
82 | p_full = p_concat.copy()
83 | p_noise = jitter_point_cloud(p_concat[None,:,:3], sigma=.2, clip=.3)[0]
84 | dist_noise = np.power(p_noise, 2)
85 | dist_noise = np.sqrt(dist_noise.sum(-1))
86 |
87 | _, mapping = ME.utils.sparse_quantize(coordinates=p_full / 0.1, return_index=True)
88 | p_full = p_full[mapping]
89 | dist_full = np.power(p_full, 2)
90 | dist_full = np.sqrt(dist_full.sum(-1))
91 |
92 | return point_set_to_sparse_refine(
93 | p_full[dist_full < 50.],
94 | p_noise[dist_noise < 50.],
95 | self.num_points*2,
96 | self.num_points,
97 | self.resolution,
98 | self.points_datapath[index],
99 | )
100 |
101 | def __len__(self):
102 | #print('DATA SIZE: ', np.floor(self.nr_data / self.sampling_window), self.nr_data % self.sampling_window)
103 | return self.nr_data
104 |
105 | ##################################################################################################
106 |
--------------------------------------------------------------------------------
/datasets/SemanticKITTI_dataloader/SemanticKITTITemporal.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from utils.pcd_preprocess import point_set_to_coord_feats, aggregate_pcds, load_poses
4 | from utils.pcd_transforms import *
5 | from utils.data_map import learning_map
6 | from utils.collations import point_set_to_sparse
7 | from natsort import natsorted
8 | import os
9 | import numpy as np
10 | import yaml
11 |
12 | import warnings
13 |
14 | warnings.filterwarnings('ignore')
15 |
16 | #################################################
17 | ################## Data loader ##################
18 | #################################################
19 |
20 | class TemporalKITTISet(Dataset):
21 | def __init__(self, data_dir, seqs, split, resolution, num_points, max_range, dataset_norm=False, std_axis_norm=False):
22 | super().__init__()
23 | self.data_dir = data_dir
24 |
25 | self.n_clusters = 50
26 | self.resolution = resolution
27 | self.num_points = num_points
28 | self.max_range = max_range
29 |
30 | self.split = split
31 | self.seqs = seqs
32 | self.cache_maps = {}
33 |
34 | # list of (shape_name, shape_txt_file_path) tuple
35 | self.datapath_list()
36 | self.data_stats = {'mean': None, 'std': None}
37 |
38 | if os.path.isfile(f'utils/data_stats_range_{int(self.max_range)}m.yml') and dataset_norm:
39 | stats = yaml.safe_load(open(f'utils/data_stats_range_{int(self.max_range)}m.yml'))
40 | data_mean = np.array([stats['mean_axis']['x'], stats['mean_axis']['y'], stats['mean_axis']['z']])
41 | if std_axis_norm:
42 | data_std = np.array([stats['std_axis']['x'], stats['std_axis']['y'], stats['std_axis']['z']])
43 | else:
44 | data_std = np.array([stats['std'], stats['std'], stats['std']])
45 | self.data_stats = {
46 | 'mean': torch.tensor(data_mean),
47 | 'std': torch.tensor(data_std)
48 | }
49 |
50 | self.nr_data = len(self.points_datapath)
51 |
52 | print('The size of %s data is %d'%(self.split,len(self.points_datapath)))
53 |
54 | def datapath_list(self):
55 | self.points_datapath = []
56 | self.seq_poses = []
57 |
58 | for seq in self.seqs:
59 | point_seq_path = os.path.join(self.data_dir, 'dataset', 'sequences', seq)
60 | point_seq_bin = natsorted(os.listdir(os.path.join(point_seq_path, 'velodyne')))
61 | poses = load_poses(os.path.join(point_seq_path, 'calib.txt'), os.path.join(point_seq_path, 'poses.txt'))
62 | p_full = np.load(f'{point_seq_path}/map_clean.npy') if self.split != 'test' else np.array([[1,0,0],[0,1,0],[0,0,1]])
63 | self.cache_maps[seq] = p_full
64 |
65 | for file_num in range(0, len(point_seq_bin)):
66 | self.points_datapath.append(os.path.join(point_seq_path, 'velodyne', point_seq_bin[file_num]))
67 | self.seq_poses.append(poses[file_num])
68 |
69 | def transforms(self, points):
70 | points = np.expand_dims(points, axis=0)
71 | points[:,:,:3] = rotate_point_cloud(points[:,:,:3])
72 | points[:,:,:3] = rotate_perturbation_point_cloud(points[:,:,:3])
73 | points[:,:,:3] = random_scale_point_cloud(points[:,:,:3])
74 | points[:,:,:3] = random_flip_point_cloud(points[:,:,:3])
75 |
76 | return np.squeeze(points, axis=0)
77 |
78 | def __getitem__(self, index):
79 | seq_num = self.points_datapath[index].split('/')[-3]
80 | fname = self.points_datapath[index].split('/')[-1].split('.')[0]
81 |
82 | p_part = np.fromfile(self.points_datapath[index], dtype=np.float32)
83 | p_part = p_part.reshape((-1,4))[:,:3]
84 |
85 | if self.split != 'test':
86 | label_file = self.points_datapath[index].replace('velodyne', 'labels').replace('.bin', '.label')
87 | l_set = np.fromfile(label_file, dtype=np.uint32)
88 | l_set = l_set.reshape((-1))
89 | l_set = l_set & 0xFFFF
90 | static_idx = (l_set < 252) & (l_set > 1)
91 | p_part = p_part[static_idx]
92 | dist_part = np.sum(p_part**2, -1)**.5
93 | p_part = p_part[(dist_part < self.max_range) & (dist_part > 3.5)]
94 | p_part = p_part[p_part[:,2] > -4.]
95 | pose = self.seq_poses[index]
96 |
97 | p_map = self.cache_maps[seq_num]
98 |
99 | if self.split != 'test':
100 | trans = pose[:-1,-1]
101 | dist_full = np.sum((p_map - trans)**2, -1)**.5
102 | p_full = p_map[dist_full < self.max_range]
103 | p_full = np.concatenate((p_full, np.ones((len(p_full),1))), axis=-1)
104 | p_full = (p_full @ np.linalg.inv(pose).T)[:,:3]
105 | p_full = p_full[p_full[:,2] > -4.]
106 | else:
107 | p_full = p_part
108 |
109 | if self.split == 'train':
110 | p_concat = np.concatenate((p_full, p_part), axis=0)
111 | p_concat = self.transforms(p_concat)
112 |
113 | p_full = p_concat[:-len(p_part)]
114 | p_part = p_concat[-len(p_part):]
115 |
116 | # patial pcd has 1/10 of the complete pcd size
117 | n_part = int(self.num_points / 10.)
118 |
119 | return point_set_to_sparse(
120 | p_full,
121 | p_part,
122 | self.num_points,
123 | n_part,
124 | self.resolution,
125 | self.points_datapath[index],
126 | p_mean=self.data_stats['mean'],
127 | p_std=self.data_stats['std'],
128 | )
129 |
130 | def __len__(self):
131 | return self.nr_data
132 |
133 | ##################################################################################################
134 |
--------------------------------------------------------------------------------
/utils/pcd_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def rotate_point_cloud(batch_data):
4 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
5 | for k in range(batch_data.shape[0]):
6 | rotation_angle = np.random.uniform() * 2 * np.pi
7 | cosval = np.cos(rotation_angle)
8 | sinval = np.sin(rotation_angle)
9 | rotation_matrix = np.array([[cosval, -sinval, 0],
10 | [sinval, cosval, 0],
11 | [0, 0, 1]])
12 | shape_pc = batch_data[k, ...]
13 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
14 | return rotated_data
15 |
16 |
17 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
18 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
19 | for k in range(batch_data.shape[0]):
20 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
21 | Rx = np.array([[1,0,0],
22 | [0,np.cos(angles[0]),-np.sin(angles[0])],
23 | [0,np.sin(angles[0]),np.cos(angles[0])]])
24 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
25 | [0,1,0],
26 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
27 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
28 | [np.sin(angles[2]),np.cos(angles[2]),0],
29 | [0,0,1]])
30 | R = np.dot(Rz, np.dot(Ry,Rx))
31 | shape_pc = batch_data[k, ...]
32 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
33 | return rotated_data
34 |
35 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
36 | B, N, C = batch_data.shape
37 | assert(clip > 0)
38 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
39 | jittered_data += batch_data
40 | return jittered_data
41 |
42 | def random_drop_n_cuboids(batch_data):
43 | batch_data = random_drop_point_cloud(batch_data)
44 | cuboids_count = 1
45 | while cuboids_count < 5 and np.random.uniform(0., 1.) > 0.3:
46 | batch_data = random_drop_point_cloud(batch_data)
47 | cuboids_count += 1
48 |
49 | return batch_data
50 |
51 | def check_aspect2D(crop_range, aspect_min):
52 | xy_aspect = np.min(crop_range[:2])/np.max(crop_range[:2])
53 | return (xy_aspect >= aspect_min)
54 |
55 | def random_cuboid_point_cloud(batch_data):
56 | batch_data = np.expand_dims(batch_data, axis=0)
57 |
58 | B, N, C = batch_data.shape
59 | new_batch_data = []
60 | for batch_index in range(B):
61 | range_xyz = np.max(batch_data[batch_index,:,0:2], axis=0) - np.min(batch_data[batch_index,:,0:2], axis=0)
62 |
63 | crop_range = 0.5 + (np.random.rand(2) * 0.5)
64 |
65 | loop_count = 0
66 | while not check_aspect2D(crop_range, 0.75):
67 | loop_count += 1
68 | crop_range = 0.5 + (np.random.rand(2) * 0.5)
69 | if loop_count > 100:
70 | break
71 |
72 | loop_count = 0
73 |
74 | while True:
75 | loop_count += 1
76 | new_range = range_xyz * crop_range / 2.0
77 | sample_center = batch_data[batch_index,np.random.choice(len(batch_data[batch_index])), 0:3]
78 | max_xyz = sample_center[:2] + new_range
79 | min_xyz = sample_center[:2] - new_range
80 |
81 | upper_idx = np.sum((batch_data[batch_index,:,:2] < max_xyz).astype(np.int32), 1) == 2
82 | lower_idx = np.sum((batch_data[batch_index,:,:2] > min_xyz).astype(np.int32), 1) == 2
83 |
84 | new_pointidx = ((upper_idx) & (lower_idx))
85 |
86 | # avoid having too small point clouds
87 | if (loop_count > 100) or (np.sum(new_pointidx) > 20000):
88 | break
89 |
90 |
91 | new_batch_data.append(batch_data[batch_index,new_pointidx,:])
92 |
93 | new_batch_data = np.array(new_batch_data)
94 |
95 | return np.squeeze(new_batch_data, axis=0)
96 |
97 | def random_drop_point_cloud(batch_data):
98 | B, N, C = batch_data.shape
99 | new_batch_data = []
100 | for batch_index in range(B):
101 | range_xyz = np.max(batch_data[batch_index,:,0:3], axis=0) - np.min(batch_data[batch_index,:,0:3], axis=0)
102 |
103 | crop_range = np.random.uniform(0.1, 0.15)
104 | new_range = range_xyz * crop_range / 2.0
105 | sample_center = batch_data[batch_index,np.random.choice(len(batch_data[batch_index])), 0:3]
106 | max_xyz = sample_center + new_range
107 | min_xyz = sample_center - new_range
108 |
109 | upper_idx = np.sum((batch_data[batch_index,:,0:3] < max_xyz).astype(np.int32), 1) == 3
110 | lower_idx = np.sum((batch_data[batch_index,:,0:3] > min_xyz).astype(np.int32), 1) == 3
111 |
112 | new_pointidx = ~((upper_idx) & (lower_idx))
113 | new_batch_data.append(batch_data[batch_index,new_pointidx,:])
114 |
115 | return np.array(new_batch_data)
116 |
117 |
118 | def random_flip_point_cloud(batch_data, scale_low=0.95, scale_high=1.05):
119 | B, N, C = batch_data.shape
120 | for batch_index in range(B):
121 | if np.random.random() > 0.5:
122 | batch_data[batch_index,:,1] = -1 * batch_data[batch_index,:,1]
123 | return batch_data
124 |
125 | def random_scale_point_cloud(batch_data, scale_low=0.95, scale_high=1.05):
126 | B, N, C = batch_data.shape
127 | scales = np.random.uniform(scale_low, scale_high, B)
128 | for batch_index in range(B):
129 | batch_data[batch_index,:,:] *= scales[batch_index]
130 | return batch_data
131 |
132 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
133 | for b in range(batch_pc.shape[0]):
134 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
135 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
136 | if len(drop_idx)>0:
137 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
138 | return batch_pc
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 💡**DistillationDPO**💡
4 | ## **[Diffusion Distillation With Direct Preference Optimization For Efficient 3D LiDAR Scene Completion [AAAI'26]](#)**
5 |
6 | by *An Zhao1, [Shengyuan Zhang](https://github.com/SYZhang0805)1, [Ling Yang](https://github.com/YangLing0818)2, [Zejian Li*](https://zejianli.github.io/)1, Jiale Wu1, Haoran Xu3, AnYang Wei3,Perry Pengyun GU3, [Lingyun Sun](https://person.zju.edu.cn/sly)1*
7 |
8 | *1Zhejiang University 2Peking University 3Zhejiang Green Zhixing Technology co., ltd*
9 |
10 | 
11 |
12 |
13 |
14 | ## **Abstract**
15 |
16 | The application of diffusion models in 3D LiDAR scene completion is limited due to diffusion's slow sampling speed.
17 | Score distillation accelerates diffusion sampling but with performance degradation, while post-training with direct policy optimization (DPO) boosts performance using preference data.
18 | This paper proposes Distillation-DPO, a novel diffusion distillation framework for LiDAR scene completion with preference aligment.
19 | First, the student model generates paired completion scenes with different initial noises.
20 | Second, using LiDAR scene evaluation metrics as preference, we construct winning and losing sample pairs.
21 | Such construction is reasonable, since most LiDAR scene metrics are informative but non-differentiable to be optimized directly.
22 | Third, Distillation-DPO optimizes the student model by exploiting the difference in score functions between the teacher and student models on the paired completion scenes.
23 | Such procedure is repeated until convergence.
24 | Extensive experiments demonstrate that, compared to state-of-the-art LiDAR scene completion diffusion models, Distillation-DPO achieves higher-quality scene completion while accelerating the completion speed by more than 5-fold.
25 | Our method is the first to explore adopting preference learning in distillation to the best of our knowledge and provide insights into preference-aligned distillation.
26 |
27 | ## **Environment setup**
28 |
29 | The following commands are tested with Python 3.8 and CUDA 11.1.
30 |
31 | Install required packages:
32 |
33 | `sudo apt install build-essential python3-dev libopenblas-dev`
34 |
35 | `pip3 install -r requirements.txt`
36 |
37 | Install [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) for sparse tensor processing:
38 |
39 | `pip3 install -U MinkowskiEngine==0.5.4 --install-option="--blas=openblas" -v --no-deps`
40 |
41 | Setup the code on the code main directory:
42 |
43 | `pip3 install -U -e .`
44 |
45 | ## **Training**
46 |
47 | We use The SemanticKITTI dataset for training.
48 |
49 | The SemanticKITTI dataset has to be downloaded from the [official site](http://www.semantic-kitti.org/dataset.html#download) and extracted in the following structure:
50 |
51 | ```
52 | DistillationDPO/
53 | └── datasets/
54 | └── SemanticKITTI
55 | └── dataset
56 | └── sequences
57 | ├── 00/
58 | │ ├── velodyne/
59 | | | ├── 000000.bin
60 | | | ├── 000001.bin
61 | | | └── ...
62 | │ └── labels/
63 | | ├── 000000.label
64 | | ├── 000001.label
65 | | └── ...
66 | ├── 01/
67 | │ ...
68 | ...
69 | ```
70 |
71 | Ground truth scenes are not provided explicitly in SemanticKITTI. To generate the ground complete scenes you can run the `map_from_scans.py` script. This will use the dataset scans and poses to generate the sequence map to be used as ground truth during training:
72 |
73 | ```
74 | python map_from_scans.py --path datasets/SemanticKITTI/dataset/sequences/
75 | ```
76 |
77 | We use the diffusion-dpo fine-tuned version of [LiDiff](https://github.com/PRBonn/LiDiff) as the teacher model as well as the teacher assistant models. Download the pre-trained weights from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/lidiff_ddpo_refined.ckpt`.
78 |
79 | Once the sequences map is generated and the teacher model is downloaded you can then train the model. The training can be started with:
80 |
81 | `python trains/DistillationDPO.py --SemanticKITTI_path datasets/SemanticKITTI --pre_trained_diff_path checkpoints/lidiff_ddpo_refined.ckpt`
82 |
83 | ## **Inference & Visualization**
84 |
85 | We use [pyrender](https://github.com/mmatl/pyrender) for offscreen rendering. Please see [this guide](https://pyrender.readthedocs.io/en/latest/install/index.html#osmesa) for installation of pyrender.
86 |
87 | After correct installation of pyrender, download the refinement model 'refine_net.ckpt' from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/refine_net.ckpt`. We also provide the pre-trained weights of distillation-dpo. Download 'distillationdpo_st.ckpt' from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/distillationdpo_st.ckpt`.
88 |
89 | Then run the inference script with the following command:
90 |
91 | `python utils/eval_path_get_pics.py --diff checkpoints/distillationdpo_st.ckpt --refine checkpoints/refine_net.ckpt`
92 |
93 | This script will read all scenes in a sepcified sequence of SemanticKITTI dataset and the result images will be saved under `exp/`.
94 |
95 | ## **Citation**
96 |
97 | If you find our paper useful or relevant to your research, please kindly cite our papers:
98 |
99 | ```bibtex
100 | @misc{zhao2025diffusiondistillationdirectpreference,
101 | title={Diffusion Distillation With Direct Preference Optimization For Efficient 3D LiDAR Scene Completion},
102 | author={An Zhao and Shengyuan Zhang and Ling Yang and Zejian Li and Jiale Wu and Haoran Xu and AnYang Wei and Perry Pengyun GU and Lingyun Sun},
103 | year={2025},
104 | eprint={2504.11447},
105 | archivePrefix={arXiv},
106 | primaryClass={cs.CV},
107 | url={https://arxiv.org/abs/2504.11447},
108 | }
109 | ```
110 |
111 | ## **Credits**
112 |
113 | DistillationDPO is highly built on the following amazing open-source projects:
114 |
115 | [Lidiff](https://github.com/PRBonn/LiDiff): Scaling Diffusion Models to Real-World 3D LiDAR Scene Completion
116 |
--------------------------------------------------------------------------------
/utils/eval_path_get_pics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import open3d as o3d
4 | import tqdm
5 | from natsort import natsorted
6 | import click
7 | import json
8 |
9 | from utils.diff_completion_pipeline import DiffCompletion_DistillationDPO, DiffCompletion_lidiff
10 | from utils.histogram_metrics import compute_hist_metrics
11 | from utils.render import offscreen_render
12 |
13 | def parse_calibration(filename):
14 | calib = {}
15 |
16 | calib_file = open(filename)
17 | for line in calib_file:
18 | key, content = line.strip().split(":")
19 | values = [float(v) for v in content.strip().split()]
20 |
21 | pose = np.zeros((4, 4))
22 | pose[0, 0:4] = values[0:4]
23 | pose[1, 0:4] = values[4:8]
24 | pose[2, 0:4] = values[8:12]
25 | pose[3, 3] = 1.0
26 |
27 | calib[key] = pose
28 |
29 | calib_file.close()
30 |
31 | return calib
32 |
33 | def load_poses(calib_fname, poses_fname):
34 | if os.path.exists(calib_fname):
35 | calibration = parse_calibration(calib_fname)
36 | Tr = calibration["Tr"]
37 | Tr_inv = np.linalg.inv(Tr)
38 |
39 | poses_file = open(poses_fname)
40 | poses = []
41 |
42 | for line in poses_file:
43 | values = [float(v) for v in line.strip().split()]
44 |
45 | pose = np.zeros((4, 4))
46 | pose[0, 0:4] = values[0:4]
47 | pose[1, 0:4] = values[4:8]
48 | pose[2, 0:4] = values[8:12]
49 | pose[3, 3] = 1.0
50 |
51 | if os.path.exists(calib_fname):
52 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
53 | else:
54 | poses.append(pose)
55 |
56 | return poses
57 |
58 |
59 | def get_scan_completion(scan_path, path, diff_completion, max_range):
60 | pcd_file = os.path.join(path, 'velodyne', scan_path)
61 | points = np.fromfile(pcd_file, dtype=np.float32)
62 | points = points.reshape(-1,4)
63 | dist = np.sqrt(np.sum(points[:,:3]**2, axis=-1))
64 | input_points = points[dist < max_range, :3]
65 | if diff_completion is None:
66 | pred_path = f'{scan_path.split(".")[0]}.ply'
67 | pcd_pred = o3d.io.read_point_cloud(os.path.join(path, pred_path))
68 | points = np.array(pcd_pred.points)
69 | dist = np.sqrt(np.sum(points**2, axis=-1))
70 | pcd_pred.points = o3d.utility.Vector3dVector(points[dist < max_range])
71 | else:
72 | complete_scan_refined, complete_scan = diff_completion.complete_scan(points)
73 | pcd_pred = o3d.geometry.PointCloud()
74 | pcd_pred.points = o3d.utility.Vector3dVector(complete_scan)
75 | pcd_pred_refined = o3d.geometry.PointCloud()
76 | pcd_pred_refined.points = o3d.utility.Vector3dVector(complete_scan_refined)
77 |
78 | return pcd_pred, pcd_pred_refined, input_points
79 |
80 | def get_ground_truth(pose, cur_scan, seq_map, max_range):
81 | trans = pose[:-1,-1]
82 | dist_gt = np.sum((seq_map - trans)**2, axis=-1)**.5
83 | scan_gt = seq_map[dist_gt < max_range]
84 | scan_gt = np.concatenate((scan_gt, np.ones((len(scan_gt),1))), axis=-1)
85 | scan_gt = (scan_gt @ np.linalg.inv(pose).T)[:,:3]
86 | scan_gt = scan_gt[(scan_gt[:,2] > -4.) & (scan_gt[:,2] < 4.4)]
87 | pcd_gt = o3d.geometry.PointCloud()
88 | pcd_gt.points = o3d.utility.Vector3dVector(scan_gt)
89 |
90 | # filter only over the view point
91 | cur_pcd = o3d.geometry.PointCloud()
92 | cur_pcd.points = o3d.utility.Vector3dVector(cur_scan)
93 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(cur_pcd, voxel_size=10.)
94 | in_viewpoint = viewpoint_grid.check_if_included(pcd_gt.points)
95 | points_gt = np.array(pcd_gt.points)
96 | pcd_gt.points = o3d.utility.Vector3dVector(points_gt[in_viewpoint])
97 |
98 | return pcd_gt
99 |
100 | def pcd_denoise(pcd):
101 | cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
102 | pcd_clean = pcd.select_by_index(ind)
103 | return pcd_clean
104 |
105 | import torch
106 | import random
107 | def set_deterministic(seed=42):
108 | random.seed(seed)
109 | np.random.seed(seed)
110 | torch.manual_seed(seed)
111 | torch.cuda.manual_seed(seed)
112 | torch.backends.cudnn.deterministic = True
113 |
114 |
115 | @click.command()
116 | @click.option('--path', '-p', type=str, default='datasets/SemanticKITTI/dataset/sequences/00', help='path to the scan sequence')
117 | @click.option('--max_range', '-m', type=float, default=50, help='max range')
118 | @click.option('--denoising_steps', '-t', type=int, default=8, help='number of denoising steps')
119 | @click.option('--cond_weight', '-s', type=float, default=3.5, help='conditioning weights')
120 | @click.option('--diff', '-d', type=str, default='checkpoints/distillationdpo_st.ckpt', help='trained diffusion model')
121 | @click.option('--refine', '-r', type=str, default='checkpoints/refine_net.ckpt', help='refinement model')
122 | @click.option('--save_path', '-sp', type=str, default='exp/pics/00', help='where to save pics')
123 | def main(path, max_range, denoising_steps, cond_weight, diff, refine, save_path):
124 |
125 | if not os.path.exists(save_path):
126 | os.makedirs(save_path)
127 |
128 | diff_completion_distdpo = DiffCompletion_DistillationDPO(diff, refine, denoising_steps, cond_weight)
129 |
130 | poses = load_poses(os.path.join(path, 'calib.txt'), os.path.join(path, 'poses.txt'))
131 | seq_map = np.load(f'{path}/map_clean.npy')
132 |
133 | import random
134 | eval_list = list(zip(poses, natsorted(os.listdir(f'{path}/velodyne'))))
135 | random.shuffle(eval_list)
136 | for pose, scan_path in tqdm.tqdm(eval_list):
137 |
138 | pcd_pred, pcd_pred_refined, cur_scan = get_scan_completion(scan_path, path, diff_completion_distdpo, max_range)
139 | pcd_in = o3d.geometry.PointCloud()
140 | pcd_in.points = o3d.utility.Vector3dVector(cur_scan)
141 | pcd_gt = get_ground_truth(pose, cur_scan, seq_map, max_range)
142 |
143 | pic_dir = os.path.join(save_path, scan_path.split('.')[0])
144 | if not os.path.exists(pic_dir):
145 | os.makedirs(pic_dir)
146 | pic_pred_dir = os.path.join(pic_dir, f'{denoising_steps}steps.png')
147 | pic_pred_refined_dir = os.path.join(pic_dir, f'{denoising_steps}steps_refined.png')
148 | pic_gt_dir = os.path.join(pic_dir, 'gt.png')
149 | pic_in_dir = os.path.join(pic_dir, 'in.png')
150 |
151 | pcd_pred = pcd_denoise(pcd_pred)
152 | pcd_pred_refined = pcd_denoise(pcd_pred_refined)
153 | pcd_pred_lidiff = pcd_denoise(pcd_pred_lidiff)
154 | pcd_pred_refined_lidiff = pcd_denoise(pcd_pred_refined_lidiff)
155 |
156 | offscreen_render(pcd_pred, pic_pred_dir)
157 | offscreen_render(pcd_pred_refined, pic_pred_refined_dir)
158 | offscreen_render(pcd_gt, pic_gt_dir)
159 | offscreen_render(pcd_in, pic_in_dir)
160 |
161 |
162 | if __name__ == '__main__':
163 | main()
164 |
165 |
--------------------------------------------------------------------------------
/utils/pcd_preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import hdbscan
4 | import matplotlib.pyplot as plt
5 | import MinkowskiEngine as ME
6 | import os
7 |
8 | def overlap_clusters(cluster_i, cluster_j, min_cluster_point=10):
9 | # get unique labels from pcd_i and pcd_j from segments bigger than min_clsuter_point
10 | unique_i, count_i = np.unique(cluster_i, return_counts=True)
11 | unique_i = unique_i[count_i > min_cluster_point]
12 |
13 | unique_j, count_j = np.unique(cluster_j, return_counts=True)
14 | unique_j = unique_j[count_j > min_cluster_point]
15 |
16 | # get labels present on both pcd (intersection)
17 | unique_ij = np.intersect1d(unique_i, unique_j)[1:]
18 |
19 | # labels not intersecting both pcd are assigned as -1 (unlabeled)
20 | cluster_i[np.in1d(cluster_i, unique_ij, invert=True)] = -1
21 | cluster_j[np.in1d(cluster_j, unique_ij, invert=True)] = -1
22 |
23 | return cluster_i, cluster_j
24 |
25 | def parse_calibration(filename):
26 | calib = {}
27 |
28 | calib_file = open(filename)
29 | for line in calib_file:
30 | key, content = line.strip().split(":")
31 | values = [float(v) for v in content.strip().split()]
32 |
33 | pose = np.zeros((4, 4))
34 | pose[0, 0:4] = values[0:4]
35 | pose[1, 0:4] = values[4:8]
36 | pose[2, 0:4] = values[8:12]
37 | pose[3, 3] = 1.0
38 |
39 | calib[key] = pose
40 |
41 | calib_file.close()
42 |
43 | return calib
44 |
45 | def load_poses(calib_fname, poses_fname):
46 | if os.path.exists(calib_fname):
47 | calibration = parse_calibration(calib_fname)
48 | Tr = calibration["Tr"]
49 | Tr_inv = np.linalg.inv(Tr)
50 |
51 | poses_file = open(poses_fname)
52 | poses = []
53 |
54 | for line in poses_file:
55 | values = [float(v) for v in line.strip().split()]
56 |
57 | pose = np.zeros((4, 4))
58 | pose[0, 0:4] = values[0:4]
59 | pose[1, 0:4] = values[4:8]
60 | pose[2, 0:4] = values[8:12]
61 | pose[3, 3] = 1.0
62 |
63 | if os.path.exists(calib_fname):
64 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
65 | else:
66 | poses.append(pose)
67 |
68 | return poses
69 |
70 | def apply_transform(points, pose):
71 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
72 | return np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)[:,:3]
73 |
74 | def undo_transform(points, pose):
75 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
76 | return np.sum(np.expand_dims(hpoints, 2) * np.linalg.inv(pose).T, axis=1)[:,:3]
77 |
78 | def aggregate_pcds(data_batch, data_dir, t_frame):
79 | # load empty pcd point cloud to aggregate
80 | pcd_full = np.empty((0,3))
81 | pcd_part = None
82 |
83 | # define "namespace"
84 | seq_num = data_batch[0].split('/')[-3]
85 | fname = data_batch[0].split('/')[-1].split('.')[0]
86 |
87 | # load poses
88 | datapath = data_batch[0].split('velodyne')[0]
89 | poses = load_poses(os.path.join(datapath, 'calib.txt'), os.path.join(datapath, 'poses.txt'))
90 |
91 | for t in range(len(data_batch)):
92 | # load the next t scan and aggregate
93 | fname = data_batch[t].split('/')[-1].split('.')[0]
94 |
95 | # load the next t scan, apply pose and aggregate
96 | p_set = np.fromfile(data_batch[t], dtype=np.float32)
97 | p_set = p_set.reshape((-1, 4))[:,:3]
98 |
99 | label_file = data_batch[t].replace('velodyne', 'labels').replace('.bin', '.label')
100 | l_set = np.fromfile(label_file, dtype=np.uint32)
101 | l_set = l_set.reshape((-1))
102 | l_set = l_set & 0xFFFF
103 |
104 | # remove moving points
105 | static_idx = l_set < 252
106 | p_set = p_set[static_idx]
107 |
108 | # remove flying artifacts
109 | dist = np.power(p_set, 2)
110 | dist = np.sqrt(dist.sum(-1))
111 | p_set = p_set[dist > 3.5]
112 |
113 | pose_idx = int(fname)
114 | p_set = apply_transform(p_set, poses[pose_idx])
115 |
116 | if t == t_frame:
117 | # will be aggregated later to the full pcd
118 | pcd_part = p_set.copy()
119 | else:
120 | pcd_full = np.vstack([pcd_full, p_set])
121 |
122 |
123 | # get start position of each aggregated pcd
124 |
125 | pose_idx = int(fname)
126 | pcd_full = undo_transform(pcd_full, poses[pose_idx])
127 | pcd_part = undo_transform(pcd_part, poses[pose_idx])
128 |
129 | return pcd_full, pcd_part
130 |
131 | def clusters_hdbscan(points_set, n_clusters=50):
132 | clusterer = hdbscan.HDBSCAN(algorithm='best', alpha=1., approx_min_span_tree=True,
133 | gen_min_span_tree=True, leaf_size=100,
134 | metric='euclidean', min_cluster_size=20, min_samples=None
135 | )
136 |
137 | clusterer.fit(points_set)
138 |
139 | labels = clusterer.labels_.copy()
140 |
141 | lbls, counts = np.unique(labels, return_counts=True)
142 | cluster_info = np.array(list(zip(lbls[1:], counts[1:])))
143 | cluster_info = cluster_info[cluster_info[:,1].argsort()]
144 |
145 | clusters_labels = cluster_info[::-1][:n_clusters, 0]
146 | labels[np.in1d(labels, clusters_labels, invert=True)] = -1
147 |
148 | return labels
149 |
150 | def clusterize_pcd(points, ground):
151 | pcd = o3d.geometry.PointCloud()
152 | pcd.points = o3d.utility.Vector3dVector(points[:, :3])
153 |
154 | # instead of ransac use patchwork
155 | inliers = list(np.where(ground == 9)[0])
156 |
157 | pcd_ = pcd.select_by_index(inliers, invert=True)
158 | labels_ = np.expand_dims(clusters_hdbscan(np.asarray(pcd_.points)), axis=-1)
159 |
160 | # that is a blessing of array handling
161 | # pcd are an ordered list of points
162 | # in a list [a, b, c, d, e] if we get the ordered indices [1, 3]
163 | # we will get [b, d], however if we get ~[1, 3] we will get the opposite indices
164 | # still ordered, i.e., [a, c, e] which means listing the inliers indices and getting
165 | # the invert we will get the outliers ordered indices (a sort of indirect indices mapping)
166 | labels = np.ones((points.shape[0], 1)) * -1
167 | mask = np.ones(labels.shape[0], dtype=bool)
168 | mask[inliers] = False
169 |
170 | labels[mask] = labels_
171 |
172 | return labels
173 |
174 | def point_set_to_coord_feats(point_set, labels, resolution, num_points, deterministic=False):
175 | p_feats = point_set.copy()
176 | p_coord = np.round(point_set[:, :3] / resolution)
177 | p_coord -= p_coord.min(0, keepdims=1)
178 |
179 | _, mapping = ME.utils.sparse_quantize(coordinates=p_coord, return_index=True)
180 | if len(mapping) > num_points:
181 | np.random.seed(42)
182 | mapping = np.random.choice(mapping, num_points, replace=False)
183 |
184 | return p_coord[mapping], p_feats[mapping], labels[mapping]
185 |
186 | def visualize_pcd_clusters(points, labels):
187 | pcd = o3d.geometry.PointCloud()
188 | pcd.points = o3d.utility.Vector3dVector(points[:,:3])
189 |
190 | colors = np.zeros((len(labels), 4))
191 | flat_indices = np.unique(labels[:,-1])
192 | max_instance = len(flat_indices)
193 | colors_instance = plt.get_cmap("prism")(np.arange(len(flat_indices)) / (max_instance if max_instance > 0 else 1))
194 |
195 | for idx in range(len(flat_indices)):
196 | colors[labels[:,-1] == flat_indices[int(idx)]] = colors_instance[int(idx)]
197 |
198 | colors[labels[:,-1] == -1] = [0.,0.,0.,0.]
199 |
200 | pcd.colors = o3d.utility.Vector3dVector(colors[:,:3])
201 |
202 | o3d.visualization.draw_geometries([pcd])
203 |
--------------------------------------------------------------------------------
/utils/eval_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import open3d as o3d
4 | import click
5 | import json
6 | import tqdm
7 | from natsort import natsorted
8 |
9 | from utils.metrics import ChamferDistance, PrecisionRecall, CompletionIoU, RMSE, EMD
10 | from utils.diff_completion_pipeline import DiffCompletion_DistillationDPO
11 | from utils.histogram_metrics import compute_hist_metrics
12 |
13 | completion_iou = CompletionIoU()
14 | rmse = RMSE()
15 | chamfer_distance = ChamferDistance()
16 | precision_recall = PrecisionRecall(0.05,2*0.05,100)
17 | emd = EMD(voxel_size=0.5)
18 |
19 | def parse_calibration(filename):
20 | calib = {}
21 |
22 | calib_file = open(filename)
23 | for line in calib_file:
24 | key, content = line.strip().split(":")
25 | values = [float(v) for v in content.strip().split()]
26 |
27 | pose = np.zeros((4, 4))
28 | pose[0, 0:4] = values[0:4]
29 | pose[1, 0:4] = values[4:8]
30 | pose[2, 0:4] = values[8:12]
31 | pose[3, 3] = 1.0
32 |
33 | calib[key] = pose
34 |
35 | calib_file.close()
36 |
37 | return calib
38 |
39 | def load_poses(calib_fname, poses_fname):
40 | if os.path.exists(calib_fname):
41 | calibration = parse_calibration(calib_fname)
42 | Tr = calibration["Tr"]
43 | Tr_inv = np.linalg.inv(Tr)
44 |
45 | poses_file = open(poses_fname)
46 | poses = []
47 |
48 | for line in poses_file:
49 | values = [float(v) for v in line.strip().split()]
50 |
51 | pose = np.zeros((4, 4))
52 | pose[0, 0:4] = values[0:4]
53 | pose[1, 0:4] = values[4:8]
54 | pose[2, 0:4] = values[8:12]
55 | pose[3, 3] = 1.0
56 |
57 | if os.path.exists(calib_fname):
58 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
59 | else:
60 | poses.append(pose)
61 |
62 | return poses
63 |
64 |
65 | def get_scan_completion(scan_path, path, diff_completion, max_range, use_refine):
66 | pcd_file = os.path.join(path, 'velodyne', scan_path)
67 | points = np.fromfile(pcd_file, dtype=np.float32)
68 | points = points.reshape(-1,4)
69 | dist = np.sqrt(np.sum(points[:,:3]**2, axis=-1))
70 | input_points = points[dist < max_range, :3]
71 | if diff_completion is None:
72 | pred_path = f'{scan_path.split(".")[0]}.ply'
73 | pcd_pred = o3d.io.read_point_cloud(os.path.join(path, pred_path))
74 | points = np.array(pcd_pred.points)
75 | dist = np.sqrt(np.sum(points**2, axis=-1))
76 | pcd_pred.points = o3d.utility.Vector3dVector(points[dist < max_range])
77 | else:
78 | # points = points[:,:3]
79 | if use_refine:
80 | complete_scan, _ = diff_completion.complete_scan(points)
81 | else:
82 | _, complete_scan = diff_completion.complete_scan(points)
83 | pcd_pred = o3d.geometry.PointCloud()
84 | pcd_pred.points = o3d.utility.Vector3dVector(complete_scan)
85 |
86 | return pcd_pred, input_points
87 |
88 | def get_ground_truth(pose, cur_scan, seq_map, max_range):
89 | trans = pose[:-1,-1]
90 | dist_gt = np.sum((seq_map - trans)**2, axis=-1)**.5
91 | scan_gt = seq_map[dist_gt < max_range]
92 | scan_gt = np.concatenate((scan_gt, np.ones((len(scan_gt),1))), axis=-1)
93 | scan_gt = (scan_gt @ np.linalg.inv(pose).T)[:,:3]
94 | scan_gt = scan_gt[(scan_gt[:,2] > -4.) & (scan_gt[:,2] < 4.4)]
95 | pcd_gt = o3d.geometry.PointCloud()
96 | pcd_gt.points = o3d.utility.Vector3dVector(scan_gt)
97 |
98 | # filter only over the view point
99 | cur_pcd = o3d.geometry.PointCloud()
100 | cur_pcd.points = o3d.utility.Vector3dVector(cur_scan)
101 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(cur_pcd, voxel_size=10.)
102 | in_viewpoint = viewpoint_grid.check_if_included(pcd_gt.points)
103 | points_gt = np.array(pcd_gt.points)
104 | pcd_gt.points = o3d.utility.Vector3dVector(points_gt[in_viewpoint])
105 |
106 | return pcd_gt
107 |
108 | import torch
109 | import random
110 | def set_deterministic(seed=42):
111 | random.seed(seed)
112 | np.random.seed(seed)
113 | torch.manual_seed(seed)
114 | torch.cuda.manual_seed(seed)
115 | torch.backends.cudnn.deterministic = True
116 |
117 |
118 | @click.command()
119 | @click.option('--path', '-p', type=str, default='datasets/SemanticKITTI/dataset/sequences/08', help='path to the scan sequence')
120 | @click.option('--max_range', '-m', type=float, default=50, help='max range')
121 | @click.option('--denoising_steps', '-t', type=int, default=8, help='number of denoising steps')
122 | @click.option('--cond_weight', '-s', type=float, default=3.5, help='conditioning weights')
123 | @click.option('--diff', '-d', type=str, help='trained diffusion model')
124 | @click.option('--refine', '-r', type=str, help='refinement model')
125 | @click.option('--use_refine', '-ur', type=bool, default=False, help='whether to use refine network')
126 | @click.option('--save_name', '-sn', type=str, help='name of saved file')
127 | def main(path, max_range, denoising_steps, cond_weight, diff, refine, use_refine, save_name):
128 |
129 | diff_completion = DiffCompletion_DistillationDPO(diff, refine, denoising_steps, cond_weight)
130 |
131 | poses = load_poses(os.path.join(path, 'calib.txt'), os.path.join(path, 'poses.txt'))
132 | seq_map = np.load(f'{path}/map_clean.npy')
133 |
134 | jsd_3d = []
135 | jsd_bev = []
136 |
137 | import random
138 | eval_list = list(zip(poses, natsorted(os.listdir(f'{path}/velodyne'))))
139 | random.shuffle(eval_list)
140 | for pose, scan_path in tqdm.tqdm(eval_list):
141 |
142 | pcd_pred, cur_scan = get_scan_completion(scan_path, path, diff_completion, max_range, use_refine)
143 | pcd_in = o3d.geometry.PointCloud()
144 | pcd_in.points = o3d.utility.Vector3dVector(cur_scan)
145 | pcd_gt = get_ground_truth(pose, cur_scan, seq_map, max_range)
146 |
147 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_gt.ply", pcd_gt)
148 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_in.ply", pcd_in)
149 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_pred.ply", pcd_pred)
150 |
151 | emd.update(pcd_gt, pcd_pred)
152 | avg_emd, _ = emd.compute()
153 | print(f'Mean emd: {avg_emd}')
154 |
155 | jsd_3d.append(compute_hist_metrics(pcd_gt, pcd_pred, bev=False))
156 | jsd_bev.append(compute_hist_metrics(pcd_gt, pcd_pred, bev=True))
157 | print(f'JSD 3D mean: {np.array(jsd_3d).mean()}')
158 | print(f'JSD BEV mean: {np.array(jsd_bev).mean()}')
159 |
160 | rmse.update(pcd_gt, pcd_pred)
161 | completion_iou.update(pcd_gt, pcd_pred)
162 | chamfer_distance.update(pcd_gt, pcd_pred)
163 | precision_recall.update(pcd_gt, pcd_pred)
164 |
165 | rmse_mean, rmse_std = rmse.compute()
166 | print(f'RMSE Mean: {rmse_mean}\tRMSE Std: {rmse_std}')
167 | thr_ious = completion_iou.compute()
168 | for v_size in thr_ious.keys():
169 | print(f'Voxel {v_size}cm IOU: {thr_ious[v_size]}')
170 | cd_mean, cd_std = chamfer_distance.compute()
171 | print(f'CD Mean: {cd_mean}\tCD Std: {cd_std}')
172 | pr, re, f1 = precision_recall.compute_auc()
173 | print(f'Precision: {pr}\tRecall: {re}\tF-Score: {f1}')
174 |
175 |
176 | print('\n\n=================== FINAL RESULTS ===================\n\n')
177 | print(f'JSD 3D: {np.array(jsd_3d).mean()}')
178 | print(f'JSD BEV: {np.array(jsd_bev).mean()}')
179 | print(f'RMSE Mean: {rmse_mean}\tRMSE Std: {rmse_std}')
180 | thr_ious = completion_iou.compute()
181 | for v_size in thr_ious.keys():
182 | print(f'Voxel {v_size}cm IOU: {thr_ious[v_size]}')
183 | cd_mean, cd_std = chamfer_distance.compute()
184 | print(f'CD Mean: {cd_mean}\tCD Std: {cd_std}')
185 | pr, re, f1 = precision_recall.compute_auc()
186 | print(f'Precision: {pr}\tRecall: {re}\tF-Score: {f1}')
187 |
188 | res_dict = {
189 | 'jsd_bev': np.array(jsd_bev).mean(),
190 | 'jsd_noclip_3d': np.array(jsd_3d).mean(),
191 | 'rmse_mean': rmse_mean, 'rmse_std': rmse_std,
192 | 'ious': thr_ious,
193 | 'cd_mean': cd_mean, 'cd_std': cd_std,
194 | 'pr': pr, 're': re, 'f1': f1,
195 | }
196 |
197 | from os.path import join, dirname, abspath
198 | with open(join(dirname(abspath(__file__)),f'{save_name}.yaml'), 'w+') as log_res:
199 | json.dump(res_dict, log_res)
200 |
201 | if __name__ == '__main__':
202 | main()
203 |
204 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | import scipy
4 | from utils.EMD import calc_EMD_with_sinkhorn_knopp
5 |
6 | MESHTYPE = 6
7 | TETRATYPE = 10
8 | PCDTYPE = 1
9 |
10 | class Metrics3D():
11 | def prediction_is_empty(self, geom):
12 |
13 | if isinstance(geom, o3d.geometry.Geometry):
14 | geom_type = geom.get_geometry_type().value
15 | if geom_type == MESHTYPE or geom_type == TETRATYPE:
16 | empty_v = self.is_empty(len(geom.vertices))
17 | empty_t = self.is_empty(len(geom.triangles))
18 | empty = empty_t or empty_v
19 | elif geom_type == PCDTYPE:
20 | empty = self.is_empty(len(geom.points))
21 | else:
22 | assert False, '{} geometry not supported'.format(geom.get_geometry_type())
23 | elif isinstance(geom, np.ndarray):
24 | empty = self.is_empty(len(geom[:, :3]))
25 | elif isinstance(geom, torch.Tensor):
26 | empty = self.is_empty(len(geom[:, :3]))
27 | else:
28 | assert False, '{} type not supported'.format(type(geom))
29 |
30 | return empty
31 |
32 | @staticmethod
33 | def convert_to_pcd(geom):
34 |
35 | if isinstance(geom, o3d.geometry.Geometry):
36 | geom_type = geom.get_geometry_type().value
37 | if geom_type == MESHTYPE or geom_type == TETRATYPE:
38 | geom_pcd = geom.sample_points_uniformly(1000000)
39 | elif geom_type == PCDTYPE:
40 | geom_pcd = geom
41 | else:
42 | assert False, '{} geometry not supported'.format(geom.get_geometry_type())
43 | elif isinstance(geom, np.ndarray):
44 | geom_pcd = o3d.geometry.PointCloud()
45 | geom_pcd.points = o3d.utility.Vector3dVector(geom[:, :3])
46 | elif isinstance(geom, torch.Tensor):
47 | geom = geom.detach().cpu().numpy()
48 | geom_pcd = o3d.geometry.PointCloud()
49 | geom_pcd.points = o3d.utility.Vector3dVector(geom[:, :3])
50 | else:
51 | assert False, '{} type not supported'.format(type(geom))
52 |
53 | return geom_pcd
54 |
55 | @staticmethod
56 | def is_empty(length):
57 | empty = True
58 | if length:
59 | empty = False
60 | return empty
61 |
62 | input()
63 |
64 | class RMSE():
65 | def __init__(self):
66 | self.dists = []
67 |
68 | return
69 |
70 | def update(self, gt_pcd, pt_pcd):
71 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd))
72 |
73 | self.dists.append(np.mean(dist_pt_2_gt))
74 |
75 | def reset(self):
76 | self.dists = []
77 |
78 | def compute(self):
79 | dist = np.array(self.dists)
80 | return dist.mean(), dist.std()
81 |
82 | class EMD():
83 | def __init__(self, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4):
84 | self.emds = []
85 | self.voxel_size = voxel_size
86 | self.epsilon = epsilon
87 | self.max_iter = max_iter
88 | self.tol = tol
89 |
90 | return
91 |
92 | def update(self, gt_pcd, pt_pcd):
93 | emd = calc_EMD_with_sinkhorn_knopp(np.array(pt_pcd.points), np.array(gt_pcd.points), voxel_size=self.voxel_size, epsilon=self.epsilon, max_iter=self.max_iter, tol=self.tol)
94 |
95 | if emd:
96 | self.emds.append(emd)
97 |
98 | def reset(self):
99 | self.emds = []
100 |
101 | def compute(self):
102 | emds_array = np.array(self.emds)
103 | return emds_array.mean(), emds_array.std()
104 |
105 | class CompletionIoU():
106 | def __init__(self, voxel_sizes=[0.5, 0.2, 0.1]):
107 | self.voxel_sizes = voxel_sizes
108 | # num_thresholds, tp, fp, fn
109 | self.conf_matrix = np.zeros((len(self.voxel_sizes), 3)).astype(np.uint64)
110 |
111 | def update(self, gt, pred):
112 | max_range = 50.
113 | for i, vsize in enumerate(self.voxel_sizes):
114 | bins = int(2 * max_range / vsize)
115 | vox_coords_gt = np.round(np.array(gt.points) / vsize).astype(int)
116 | hist_gt = np.histogramdd(
117 | vox_coords_gt, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range])
118 | )[0].astype(bool).astype(int)
119 |
120 | vox_coords_pred = np.round(np.array(pred.points) / vsize).astype(int)
121 | hist_pred = np.histogramdd(
122 | vox_coords_pred, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range])
123 | )[0].astype(bool).astype(int)
124 |
125 | self.conf_matrix[i][0] += ((hist_gt == 1) & (hist_pred == 1)).sum() # tp
126 | self.conf_matrix[i][1] += ((hist_gt == 1) & (hist_pred == 0)).sum() # fn
127 | self.conf_matrix[i][2] += ((hist_gt == 0) & (hist_pred == 1)).sum() # fp
128 |
129 | def compute(self):
130 | res_vsizes = {}
131 | for i, vsize in enumerate(self.voxel_sizes):
132 | tp = self.conf_matrix[i][0]
133 | fn = self.conf_matrix[i][1]
134 | fp = self.conf_matrix[i][2]
135 |
136 | intersection = tp
137 | union = tp + fn + fp + 1e-15
138 |
139 | res_vsizes[vsize] = intersection / union
140 |
141 | return res_vsizes
142 |
143 | def reset(self):
144 | self.conf_matrix = np.zeros((len(self.voxel_sizes), 3)).astype(np.uint)
145 |
146 | class ChamferDistance():
147 | def __init__(self):
148 | self.dists = []
149 |
150 | return
151 |
152 | def update(self, gt_pcd, pt_pcd):
153 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd))
154 | dist_gt_2_pt = np.asarray(gt_pcd.compute_point_cloud_distance(pt_pcd))
155 |
156 | self.dists.append((np.mean(dist_gt_2_pt) + np.mean(dist_pt_2_gt)) / 2)
157 |
158 | def reset(self):
159 | self.dists = []
160 |
161 | def compute(self):
162 | cdist = np.array(self.dists)
163 | return cdist.mean(), cdist.std()
164 |
165 | class PrecisionRecall(Metrics3D):
166 |
167 | def __init__(self, min_t, max_t, num):
168 | self.thresholds = np.linspace(min_t, max_t, num)
169 | self.pr_dict = {t: [] for t in self.thresholds}
170 | self.re_dict = {t: [] for t in self.thresholds}
171 | self.f1_dict = {t: [] for t in self.thresholds}
172 |
173 | def update(self, gt_pcd, pt_pcd):
174 | # precision: predicted --> ground truth
175 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd))
176 |
177 | # recall: ground truth --> predicted
178 | dist_gt_2_pt = np.asarray(gt_pcd.compute_point_cloud_distance(pt_pcd))
179 |
180 | for t in self.thresholds:
181 | p = np.where(dist_pt_2_gt < t)[0]
182 | p = 100 / len(dist_pt_2_gt) * len(p)
183 | self.pr_dict[t].append(p)
184 |
185 | r = np.where(dist_gt_2_pt < t)[0]
186 | r = 100 / len(dist_gt_2_pt) * len(r)
187 | self.re_dict[t].append(r)
188 |
189 | # fscore
190 | if p == 0 or r == 0:
191 | f = 0
192 | else:
193 | f = 2 * p * r / (p + r)
194 | self.f1_dict[t].append(f)
195 |
196 | def reset(self):
197 | self.pr_dict = {t: [] for t in self.thresholds}
198 | self.re_dict = {t: [] for t in self.thresholds}
199 | self.f1_dict = {t: [] for t in self.thresholds}
200 |
201 | def compute_at_threshold(self, threshold):
202 | t = self.find_nearest_threshold(threshold)
203 | # print('computing metrics at threshold:', t)
204 | pr = sum(self.pr_dict[t]) / len(self.pr_dict[t])
205 | re = sum(self.re_dict[t]) / len(self.re_dict[t])
206 | f1 = sum(self.f1_dict[t]) / len(self.f1_dict[t])
207 | # print('precision: {}'.format(pr))
208 | # print('recall: {}'.format(re))
209 | # print('fscore: {}'.format(f1))
210 | return pr, re, f1, t
211 |
212 | def compute_auc(self):
213 | dx = self.thresholds[1] - self.thresholds[0]
214 | perfect_predictor = scipy.integrate.simpson(np.ones_like(self.thresholds), dx=dx)
215 |
216 | pr, re, f1 = self.compute_at_all_thresholds()
217 |
218 | pr_area = scipy.integrate.simpson(pr, dx=dx)
219 | norm_pr_area = pr_area / perfect_predictor
220 |
221 | re_area = scipy.integrate.simpson(re, dx=dx)
222 | norm_re_area = re_area / perfect_predictor
223 |
224 | f1_area = scipy.integrate.simpson(f1, dx=dx)
225 | norm_f1_area = f1_area / perfect_predictor
226 |
227 | # print('computing area under curve')
228 | # print('precision: {}'.format(norm_pr_area))
229 | # print('recall: {}'.format(norm_re_area))
230 | # print('fscore: {}'.format(norm_f1_area))
231 |
232 | return norm_pr_area, norm_re_area, norm_f1_area
233 |
234 | def compute_at_all_thresholds(self):
235 | pr = [sum(self.pr_dict[t]) / len(self.pr_dict[t]) for t in self.thresholds]
236 | re = [sum(self.re_dict[t]) / len(self.re_dict[t]) for t in self.thresholds]
237 | f1 = [sum(self.f1_dict[t]) / len(self.f1_dict[t]) for t in self.thresholds]
238 | return pr, re, f1
239 |
240 | def find_nearest_threshold(self, value):
241 | idx = (np.abs(self.thresholds - value)).argmin()
242 | return self.thresholds[idx]
243 |
244 |
--------------------------------------------------------------------------------
/utils/diff_completion_pipeline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import MinkowskiEngine as ME
3 | import torch
4 | from models.minkunet import MinkRewardModel,MinkGlobalEnc,MinkUNetDiff, MinkUNet
5 | import open3d as o3d
6 | from diffusers import DPMSolverMultistepScheduler
7 | from pytorch_lightning.core.lightning import LightningModule
8 | import yaml
9 | import os
10 | import tqdm
11 | from natsort import natsorted
12 | import click
13 | import time
14 | from os.path import join, dirname, abspath
15 | from os import environ, makedirs
16 |
17 | class DiffCompletion_DistillationDPO(LightningModule):
18 | def __init__(self, diff_path, refine_path, denoising_steps, uncond_w):
19 | super().__init__()
20 |
21 | # load diff net
22 | ckpt_diff = torch.load(diff_path)
23 |
24 | dm_weights = {k.replace('generator.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('generator.')}
25 | encoder_weights = {k.replace('partial_enc.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('partial_enc.')}
26 |
27 | # load encoder and model
28 | self.partial_enc = MinkGlobalEnc().cuda()
29 | self.partial_enc.load_state_dict(encoder_weights, strict=True)
30 | self.partial_enc.eval()
31 |
32 | self.model = MinkUNetDiff().cuda()
33 | self.model.load_state_dict(dm_weights, strict=True)
34 | self.model.eval()
35 |
36 | # load refiner
37 | ckpt_refine = torch.load(refine_path)
38 | refiner_weights = {k.replace('model_refine.', ''): v for k, v in ckpt_refine["state_dict"].items() if k.startswith('model_refine.')}
39 | self.model_refine = MinkUNet(in_channels=3, out_channels=3*6).cuda()
40 | self.model_refine.load_state_dict(refiner_weights, strict=True)
41 | self.model_refine.eval()
42 |
43 | self.cuda()
44 |
45 | # for fast sampling
46 | self.dpm_scheduler = DPMSolverMultistepScheduler(
47 | num_train_timesteps=1000,
48 | beta_start=3.5e-5,
49 | beta_end=0.007,
50 | beta_schedule='linear',
51 | algorithm_type='sde-dpmsolver++',
52 | solver_order=2,
53 | )
54 | self.dpm_scheduler.set_timesteps(num_inference_steps=denoising_steps)
55 | self.scheduler_to_cuda()
56 |
57 | self.w_uncond = uncond_w
58 |
59 | def scheduler_to_cuda(self):
60 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.cuda()
61 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.cuda()
62 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.cuda()
63 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.cuda()
64 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.cuda()
65 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.cuda()
66 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.cuda()
67 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.cuda()
68 |
69 | def feats_to_coord(self, p_feats, resolution, mean=None, std=None):
70 | p_feats = p_feats.reshape(mean.shape[0],-1,3)
71 | p_coord = torch.round(p_feats / resolution)
72 |
73 | return p_coord.reshape(-1,3)
74 |
75 | def points_to_tensor(self, points):
76 | x_feats = ME.utils.batched_coordinates(list(points[:]), dtype=torch.float32, device=self.device)
77 |
78 | x_coord = x_feats.clone()
79 | x_coord = torch.round(x_coord / 0.05)
80 |
81 | x_t = ME.TensorField(
82 | features=x_feats[:,1:],
83 | coordinates=x_coord,
84 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
85 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
86 | device=self.device,
87 | )
88 |
89 | torch.cuda.empty_cache()
90 |
91 | return x_t
92 |
93 | def reset_partial_pcd(self, x_part):
94 | x_part = self.points_to_tensor(x_part.F.reshape(1,-1,3).detach())
95 |
96 | return x_part
97 |
98 | def preprocess_scan(self, scan):
99 | dist = np.sqrt(np.sum((scan)**2, -1))
100 | scan = scan[(dist < 50.0) & (dist > 3.5)][:,:3]
101 |
102 | # use farthest point sampling
103 | pcd_scan = o3d.geometry.PointCloud()
104 | pcd_scan.points = o3d.utility.Vector3dVector(scan)
105 | pcd_scan = pcd_scan.farthest_point_down_sample(int(180000 / 10))
106 | scan = torch.tensor(np.array(pcd_scan.points)).cuda()
107 |
108 | scan = scan.repeat(10,1)
109 | scan = scan[None,:,:]
110 |
111 | return scan
112 |
113 |
114 | def postprocess_scan(self, completed_scan, input_scan):
115 | dist = np.sqrt(np.sum((completed_scan)**2, -1))
116 | post_scan = completed_scan[dist < 50.0]
117 | max_z = input_scan[...,2].max().item()
118 | min_z = (input_scan[...,2].mean() - 2 * input_scan[...,2].std()).item()
119 |
120 | post_scan = post_scan[(post_scan[:,2] < max_z) & (post_scan[:,2] > min_z)]
121 |
122 | return post_scan
123 |
124 | def complete_scan(self, pcd_part):
125 | pcd_part_rep = self.preprocess_scan(pcd_part).view(1,-1,3)
126 | # pcd_part = torch.tensor(pcd_part, device=self.device).view(1,-1,3)
127 | # print(f'pcd_part_rep.shape = {pcd_part_rep.shape}')
128 | # print(f'pcd_part.shape = {pcd_part.shape}')
129 |
130 | x_feats = pcd_part_rep + torch.randn(pcd_part_rep.shape, device=self.device)
131 | x_full = self.points_to_tensor(x_feats) # x_T
132 | x_cond = self.points_to_tensor(pcd_part_rep) # x_0
133 | x_uncond = self.points_to_tensor(torch.zeros_like(pcd_part_rep))
134 |
135 | completed_scan = self.completion_loop(pcd_part_rep, x_full, x_cond, x_uncond)
136 | post_scan = self.postprocess_scan(completed_scan, pcd_part_rep)
137 |
138 | refine_in = self.points_to_tensor(post_scan[None,:,:])
139 | offset = self.refine_forward(refine_in).reshape(-1,6,3)
140 |
141 | refine_complete_scan = post_scan[:,None,:] + offset.cpu().numpy()
142 |
143 | return refine_complete_scan.reshape(-1,3), post_scan
144 |
145 |
146 | def refine_forward(self, x_in):
147 | with torch.no_grad():
148 | offset = self.model_refine(x_in)
149 |
150 | return offset
151 |
152 | def forward(self, x_full, x_full_sparse, x_part, t):
153 | with torch.no_grad():
154 | part_feat = self.partial_enc(x_part)
155 | out = self.model(x_full, x_full_sparse, part_feat, t)
156 |
157 | torch.cuda.empty_cache()
158 | return out.reshape(t.shape[0],-1,3)
159 |
160 | def classfree_forward(self, x_t, x_cond, x_uncond, t):
161 | x_t_sparse = x_t.sparse()
162 | x_cond = self.forward(x_t, x_t_sparse, x_cond, t)
163 | x_uncond = self.forward(x_t, x_t_sparse, x_uncond, t)
164 |
165 | return x_uncond + self.w_uncond * (x_cond - x_uncond)
166 |
167 | def completion_loop(self, x_init, x_t, x_cond, x_uncond):
168 | self.scheduler_to_cuda()
169 |
170 | # for t in tqdm.tqdm(range(len(self.dpm_scheduler.timesteps))):
171 | for t in range(len(self.dpm_scheduler.timesteps)):
172 | t = self.dpm_scheduler.timesteps[t].cuda()[None]
173 |
174 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t)
175 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
176 | x_t = x_init + self.dpm_scheduler.step(noise_t, t, input_noise)['prev_sample']
177 | x_t = self.points_to_tensor(x_t)
178 |
179 | x_cond = self.reset_partial_pcd(x_cond)
180 | torch.cuda.empty_cache()
181 |
182 | return x_t.F.cpu().detach().numpy()
183 |
184 | def load_pcd(pcd_file):
185 | if pcd_file.endswith('.bin'):
186 | return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3]
187 | elif pcd_file.endswith('.ply'):
188 | return np.array(o3d.io.read_point_cloud(pcd_file).points)
189 | else:
190 | print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)")
191 |
192 | class DiffCompletion_lidiff(LightningModule):
193 | def __init__(self, diff_path, refine_path, denoising_steps, uncond_w):
194 | super().__init__()
195 |
196 | # load diff net
197 | ckpt_diff = torch.load(diff_path)
198 |
199 | dm_weights = {k.replace('model.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('model.')}
200 | encoder_weights = {k.replace('partial_enc.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('partial_enc.')}
201 |
202 | # load encoder and model
203 | self.partial_enc = MinkGlobalEnc().cuda()
204 | self.partial_enc.load_state_dict(encoder_weights, strict=True)
205 | self.partial_enc.eval()
206 |
207 | self.model = MinkUNetDiff().cuda()
208 | self.model.load_state_dict(dm_weights, strict=True)
209 | self.model.eval()
210 |
211 | # load refiner
212 | ckpt_refine = torch.load(refine_path)
213 | refiner_weights = {k.replace('model_refine.', ''): v for k, v in ckpt_refine["state_dict"].items() if k.startswith('model_refine.')}
214 | self.model_refine = MinkUNet(in_channels=3, out_channels=3*6).cuda()
215 | self.model_refine.load_state_dict(refiner_weights, strict=True)
216 | self.model_refine.eval()
217 |
218 | self.cuda()
219 |
220 | # for fast sampling
221 | self.dpm_scheduler = DPMSolverMultistepScheduler(
222 | num_train_timesteps=1000,
223 | beta_start=3.5e-5,
224 | beta_end=0.007,
225 | beta_schedule='linear',
226 | algorithm_type='sde-dpmsolver++',
227 | solver_order=2,
228 | )
229 | self.dpm_scheduler.set_timesteps(num_inference_steps=denoising_steps)
230 | self.scheduler_to_cuda()
231 |
232 | self.w_uncond = uncond_w
233 |
234 | def scheduler_to_cuda(self):
235 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.cuda()
236 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.cuda()
237 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.cuda()
238 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.cuda()
239 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.cuda()
240 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.cuda()
241 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.cuda()
242 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.cuda()
243 |
244 | def feats_to_coord(self, p_feats, resolution, mean=None, std=None):
245 | p_feats = p_feats.reshape(mean.shape[0],-1,3)
246 | p_coord = torch.round(p_feats / resolution)
247 |
248 | return p_coord.reshape(-1,3)
249 |
250 | def points_to_tensor(self, points):
251 | x_feats = ME.utils.batched_coordinates(list(points[:]), dtype=torch.float32, device=self.device)
252 |
253 | x_coord = x_feats.clone()
254 | x_coord = torch.round(x_coord / 0.05)
255 |
256 | x_t = ME.TensorField(
257 | features=x_feats[:,1:],
258 | coordinates=x_coord,
259 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
260 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
261 | device=self.device,
262 | )
263 |
264 | torch.cuda.empty_cache()
265 |
266 | return x_t
267 |
268 | def reset_partial_pcd(self, x_part):
269 | x_part = self.points_to_tensor(x_part.F.reshape(1,-1,3).detach())
270 |
271 | return x_part
272 |
273 | def preprocess_scan(self, scan):
274 | dist = np.sqrt(np.sum((scan)**2, -1))
275 | scan = scan[(dist < 50.0) & (dist > 3.5)][:,:3]
276 |
277 | # use farthest point sampling
278 | pcd_scan = o3d.geometry.PointCloud()
279 | pcd_scan.points = o3d.utility.Vector3dVector(scan)
280 | pcd_scan = pcd_scan.farthest_point_down_sample(int(180000 / 10))
281 | scan = torch.tensor(np.array(pcd_scan.points)).cuda()
282 |
283 | scan = scan.repeat(10,1)
284 | scan = scan[None,:,:]
285 |
286 | return scan
287 |
288 |
289 | def postprocess_scan(self, completed_scan, input_scan):
290 | dist = np.sqrt(np.sum((completed_scan)**2, -1))
291 | post_scan = completed_scan[dist < 50.0]
292 | max_z = input_scan[...,2].max().item()
293 | min_z = (input_scan[...,2].mean() - 2 * input_scan[...,2].std()).item()
294 |
295 | post_scan = post_scan[(post_scan[:,2] < max_z) & (post_scan[:,2] > min_z)]
296 |
297 | return post_scan
298 |
299 | def complete_scan(self, pcd_part):
300 | pcd_part_rep = self.preprocess_scan(pcd_part).view(1,-1,3)
301 | # pcd_part = torch.tensor(pcd_part, device=self.device).view(1,-1,3)
302 | # print(f'pcd_part_rep.shape = {pcd_part_rep.shape}')
303 | # print(f'pcd_part.shape = {pcd_part.shape}')
304 |
305 | x_feats = pcd_part_rep + torch.randn(pcd_part_rep.shape, device=self.device)
306 | x_full = self.points_to_tensor(x_feats) # x_T
307 | x_cond = self.points_to_tensor(pcd_part_rep) # x_0
308 | x_uncond = self.points_to_tensor(torch.zeros_like(pcd_part_rep))
309 |
310 | completed_scan = self.completion_loop(pcd_part_rep, x_full, x_cond, x_uncond)
311 | post_scan = self.postprocess_scan(completed_scan, pcd_part_rep)
312 |
313 | refine_in = self.points_to_tensor(post_scan[None,:,:])
314 | offset = self.refine_forward(refine_in).reshape(-1,6,3)
315 |
316 | refine_complete_scan = post_scan[:,None,:] + offset.cpu().numpy()
317 |
318 | return refine_complete_scan.reshape(-1,3), post_scan
319 |
320 |
321 | def refine_forward(self, x_in):
322 | with torch.no_grad():
323 | offset = self.model_refine(x_in)
324 |
325 | return offset
326 |
327 | def forward(self, x_full, x_full_sparse, x_part, t):
328 | with torch.no_grad():
329 | part_feat = self.partial_enc(x_part)
330 | out = self.model(x_full, x_full_sparse, part_feat, t)
331 |
332 | torch.cuda.empty_cache()
333 | return out.reshape(t.shape[0],-1,3)
334 |
335 | def classfree_forward(self, x_t, x_cond, x_uncond, t):
336 | x_t_sparse = x_t.sparse()
337 | x_cond = self.forward(x_t, x_t_sparse, x_cond, t)
338 | x_uncond = self.forward(x_t, x_t_sparse, x_uncond, t)
339 |
340 | return x_uncond + self.w_uncond * (x_cond - x_uncond)
341 |
342 | def completion_loop(self, x_init, x_t, x_cond, x_uncond):
343 | self.scheduler_to_cuda()
344 |
345 | # for t in tqdm.tqdm(range(len(self.dpm_scheduler.timesteps))):
346 | for t in range(len(self.dpm_scheduler.timesteps)):
347 | t = self.dpm_scheduler.timesteps[t].cuda()[None]
348 |
349 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t)
350 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
351 | x_t = x_init + self.dpm_scheduler.step(noise_t, t, input_noise)['prev_sample']
352 | x_t = self.points_to_tensor(x_t)
353 |
354 | x_cond = self.reset_partial_pcd(x_cond)
355 | torch.cuda.empty_cache()
356 |
357 | return x_t.F.cpu().detach().numpy()
358 |
359 | def load_pcd(pcd_file):
360 | if pcd_file.endswith('.bin'):
361 | return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3]
362 | elif pcd_file.endswith('.ply'):
363 | return np.array(o3d.io.read_point_cloud(pcd_file).points)
364 | else:
365 | print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)")
366 |
--------------------------------------------------------------------------------
/models/minkunet.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | import MinkowskiEngine as ME
7 | import numpy as np
8 | from pykeops.torch import LazyTensor
9 |
10 | __all__ = ['MinkUNetDiff']
11 |
12 |
13 | class BasicConvolutionBlock(nn.Module):
14 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3):
15 | super().__init__()
16 | self.net = nn.Sequential(
17 | ME.MinkowskiConvolution(inc,
18 | outc,
19 | kernel_size=ks,
20 | dilation=dilation,
21 | stride=stride,
22 | dimension=D),
23 | ME.MinkowskiBatchNorm(outc),
24 | ME.MinkowskiReLU(inplace=True)
25 | )
26 |
27 | def forward(self, x):
28 | out = self.net(x)
29 | return out
30 |
31 |
32 | class BasicDeconvolutionBlock(nn.Module):
33 | def __init__(self, inc, outc, ks=3, stride=1, D=3):
34 | super().__init__()
35 | self.net = nn.Sequential(
36 | ME.MinkowskiConvolutionTranspose(inc,
37 | outc,
38 | kernel_size=ks,
39 | stride=stride,
40 | dimension=D),
41 | ME.MinkowskiBatchNorm(outc),
42 | ME.MinkowskiReLU(inplace=True)
43 | )
44 |
45 | def forward(self, x):
46 | return self.net(x)
47 |
48 |
49 | class ResidualBlock(nn.Module):
50 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3):
51 | super().__init__()
52 | self.net = nn.Sequential(
53 | ME.MinkowskiConvolution(inc,
54 | outc,
55 | kernel_size=ks,
56 | dilation=dilation,
57 | stride=stride,
58 | dimension=D),
59 | ME.MinkowskiBatchNorm(outc),
60 | ME.MinkowskiReLU(inplace=True),
61 | ME.MinkowskiConvolution(outc,
62 | outc,
63 | kernel_size=ks,
64 | dilation=dilation,
65 | stride=1,
66 | dimension=D),
67 | ME.MinkowskiBatchNorm(outc)
68 | )
69 |
70 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
71 | nn.Sequential(
72 | ME.MinkowskiConvolution(inc, outc, kernel_size=1, dilation=1, stride=stride, dimension=D),
73 | ME.MinkowskiBatchNorm(outc)
74 | )
75 |
76 | self.relu = ME.MinkowskiReLU(inplace=True)
77 |
78 | def forward(self, x):
79 | out = self.relu(self.net(x) + self.downsample(x))
80 | return out
81 |
82 |
83 | class MinkGlobalEnc(nn.Module):
84 | def __init__(self, **kwargs):
85 | super().__init__()
86 | cr = kwargs.get('cr', 1.0)
87 | in_channels = kwargs.get('in_channels', 3)
88 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
89 | cs = [int(cr * x) for x in cs]
90 | self.embed_dim = cs[-1]
91 | self.run_up = kwargs.get('run_up', True)
92 | self.D = kwargs.get('D', 3)
93 | self.stem = nn.Sequential(
94 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D),
95 | ME.MinkowskiBatchNorm(cs[0]),
96 | ME.MinkowskiReLU(True),
97 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D),
98 | ME.MinkowskiBatchNorm(cs[0]),
99 | ME.MinkowskiReLU(inplace=True)
100 | )
101 |
102 | self.stage1 = nn.Sequential(
103 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D),
104 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D),
105 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D),
106 | )
107 |
108 | self.stage2 = nn.Sequential(
109 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D),
110 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D),
111 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D),
112 | )
113 |
114 | self.stage3 = nn.Sequential(
115 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D),
116 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D),
117 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D),
118 | )
119 |
120 | self.stage4 = nn.Sequential(
121 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D),
122 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D),
123 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D),
124 | )
125 |
126 | self.weight_initialization()
127 |
128 | def weight_initialization(self):
129 | for m in self.modules():
130 | if isinstance(m, nn.BatchNorm1d):
131 | nn.init.constant_(m.weight, 1)
132 | nn.init.constant_(m.bias, 0)
133 |
134 | def forward(self, x):
135 | x0 = self.stem(x.sparse())
136 | x1 = self.stage1(x0)
137 | x2 = self.stage2(x1)
138 | x3 = self.stage3(x2)
139 | x4 = self.stage4(x3)
140 |
141 | return x4
142 |
143 |
144 | class MinkUNetDiff(nn.Module):
145 | def __init__(self, **kwargs):
146 | super().__init__()
147 |
148 | cr = kwargs.get('cr', 1.0)
149 | in_channels = kwargs.get('in_channels', 3)
150 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
151 | cs = [int(cr * x) for x in cs]
152 | self.embed_dim = cs[-1]
153 | self.run_up = kwargs.get('run_up', True)
154 | self.D = kwargs.get('D', 3)
155 | self.stem = nn.Sequential(
156 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D),
157 | ME.MinkowskiBatchNorm(cs[0]),
158 | ME.MinkowskiReLU(True),
159 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D),
160 | ME.MinkowskiBatchNorm(cs[0]),
161 | ME.MinkowskiReLU(inplace=True)
162 | )
163 |
164 | # Stage1 temp embed proj and conv
165 | self.latent_stage1 = nn.Sequential(
166 | nn.Linear(cs[4], cs[4]),
167 | nn.LeakyReLU(0.1, inplace=True),
168 | nn.Linear(cs[4], cs[4]),
169 | )
170 |
171 | self.latemp_stage1 = nn.Sequential(
172 | nn.Linear(cs[4]+cs[4], cs[4]),
173 | nn.LeakyReLU(0.1, inplace=True),
174 | nn.Linear(cs[4], cs[0]),
175 | )
176 |
177 | self.stage1_temp = nn.Sequential(
178 | nn.Linear(self.embed_dim, self.embed_dim),
179 | nn.LeakyReLU(0.1, inplace=True),
180 | nn.Linear(self.embed_dim, cs[4]),
181 | )
182 |
183 | self.stage1 = nn.Sequential(
184 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D),
185 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D),
186 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D),
187 | )
188 |
189 | # Stage2 temp embed proj and conv
190 | self.latent_stage2 = nn.Sequential(
191 | nn.Linear(cs[4], cs[4]),
192 | nn.LeakyReLU(0.1, inplace=True),
193 | nn.Linear(cs[4], cs[4]),
194 | )
195 |
196 | self.latemp_stage2 = nn.Sequential(
197 | nn.Linear(cs[4]+cs[4], cs[4]),
198 | nn.LeakyReLU(0.1, inplace=True),
199 | nn.Linear(cs[4], cs[1]),
200 | )
201 |
202 | self.stage2_temp = nn.Sequential(
203 | nn.Linear(self.embed_dim, self.embed_dim),
204 | nn.LeakyReLU(0.1, inplace=True),
205 | nn.Linear(self.embed_dim, cs[4]),
206 | )
207 |
208 | self.stage2 = nn.Sequential(
209 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D),
210 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D),
211 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D)
212 | )
213 |
214 | # Stage3 temp embed proj and conv
215 | self.latent_stage3 = nn.Sequential(
216 | nn.Linear(cs[4], cs[4]),
217 | nn.LeakyReLU(0.1, inplace=True),
218 | nn.Linear(cs[4], cs[4]),
219 | )
220 |
221 | self.latemp_stage3 = nn.Sequential(
222 | nn.Linear(cs[4]+cs[4], cs[4]),
223 | nn.LeakyReLU(0.1, inplace=True),
224 | nn.Linear(cs[4], cs[2]),
225 | )
226 |
227 | self.stage3_temp = nn.Sequential(
228 | nn.Linear(self.embed_dim, self.embed_dim),
229 | nn.LeakyReLU(0.1, inplace=True),
230 | nn.Linear(self.embed_dim, cs[4]),
231 | )
232 |
233 | self.stage3 = nn.Sequential(
234 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D),
235 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D),
236 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D),
237 | )
238 |
239 | # Stage4 temp embed proj and conv
240 | self.latent_stage4 = nn.Sequential(
241 | nn.Linear(cs[4], cs[4]),
242 | nn.LeakyReLU(0.1, inplace=True),
243 | nn.Linear(cs[4], cs[4]),
244 | )
245 |
246 | self.latemp_stage4 = nn.Sequential(
247 | nn.Linear(cs[4]+cs[4], cs[4]),
248 | nn.LeakyReLU(0.1, inplace=True),
249 | nn.Linear(cs[4], cs[3]),
250 | )
251 |
252 | self.stage4_temp = nn.Sequential(
253 | nn.Linear(self.embed_dim, self.embed_dim),
254 | nn.LeakyReLU(0.1, inplace=True),
255 | nn.Linear(self.embed_dim, cs[4]),
256 | )
257 |
258 | self.stage4 = nn.Sequential(
259 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D),
260 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D),
261 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D),
262 | )
263 |
264 | # Up1 temp embed proj and conv
265 | self.latent_up1 = nn.Sequential(
266 | nn.Linear(cs[4], cs[4]),
267 | nn.LeakyReLU(0.1, inplace=True),
268 | nn.Linear(cs[4], cs[4]),
269 | )
270 |
271 | self.latemp_up1 = nn.Sequential(
272 | nn.Linear(cs[4]+cs[4], cs[4]),
273 | nn.LeakyReLU(0.1, inplace=True),
274 | nn.Linear(cs[4], cs[4]),
275 | )
276 |
277 | self.up1_temp = nn.Sequential(
278 | nn.Linear(self.embed_dim, self.embed_dim),
279 | nn.LeakyReLU(0.1, inplace=True),
280 | nn.Linear(self.embed_dim, cs[4]),
281 | )
282 |
283 | self.up1 = nn.ModuleList([
284 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D),
285 | nn.Sequential(
286 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
287 | dilation=1, D=self.D),
288 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D),
289 | )
290 | ])
291 |
292 | # Up2 temp embed proj and conv
293 | self.latent_up2 = nn.Sequential(
294 | nn.Linear(cs[4], cs[4]),
295 | nn.LeakyReLU(0.1, inplace=True),
296 | nn.Linear(cs[4], cs[4]),
297 | )
298 |
299 | self.latemp_up2 = nn.Sequential(
300 | nn.Linear(cs[4]+cs[4], cs[5]),
301 | nn.LeakyReLU(0.1, inplace=True),
302 | nn.Linear(cs[5], cs[5]),
303 | )
304 |
305 | self.up2_temp = nn.Sequential(
306 | nn.Linear(self.embed_dim, self.embed_dim),
307 | nn.LeakyReLU(0.1, inplace=True),
308 | nn.Linear(self.embed_dim, cs[4]),
309 | )
310 |
311 | self.up2 = nn.ModuleList([
312 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D),
313 | nn.Sequential(
314 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
315 | dilation=1, D=self.D),
316 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D),
317 | )
318 | ])
319 |
320 | # Up3 temp embed proj and conv
321 | self.latent_up3 = nn.Sequential(
322 | nn.Linear(cs[4], cs[4]),
323 | nn.LeakyReLU(0.1, inplace=True),
324 | nn.Linear(cs[4], cs[4]),
325 | )
326 |
327 | self.latemp_up3 = nn.Sequential(
328 | nn.Linear(cs[4]+cs[4], cs[6]),
329 | nn.LeakyReLU(0.1, inplace=True),
330 | nn.Linear(cs[6], cs[6]),
331 | )
332 |
333 | self.up3_temp = nn.Sequential(
334 | nn.Linear(self.embed_dim, self.embed_dim),
335 | nn.LeakyReLU(0.1, inplace=True),
336 | nn.Linear(self.embed_dim, cs[4]),
337 | )
338 |
339 | self.up3 = nn.ModuleList([
340 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D),
341 | nn.Sequential(
342 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
343 | dilation=1, D=self.D),
344 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D),
345 | )
346 | ])
347 |
348 | # Up4 temp embed proj and conv
349 | self.latent_up4 = nn.Sequential(
350 | nn.Linear(cs[4], cs[4]),
351 | nn.LeakyReLU(0.1, inplace=True),
352 | nn.Linear(cs[4], cs[4]),
353 | )
354 |
355 | self.latemp_up4 = nn.Sequential(
356 | nn.Linear(cs[4]+cs[4], cs[7]),
357 | nn.LeakyReLU(0.1, inplace=True),
358 | nn.Linear(cs[7], cs[7]),
359 | )
360 |
361 | self.up4_temp = nn.Sequential(
362 | nn.Linear(self.embed_dim, self.embed_dim),
363 | nn.LeakyReLU(0.1, inplace=True),
364 | nn.Linear(self.embed_dim, cs[4]),
365 | )
366 |
367 | self.up4 = nn.ModuleList([
368 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D),
369 | nn.Sequential(
370 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1,
371 | dilation=1, D=self.D),
372 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D),
373 | )
374 | ])
375 |
376 | self.last = nn.Sequential(
377 | nn.Linear(cs[8], 20),
378 | nn.LeakyReLU(0.1, inplace=True),
379 | nn.Linear(20, 3),
380 | )
381 |
382 | self.weight_initialization()
383 |
384 | def weight_initialization(self):
385 | for m in self.modules():
386 | if isinstance(m, nn.BatchNorm1d):
387 | nn.init.constant_(m.weight, 1)
388 | nn.init.constant_(m.bias, 0)
389 |
390 | def get_timestep_embedding(self, timesteps):
391 | assert len(timesteps.shape) == 1
392 |
393 | half_dim = self.embed_dim // 2
394 | emb = np.log(10000) / (half_dim - 1)
395 | emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(torch.device('cuda'))
396 | emb = timesteps[:, None] * emb[None, :]
397 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
398 | if self.embed_dim % 2 == 1: # zero pad
399 | emb = nn.functional.pad(emb, (0, 1), "constant", 0)
400 | assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
401 | return emb
402 |
403 | def match_part_to_full(self, x_full, x_part):
404 | full_c = x_full.C.clone().float()
405 | part_c = x_part.C.clone().float()
406 |
407 | # hash batch coord
408 | max_coord = full_c.max()
409 | full_c[:,0] *= max_coord * 2.
410 | part_c[:,0] *= max_coord * 2.
411 |
412 | f_coord = LazyTensor(full_c[:,None,:])
413 | p_coord = LazyTensor(part_c[None,:,:])
414 |
415 | dist_fp = ((f_coord - p_coord)**2).sum(-1)
416 | match_feats = dist_fp.argKmin(1,dim=1)[:,0]
417 |
418 | return x_part.F[match_feats]
419 |
420 | def forward(self, x, x_sparse, part_feats, t):
421 | temp_emb = self.get_timestep_embedding(t)
422 |
423 | x0 = self.stem(x_sparse)
424 | match0 = self.match_part_to_full(x0, part_feats)
425 | p0 = self.latent_stage1(match0)
426 | t0 = self.stage1_temp(temp_emb)
427 | batch_temp = torch.unique(x0.C[:,0], return_counts=True)[1]
428 | t0 = torch.repeat_interleave(t0, batch_temp, dim=0)
429 | w0 = self.latemp_stage1(torch.cat((p0,t0),-1))
430 |
431 | x1 = self.stage1(x0*w0)
432 | match1 = self.match_part_to_full(x1, part_feats)
433 | p1 = self.latent_stage2(match1)
434 | t1 = self.stage2_temp(temp_emb)
435 | batch_temp = torch.unique(x1.C[:,0], return_counts=True)[1]
436 | t1 = torch.repeat_interleave(t1, batch_temp, dim=0)
437 | w1 = self.latemp_stage2(torch.cat((p1,t1),-1))
438 |
439 | x2 = self.stage2(x1*w1)
440 | match2 = self.match_part_to_full(x2, part_feats)
441 | p2 = self.latent_stage3(match2)
442 | t2 = self.stage3_temp(temp_emb)
443 | batch_temp = torch.unique(x2.C[:,0], return_counts=True)[1]
444 | t2 = torch.repeat_interleave(t2, batch_temp, dim=0)
445 | w2 = self.latemp_stage3(torch.cat((p2,t2),-1))
446 |
447 | x3 = self.stage3(x2*w2)
448 | match3 = self.match_part_to_full(x3, part_feats)
449 | p3 = self.latent_stage4(match3)
450 | t3 = self.stage4_temp(temp_emb)
451 | batch_temp = torch.unique(x3.C[:,0], return_counts=True)[1]
452 | t3 = torch.repeat_interleave(t3, batch_temp, dim=0)
453 | w3 = self.latemp_stage4(torch.cat((p3,t3),-1))
454 |
455 | x4 = self.stage4(x3*w3)
456 | match4 = self.match_part_to_full(x4, part_feats)
457 | p4 = self.latent_up1(match4)
458 | t4 = self.up1_temp(temp_emb)
459 | batch_temp = torch.unique(x4.C[:,0], return_counts=True)[1]
460 | t4 = torch.repeat_interleave(t4, batch_temp, dim=0)
461 | w4 = self.latemp_up1(torch.cat((t4,p4),-1))
462 |
463 | y1 = self.up1[0](x4*w4)
464 | y1 = ME.cat(y1, x3)
465 | y1 = self.up1[1](y1)
466 | match5 = self.match_part_to_full(y1, part_feats)
467 | p5 = self.latent_up2(match5)
468 | t5 = self.up2_temp(temp_emb)
469 | batch_temp = torch.unique(y1.C[:,0], return_counts=True)[1]
470 | t5 = torch.repeat_interleave(t5, batch_temp, dim=0)
471 | w5 = self.latemp_up2(torch.cat((p5,t5),-1))
472 |
473 | y2 = self.up2[0](y1*w5)
474 | y2 = ME.cat(y2, x2)
475 | y2 = self.up2[1](y2)
476 | match6 = self.match_part_to_full(y2, part_feats)
477 | p6 = self.latent_up3(match6)
478 | t6 = self.up3_temp(temp_emb)
479 | batch_temp = torch.unique(y2.C[:,0], return_counts=True)[1]
480 | t6 = torch.repeat_interleave(t6, batch_temp, dim=0)
481 | w6 = self.latemp_up3(torch.cat((p6,t6),-1))
482 |
483 | y3 = self.up3[0](y2*w6)
484 | y3 = ME.cat(y3, x1)
485 | y3 = self.up3[1](y3)
486 | match7 = self.match_part_to_full(y3, part_feats)
487 | p7 = self.latent_up4(match7)
488 | t7 = self.up4_temp(temp_emb)
489 | batch_temp = torch.unique(y3.C[:,0], return_counts=True)[1]
490 | t7 = torch.repeat_interleave(t7, batch_temp, dim=0)
491 | w7 = self.latemp_up4(torch.cat((p7,t7),-1))
492 |
493 | y4 = self.up4[0](y3*w7)
494 | y4 = ME.cat(y4, x0)
495 | y4 = self.up4[1](y4)
496 |
497 | return self.last(y4.slice(x).F)
498 |
499 | class MinkUNet(nn.Module):
500 | def __init__(self, **kwargs):
501 | super().__init__()
502 |
503 | cr = kwargs.get('cr', 1.0)
504 | in_channels = kwargs.get('in_channels', 3)
505 | out_channels = kwargs.get('out_channels', 3)
506 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
507 | cs = [int(cr * x) for x in cs]
508 | self.run_up = kwargs.get('run_up', True)
509 | self.D = kwargs.get('D', 3)
510 | self.stem = nn.Sequential(
511 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D),
512 | ME.MinkowskiBatchNorm(cs[0]),
513 | ME.MinkowskiReLU(True),
514 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D),
515 | ME.MinkowskiBatchNorm(cs[0]),
516 | ME.MinkowskiReLU(inplace=True)
517 | )
518 |
519 | self.stage1 = nn.Sequential(
520 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D),
521 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D),
522 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D),
523 | )
524 |
525 | self.stage2 = nn.Sequential(
526 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D),
527 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D),
528 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D)
529 | )
530 |
531 | self.stage3 = nn.Sequential(
532 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D),
533 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D),
534 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D),
535 | )
536 |
537 | self.stage4 = nn.Sequential(
538 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D),
539 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D),
540 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D),
541 | )
542 |
543 | self.up1 = nn.ModuleList([
544 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D),
545 | nn.Sequential(
546 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1,
547 | dilation=1, D=self.D),
548 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D),
549 | )
550 | ])
551 |
552 | self.up2 = nn.ModuleList([
553 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D),
554 | nn.Sequential(
555 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1,
556 | dilation=1, D=self.D),
557 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D),
558 | )
559 | ])
560 |
561 | self.up3 = nn.ModuleList([
562 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D),
563 | nn.Sequential(
564 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1,
565 | dilation=1, D=self.D),
566 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D),
567 | )
568 | ])
569 |
570 | self.up4 = nn.ModuleList([
571 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D),
572 | nn.Sequential(
573 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1,
574 | dilation=1, D=self.D),
575 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D),
576 | )
577 | ])
578 |
579 | self.last = nn.Sequential(
580 | nn.Linear(cs[8], 20),
581 | nn.LeakyReLU(0.1, inplace=True),
582 | nn.Linear(20, out_channels),
583 | nn.Tanh(),
584 | )
585 |
586 | self.weight_initialization()
587 | self.dropout = nn.Dropout(0.3, True)
588 |
589 | def weight_initialization(self):
590 | for m in self.modules():
591 | if isinstance(m, nn.BatchNorm1d):
592 | nn.init.constant_(m.weight, 1)
593 | nn.init.constant_(m.bias, 0)
594 |
595 | def forward(self, x):
596 | x0 = self.stem(x.sparse())
597 | x1 = self.stage1(x0)
598 | x2 = self.stage2(x1)
599 | x3 = self.stage3(x2)
600 | x4 = self.stage4(x3)
601 |
602 | y1 = self.up1[0](x4)
603 | y1 = ME.cat(y1, x3)
604 | y1 = self.up1[1](y1)
605 |
606 | y2 = self.up2[0](y1)
607 | y2 = ME.cat(y2, x2)
608 | y2 = self.up2[1](y2)
609 |
610 | y3 = self.up3[0](y2)
611 | y3 = ME.cat(y3, x1)
612 | y3 = self.up3[1](y3)
613 |
614 | y4 = self.up4[0](y3)
615 | y4 = ME.cat(y4, x0)
616 | y4 = self.up4[1](y4)
617 |
618 | return self.last(y4.slice(x).F)
619 |
620 |
621 |
622 |
--------------------------------------------------------------------------------
/trains/DistillationDPO.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import MinkowskiEngine as ME
6 | import open3d as o3d
7 | import datetime
8 | from tqdm import tqdm
9 | from os import makedirs, path
10 | import os
11 | import copy
12 |
13 | from pytorch_lightning.core.lightning import LightningModule
14 | from pytorch_lightning import Trainer
15 | from pytorch_lightning import loggers as pl_loggers
16 | from pytorch_lightning.callbacks import ModelCheckpoint
17 | from diffusers import DPMSolverMultistepScheduler
18 |
19 | from utils.collations import *
20 | from utils.metrics import ChamferDistance, PrecisionRecall
21 | from utils.scheduling import beta_func
22 | from utils.metrics import ChamferDistance, PrecisionRecall, CompletionIoU, RMSE, EMD
23 | from utils.histogram_metrics import compute_hist_metrics
24 | from models.minkunet import MinkRewardModel,MinkGlobalEnc,MinkUNetDiff
25 | import datasets.SemanticKITTI_dataset as SemanticKITTI_dataset
26 |
27 | class DistillationDPO(LightningModule):
28 | def __init__(self, args):
29 | super().__init__()
30 |
31 | # configs
32 | self.lr = args.lr
33 | self.timestamp = args.timestamp
34 | self.args = args
35 | self.w_uncond = 3.5
36 |
37 | # Load pre-trained DM weights, init reference model
38 | dm_ckpt = torch.load(args.pre_trained_diff_path)
39 | # dm_weights = {k.replace('model.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('model.')}
40 | dm_weights = {k.replace('DiffModel.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('DiffModel.')}
41 | generator_weights = copy.deepcopy(dm_weights)
42 | auxDiffBetter_weights = copy.deepcopy(dm_weights)
43 | auxDiffWorse_weights = copy.deepcopy(dm_weights)
44 | teacher_weights = dm_weights
45 | # encoder_weights = {k.replace('partial_enc.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('partial_enc.')}
46 | encoder_weights = {k.replace('DM_encoder.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('DM_encoder.')}
47 |
48 | self.partial_enc = MinkGlobalEnc()
49 | self.generator = MinkUNetDiff()
50 | self.teacher = MinkUNetDiff()
51 | self.auxDiffBetter = MinkUNetDiff()
52 | self.auxDiffWorse = MinkUNetDiff()
53 | self.partial_enc.load_state_dict(encoder_weights, strict=True)
54 | self.generator.load_state_dict(generator_weights, strict=True)
55 | self.teacher.load_state_dict(teacher_weights, strict=True)
56 | self.auxDiffBetter.load_state_dict(auxDiffBetter_weights, strict=True)
57 | self.auxDiffWorse.load_state_dict(auxDiffWorse_weights, strict=True)
58 |
59 | for param in self.partial_enc.parameters():
60 | param.requires_grad = False
61 | for param in self.generator.parameters():
62 | param.requires_grad = True
63 | for param in self.teacher.parameters():
64 | param.requires_grad = False
65 | for param in self.auxDiffBetter.parameters():
66 | param.requires_grad = True
67 | for param in self.auxDiffWorse.parameters():
68 | param.requires_grad = True
69 |
70 | self.partial_enc.eval()
71 |
72 | # init scheduler for DM
73 | self.betas = beta_func['linear'](1000, 3.5e-5, 0.007)
74 |
75 | self.t_steps = 1000
76 | self.s_steps = 1
77 | self.s_steps_val = 8
78 |
79 | self.alphas = 1. - self.betas
80 | self.alphas_cumprod = torch.tensor(
81 | np.cumprod(self.alphas, axis=0), dtype=torch.float32, device=self.device
82 | )
83 |
84 | self.alphas_cumprod_prev = torch.tensor(
85 | np.append(1., self.alphas_cumprod[:-1].cpu().numpy()), dtype=torch.float32, device=self.device
86 | )
87 |
88 | self.betas = torch.tensor(self.betas, device=self.device)
89 | self.alphas = torch.tensor(self.alphas, device=self.device)
90 |
91 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
92 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
93 | self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod)
94 | self.sqrt_recip_alphas = torch.sqrt(1. / self.alphas)
95 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
96 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1.)
97 |
98 | self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
99 | self.sqrt_posterior_variance = torch.sqrt(self.posterior_variance)
100 | self.posterior_log_var = torch.log(
101 | torch.max(self.posterior_variance, 1e-20 * torch.ones_like(self.posterior_variance))
102 | )
103 |
104 | self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
105 | self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod)
106 |
107 | self.dpm_scheduler = DPMSolverMultistepScheduler(
108 | num_train_timesteps=self.t_steps,
109 | beta_start=3.5e-5,
110 | beta_end=0.007,
111 | beta_schedule='linear',
112 | algorithm_type='sde-dpmsolver++',
113 | solver_order=2,
114 | )
115 | self.dpm_scheduler.set_timesteps(num_inference_steps=self.s_steps)
116 | self.dpm_scheduler_val = DPMSolverMultistepScheduler(
117 | num_train_timesteps=self.t_steps,
118 | beta_start=3.5e-5,
119 | beta_end=0.007,
120 | beta_schedule='linear',
121 | algorithm_type='sde-dpmsolver++',
122 | solver_order=2,
123 | )
124 | self.dpm_scheduler_val.set_timesteps(num_inference_steps=self.s_steps_val)
125 | self.scheduler_to_cuda()
126 |
127 | # metrcis for validation
128 | self.chamfer_distance = ChamferDistance()
129 | self.precision_recall = PrecisionRecall(0.05 ,2*0.05, 100)
130 | self.completion_iou = CompletionIoU()
131 |
132 | def scheduler_to_cuda(self):
133 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.to(self.device)
134 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.to(self.device)
135 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.to(self.device)
136 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.to(self.device)
137 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.to(self.device)
138 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.to(self.device)
139 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.to(self.device)
140 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.to(self.device)
141 |
142 | self.dpm_scheduler_val.timesteps = self.dpm_scheduler_val.timesteps.to(self.device)
143 | self.dpm_scheduler_val.betas = self.dpm_scheduler_val.betas.to(self.device)
144 | self.dpm_scheduler_val.alphas = self.dpm_scheduler_val.alphas.to(self.device)
145 | self.dpm_scheduler_val.alphas_cumprod = self.dpm_scheduler_val.alphas_cumprod.to(self.device)
146 | self.dpm_scheduler_val.alpha_t = self.dpm_scheduler_val.alpha_t.to(self.device)
147 | self.dpm_scheduler_val.sigma_t = self.dpm_scheduler_val.sigma_t.to(self.device)
148 | self.dpm_scheduler_val.lambda_t = self.dpm_scheduler_val.lambda_t.to(self.device)
149 | self.dpm_scheduler_val.sigmas = self.dpm_scheduler_val.sigmas.to(self.device)
150 |
151 | def q_sample(self, x, t, noise):
152 | return self.sqrt_alphas_cumprod[t][:,None,None].to(self.device) * x + \
153 | self.sqrt_one_minus_alphas_cumprod[t][:,None,None].to(self.device) * noise
154 |
155 | def reset_partial_pcd(self, x_part, x_uncond, x_mean, x_std):
156 | x_part = self.points_to_tensor(x_part.F.reshape(x_mean.shape[0],-1,3).detach(), x_mean, x_std)
157 | x_uncond = self.points_to_tensor(
158 | torch.zeros_like(x_part.F.reshape(x_mean.shape[0],-1,3)), torch.zeros_like(x_mean), torch.zeros_like(x_std)
159 | )
160 |
161 | return x_part, x_uncond
162 |
163 | def reset_partial_pcd_part(self, x_part, x_mean, x_std):
164 | x_part = self.points_to_tensor(x_part.F.reshape(x_mean.shape[0],-1,3).detach(), x_mean, x_std)
165 |
166 | return x_part
167 |
168 | def p_sample_loop(self, x_init, x_t, x_cond, x_uncond, gt_pts, x_mean, x_std):
169 | pcd = o3d.geometry.PointCloud()
170 | self.scheduler_to_cuda()
171 |
172 | for t in tqdm(range(len(self.dpm_scheduler.timesteps))):
173 | t = torch.ones(gt_pts.shape[0]).to(self.device).long() * self.dpm_scheduler.timesteps[t].to(self.device)
174 |
175 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t)
176 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
177 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample']
178 | x_t = self.points_to_tensor(x_t, x_mean, x_std)
179 |
180 | # this is needed otherwise minkEngine will keep "stacking" coords maps over the x_part and x_uncond
181 | # i.e. memory leak
182 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std)
183 | torch.cuda.empty_cache()
184 |
185 | makedirs(f'{self.logger.log_dir}/generated_pcd/', exist_ok=True)
186 |
187 | return x_t
188 |
189 | def p_sample_with_final_step_grad(self, x_init, x_t, x_cond, x_mean, x_std):
190 |
191 | assert len(self.dpm_scheduler.timesteps) == self.s_steps
192 |
193 | with torch.no_grad():
194 | for t in range(len(self.dpm_scheduler.timesteps)-1):
195 | t = torch.full((x_init.shape[0],), self.dpm_scheduler.timesteps[t]).to(self.device)
196 |
197 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t)
198 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
199 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample']
200 | x_t = self.points_to_tensor(x_t, x_mean, x_std)
201 |
202 | t = torch.full((x_init.shape[0],), self.dpm_scheduler.timesteps[-1]).to(self.device)
203 |
204 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t)
205 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
206 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample']
207 | x_t = self.points_to_tensor(x_t, x_mean, x_std)
208 |
209 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std)
210 | torch.cuda.empty_cache()
211 |
212 | return x_t
213 |
214 | def p_sample_one_step(self, x_init, x_t, x_cond, t, x_mean, x_std):
215 |
216 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t)
217 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
218 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample']
219 | x_t = self.points_to_tensor(x_t, x_mean, x_std)
220 |
221 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std)
222 | torch.cuda.empty_cache()
223 |
224 | return x_t
225 |
226 | def pred_noise(self, model, sparse_TF, x_t_TF, t:torch.Tensor):
227 |
228 | condition = self.DM_encoder(sparse_TF)
229 | pred_noise = model(x_t_TF, x_t_TF.sparse(), condition, t)
230 |
231 | torch.cuda.empty_cache()
232 | return pred_noise
233 |
234 |
235 | def sample_val(self, x_init, x_t, x_cond, x_uncond, x_mean, x_std):
236 |
237 | assert len(self.dpm_scheduler_val.timesteps) == self.s_steps_val
238 |
239 | for t in range(len(self.dpm_scheduler_val.timesteps)):
240 | t = torch.full((x_init.shape[0],), self.dpm_scheduler_val.timesteps[t]).to(self.device)
241 |
242 | noise_t = self.classfree_forward_generator(x_t, x_cond, x_uncond, t)
243 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init
244 | x_t = x_init + self.dpm_scheduler_val.step(noise_t, t[0], input_noise)['prev_sample']
245 | x_t = self.points_to_tensor(x_t, x_mean, x_std)
246 |
247 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std)
248 | torch.cuda.empty_cache()
249 |
250 | return x_t
251 |
252 | def do_forward(self, model, x_full, x_full_sparse, x_part, t):
253 | part_feat = self.partial_enc(x_part)
254 | out = model(x_full, x_full_sparse, part_feat, t)
255 | torch.cuda.empty_cache()
256 | return out.reshape(t.shape[0],-1,3)
257 |
258 | def forward_generator(self, x_full, x_full_sparse, x_part, t):
259 | part_feat = self.partial_enc(x_part)
260 | out = self.generator(x_full, x_full_sparse, part_feat, t)
261 | torch.cuda.empty_cache()
262 | return out.reshape(t.shape[0],-1,3)
263 |
264 | def classfree_forward_generator(self, x_t, x_cond, x_uncond, t):
265 | x_t_sparse = x_t.sparse()
266 | x_cond = self.forward_generator(x_t, x_t_sparse, x_cond, t)
267 | x_uncond = self.forward_generator(x_t, x_t_sparse, x_uncond, t)
268 |
269 | return x_uncond + self.w_uncond * (x_cond - x_uncond)
270 |
271 | def feats_to_coord(self, p_feats, resolution, batch_size):
272 | p_feats = p_feats.reshape(batch_size,-1,3)
273 | p_coord = torch.round(p_feats / resolution)
274 |
275 | return p_coord.reshape(-1,3)
276 |
277 | def points_to_tensor(self, x_feats, mean=None, std=None):
278 | if mean is None:
279 | batch_size = x_feats.shape[0]
280 | x_feats = ME.utils.batched_coordinates(list(x_feats[:]), dtype=torch.float32, device=self.device)
281 |
282 | x_coord = x_feats.clone()
283 | x_coord[:,1:] = self.feats_to_coord(x_feats[:,1:], 0.05, batch_size)
284 |
285 | x_t = ME.TensorField(
286 | features=x_feats[:,1:],
287 | coordinates=x_coord,
288 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
289 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
290 | device=self.device,
291 | )
292 |
293 | torch.cuda.empty_cache()
294 |
295 | return x_t
296 |
297 | else:
298 | x_feats = ME.utils.batched_coordinates(list(x_feats[:]), dtype=torch.float32, device=self.device)
299 |
300 | x_coord = x_feats.clone()
301 | x_coord[:,1:] = feats_to_coord(x_feats[:,1:], 0.05, mean.shape[0])
302 |
303 | x_t = ME.TensorField(
304 | features=x_feats[:,1:],
305 | coordinates=x_coord,
306 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
307 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
308 | device=self.device,
309 | )
310 |
311 | torch.cuda.empty_cache()
312 |
313 | return x_t
314 |
315 | def point_set_to_sparse(self, p_full, n_part):
316 | p_full = p_full[0].cpu().detach().numpy()
317 |
318 | dist_part = np.sum(p_full**2, -1)**.5
319 | p_full = p_full[(dist_part < 50.) & (dist_part > 3.5)]
320 | p_full = p_full[p_full[:,2] > -4.]
321 |
322 | pcd_part = o3d.geometry.PointCloud()
323 | pcd_part.points = o3d.utility.Vector3dVector(p_full)
324 |
325 | pcd_part = pcd_part.farthest_point_down_sample(n_part)
326 | p_part = torch.tensor(np.array(pcd_part.points))
327 |
328 | return p_part
329 |
330 | def calc_cd(self, pred:torch.Tensor, gt:torch.Tensor):
331 |
332 | assert pred.shape[0] == 1
333 | assert gt.shape[0] == 1
334 |
335 | pcd_pred = o3d.geometry.PointCloud()
336 | pcd_pred.points = o3d.utility.Vector3dVector(pred[0].cpu().detach().numpy())
337 | pcd_gt = o3d.geometry.PointCloud()
338 | pcd_gt.points = o3d.utility.Vector3dVector(gt[0].cpu().detach().numpy())
339 |
340 | dist_pt_2_gt = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt))
341 | dist_gt_2_pt = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred))
342 |
343 | cd = (np.mean(dist_gt_2_pt) + np.mean(dist_pt_2_gt)) / 2
344 |
345 | return cd
346 |
347 | def calc_jsd(self, pred:torch.Tensor, gt:torch.Tensor):
348 |
349 | assert pred.shape[0] == 1
350 | assert gt.shape[0] == 1
351 |
352 | pcd_pred = o3d.geometry.PointCloud()
353 | pcd_pred.points = o3d.utility.Vector3dVector(pred[0].cpu().detach().numpy())
354 | pcd_gt = o3d.geometry.PointCloud()
355 | pcd_gt.points = o3d.utility.Vector3dVector(gt[0].cpu().detach().numpy())
356 |
357 | jsd = compute_hist_metrics(pcd_gt, pcd_pred, bev=False)
358 |
359 | return jsd
360 |
361 | def training_step(self, batch:dict, batch_idx, optimizer_idx):
362 |
363 | # vars to be used
364 | partial_tensor = batch['pcd_part']
365 | partial_x10_tensor = partial_tensor.repeat(1,10,1)
366 | gt_tensor = batch['pcd_full']
367 | B = batch['pcd_part'].shape[0]
368 | partial_TF = self.points_to_tensor(partial_tensor)
369 |
370 | if optimizer_idx == 0: # train auxDiffs
371 |
372 | # generate a batch of better & worse samples from Generator
373 | with torch.no_grad():
374 | self.generator.eval()
375 | t_gen = torch.randint(0, self.t_steps, size=(B,)).to(self.device)
376 | noise = torch.randn(partial_x10_tensor.shape, device=self.device)
377 |
378 | # better
379 | partial_x10_better_noised_tensor = partial_x10_tensor + noise
380 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor)
381 | generated_sample1 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3)
382 |
383 |
384 | # worse
385 | partial_x10_better_noised_tensor = partial_x10_tensor + 1.1*noise
386 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor)
387 | generated_sample2 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3)
388 |
389 | # calc cd
390 | cd1 = self.calc_cd(generated_sample1, gt_tensor).item()
391 | cd2 = self.calc_cd(generated_sample2, gt_tensor).item()
392 |
393 | # compare
394 | generated_better_sample = generated_sample1 if cd1 < cd2 else generated_sample2
395 | generated_worse_sample = generated_sample2 if cd1 < cd2 else generated_sample1
396 | cd_better = cd1 if cd1 < cd2 else cd2
397 | cd_worse = cd2 if cd1 < cd2 else cd1
398 | switch = cd1 > cd2
399 | self.log('cd/better', cd_better, prog_bar=False)
400 | self.log('cd/worse', cd_worse, prog_bar=False)
401 | self.log('cd/switch', switch, prog_bar=False)
402 |
403 | # add noise to generated samples
404 | with torch.no_grad():
405 | # better
406 | t_better = torch.randint(0, self.t_steps, size=(B,)).to(self.device)
407 | noise_better = torch.randn(generated_better_sample.shape, device=self.device)
408 | generated_better_noised_sample_tensor = generated_better_sample + self.q_sample(torch.zeros_like(generated_better_sample),
409 | t_better, noise_better)
410 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor)
411 |
412 | # worse
413 | t_worse = torch.randint(0, self.t_steps, size=(B,)).to(self.device)
414 | noise_worse = torch.randn(generated_worse_sample.shape, device=self.device)
415 | generated_worse_noised_sample_tensor = generated_worse_sample + self.q_sample(torch.zeros_like(generated_worse_sample),
416 | t_worse, noise_worse)
417 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor)
418 |
419 | # denoise generated samples with auxDiffs
420 | self.auxDiffBetter.train()
421 | pred_noise_auxB = self.do_forward(self.auxDiffBetter, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(),
422 | partial_TF, t_better)
423 | self.auxDiffWorse.train()
424 | pred_noise_auxW = self.do_forward(self.auxDiffWorse, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(),
425 | partial_TF, t_worse)
426 |
427 | # calculate loss
428 | auxDiffBetter_loss = F.mse_loss(noise_better, pred_noise_auxB)
429 | auxDiffWorse_loss = F.mse_loss(noise_worse, pred_noise_auxW)
430 | loss_aux = auxDiffBetter_loss + auxDiffWorse_loss
431 |
432 | # log info on progress bar
433 | self.log('loss_auxB', auxDiffBetter_loss, prog_bar=True)
434 | self.log('loss_auxW', auxDiffWorse_loss, prog_bar=True)
435 |
436 | torch.cuda.empty_cache()
437 |
438 | return loss_aux
439 |
440 | if optimizer_idx == 1: # train generator
441 |
442 | # get a batch of sample from Generator
443 | self.generator.train()
444 | t_gen = torch.randint(0, self.t_steps, size=(B,)).to(self.device)
445 | noise = torch.randn(partial_x10_tensor.shape, device=self.device)
446 |
447 | partial_x10_better_noised_tensor = partial_x10_tensor + noise
448 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor)
449 | generated_sample1 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3)
450 |
451 | partial_x10_worse_noised_tensor = partial_x10_tensor + 1.1*noise
452 | partial_x10_worse_noised_TF = self.points_to_tensor(partial_x10_worse_noised_tensor)
453 | generated_sample2 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_worse_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3)
454 |
455 | cd1 = self.calc_cd(generated_sample1, gt_tensor).item()
456 | cd2 = self.calc_cd(generated_sample2, gt_tensor).item()
457 |
458 | generated_better_sample = generated_sample1 if cd1 < cd2 else generated_sample2
459 | generated_worse_sample = generated_sample2 if cd1 < cd2 else generated_sample1
460 |
461 | # add noise to generated samples
462 | t = torch.randint(0, self.t_steps, size=(B,)).to(self.device)
463 | noise = torch.randn(generated_better_sample.shape, device=self.device)
464 | generated_better_noised_sample_tensor = generated_better_sample + self.q_sample(torch.zeros_like(generated_better_sample), t, noise)
465 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor)
466 |
467 | generated_worse_noised_sample_tensor = generated_worse_sample + self.q_sample(torch.zeros_like(generated_worse_sample), t, noise)
468 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor)
469 |
470 | # denoise generated samples with axudiffs and teacher, respectively
471 | with torch.no_grad():
472 | self.auxDiffBetter.eval()
473 | noise_auxB = self.do_forward(self.auxDiffBetter, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(), partial_TF, t)
474 |
475 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor)
476 | self.auxDiffWorse.eval()
477 | noise_auxW = self.do_forward(self.auxDiffWorse, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(), partial_TF, t)
478 |
479 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor)
480 | self.teacher.eval()
481 | noise_better_teacher = self.do_forward(self.teacher, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(), partial_TF, t)
482 | noise_worse_teacher = self.do_forward(self.teacher, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(), partial_TF, t)
483 |
484 |
485 | distil_loss = ((noise_worse_teacher - noise_auxW) * (generated_worse_noised_sample_tensor - generated_worse_sample)).mean() \
486 | - ((noise_better_teacher - noise_auxB) * (generated_better_noised_sample_tensor - generated_better_sample)).mean()
487 |
488 | generator_loss = distil_loss
489 |
490 | # log info on progress bar
491 | self.log('loss_g', generator_loss, prog_bar=True)
492 |
493 | torch.cuda.empty_cache()
494 |
495 | return generator_loss
496 |
497 | def configure_optimizers(self):
498 |
499 | optimizer_aux = torch.optim.SGD(list(self.auxDiffBetter.parameters())+list(self.auxDiffWorse.parameters()), lr=self.args.lr)
500 | optimizer_g = torch.optim.SGD(self.generator.parameters(), lr=self.args.lr)
501 |
502 | from torch.optim.lr_scheduler import StepLR
503 | scheduler_aux = StepLR(optimizer_aux, step_size=1, gamma=0.999)
504 | scheduler_g = StepLR(optimizer_g, step_size=1, gamma=0.999)
505 |
506 | return [optimizer_aux, optimizer_g], [scheduler_aux, scheduler_g]
507 |
508 | def validation_step(self, batch:dict, batch_idx):
509 |
510 | with torch.no_grad():
511 | gt_pts = batch['pcd_full'].detach().cpu().numpy()
512 |
513 | # for inference we get the partial pcd and sample the noise around the partial
514 | x_init = batch['pcd_part'].repeat(1,10,1)
515 | x_feats = x_init + torch.randn(x_init.shape, device=self.device)
516 | x_full = self.points_to_tensor(x_feats, batch['mean'], batch['std'])
517 | x_part = self.points_to_tensor(batch['pcd_part'], batch['mean'], batch['std'])
518 | x_uncond = self.points_to_tensor(
519 | torch.zeros_like(batch['pcd_part']), torch.zeros_like(batch['mean']), torch.zeros_like(batch['std'])
520 | )
521 |
522 | x_gen_eval = self.sample_val(x_init, x_full, x_part, x_uncond, batch['mean'], batch['std'])
523 | x_gen_eval = x_gen_eval.F.reshape((gt_pts.shape[0],-1,3))
524 |
525 | for i in range(len(batch['pcd_full'])):
526 |
527 |
528 | pcd_pred = o3d.geometry.PointCloud()
529 | # pcd_pred_all = o3d.geometry.PointCloud()
530 | c_pred = x_gen_eval[i].cpu().detach().numpy()
531 |
532 | # pcd_pred_all.points = o3d.utility.Vector3dVector(c_pred)
533 | dist_pts = np.sqrt(np.sum((c_pred)**2, axis=-1))
534 | dist_idx = dist_pts < 50.0
535 | points = c_pred[dist_idx]
536 | max_z = x_init[i][...,2].max().item()
537 | min_z = (x_init[i][...,2].mean() - 2 * x_init[i][...,2].std()).item()
538 | pcd_pred.points = o3d.utility.Vector3dVector(points[(points[:,2] < max_z) & (points[:,2] > min_z)])
539 | pcd_pred.paint_uniform_color([1.0, 0.,0.])
540 |
541 | file_path = f'exp/distill/sdpo_{self.args.timestamp}/samples/{self.args.batch_size*batch_idx+i}.pcd'
542 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
543 | o3d.io.write_point_cloud(file_path, pcd_pred)
544 |
545 | pcd_gt = o3d.geometry.PointCloud()
546 | # pcd_gt_all = o3d.geometry.PointCloud()
547 | g_pred = batch['pcd_full'][i].cpu().detach().numpy()
548 | # pcd_gt_all.points = o3d.utility.Vector3dVector(g_pred)
549 | pcd_gt.points = o3d.utility.Vector3dVector(g_pred)
550 | pcd_gt.paint_uniform_color([0., 1.,0.])
551 |
552 | pcd_part = o3d.geometry.PointCloud()
553 | pcd_part.points = o3d.utility.Vector3dVector(batch['pcd_part'][i].cpu().detach().numpy())
554 | pcd_part.paint_uniform_color([0., 1.,0.])
555 |
556 | self.chamfer_distance.update(pcd_gt, pcd_pred)
557 | self.precision_recall.update(pcd_gt, pcd_pred)
558 | self.completion_iou.update(pcd_gt, pcd_pred)
559 |
560 | torch.cuda.empty_cache()
561 |
562 | cd_mean, cd_std = self.chamfer_distance.compute()
563 | pr, re, f1 = self.precision_recall.compute_auc()
564 | thr_ious = self.completion_iou.compute()
565 |
566 | return {'val_cd_mean': cd_mean, 'val_cd_std': cd_std, 'val_precision': pr, 'val_recall': re, 'val_fscore': f1, 'val_iou0.5': thr_ious[0.5], 'val_iou0.2': thr_ious[0.2], 'val_iou0.1': thr_ious[0.1]}
567 |
568 | def validation_epoch_end(self, outputs):
569 |
570 | cd_mean = np.mean(np.stack([x["val_cd_mean"] for x in outputs]))
571 | cd_std = np.mean(np.stack([x["val_cd_std"] for x in outputs]))
572 | pr = np.mean(np.stack([x["val_precision"] for x in outputs]))
573 | re = np.mean(np.stack([x["val_recall"] for x in outputs]))
574 | f1 = np.mean(np.stack([x["val_fscore"] for x in outputs]))
575 | iou0_5 = np.mean(np.stack([x["val_iou0.5"] for x in outputs]))
576 | iou0_2 = np.mean(np.stack([x["val_iou0.2"] for x in outputs]))
577 | iou0_1 = np.mean(np.stack([x["val_iou0.1"] for x in outputs]))
578 |
579 |
580 | self.log('val_cd_mean', cd_mean, prog_bar=False)
581 | self.log('val_cd_std', cd_std)
582 | self.log('val_precision', pr)
583 | self.log('val_recall', re)
584 | self.log('val_fscore', f1, prog_bar=False)
585 | self.log('val_iou05', iou0_5, prog_bar=True)
586 | self.log('val_iou02', iou0_2, prog_bar=True)
587 | self.log('val_iou01', iou0_1, prog_bar=True)
588 |
589 | self.chamfer_distance.reset()
590 | self.precision_recall.reset()
591 | self.completion_iou.reset()
592 |
593 | def valid_paths(self, filenames):
594 | output_paths = []
595 | skip = []
596 |
597 | for fname in filenames:
598 | seq_dir = f'{self.logger.log_dir}/generated_pcd/{fname.split("/")[-3]}'
599 | ply_name = f'{fname.split("/")[-1].split(".")[0]}.ply'
600 |
601 | skip.append(path.isfile(f'{seq_dir}/{ply_name}'))
602 | makedirs(seq_dir, exist_ok=True)
603 | output_paths.append(f'{seq_dir}/{ply_name}')
604 |
605 | return np.all(skip), output_paths
606 |
607 |
608 |
609 | def main(args):
610 | # metadata
611 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
612 | args.timestamp = timestamp
613 |
614 | # model
615 | model = DistillationDPO(args)
616 |
617 | # dataset
618 | dataloader = SemanticKITTI_dataset.dataloaders['KITTI'](args)
619 |
620 | # ckpt saving config
621 | checkpoint_callback = ModelCheckpoint(
622 | dirpath=f"exp/distill/sdpo_{timestamp}/checkpoints",
623 | filename="{epoch}-{step}-{val_iou05:.3f}-{val_iou02:.3f}-{val_iou01:.3f}",
624 | save_top_k=-1, # save every ckpt
625 | every_n_epochs=1,
626 | )
627 |
628 | # logger
629 | tb_logger = pl_loggers.TensorBoardLogger(f"exp/distill/sdpo_{timestamp}", default_hp_metric=False)
630 |
631 | # setup trainer
632 | trainer = Trainer(
633 | gpus=2, strategy='ddp', accelerator='gpu',
634 | max_epochs=10,
635 | logger=tb_logger,
636 | log_every_n_steps=1,
637 | callbacks=[checkpoint_callback],
638 | gradient_clip_val=0.05,
639 | val_check_interval=200,
640 | limit_val_batches=1,
641 | detect_anomaly=True,
642 | )
643 | trainer.fit(model, dataloader)
644 |
645 | if __name__ == "__main__":
646 | parser = argparse.ArgumentParser()
647 |
648 | parser.add_argument("--batch_size", default=1, type=int, required=False, help="batch size")
649 | parser.add_argument("--lr", default=1e-5, type=float, required=False, help="learning rate")
650 | parser.add_argument(
651 | "--SemanticKITTI_path",
652 | default='datasets/SemanticKITTI',
653 | type=str,
654 | required=False,
655 | help="path to SementicKITTI dataset"
656 | )
657 | parser.add_argument(
658 | "--pre_trained_diff_path",
659 | default='checkpoints/lidiff_ddpo_refined.ckpt',
660 | type=str,
661 | required=False,
662 | help="path to pre-trained diffusion model weights"
663 | )
664 |
665 | args = parser.parse_args()
666 | main(args)
--------------------------------------------------------------------------------