├── models ├── __init__.py └── minkunet.py ├── trains ├── __init__.py └── DistillationDPO.py ├── utils ├── __init__.py ├── ema.py ├── scheduling.py ├── render.py ├── EMD.py ├── histogram_metrics.py ├── data_map.py ├── collations.py ├── pcd_transforms.py ├── eval_path_get_pics.py ├── pcd_preprocess.py ├── eval_path.py ├── metrics.py └── diff_completion_pipeline.py ├── datasets ├── SemanticKITTI_dataloader │ ├── __init__.py │ ├── SemanticKITTITemporalAggr.py │ └── SemanticKITTITemporal.py └── SemanticKITTI_dataset.py ├── pics └── teaser.png ├── .gitignore ├── setup.py ├── map_from_scans.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trains/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/SemanticKITTI_dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pics/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/happyw1nd/DistillationDPO/HEAD/pics/teaser.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # experiment results 2 | exp/ 3 | 4 | # model weights 5 | checkpoints/ 6 | 7 | *.egg-info/ 8 | *.egg/ 9 | 10 | **/__pycache__/ 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | pkg_name = 'DistillationDPO' 4 | setup(name=pkg_name, version='1.0', packages=find_packages()) 5 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class EMA: 4 | def __init__(self, model, decay=0.999): 5 | self.model = model 6 | self.decay = decay 7 | self.shadow = {} 8 | self._initialize() 9 | 10 | def _initialize(self): 11 | for name, param in self.model.named_parameters(): 12 | if param.requires_grad: 13 | self.shadow[name] = param.clone().detach().to(param.device) 14 | 15 | def update(self): 16 | """Update the shadow weights using Exponential Moving Average (EMA).""" 17 | with torch.no_grad(): 18 | for name, param in self.model.named_parameters(): 19 | if param.requires_grad: 20 | self.shadow[name] = (self.decay * self.shadow[name].to(param.device) + (1.0 - self.decay) * param) 21 | 22 | def apply(self): 23 | """Apply the EMA weights to the model.""" 24 | for name, param in self.model.named_parameters(): 25 | if param.requires_grad: 26 | param.data = self.shadow[name].to(param.device) 27 | -------------------------------------------------------------------------------- /utils/scheduling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from numpy import pi 3 | 4 | def cosine_beta_schedule(timesteps, s=0.008): 5 | """ 6 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672 7 | """ 8 | steps = timesteps + 1 9 | x = torch.linspace(0, timesteps, steps) 10 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * pi * 0.5) ** 2 11 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 12 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 13 | return torch.clip(betas, 0.0001, 0.9999) 14 | 15 | def linear_beta_schedule(timesteps, beta_start, beta_end): 16 | return torch.linspace(beta_start, beta_end, timesteps) 17 | 18 | def quadratic_beta_schedule(timesteps, beta_start, beta_end): 19 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 20 | 21 | def sigmoid_beta_schedule(timesteps, beta_start, beta_end): 22 | betas = torch.linspace(-6, 6, timesteps) 23 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 24 | 25 | beta_func = { 26 | 'cosine': cosine_beta_schedule, 27 | 'linear': linear_beta_schedule, 28 | 'quadratic': quadratic_beta_schedule, 29 | 'sigmoid': sigmoid_beta_schedule, 30 | } 31 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["PYOPENGL_PLATFORM"] = "osmesa" 3 | 4 | import numpy as np 5 | import pyrender 6 | import trimesh 7 | from PIL import Image 8 | import open3d as o3d 9 | import matplotlib.pyplot as plt 10 | import matplotlib.colors as mcolors 11 | 12 | def offscreen_render(pcd, output_name): 13 | 14 | # load point cloud 15 | points = np.asarray(pcd.points) 16 | z_values = points[:, 2] 17 | 18 | lower_bound = np.percentile(z_values, 1) 19 | upper_bound = np.percentile(z_values, 99) 20 | z_values = np.clip(z_values, lower_bound, upper_bound) 21 | 22 | norm = mcolors.Normalize(vmin=z_values.min(), vmax=z_values.max()) 23 | cmap = plt.get_cmap("gist_rainbow_r") 24 | colors = cmap(norm(z_values)) 25 | colors = colors * 0.5 26 | 27 | # object 28 | m = pyrender.Mesh.from_points(points, colors=colors) 29 | 30 | # light 31 | dl = pyrender.SpotLight(color=[1.0, 1.0, 1.0], intensity=2.0,innerConeAngle=0.05, outerConeAngle=0.5) 32 | light_theta = np.radians(-40) 33 | light_pose = np.array([ 34 | [1.0, 0.0, 0.0, 0.0], # x axis ← 35 | [0.0, np.cos(light_theta), -np.sin(light_theta), 0.0], # y axis ↓ 36 | [0.0, np.sin(light_theta), np.cos(light_theta), 0.0], # z axis · 37 | [0.0, 0.0, 0.0, 1.0] 38 | ]) 39 | 40 | # camera 41 | pc = pyrender.PerspectiveCamera(yfov=np.pi / 3.0) 42 | camera_theta = np.radians(20) 43 | camera_pose_rot = np.array([ 44 | [1.0, 0.0, 0.0, 0.0], # x axis ← 45 | [0.0, np.cos(camera_theta), -np.sin(camera_theta), 0.0], # y axis ↓ 46 | [0.0, np.sin(camera_theta), np.cos(camera_theta), 0.0], # z axis · 47 | [0.0, 0.0, 0.0, 1.0] 48 | ]) 49 | camera_pose_trans = np.array([ 50 | [1.0, 0.0, 0.0, 0.0], 51 | [0.0, 1.0, 0.0, 0.0], 52 | [0.0, 0.0, 1.0, 50.0], 53 | [0.0, 0.0, 0.0, 1.0] 54 | ]) 55 | camera_pose = np.matmul(camera_pose_rot, camera_pose_trans) 56 | 57 | # scene 58 | scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0], bg_color=[1.0, 1.0, 1.0]) 59 | scene.add(m) 60 | scene.add(dl, pose=light_pose) 61 | scene.add(pc, pose=camera_pose) 62 | 63 | # renderer 64 | r = pyrender.OffscreenRenderer(viewport_width=1000, viewport_height=700, point_size=3.0) 65 | flags = pyrender.RenderFlags.SHADOWS_ALL 66 | color, depth = r.render(scene, flags=flags) 67 | r.delete() 68 | 69 | # save pics 70 | Image.fromarray(color).save(output_name) 71 | -------------------------------------------------------------------------------- /utils/EMD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from tqdm import tqdm 4 | 5 | def voxelize_point_cloud(points, voxel_size): 6 | if points.size == 0: 7 | raise ValueError() 8 | 9 | voxel_indices = np.floor(points / voxel_size).astype(int) 10 | 11 | unique_voxels, voxel_counts = np.unique(voxel_indices, axis=0, return_counts=True) 12 | 13 | weights = voxel_counts / voxel_counts.sum() 14 | 15 | voxel_centers = (unique_voxels + 0.5) * voxel_size 16 | 17 | return voxel_centers, weights 18 | 19 | def sinkhorn_knopp_emd(source_points, target_points, source_weights, target_weights, epsilon=0.01, max_iter=3000, tol=1e-3): 20 | 21 | if source_points.shape[1] != target_points.shape[1]: 22 | raise ValueError() 23 | if not (np.isclose(source_weights.sum(), 1) and np.isclose(target_weights.sum(), 1)): 24 | raise ValueError() 25 | 26 | cost_matrix = cdist(source_points, target_points, metric='euclidean') 27 | 28 | # Sinkhorn-Knopp 29 | K = np.exp(-cost_matrix / epsilon) 30 | K += 1e-9 31 | 32 | u = np.ones_like(source_weights) 33 | v = np.ones_like(target_weights) 34 | 35 | with tqdm(total=max_iter, desc="Running Sinkhorn", unit="iter") as pbar: 36 | for _ in range(max_iter): 37 | pbar.update(1) 38 | u_prev = u.copy() 39 | u = source_weights / (K @ v) 40 | v = target_weights / (K.T @ u) 41 | 42 | diff = np.linalg.norm(u - u_prev, 1) 43 | pbar.set_postfix({'diff': f"{diff:.5e}"}) 44 | if diff < tol: 45 | break 46 | 47 | transport_matrix = np.outer(u, v) * K 48 | emd_approx = np.sum(transport_matrix * cost_matrix) 49 | 50 | return emd_approx 51 | 52 | def calc_EMD_with_sinkhorn_knopp(point_cloud_1, point_cloud_2, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4): 53 | 54 | voxel_centers_1, weights_1 = voxelize_point_cloud(point_cloud_1, voxel_size) 55 | voxel_centers_2, weights_2 = voxelize_point_cloud(point_cloud_2, voxel_size) 56 | 57 | # clip 58 | if voxel_centers_1.shape[0] + voxel_centers_2.shape[0] > 130000: 59 | return None 60 | 61 | return sinkhorn_knopp_emd(voxel_centers_1, voxel_centers_2, weights_1, weights_2, epsilon, max_iter, tol) 62 | 63 | if __name__ == "__main__": 64 | 65 | num_points_1 = 10000 66 | num_points_2 = 120000 67 | point_cloud_1 = np.random.rand(num_points_1, 3) * 15 68 | point_cloud_2 = (np.random.rand(num_points_2, 3)-0.5) * 15 69 | 70 | emd_result = calc_EMD_with_sinkhorn_knopp(point_cloud_1, point_cloud_2, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4) 71 | 72 | print(f"Approximate Earth Mover's Distance (Sinkhorn): {emd_result:.4f}") 73 | -------------------------------------------------------------------------------- /datasets/SemanticKITTI_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from pytorch_lightning import LightningDataModule 4 | from datasets.SemanticKITTI_dataloader.SemanticKITTITemporal import TemporalKITTISet 5 | from utils.collations import SparseSegmentCollation 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore') 9 | 10 | __all__ = ['TemporalKittiDataModule'] 11 | 12 | class TemporalKittiDataModule(LightningDataModule): 13 | def __init__(self, args): 14 | super().__init__() 15 | self.args = args 16 | 17 | def prepare_data(self): 18 | # Augmentations 19 | pass 20 | 21 | def setup(self, stage=None): 22 | # Create datasets 23 | pass 24 | 25 | def train_dataloader(self): 26 | collate = SparseSegmentCollation() 27 | 28 | data_set = TemporalKITTISet( 29 | data_dir=self.args.SemanticKITTI_path, 30 | seqs=[ '00', '01', '02', '03', '04', '05', '06', '07', '09', '10' ], 31 | split='train', 32 | resolution=0.05, 33 | num_points=180000, 34 | max_range=50.0, 35 | dataset_norm=False, 36 | std_axis_norm=False) 37 | loader = DataLoader(data_set, batch_size=self.args.batch_size, shuffle=True, 38 | num_workers=4, collate_fn=collate) 39 | return loader 40 | 41 | def val_dataloader(self, pre_training=True): 42 | collate = SparseSegmentCollation() 43 | 44 | data_set = TemporalKITTISet( 45 | data_dir=self.args.SemanticKITTI_path, 46 | seqs=[ '08' ], 47 | split='validation', 48 | resolution=0.05, 49 | num_points=180000, 50 | max_range=50.0, 51 | dataset_norm=False, 52 | std_axis_norm=False) 53 | loader = DataLoader(data_set, batch_size=1,#self.cfg['train']['batch_size'], 54 | shuffle=False, 55 | num_workers=4, 56 | collate_fn=collate) 57 | return loader 58 | 59 | def test_dataloader(self): 60 | collate = SparseSegmentCollation() 61 | 62 | data_set = TemporalKITTISet( 63 | data_dir=self.args.SemanticKITTI_path, 64 | seqs=[ '08' ], 65 | split='validation', 66 | resolution=0.05, 67 | num_points=180000, 68 | max_range=50.0, 69 | dataset_norm=False, 70 | std_axis_norm=False) 71 | loader = DataLoader(data_set, batch_size=self.args.batch_size, shuffle=False, 72 | num_workers=4, collate_fn=collate) 73 | return loader 74 | 75 | dataloaders = { 76 | 'KITTI': TemporalKittiDataModule, 77 | } 78 | 79 | -------------------------------------------------------------------------------- /utils/histogram_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | from scipy.spatial.distance import jensenshannon 4 | from utils.metrics import ChamferDistance, PrecisionRecall 5 | import matplotlib.pyplot as plt 6 | 7 | def histogram_point_cloud(pcd, resolution, max_range, bev=False): 8 | # get bins size by the number of voxels in the pcd 9 | bins = int(2 * max_range / resolution) 10 | 11 | hist = np.histogramdd(pcd, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range])) 12 | 13 | return np.clip(hist[0], a_min=0., a_max=1.) if bev else hist[0] 14 | 15 | def compute_jsd(hist_gt, hist_pred, bev=False, visualize=False): 16 | bev_gt = hist_gt.sum(-1) if bev else hist_gt 17 | norm_bev_gt = bev_gt / bev_gt.sum() 18 | norm_bev_gt = norm_bev_gt.flatten() 19 | 20 | bev_pred = hist_pred.sum(-1) if bev else hist_pred 21 | norm_bev_pred = bev_pred / bev_pred.sum() 22 | norm_bev_pred = norm_bev_pred.flatten() 23 | 24 | if visualize: 25 | # for visualization purposes 26 | grid = np.meshgrid(np.arange(len(hist_gt)), np.arange(len(hist_gt))) 27 | points = np.concatenate((grid[0].flatten()[:,None], grid[1].flatten()[:,None]), axis=-1) 28 | points = np.concatenate((points, np.zeros((len(points),1))),axis=-1) 29 | 30 | # build bev histogram gt view 31 | norm_hist_gt = bev_gt / bev_gt.max() 32 | colors_gt = plt.get_cmap('viridis')(norm_hist_gt) 33 | pcd_gt = o3d.geometry.PointCloud() 34 | pcd_gt.points = o3d.utility.Vector3dVector(points) 35 | pcd_gt.colors = o3d.utility.Vector3dVector(colors_gt.reshape(-1,4)[:,:3]) 36 | 37 | # build bev histogram pred view 38 | norm_hist_pred = bev_pred / bev_pred.max() 39 | colors_pred = plt.get_cmap('viridis')(norm_hist_pred) 40 | pcd_pred = o3d.geometry.PointCloud() 41 | pcd_pred.points = o3d.utility.Vector3dVector(points) 42 | pcd_pred.colors = o3d.utility.Vector3dVector(colors_pred.reshape(-1,4)[:,:3]) 43 | 44 | return jensenshannon(norm_bev_gt, norm_bev_pred) 45 | 46 | 47 | def compute_hist_metrics(pcd_gt, pcd_pred, bev=False): 48 | hist_pred = histogram_point_cloud(np.array(pcd_pred.points), 0.5, 50., bev) 49 | hist_gt = histogram_point_cloud(np.array(pcd_gt.points), 0.5, 50., bev) 50 | 51 | return compute_jsd(hist_gt, hist_pred, bev) 52 | 53 | def compute_chamfer(pcd_pred, pcd_gt): 54 | chamfer_distance = ChamferDistance() 55 | chamfer_distance.update(pcd_gt, pcd_pred) 56 | cd_pred_mean, cd_pred_std = chamfer_distance.compute() 57 | 58 | return cd_pred_mean 59 | 60 | def compute_precision_recall(pcd_pred, pcd_gt): 61 | precision_recall = PrecisionRecall(0.05,2*0.05,100) 62 | precision_recall.update(pcd_gt, pcd_pred) 63 | pr, re, f1 = precision_recall.compute_auc() 64 | 65 | return pr, re, f1 66 | 67 | def preprocess_pcd(pcd): 68 | points = np.array(pcd.points) 69 | dist = np.sqrt(np.sum(points**2, axis=-1)) 70 | pcd.points = o3d.utility.Vector3dVector(points[dist < 30.]) 71 | 72 | return pcd 73 | 74 | def compute_metrics(pred_path, gt_path): 75 | pcd_pred = preprocess_pcd(o3d.io.read_point_cloud(pred_path)) 76 | points_pred = np.array(pcd_pred.points) 77 | pcd_gt = preprocess_pcd(o3d.io.read_point_cloud(gt_path)) 78 | points_gt = np.array(pcd_gt.points) 79 | 80 | jsd_pred = compute_hist_metrics(points_pred, points_gt) 81 | 82 | cd_pred = compute_chamfer(pcd_pred, pcd_gt) 83 | 84 | pr_pred, re_pred, f1_pred = compute_precision_recall(pcd_pred, pcd_gt) 85 | 86 | return cd_pred, pr_pred, re_pred, f1_pred 87 | 88 | -------------------------------------------------------------------------------- /utils/data_map.py: -------------------------------------------------------------------------------- 1 | learning_map = { 2 | 0 : 0, # "unlabeled" 3 | 1 : 0, # "outlier" mapped to "unlabeled" --------------------------mapped 4 | 10: 1, # "car" 5 | 11: 2, # "bicycle" 6 | 13: 5, # "bus" mapped to "other-vehicle" --------------------------mapped 7 | 15: 3, # "motorcycle" 8 | 16: 5, # "on-rails" mapped to "other-vehicle" ---------------------mapped 9 | 18: 4, # "truck" 10 | 20: 5, # "other-vehicle" 11 | 30: 6, # "person" 12 | 31: 7, # "bicyclist" 13 | 32: 8, # "motorcyclist" 14 | 40: 9, # "road" 15 | 44: 10, # "parking" 16 | 48: 11, # "sidewalk" 17 | 49: 12, # "other-ground" 18 | 50: 13, # "building" 19 | 51: 14, # "fence" 20 | 52: 0, # "other-structure" mapped to "unlabeled" ------------------mapped 21 | 60: 9, # "lane-marking" to "road" ---------------------------------mapped 22 | 70: 15, # "vegetation" 23 | 71: 16, # "trunk" 24 | 72: 17, # "terrain" 25 | 80: 18, # "pole" 26 | 81: 19, # "traffic-sign" 27 | 99: 0, # "other-object" to "unlabeled" ----------------------------mapped 28 | 252: 1, # "moving-car" 29 | 253: 7, # "moving-bicyclist" 30 | 254: 6, # "moving-person" 31 | 255: 8, # "moving-motorcyclist" 32 | 256: 5, # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 33 | 257: 5, # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 34 | 258: 4, # "moving-truck" 35 | 259: 5, # "moving-other-vehicle" 36 | } 37 | 38 | content = { # as a ratio with the total number of points 39 | 0: 0.03150183342534689, 40 | 1: 0.042607828674502385, 41 | 2: 0.00016609538710764618, 42 | 3: 0.00039838616015114444, 43 | 4: 0.0021649398241338114, 44 | 5: 0.0018070552978863615, 45 | 6: 0.0003375832743104974, 46 | 7: 0.00012711105887399155, 47 | 8: 3.746106399997359e-05, 48 | 9: 0.19879647126983288, 49 | 10: 0.014717169549888214, 50 | 11: 0.14392298360372, 51 | 12: 0.0039048553037472045, 52 | 13: 0.1326861944777486, 53 | 14: 0.0723592229456223, 54 | 15: 0.26681502148037506, 55 | 16: 0.006035012012626033, 56 | 17: 0.07814222006271769, 57 | 18: 0.002855498193863172, 58 | 19: 0.0006155958086189918, 59 | } 60 | 61 | content_indoor = { 62 | 0: 0.18111755628849344, 63 | 1: 0.15350115272756307, 64 | 2: 0.264323444618407, 65 | 3: 0.017095487624768667, 66 | 4: 0.02018415055214108, 67 | 5: 0.025684283218171625, 68 | 6: 0.05237503359636922, 69 | 7: 0.03495118545614923, 70 | 8: 0.04252921527371275, 71 | 9: 0.004767541066020183, 72 | 10: 0.06899976905686542, 73 | 11: 0.012345517150886037, 74 | 12: 0.12212566337045223, 75 | } 76 | 77 | labels = { 78 | 0: "unlabeled", 79 | 1: "car", 80 | 2: "bicycle", 81 | 3: "motorcycle", 82 | 4: "truck", 83 | 5: "other-vehicle", 84 | 6: "person", 85 | 7: "bicyclist", 86 | 8: "motorcyclist", 87 | 9: "road", 88 | 10: "parking", 89 | 11: "sidewalk", 90 | 12: "other-ground", 91 | 13: "building", 92 | 14: "fence", 93 | 15: "vegetation", 94 | 16: "trunk", 95 | 17: "terrain", 96 | 18: "pole", 97 | 19: "traffic-sign", 98 | } 99 | 100 | color_map = { 101 | 0: [0, 0, 0], 102 | 1: [245, 150, 100], 103 | 2: [245, 230, 100], 104 | 3: [150, 60, 30], 105 | 4: [180, 30, 80], 106 | 5: [255, 0, 0], 107 | 6: [30, 30, 255], 108 | 7: [200, 40, 255], 109 | 8: [90, 30, 150], 110 | 9: [255, 0, 255], 111 | 10: [255, 150, 255], 112 | 11: [75, 0, 75], 113 | 12: [75, 0, 175], 114 | 13: [0, 200, 255], 115 | 14: [50, 120, 255], 116 | 15: [0, 175, 0], 117 | 16: [0, 60, 135], 118 | 17: [80, 240, 150], 119 | 18: [150, 240, 255], 120 | 19: [0, 0, 255], 121 | } 122 | -------------------------------------------------------------------------------- /map_from_scans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import open3d as o3d 5 | from natsort import natsorted 6 | import click 7 | import tqdm 8 | import MinkowskiEngine as ME 9 | 10 | def parse_calibration(filename): 11 | calib = {} 12 | 13 | calib_file = open(filename) 14 | for line in calib_file: 15 | key, content = line.strip().split(":") 16 | values = [float(v) for v in content.strip().split()] 17 | 18 | pose = np.zeros((4, 4)) 19 | pose[0, 0:4] = values[0:4] 20 | pose[1, 0:4] = values[4:8] 21 | pose[2, 0:4] = values[8:12] 22 | pose[3, 3] = 1.0 23 | 24 | calib[key] = pose 25 | 26 | calib_file.close() 27 | 28 | return calib 29 | 30 | def load_poses(calib_fname, poses_fname): 31 | if os.path.exists(calib_fname): 32 | calibration = parse_calibration(calib_fname) 33 | Tr = calibration["Tr"] 34 | Tr_inv = np.linalg.inv(Tr) 35 | 36 | poses_file = open(poses_fname) 37 | poses = [] 38 | 39 | for line in poses_file: 40 | values = [float(v) for v in line.strip().split()] 41 | 42 | pose = np.zeros((4, 4)) 43 | pose[0, 0:4] = values[0:4] 44 | pose[1, 0:4] = values[4:8] 45 | pose[2, 0:4] = values[8:12] 46 | pose[3, 3] = 1.0 47 | 48 | if os.path.exists(calib_fname): 49 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 50 | else: 51 | poses.append(pose) 52 | 53 | return poses 54 | 55 | @click.command() 56 | @click.option('--path', '-p', type=str, help='path to the scan sequence') 57 | @click.option('--voxel_size', '-v', type=float, default=0.1, help='voxel size') 58 | @click.option('--cpu', '-c', is_flag=True, help='Use CPU') 59 | def main(path, voxel_size, cpu): 60 | device_label = 'cuda' if torch.cuda.is_available() else 'cpu' 61 | device_label = 'cpu' if cpu else device_label 62 | device = torch.device(device_label) 63 | for seq in ['00','01','02','03','04','05','06','07','08','09','10']: 64 | map_points = torch.empty((0,3)).to(device) 65 | 66 | poses = load_poses(os.path.join(path, seq, 'calib.txt'), os.path.join(path, seq, 'poses.txt')) 67 | for pose, pcd_path in tqdm.tqdm(list(zip(poses, natsorted(os.listdir(os.path.join(path, seq, 'velodyne')))))): 68 | pose = torch.from_numpy(pose).float().to(device) 69 | pcd_file = os.path.join(path, seq, 'velodyne', pcd_path) 70 | points = torch.from_numpy(np.fromfile(pcd_file, dtype=np.float32)).to(device) 71 | points = points.reshape(-1,4) 72 | 73 | label_file = pcd_file.replace('velodyne', 'labels').replace('.bin', '.label') 74 | l_set = np.fromfile(label_file, dtype=np.uint32) 75 | l_set = l_set.reshape((-1)) 76 | l_set = l_set & 0xFFFF 77 | 78 | # remove moving points 79 | static_idx = (l_set < 252) & (l_set > 1) 80 | points = points[static_idx] 81 | 82 | # remove flying artifacts 83 | dist = torch.pow(points, 2) 84 | dist = torch.sqrt(dist.sum(-1)) 85 | points = points[dist > 3.5] 86 | 87 | points[:,-1] = 1. 88 | points = points @ pose.T 89 | 90 | map_points = torch.cat((map_points, points[:,:3]), axis=0) 91 | _, mapping = ME.utils.sparse_quantize(coordinates=map_points / voxel_size, return_index=True, device=device_label) 92 | map_points = map_points[mapping] 93 | 94 | 95 | print(f'saving map for sequence {seq}') 96 | np.save(os.path.join(path, seq, 'map_clean.npy'), map_points.cpu().numpy()) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() -------------------------------------------------------------------------------- /utils/collations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import MinkowskiEngine as ME 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import open3d as o3d 7 | 8 | def feats_to_coord(p_feats, resolution, batch_size): 9 | p_feats = p_feats.reshape(batch_size,-1,3) 10 | p_coord = torch.round(p_feats / resolution) 11 | 12 | return p_coord.reshape(-1,3) 13 | 14 | def normalize_pcd(points, mean, std): 15 | return (points - mean[:,None,:]) / std[:,None,:] if len(mean.shape) == 2 else (points - mean) / std 16 | 17 | def unormalize_pcd(points, mean, std): 18 | return (points * std[:,None,:]) + mean[:,None,:] if len(mean.shape) == 2 else (points * std) + mean 19 | 20 | def point_set_to_sparse_refine(p_full, p_part, n_full, n_part, resolution, filename): 21 | concat_full = np.ceil(n_full / p_full.shape[0]) 22 | concat_part = np.ceil(n_part / p_part.shape[0]) 23 | 24 | #if mode == 'diffusion': 25 | #p_full = p_full[torch.randperm(p_full.shape[0])] 26 | #p_part = p_part[torch.randperm(p_part.shape[0])] 27 | #elif mode == 'refine': 28 | p_full = p_full[torch.randperm(p_full.shape[0])] 29 | p_full = torch.tensor(p_full.repeat(concat_full, 0)[:n_full]) 30 | 31 | p_part = p_part[torch.randperm(p_part.shape[0])] 32 | p_part = torch.tensor(p_part.repeat(concat_part, 0)[:n_part]) 33 | 34 | #p_feats = ME.utils.batched_coordinates([p_feats], dtype=torch.float32)[:2000] 35 | 36 | # after creating the voxel coordinates we normalize the floating coordinates towards mean=0 and std=1 37 | p_mean, p_std = p_full.mean(axis=0), p_full.std(axis=0) 38 | 39 | return [p_full, p_mean, p_std, p_part, filename] 40 | 41 | def point_set_to_sparse(p_full, p_part, n_full, n_part, resolution, filename, p_mean=None, p_std=None): 42 | concat_part = np.ceil(n_part / p_part.shape[0]) 43 | p_part = p_part.repeat(concat_part, 0) 44 | pcd_part = o3d.geometry.PointCloud() 45 | pcd_part.points = o3d.utility.Vector3dVector(p_part) 46 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd_part, voxel_size=10.) 47 | pcd_part = pcd_part.farthest_point_down_sample(n_part) 48 | p_part = torch.tensor(np.array(pcd_part.points)) 49 | 50 | in_viewpoint = viewpoint_grid.check_if_included(o3d.utility.Vector3dVector(p_full)) 51 | p_full = p_full[in_viewpoint] 52 | concat_full = np.ceil(n_full / p_full.shape[0]) 53 | 54 | p_full = p_full[torch.randperm(p_full.shape[0])] 55 | p_full = p_full.repeat(concat_full, 0)[:n_full] 56 | 57 | p_full = torch.tensor(p_full) 58 | 59 | # after creating the voxel coordinates we normalize the floating coordinates towards mean=0 and std=1 60 | p_mean = p_full.mean(axis=0) if p_mean is None else p_mean 61 | p_std = p_full.std(axis=0) if p_std is None else p_std 62 | 63 | return [p_full, p_mean, p_std, p_part, filename] 64 | 65 | def numpy_to_sparse_tensor(p_coord, p_feats, p_label=None): 66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | p_coord = ME.utils.batched_coordinates(p_coord, dtype=torch.float32) 68 | p_feats = torch.vstack(p_feats).float() 69 | 70 | if p_label is not None: 71 | p_label = ME.utils.batched_coordinates(p_label, device=torch.device('cpu')).numpy() 72 | 73 | return ME.SparseTensor( 74 | features=p_feats, 75 | coordinates=p_coord, 76 | device=device, 77 | ), p_label 78 | 79 | return ME.SparseTensor( 80 | features=p_feats, 81 | coordinates=p_coord, 82 | device=device, 83 | ) 84 | 85 | class SparseSegmentCollation: 86 | def __init__(self, mode='diffusion'): 87 | self.mode = mode 88 | return 89 | 90 | def __call__(self, data): 91 | # "transpose" the batch(pt, ptn) to batch(pt), batch(ptn) 92 | batch = list(zip(*data)) 93 | 94 | return {'pcd_full': torch.stack(batch[0]).float(), 95 | 'mean': torch.stack(batch[1]).float(), 96 | 'std': torch.stack(batch[2]).float(), 97 | 'pcd_part' if self.mode == 'diffusion' else 'pcd_noise': torch.stack(batch[3]).float(), 98 | 'filename': batch[4], 99 | } 100 | -------------------------------------------------------------------------------- /datasets/SemanticKITTI_dataloader/SemanticKITTITemporalAggr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from utils.pcd_preprocess import clusterize_pcd, visualize_pcd_clusters, point_set_to_coord_feats, overlap_clusters, aggregate_pcds 4 | from utils.pcd_transforms import * 5 | from utils.data_map import learning_map 6 | from utils.collations import point_set_to_sparse_refine 7 | import os 8 | import numpy as np 9 | import MinkowskiEngine as ME 10 | 11 | import warnings 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | ################################################# 16 | ################## Data loader ################## 17 | ################################################# 18 | 19 | class TemporalKITTISet(Dataset): 20 | def __init__(self, data_dir, scan_window, seqs, split, resolution, num_points, mode): 21 | super().__init__() 22 | self.data_dir = data_dir 23 | self.augmented_dir = 'segments_views' 24 | 25 | self.n_clusters = 50 26 | self.resolution = resolution 27 | self.scan_window = scan_window 28 | self.num_points = num_points 29 | self.seg_batch = True 30 | 31 | self.split = split 32 | self.seqs = seqs 33 | self.mode = mode 34 | 35 | # list of (shape_name, shape_txt_file_path) tuple 36 | self.datapath_list() 37 | 38 | self.nr_data = len(self.points_datapath) 39 | 40 | print('The size of %s data is %d'%(self.split,len(self.points_datapath))) 41 | 42 | def datapath_list(self): 43 | self.points_datapath = [] 44 | 45 | for seq in self.seqs: 46 | point_seq_path = os.path.join(self.data_dir, 'dataset', 'sequences', seq, 'velodyne') 47 | point_seq_bin = os.listdir(point_seq_path) 48 | point_seq_bin.sort() 49 | 50 | for file_num in range(0, len(point_seq_bin)): 51 | # we guarantee that the end of sequence will not generate single scans as aggregated pcds 52 | end_file = file_num + self.scan_window if len(point_seq_bin) - file_num > 1.5 * self.scan_window else len(point_seq_bin) 53 | self.points_datapath.append([os.path.join(point_seq_path, point_file) for point_file in point_seq_bin[file_num:end_file] ]) 54 | if end_file == len(point_seq_bin): 55 | break 56 | 57 | #self.points_datapath = self.points_datapath[:200] 58 | 59 | def transforms(self, points): 60 | points = points[None,...] 61 | 62 | points[:,:,:3] = rotate_point_cloud(points[:,:,:3]) 63 | points[:,:,:3] = rotate_perturbation_point_cloud(points[:,:,:3]) 64 | points[:,:,:3] = random_scale_point_cloud(points[:,:,:3]) 65 | points[:,:,:3] = random_flip_point_cloud(points[:,:,:3]) 66 | 67 | return points[0] 68 | 69 | def __getitem__(self, index): 70 | #index = 500 71 | seq_num = self.points_datapath[index][0].split('/')[-3] 72 | fname = self.points_datapath[index][0].split('/')[-1].split('.')[0] 73 | 74 | #t_frame = np.random.randint(len(self.points_datapath[index])) 75 | t_frame = int(len(self.points_datapath[index]) / 2) 76 | p_full, p_part = aggregate_pcds(self.points_datapath[index], self.data_dir, t_frame) 77 | 78 | p_concat = np.concatenate((p_full, p_part), axis=0) 79 | p_gt = p_concat.copy() 80 | p_concat = self.transforms(p_concat) if self.split == 'train' else p_concat 81 | 82 | p_full = p_concat.copy() 83 | p_noise = jitter_point_cloud(p_concat[None,:,:3], sigma=.2, clip=.3)[0] 84 | dist_noise = np.power(p_noise, 2) 85 | dist_noise = np.sqrt(dist_noise.sum(-1)) 86 | 87 | _, mapping = ME.utils.sparse_quantize(coordinates=p_full / 0.1, return_index=True) 88 | p_full = p_full[mapping] 89 | dist_full = np.power(p_full, 2) 90 | dist_full = np.sqrt(dist_full.sum(-1)) 91 | 92 | return point_set_to_sparse_refine( 93 | p_full[dist_full < 50.], 94 | p_noise[dist_noise < 50.], 95 | self.num_points*2, 96 | self.num_points, 97 | self.resolution, 98 | self.points_datapath[index], 99 | ) 100 | 101 | def __len__(self): 102 | #print('DATA SIZE: ', np.floor(self.nr_data / self.sampling_window), self.nr_data % self.sampling_window) 103 | return self.nr_data 104 | 105 | ################################################################################################## 106 | -------------------------------------------------------------------------------- /datasets/SemanticKITTI_dataloader/SemanticKITTITemporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from utils.pcd_preprocess import point_set_to_coord_feats, aggregate_pcds, load_poses 4 | from utils.pcd_transforms import * 5 | from utils.data_map import learning_map 6 | from utils.collations import point_set_to_sparse 7 | from natsort import natsorted 8 | import os 9 | import numpy as np 10 | import yaml 11 | 12 | import warnings 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | ################################################# 17 | ################## Data loader ################## 18 | ################################################# 19 | 20 | class TemporalKITTISet(Dataset): 21 | def __init__(self, data_dir, seqs, split, resolution, num_points, max_range, dataset_norm=False, std_axis_norm=False): 22 | super().__init__() 23 | self.data_dir = data_dir 24 | 25 | self.n_clusters = 50 26 | self.resolution = resolution 27 | self.num_points = num_points 28 | self.max_range = max_range 29 | 30 | self.split = split 31 | self.seqs = seqs 32 | self.cache_maps = {} 33 | 34 | # list of (shape_name, shape_txt_file_path) tuple 35 | self.datapath_list() 36 | self.data_stats = {'mean': None, 'std': None} 37 | 38 | if os.path.isfile(f'utils/data_stats_range_{int(self.max_range)}m.yml') and dataset_norm: 39 | stats = yaml.safe_load(open(f'utils/data_stats_range_{int(self.max_range)}m.yml')) 40 | data_mean = np.array([stats['mean_axis']['x'], stats['mean_axis']['y'], stats['mean_axis']['z']]) 41 | if std_axis_norm: 42 | data_std = np.array([stats['std_axis']['x'], stats['std_axis']['y'], stats['std_axis']['z']]) 43 | else: 44 | data_std = np.array([stats['std'], stats['std'], stats['std']]) 45 | self.data_stats = { 46 | 'mean': torch.tensor(data_mean), 47 | 'std': torch.tensor(data_std) 48 | } 49 | 50 | self.nr_data = len(self.points_datapath) 51 | 52 | print('The size of %s data is %d'%(self.split,len(self.points_datapath))) 53 | 54 | def datapath_list(self): 55 | self.points_datapath = [] 56 | self.seq_poses = [] 57 | 58 | for seq in self.seqs: 59 | point_seq_path = os.path.join(self.data_dir, 'dataset', 'sequences', seq) 60 | point_seq_bin = natsorted(os.listdir(os.path.join(point_seq_path, 'velodyne'))) 61 | poses = load_poses(os.path.join(point_seq_path, 'calib.txt'), os.path.join(point_seq_path, 'poses.txt')) 62 | p_full = np.load(f'{point_seq_path}/map_clean.npy') if self.split != 'test' else np.array([[1,0,0],[0,1,0],[0,0,1]]) 63 | self.cache_maps[seq] = p_full 64 | 65 | for file_num in range(0, len(point_seq_bin)): 66 | self.points_datapath.append(os.path.join(point_seq_path, 'velodyne', point_seq_bin[file_num])) 67 | self.seq_poses.append(poses[file_num]) 68 | 69 | def transforms(self, points): 70 | points = np.expand_dims(points, axis=0) 71 | points[:,:,:3] = rotate_point_cloud(points[:,:,:3]) 72 | points[:,:,:3] = rotate_perturbation_point_cloud(points[:,:,:3]) 73 | points[:,:,:3] = random_scale_point_cloud(points[:,:,:3]) 74 | points[:,:,:3] = random_flip_point_cloud(points[:,:,:3]) 75 | 76 | return np.squeeze(points, axis=0) 77 | 78 | def __getitem__(self, index): 79 | seq_num = self.points_datapath[index].split('/')[-3] 80 | fname = self.points_datapath[index].split('/')[-1].split('.')[0] 81 | 82 | p_part = np.fromfile(self.points_datapath[index], dtype=np.float32) 83 | p_part = p_part.reshape((-1,4))[:,:3] 84 | 85 | if self.split != 'test': 86 | label_file = self.points_datapath[index].replace('velodyne', 'labels').replace('.bin', '.label') 87 | l_set = np.fromfile(label_file, dtype=np.uint32) 88 | l_set = l_set.reshape((-1)) 89 | l_set = l_set & 0xFFFF 90 | static_idx = (l_set < 252) & (l_set > 1) 91 | p_part = p_part[static_idx] 92 | dist_part = np.sum(p_part**2, -1)**.5 93 | p_part = p_part[(dist_part < self.max_range) & (dist_part > 3.5)] 94 | p_part = p_part[p_part[:,2] > -4.] 95 | pose = self.seq_poses[index] 96 | 97 | p_map = self.cache_maps[seq_num] 98 | 99 | if self.split != 'test': 100 | trans = pose[:-1,-1] 101 | dist_full = np.sum((p_map - trans)**2, -1)**.5 102 | p_full = p_map[dist_full < self.max_range] 103 | p_full = np.concatenate((p_full, np.ones((len(p_full),1))), axis=-1) 104 | p_full = (p_full @ np.linalg.inv(pose).T)[:,:3] 105 | p_full = p_full[p_full[:,2] > -4.] 106 | else: 107 | p_full = p_part 108 | 109 | if self.split == 'train': 110 | p_concat = np.concatenate((p_full, p_part), axis=0) 111 | p_concat = self.transforms(p_concat) 112 | 113 | p_full = p_concat[:-len(p_part)] 114 | p_part = p_concat[-len(p_part):] 115 | 116 | # patial pcd has 1/10 of the complete pcd size 117 | n_part = int(self.num_points / 10.) 118 | 119 | return point_set_to_sparse( 120 | p_full, 121 | p_part, 122 | self.num_points, 123 | n_part, 124 | self.resolution, 125 | self.points_datapath[index], 126 | p_mean=self.data_stats['mean'], 127 | p_std=self.data_stats['std'], 128 | ) 129 | 130 | def __len__(self): 131 | return self.nr_data 132 | 133 | ################################################################################################## 134 | -------------------------------------------------------------------------------- /utils/pcd_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def rotate_point_cloud(batch_data): 4 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 5 | for k in range(batch_data.shape[0]): 6 | rotation_angle = np.random.uniform() * 2 * np.pi 7 | cosval = np.cos(rotation_angle) 8 | sinval = np.sin(rotation_angle) 9 | rotation_matrix = np.array([[cosval, -sinval, 0], 10 | [sinval, cosval, 0], 11 | [0, 0, 1]]) 12 | shape_pc = batch_data[k, ...] 13 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 14 | return rotated_data 15 | 16 | 17 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 18 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 19 | for k in range(batch_data.shape[0]): 20 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 21 | Rx = np.array([[1,0,0], 22 | [0,np.cos(angles[0]),-np.sin(angles[0])], 23 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 24 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 25 | [0,1,0], 26 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 27 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 28 | [np.sin(angles[2]),np.cos(angles[2]),0], 29 | [0,0,1]]) 30 | R = np.dot(Rz, np.dot(Ry,Rx)) 31 | shape_pc = batch_data[k, ...] 32 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 33 | return rotated_data 34 | 35 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 36 | B, N, C = batch_data.shape 37 | assert(clip > 0) 38 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 39 | jittered_data += batch_data 40 | return jittered_data 41 | 42 | def random_drop_n_cuboids(batch_data): 43 | batch_data = random_drop_point_cloud(batch_data) 44 | cuboids_count = 1 45 | while cuboids_count < 5 and np.random.uniform(0., 1.) > 0.3: 46 | batch_data = random_drop_point_cloud(batch_data) 47 | cuboids_count += 1 48 | 49 | return batch_data 50 | 51 | def check_aspect2D(crop_range, aspect_min): 52 | xy_aspect = np.min(crop_range[:2])/np.max(crop_range[:2]) 53 | return (xy_aspect >= aspect_min) 54 | 55 | def random_cuboid_point_cloud(batch_data): 56 | batch_data = np.expand_dims(batch_data, axis=0) 57 | 58 | B, N, C = batch_data.shape 59 | new_batch_data = [] 60 | for batch_index in range(B): 61 | range_xyz = np.max(batch_data[batch_index,:,0:2], axis=0) - np.min(batch_data[batch_index,:,0:2], axis=0) 62 | 63 | crop_range = 0.5 + (np.random.rand(2) * 0.5) 64 | 65 | loop_count = 0 66 | while not check_aspect2D(crop_range, 0.75): 67 | loop_count += 1 68 | crop_range = 0.5 + (np.random.rand(2) * 0.5) 69 | if loop_count > 100: 70 | break 71 | 72 | loop_count = 0 73 | 74 | while True: 75 | loop_count += 1 76 | new_range = range_xyz * crop_range / 2.0 77 | sample_center = batch_data[batch_index,np.random.choice(len(batch_data[batch_index])), 0:3] 78 | max_xyz = sample_center[:2] + new_range 79 | min_xyz = sample_center[:2] - new_range 80 | 81 | upper_idx = np.sum((batch_data[batch_index,:,:2] < max_xyz).astype(np.int32), 1) == 2 82 | lower_idx = np.sum((batch_data[batch_index,:,:2] > min_xyz).astype(np.int32), 1) == 2 83 | 84 | new_pointidx = ((upper_idx) & (lower_idx)) 85 | 86 | # avoid having too small point clouds 87 | if (loop_count > 100) or (np.sum(new_pointidx) > 20000): 88 | break 89 | 90 | 91 | new_batch_data.append(batch_data[batch_index,new_pointidx,:]) 92 | 93 | new_batch_data = np.array(new_batch_data) 94 | 95 | return np.squeeze(new_batch_data, axis=0) 96 | 97 | def random_drop_point_cloud(batch_data): 98 | B, N, C = batch_data.shape 99 | new_batch_data = [] 100 | for batch_index in range(B): 101 | range_xyz = np.max(batch_data[batch_index,:,0:3], axis=0) - np.min(batch_data[batch_index,:,0:3], axis=0) 102 | 103 | crop_range = np.random.uniform(0.1, 0.15) 104 | new_range = range_xyz * crop_range / 2.0 105 | sample_center = batch_data[batch_index,np.random.choice(len(batch_data[batch_index])), 0:3] 106 | max_xyz = sample_center + new_range 107 | min_xyz = sample_center - new_range 108 | 109 | upper_idx = np.sum((batch_data[batch_index,:,0:3] < max_xyz).astype(np.int32), 1) == 3 110 | lower_idx = np.sum((batch_data[batch_index,:,0:3] > min_xyz).astype(np.int32), 1) == 3 111 | 112 | new_pointidx = ~((upper_idx) & (lower_idx)) 113 | new_batch_data.append(batch_data[batch_index,new_pointidx,:]) 114 | 115 | return np.array(new_batch_data) 116 | 117 | 118 | def random_flip_point_cloud(batch_data, scale_low=0.95, scale_high=1.05): 119 | B, N, C = batch_data.shape 120 | for batch_index in range(B): 121 | if np.random.random() > 0.5: 122 | batch_data[batch_index,:,1] = -1 * batch_data[batch_index,:,1] 123 | return batch_data 124 | 125 | def random_scale_point_cloud(batch_data, scale_low=0.95, scale_high=1.05): 126 | B, N, C = batch_data.shape 127 | scales = np.random.uniform(scale_low, scale_high, B) 128 | for batch_index in range(B): 129 | batch_data[batch_index,:,:] *= scales[batch_index] 130 | return batch_data 131 | 132 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 133 | for b in range(batch_pc.shape[0]): 134 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 135 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 136 | if len(drop_idx)>0: 137 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 138 | return batch_pc 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # 💡**DistillationDPO**💡 4 | ## **[Diffusion Distillation With Direct Preference Optimization For Efficient 3D LiDAR Scene Completion [AAAI'26]](#)** 5 | 6 | by *An Zhao1, [Shengyuan Zhang](https://github.com/SYZhang0805)1, [Ling Yang](https://github.com/YangLing0818)2, [Zejian Li*](https://zejianli.github.io/)1, Jiale Wu1, Haoran Xu3, AnYang Wei3,Perry Pengyun GU3, [Lingyun Sun](https://person.zju.edu.cn/sly)1* 7 | 8 | *1Zhejiang University 2Peking University 3Zhejiang Green Zhixing Technology co., ltd* 9 | 10 | ![](./pics/teaser.png) 11 | 12 |
13 | 14 | ## **Abstract** 15 | 16 | The application of diffusion models in 3D LiDAR scene completion is limited due to diffusion's slow sampling speed. 17 | Score distillation accelerates diffusion sampling but with performance degradation, while post-training with direct policy optimization (DPO) boosts performance using preference data. 18 | This paper proposes Distillation-DPO, a novel diffusion distillation framework for LiDAR scene completion with preference aligment. 19 | First, the student model generates paired completion scenes with different initial noises. 20 | Second, using LiDAR scene evaluation metrics as preference, we construct winning and losing sample pairs. 21 | Such construction is reasonable, since most LiDAR scene metrics are informative but non-differentiable to be optimized directly. 22 | Third, Distillation-DPO optimizes the student model by exploiting the difference in score functions between the teacher and student models on the paired completion scenes. 23 | Such procedure is repeated until convergence. 24 | Extensive experiments demonstrate that, compared to state-of-the-art LiDAR scene completion diffusion models, Distillation-DPO achieves higher-quality scene completion while accelerating the completion speed by more than 5-fold. 25 | Our method is the first to explore adopting preference learning in distillation to the best of our knowledge and provide insights into preference-aligned distillation. 26 | 27 | ## **Environment setup** 28 | 29 | The following commands are tested with Python 3.8 and CUDA 11.1. 30 | 31 | Install required packages: 32 | 33 | `sudo apt install build-essential python3-dev libopenblas-dev` 34 | 35 | `pip3 install -r requirements.txt` 36 | 37 | Install [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) for sparse tensor processing: 38 | 39 | `pip3 install -U MinkowskiEngine==0.5.4 --install-option="--blas=openblas" -v --no-deps` 40 | 41 | Setup the code on the code main directory: 42 | 43 | `pip3 install -U -e .` 44 | 45 | ## **Training** 46 | 47 | We use The SemanticKITTI dataset for training. 48 | 49 | The SemanticKITTI dataset has to be downloaded from the [official site](http://www.semantic-kitti.org/dataset.html#download) and extracted in the following structure: 50 | 51 | ``` 52 | DistillationDPO/ 53 | └── datasets/ 54 | └── SemanticKITTI 55 | └── dataset 56 | └── sequences 57 | ├── 00/ 58 | │ ├── velodyne/ 59 | | | ├── 000000.bin 60 | | | ├── 000001.bin 61 | | | └── ... 62 | │ └── labels/ 63 | | ├── 000000.label 64 | | ├── 000001.label 65 | | └── ... 66 | ├── 01/ 67 | │ ... 68 | ... 69 | ``` 70 | 71 | Ground truth scenes are not provided explicitly in SemanticKITTI. To generate the ground complete scenes you can run the `map_from_scans.py` script. This will use the dataset scans and poses to generate the sequence map to be used as ground truth during training: 72 | 73 | ``` 74 | python map_from_scans.py --path datasets/SemanticKITTI/dataset/sequences/ 75 | ``` 76 | 77 | We use the diffusion-dpo fine-tuned version of [LiDiff](https://github.com/PRBonn/LiDiff) as the teacher model as well as the teacher assistant models. Download the pre-trained weights from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/lidiff_ddpo_refined.ckpt`. 78 | 79 | Once the sequences map is generated and the teacher model is downloaded you can then train the model. The training can be started with: 80 | 81 | `python trains/DistillationDPO.py --SemanticKITTI_path datasets/SemanticKITTI --pre_trained_diff_path checkpoints/lidiff_ddpo_refined.ckpt` 82 | 83 | ## **Inference & Visualization** 84 | 85 | We use [pyrender](https://github.com/mmatl/pyrender) for offscreen rendering. Please see [this guide](https://pyrender.readthedocs.io/en/latest/install/index.html#osmesa) for installation of pyrender. 86 | 87 | After correct installation of pyrender, download the refinement model 'refine_net.ckpt' from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/refine_net.ckpt`. We also provide the pre-trained weights of distillation-dpo. Download 'distillationdpo_st.ckpt' from [here](https://drive.google.com/drive/folders/1z7Iq6nPDZXtASUDP8R8sqhUAvVfRqKQH?usp=sharing) and place it at `checkpoints/distillationdpo_st.ckpt`. 88 | 89 | Then run the inference script with the following command: 90 | 91 | `python utils/eval_path_get_pics.py --diff checkpoints/distillationdpo_st.ckpt --refine checkpoints/refine_net.ckpt` 92 | 93 | This script will read all scenes in a sepcified sequence of SemanticKITTI dataset and the result images will be saved under `exp/`. 94 | 95 | ## **Citation** 96 | 97 | If you find our paper useful or relevant to your research, please kindly cite our papers: 98 | 99 | ```bibtex 100 | @misc{zhao2025diffusiondistillationdirectpreference, 101 | title={Diffusion Distillation With Direct Preference Optimization For Efficient 3D LiDAR Scene Completion}, 102 | author={An Zhao and Shengyuan Zhang and Ling Yang and Zejian Li and Jiale Wu and Haoran Xu and AnYang Wei and Perry Pengyun GU and Lingyun Sun}, 103 | year={2025}, 104 | eprint={2504.11447}, 105 | archivePrefix={arXiv}, 106 | primaryClass={cs.CV}, 107 | url={https://arxiv.org/abs/2504.11447}, 108 | } 109 | ``` 110 | 111 | ## **Credits** 112 | 113 | DistillationDPO is highly built on the following amazing open-source projects: 114 | 115 | [Lidiff](https://github.com/PRBonn/LiDiff): Scaling Diffusion Models to Real-World 3D LiDAR Scene Completion 116 | -------------------------------------------------------------------------------- /utils/eval_path_get_pics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | import tqdm 5 | from natsort import natsorted 6 | import click 7 | import json 8 | 9 | from utils.diff_completion_pipeline import DiffCompletion_DistillationDPO, DiffCompletion_lidiff 10 | from utils.histogram_metrics import compute_hist_metrics 11 | from utils.render import offscreen_render 12 | 13 | def parse_calibration(filename): 14 | calib = {} 15 | 16 | calib_file = open(filename) 17 | for line in calib_file: 18 | key, content = line.strip().split(":") 19 | values = [float(v) for v in content.strip().split()] 20 | 21 | pose = np.zeros((4, 4)) 22 | pose[0, 0:4] = values[0:4] 23 | pose[1, 0:4] = values[4:8] 24 | pose[2, 0:4] = values[8:12] 25 | pose[3, 3] = 1.0 26 | 27 | calib[key] = pose 28 | 29 | calib_file.close() 30 | 31 | return calib 32 | 33 | def load_poses(calib_fname, poses_fname): 34 | if os.path.exists(calib_fname): 35 | calibration = parse_calibration(calib_fname) 36 | Tr = calibration["Tr"] 37 | Tr_inv = np.linalg.inv(Tr) 38 | 39 | poses_file = open(poses_fname) 40 | poses = [] 41 | 42 | for line in poses_file: 43 | values = [float(v) for v in line.strip().split()] 44 | 45 | pose = np.zeros((4, 4)) 46 | pose[0, 0:4] = values[0:4] 47 | pose[1, 0:4] = values[4:8] 48 | pose[2, 0:4] = values[8:12] 49 | pose[3, 3] = 1.0 50 | 51 | if os.path.exists(calib_fname): 52 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 53 | else: 54 | poses.append(pose) 55 | 56 | return poses 57 | 58 | 59 | def get_scan_completion(scan_path, path, diff_completion, max_range): 60 | pcd_file = os.path.join(path, 'velodyne', scan_path) 61 | points = np.fromfile(pcd_file, dtype=np.float32) 62 | points = points.reshape(-1,4) 63 | dist = np.sqrt(np.sum(points[:,:3]**2, axis=-1)) 64 | input_points = points[dist < max_range, :3] 65 | if diff_completion is None: 66 | pred_path = f'{scan_path.split(".")[0]}.ply' 67 | pcd_pred = o3d.io.read_point_cloud(os.path.join(path, pred_path)) 68 | points = np.array(pcd_pred.points) 69 | dist = np.sqrt(np.sum(points**2, axis=-1)) 70 | pcd_pred.points = o3d.utility.Vector3dVector(points[dist < max_range]) 71 | else: 72 | complete_scan_refined, complete_scan = diff_completion.complete_scan(points) 73 | pcd_pred = o3d.geometry.PointCloud() 74 | pcd_pred.points = o3d.utility.Vector3dVector(complete_scan) 75 | pcd_pred_refined = o3d.geometry.PointCloud() 76 | pcd_pred_refined.points = o3d.utility.Vector3dVector(complete_scan_refined) 77 | 78 | return pcd_pred, pcd_pred_refined, input_points 79 | 80 | def get_ground_truth(pose, cur_scan, seq_map, max_range): 81 | trans = pose[:-1,-1] 82 | dist_gt = np.sum((seq_map - trans)**2, axis=-1)**.5 83 | scan_gt = seq_map[dist_gt < max_range] 84 | scan_gt = np.concatenate((scan_gt, np.ones((len(scan_gt),1))), axis=-1) 85 | scan_gt = (scan_gt @ np.linalg.inv(pose).T)[:,:3] 86 | scan_gt = scan_gt[(scan_gt[:,2] > -4.) & (scan_gt[:,2] < 4.4)] 87 | pcd_gt = o3d.geometry.PointCloud() 88 | pcd_gt.points = o3d.utility.Vector3dVector(scan_gt) 89 | 90 | # filter only over the view point 91 | cur_pcd = o3d.geometry.PointCloud() 92 | cur_pcd.points = o3d.utility.Vector3dVector(cur_scan) 93 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(cur_pcd, voxel_size=10.) 94 | in_viewpoint = viewpoint_grid.check_if_included(pcd_gt.points) 95 | points_gt = np.array(pcd_gt.points) 96 | pcd_gt.points = o3d.utility.Vector3dVector(points_gt[in_viewpoint]) 97 | 98 | return pcd_gt 99 | 100 | def pcd_denoise(pcd): 101 | cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) 102 | pcd_clean = pcd.select_by_index(ind) 103 | return pcd_clean 104 | 105 | import torch 106 | import random 107 | def set_deterministic(seed=42): 108 | random.seed(seed) 109 | np.random.seed(seed) 110 | torch.manual_seed(seed) 111 | torch.cuda.manual_seed(seed) 112 | torch.backends.cudnn.deterministic = True 113 | 114 | 115 | @click.command() 116 | @click.option('--path', '-p', type=str, default='datasets/SemanticKITTI/dataset/sequences/00', help='path to the scan sequence') 117 | @click.option('--max_range', '-m', type=float, default=50, help='max range') 118 | @click.option('--denoising_steps', '-t', type=int, default=8, help='number of denoising steps') 119 | @click.option('--cond_weight', '-s', type=float, default=3.5, help='conditioning weights') 120 | @click.option('--diff', '-d', type=str, default='checkpoints/distillationdpo_st.ckpt', help='trained diffusion model') 121 | @click.option('--refine', '-r', type=str, default='checkpoints/refine_net.ckpt', help='refinement model') 122 | @click.option('--save_path', '-sp', type=str, default='exp/pics/00', help='where to save pics') 123 | def main(path, max_range, denoising_steps, cond_weight, diff, refine, save_path): 124 | 125 | if not os.path.exists(save_path): 126 | os.makedirs(save_path) 127 | 128 | diff_completion_distdpo = DiffCompletion_DistillationDPO(diff, refine, denoising_steps, cond_weight) 129 | 130 | poses = load_poses(os.path.join(path, 'calib.txt'), os.path.join(path, 'poses.txt')) 131 | seq_map = np.load(f'{path}/map_clean.npy') 132 | 133 | import random 134 | eval_list = list(zip(poses, natsorted(os.listdir(f'{path}/velodyne')))) 135 | random.shuffle(eval_list) 136 | for pose, scan_path in tqdm.tqdm(eval_list): 137 | 138 | pcd_pred, pcd_pred_refined, cur_scan = get_scan_completion(scan_path, path, diff_completion_distdpo, max_range) 139 | pcd_in = o3d.geometry.PointCloud() 140 | pcd_in.points = o3d.utility.Vector3dVector(cur_scan) 141 | pcd_gt = get_ground_truth(pose, cur_scan, seq_map, max_range) 142 | 143 | pic_dir = os.path.join(save_path, scan_path.split('.')[0]) 144 | if not os.path.exists(pic_dir): 145 | os.makedirs(pic_dir) 146 | pic_pred_dir = os.path.join(pic_dir, f'{denoising_steps}steps.png') 147 | pic_pred_refined_dir = os.path.join(pic_dir, f'{denoising_steps}steps_refined.png') 148 | pic_gt_dir = os.path.join(pic_dir, 'gt.png') 149 | pic_in_dir = os.path.join(pic_dir, 'in.png') 150 | 151 | pcd_pred = pcd_denoise(pcd_pred) 152 | pcd_pred_refined = pcd_denoise(pcd_pred_refined) 153 | pcd_pred_lidiff = pcd_denoise(pcd_pred_lidiff) 154 | pcd_pred_refined_lidiff = pcd_denoise(pcd_pred_refined_lidiff) 155 | 156 | offscreen_render(pcd_pred, pic_pred_dir) 157 | offscreen_render(pcd_pred_refined, pic_pred_refined_dir) 158 | offscreen_render(pcd_gt, pic_gt_dir) 159 | offscreen_render(pcd_in, pic_in_dir) 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | 165 | -------------------------------------------------------------------------------- /utils/pcd_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import hdbscan 4 | import matplotlib.pyplot as plt 5 | import MinkowskiEngine as ME 6 | import os 7 | 8 | def overlap_clusters(cluster_i, cluster_j, min_cluster_point=10): 9 | # get unique labels from pcd_i and pcd_j from segments bigger than min_clsuter_point 10 | unique_i, count_i = np.unique(cluster_i, return_counts=True) 11 | unique_i = unique_i[count_i > min_cluster_point] 12 | 13 | unique_j, count_j = np.unique(cluster_j, return_counts=True) 14 | unique_j = unique_j[count_j > min_cluster_point] 15 | 16 | # get labels present on both pcd (intersection) 17 | unique_ij = np.intersect1d(unique_i, unique_j)[1:] 18 | 19 | # labels not intersecting both pcd are assigned as -1 (unlabeled) 20 | cluster_i[np.in1d(cluster_i, unique_ij, invert=True)] = -1 21 | cluster_j[np.in1d(cluster_j, unique_ij, invert=True)] = -1 22 | 23 | return cluster_i, cluster_j 24 | 25 | def parse_calibration(filename): 26 | calib = {} 27 | 28 | calib_file = open(filename) 29 | for line in calib_file: 30 | key, content = line.strip().split(":") 31 | values = [float(v) for v in content.strip().split()] 32 | 33 | pose = np.zeros((4, 4)) 34 | pose[0, 0:4] = values[0:4] 35 | pose[1, 0:4] = values[4:8] 36 | pose[2, 0:4] = values[8:12] 37 | pose[3, 3] = 1.0 38 | 39 | calib[key] = pose 40 | 41 | calib_file.close() 42 | 43 | return calib 44 | 45 | def load_poses(calib_fname, poses_fname): 46 | if os.path.exists(calib_fname): 47 | calibration = parse_calibration(calib_fname) 48 | Tr = calibration["Tr"] 49 | Tr_inv = np.linalg.inv(Tr) 50 | 51 | poses_file = open(poses_fname) 52 | poses = [] 53 | 54 | for line in poses_file: 55 | values = [float(v) for v in line.strip().split()] 56 | 57 | pose = np.zeros((4, 4)) 58 | pose[0, 0:4] = values[0:4] 59 | pose[1, 0:4] = values[4:8] 60 | pose[2, 0:4] = values[8:12] 61 | pose[3, 3] = 1.0 62 | 63 | if os.path.exists(calib_fname): 64 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 65 | else: 66 | poses.append(pose) 67 | 68 | return poses 69 | 70 | def apply_transform(points, pose): 71 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 72 | return np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)[:,:3] 73 | 74 | def undo_transform(points, pose): 75 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 76 | return np.sum(np.expand_dims(hpoints, 2) * np.linalg.inv(pose).T, axis=1)[:,:3] 77 | 78 | def aggregate_pcds(data_batch, data_dir, t_frame): 79 | # load empty pcd point cloud to aggregate 80 | pcd_full = np.empty((0,3)) 81 | pcd_part = None 82 | 83 | # define "namespace" 84 | seq_num = data_batch[0].split('/')[-3] 85 | fname = data_batch[0].split('/')[-1].split('.')[0] 86 | 87 | # load poses 88 | datapath = data_batch[0].split('velodyne')[0] 89 | poses = load_poses(os.path.join(datapath, 'calib.txt'), os.path.join(datapath, 'poses.txt')) 90 | 91 | for t in range(len(data_batch)): 92 | # load the next t scan and aggregate 93 | fname = data_batch[t].split('/')[-1].split('.')[0] 94 | 95 | # load the next t scan, apply pose and aggregate 96 | p_set = np.fromfile(data_batch[t], dtype=np.float32) 97 | p_set = p_set.reshape((-1, 4))[:,:3] 98 | 99 | label_file = data_batch[t].replace('velodyne', 'labels').replace('.bin', '.label') 100 | l_set = np.fromfile(label_file, dtype=np.uint32) 101 | l_set = l_set.reshape((-1)) 102 | l_set = l_set & 0xFFFF 103 | 104 | # remove moving points 105 | static_idx = l_set < 252 106 | p_set = p_set[static_idx] 107 | 108 | # remove flying artifacts 109 | dist = np.power(p_set, 2) 110 | dist = np.sqrt(dist.sum(-1)) 111 | p_set = p_set[dist > 3.5] 112 | 113 | pose_idx = int(fname) 114 | p_set = apply_transform(p_set, poses[pose_idx]) 115 | 116 | if t == t_frame: 117 | # will be aggregated later to the full pcd 118 | pcd_part = p_set.copy() 119 | else: 120 | pcd_full = np.vstack([pcd_full, p_set]) 121 | 122 | 123 | # get start position of each aggregated pcd 124 | 125 | pose_idx = int(fname) 126 | pcd_full = undo_transform(pcd_full, poses[pose_idx]) 127 | pcd_part = undo_transform(pcd_part, poses[pose_idx]) 128 | 129 | return pcd_full, pcd_part 130 | 131 | def clusters_hdbscan(points_set, n_clusters=50): 132 | clusterer = hdbscan.HDBSCAN(algorithm='best', alpha=1., approx_min_span_tree=True, 133 | gen_min_span_tree=True, leaf_size=100, 134 | metric='euclidean', min_cluster_size=20, min_samples=None 135 | ) 136 | 137 | clusterer.fit(points_set) 138 | 139 | labels = clusterer.labels_.copy() 140 | 141 | lbls, counts = np.unique(labels, return_counts=True) 142 | cluster_info = np.array(list(zip(lbls[1:], counts[1:]))) 143 | cluster_info = cluster_info[cluster_info[:,1].argsort()] 144 | 145 | clusters_labels = cluster_info[::-1][:n_clusters, 0] 146 | labels[np.in1d(labels, clusters_labels, invert=True)] = -1 147 | 148 | return labels 149 | 150 | def clusterize_pcd(points, ground): 151 | pcd = o3d.geometry.PointCloud() 152 | pcd.points = o3d.utility.Vector3dVector(points[:, :3]) 153 | 154 | # instead of ransac use patchwork 155 | inliers = list(np.where(ground == 9)[0]) 156 | 157 | pcd_ = pcd.select_by_index(inliers, invert=True) 158 | labels_ = np.expand_dims(clusters_hdbscan(np.asarray(pcd_.points)), axis=-1) 159 | 160 | # that is a blessing of array handling 161 | # pcd are an ordered list of points 162 | # in a list [a, b, c, d, e] if we get the ordered indices [1, 3] 163 | # we will get [b, d], however if we get ~[1, 3] we will get the opposite indices 164 | # still ordered, i.e., [a, c, e] which means listing the inliers indices and getting 165 | # the invert we will get the outliers ordered indices (a sort of indirect indices mapping) 166 | labels = np.ones((points.shape[0], 1)) * -1 167 | mask = np.ones(labels.shape[0], dtype=bool) 168 | mask[inliers] = False 169 | 170 | labels[mask] = labels_ 171 | 172 | return labels 173 | 174 | def point_set_to_coord_feats(point_set, labels, resolution, num_points, deterministic=False): 175 | p_feats = point_set.copy() 176 | p_coord = np.round(point_set[:, :3] / resolution) 177 | p_coord -= p_coord.min(0, keepdims=1) 178 | 179 | _, mapping = ME.utils.sparse_quantize(coordinates=p_coord, return_index=True) 180 | if len(mapping) > num_points: 181 | np.random.seed(42) 182 | mapping = np.random.choice(mapping, num_points, replace=False) 183 | 184 | return p_coord[mapping], p_feats[mapping], labels[mapping] 185 | 186 | def visualize_pcd_clusters(points, labels): 187 | pcd = o3d.geometry.PointCloud() 188 | pcd.points = o3d.utility.Vector3dVector(points[:,:3]) 189 | 190 | colors = np.zeros((len(labels), 4)) 191 | flat_indices = np.unique(labels[:,-1]) 192 | max_instance = len(flat_indices) 193 | colors_instance = plt.get_cmap("prism")(np.arange(len(flat_indices)) / (max_instance if max_instance > 0 else 1)) 194 | 195 | for idx in range(len(flat_indices)): 196 | colors[labels[:,-1] == flat_indices[int(idx)]] = colors_instance[int(idx)] 197 | 198 | colors[labels[:,-1] == -1] = [0.,0.,0.,0.] 199 | 200 | pcd.colors = o3d.utility.Vector3dVector(colors[:,:3]) 201 | 202 | o3d.visualization.draw_geometries([pcd]) 203 | -------------------------------------------------------------------------------- /utils/eval_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | import click 5 | import json 6 | import tqdm 7 | from natsort import natsorted 8 | 9 | from utils.metrics import ChamferDistance, PrecisionRecall, CompletionIoU, RMSE, EMD 10 | from utils.diff_completion_pipeline import DiffCompletion_DistillationDPO 11 | from utils.histogram_metrics import compute_hist_metrics 12 | 13 | completion_iou = CompletionIoU() 14 | rmse = RMSE() 15 | chamfer_distance = ChamferDistance() 16 | precision_recall = PrecisionRecall(0.05,2*0.05,100) 17 | emd = EMD(voxel_size=0.5) 18 | 19 | def parse_calibration(filename): 20 | calib = {} 21 | 22 | calib_file = open(filename) 23 | for line in calib_file: 24 | key, content = line.strip().split(":") 25 | values = [float(v) for v in content.strip().split()] 26 | 27 | pose = np.zeros((4, 4)) 28 | pose[0, 0:4] = values[0:4] 29 | pose[1, 0:4] = values[4:8] 30 | pose[2, 0:4] = values[8:12] 31 | pose[3, 3] = 1.0 32 | 33 | calib[key] = pose 34 | 35 | calib_file.close() 36 | 37 | return calib 38 | 39 | def load_poses(calib_fname, poses_fname): 40 | if os.path.exists(calib_fname): 41 | calibration = parse_calibration(calib_fname) 42 | Tr = calibration["Tr"] 43 | Tr_inv = np.linalg.inv(Tr) 44 | 45 | poses_file = open(poses_fname) 46 | poses = [] 47 | 48 | for line in poses_file: 49 | values = [float(v) for v in line.strip().split()] 50 | 51 | pose = np.zeros((4, 4)) 52 | pose[0, 0:4] = values[0:4] 53 | pose[1, 0:4] = values[4:8] 54 | pose[2, 0:4] = values[8:12] 55 | pose[3, 3] = 1.0 56 | 57 | if os.path.exists(calib_fname): 58 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 59 | else: 60 | poses.append(pose) 61 | 62 | return poses 63 | 64 | 65 | def get_scan_completion(scan_path, path, diff_completion, max_range, use_refine): 66 | pcd_file = os.path.join(path, 'velodyne', scan_path) 67 | points = np.fromfile(pcd_file, dtype=np.float32) 68 | points = points.reshape(-1,4) 69 | dist = np.sqrt(np.sum(points[:,:3]**2, axis=-1)) 70 | input_points = points[dist < max_range, :3] 71 | if diff_completion is None: 72 | pred_path = f'{scan_path.split(".")[0]}.ply' 73 | pcd_pred = o3d.io.read_point_cloud(os.path.join(path, pred_path)) 74 | points = np.array(pcd_pred.points) 75 | dist = np.sqrt(np.sum(points**2, axis=-1)) 76 | pcd_pred.points = o3d.utility.Vector3dVector(points[dist < max_range]) 77 | else: 78 | # points = points[:,:3] 79 | if use_refine: 80 | complete_scan, _ = diff_completion.complete_scan(points) 81 | else: 82 | _, complete_scan = diff_completion.complete_scan(points) 83 | pcd_pred = o3d.geometry.PointCloud() 84 | pcd_pred.points = o3d.utility.Vector3dVector(complete_scan) 85 | 86 | return pcd_pred, input_points 87 | 88 | def get_ground_truth(pose, cur_scan, seq_map, max_range): 89 | trans = pose[:-1,-1] 90 | dist_gt = np.sum((seq_map - trans)**2, axis=-1)**.5 91 | scan_gt = seq_map[dist_gt < max_range] 92 | scan_gt = np.concatenate((scan_gt, np.ones((len(scan_gt),1))), axis=-1) 93 | scan_gt = (scan_gt @ np.linalg.inv(pose).T)[:,:3] 94 | scan_gt = scan_gt[(scan_gt[:,2] > -4.) & (scan_gt[:,2] < 4.4)] 95 | pcd_gt = o3d.geometry.PointCloud() 96 | pcd_gt.points = o3d.utility.Vector3dVector(scan_gt) 97 | 98 | # filter only over the view point 99 | cur_pcd = o3d.geometry.PointCloud() 100 | cur_pcd.points = o3d.utility.Vector3dVector(cur_scan) 101 | viewpoint_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(cur_pcd, voxel_size=10.) 102 | in_viewpoint = viewpoint_grid.check_if_included(pcd_gt.points) 103 | points_gt = np.array(pcd_gt.points) 104 | pcd_gt.points = o3d.utility.Vector3dVector(points_gt[in_viewpoint]) 105 | 106 | return pcd_gt 107 | 108 | import torch 109 | import random 110 | def set_deterministic(seed=42): 111 | random.seed(seed) 112 | np.random.seed(seed) 113 | torch.manual_seed(seed) 114 | torch.cuda.manual_seed(seed) 115 | torch.backends.cudnn.deterministic = True 116 | 117 | 118 | @click.command() 119 | @click.option('--path', '-p', type=str, default='datasets/SemanticKITTI/dataset/sequences/08', help='path to the scan sequence') 120 | @click.option('--max_range', '-m', type=float, default=50, help='max range') 121 | @click.option('--denoising_steps', '-t', type=int, default=8, help='number of denoising steps') 122 | @click.option('--cond_weight', '-s', type=float, default=3.5, help='conditioning weights') 123 | @click.option('--diff', '-d', type=str, help='trained diffusion model') 124 | @click.option('--refine', '-r', type=str, help='refinement model') 125 | @click.option('--use_refine', '-ur', type=bool, default=False, help='whether to use refine network') 126 | @click.option('--save_name', '-sn', type=str, help='name of saved file') 127 | def main(path, max_range, denoising_steps, cond_weight, diff, refine, use_refine, save_name): 128 | 129 | diff_completion = DiffCompletion_DistillationDPO(diff, refine, denoising_steps, cond_weight) 130 | 131 | poses = load_poses(os.path.join(path, 'calib.txt'), os.path.join(path, 'poses.txt')) 132 | seq_map = np.load(f'{path}/map_clean.npy') 133 | 134 | jsd_3d = [] 135 | jsd_bev = [] 136 | 137 | import random 138 | eval_list = list(zip(poses, natsorted(os.listdir(f'{path}/velodyne')))) 139 | random.shuffle(eval_list) 140 | for pose, scan_path in tqdm.tqdm(eval_list): 141 | 142 | pcd_pred, cur_scan = get_scan_completion(scan_path, path, diff_completion, max_range, use_refine) 143 | pcd_in = o3d.geometry.PointCloud() 144 | pcd_in.points = o3d.utility.Vector3dVector(cur_scan) 145 | pcd_gt = get_ground_truth(pose, cur_scan, seq_map, max_range) 146 | 147 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_gt.ply", pcd_gt) 148 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_in.ply", pcd_in) 149 | # o3d.io.write_point_cloud(f"exp/metrics/pcd_pred.ply", pcd_pred) 150 | 151 | emd.update(pcd_gt, pcd_pred) 152 | avg_emd, _ = emd.compute() 153 | print(f'Mean emd: {avg_emd}') 154 | 155 | jsd_3d.append(compute_hist_metrics(pcd_gt, pcd_pred, bev=False)) 156 | jsd_bev.append(compute_hist_metrics(pcd_gt, pcd_pred, bev=True)) 157 | print(f'JSD 3D mean: {np.array(jsd_3d).mean()}') 158 | print(f'JSD BEV mean: {np.array(jsd_bev).mean()}') 159 | 160 | rmse.update(pcd_gt, pcd_pred) 161 | completion_iou.update(pcd_gt, pcd_pred) 162 | chamfer_distance.update(pcd_gt, pcd_pred) 163 | precision_recall.update(pcd_gt, pcd_pred) 164 | 165 | rmse_mean, rmse_std = rmse.compute() 166 | print(f'RMSE Mean: {rmse_mean}\tRMSE Std: {rmse_std}') 167 | thr_ious = completion_iou.compute() 168 | for v_size in thr_ious.keys(): 169 | print(f'Voxel {v_size}cm IOU: {thr_ious[v_size]}') 170 | cd_mean, cd_std = chamfer_distance.compute() 171 | print(f'CD Mean: {cd_mean}\tCD Std: {cd_std}') 172 | pr, re, f1 = precision_recall.compute_auc() 173 | print(f'Precision: {pr}\tRecall: {re}\tF-Score: {f1}') 174 | 175 | 176 | print('\n\n=================== FINAL RESULTS ===================\n\n') 177 | print(f'JSD 3D: {np.array(jsd_3d).mean()}') 178 | print(f'JSD BEV: {np.array(jsd_bev).mean()}') 179 | print(f'RMSE Mean: {rmse_mean}\tRMSE Std: {rmse_std}') 180 | thr_ious = completion_iou.compute() 181 | for v_size in thr_ious.keys(): 182 | print(f'Voxel {v_size}cm IOU: {thr_ious[v_size]}') 183 | cd_mean, cd_std = chamfer_distance.compute() 184 | print(f'CD Mean: {cd_mean}\tCD Std: {cd_std}') 185 | pr, re, f1 = precision_recall.compute_auc() 186 | print(f'Precision: {pr}\tRecall: {re}\tF-Score: {f1}') 187 | 188 | res_dict = { 189 | 'jsd_bev': np.array(jsd_bev).mean(), 190 | 'jsd_noclip_3d': np.array(jsd_3d).mean(), 191 | 'rmse_mean': rmse_mean, 'rmse_std': rmse_std, 192 | 'ious': thr_ious, 193 | 'cd_mean': cd_mean, 'cd_std': cd_std, 194 | 'pr': pr, 're': re, 'f1': f1, 195 | } 196 | 197 | from os.path import join, dirname, abspath 198 | with open(join(dirname(abspath(__file__)),f'{save_name}.yaml'), 'w+') as log_res: 199 | json.dump(res_dict, log_res) 200 | 201 | if __name__ == '__main__': 202 | main() 203 | 204 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import scipy 4 | from utils.EMD import calc_EMD_with_sinkhorn_knopp 5 | 6 | MESHTYPE = 6 7 | TETRATYPE = 10 8 | PCDTYPE = 1 9 | 10 | class Metrics3D(): 11 | def prediction_is_empty(self, geom): 12 | 13 | if isinstance(geom, o3d.geometry.Geometry): 14 | geom_type = geom.get_geometry_type().value 15 | if geom_type == MESHTYPE or geom_type == TETRATYPE: 16 | empty_v = self.is_empty(len(geom.vertices)) 17 | empty_t = self.is_empty(len(geom.triangles)) 18 | empty = empty_t or empty_v 19 | elif geom_type == PCDTYPE: 20 | empty = self.is_empty(len(geom.points)) 21 | else: 22 | assert False, '{} geometry not supported'.format(geom.get_geometry_type()) 23 | elif isinstance(geom, np.ndarray): 24 | empty = self.is_empty(len(geom[:, :3])) 25 | elif isinstance(geom, torch.Tensor): 26 | empty = self.is_empty(len(geom[:, :3])) 27 | else: 28 | assert False, '{} type not supported'.format(type(geom)) 29 | 30 | return empty 31 | 32 | @staticmethod 33 | def convert_to_pcd(geom): 34 | 35 | if isinstance(geom, o3d.geometry.Geometry): 36 | geom_type = geom.get_geometry_type().value 37 | if geom_type == MESHTYPE or geom_type == TETRATYPE: 38 | geom_pcd = geom.sample_points_uniformly(1000000) 39 | elif geom_type == PCDTYPE: 40 | geom_pcd = geom 41 | else: 42 | assert False, '{} geometry not supported'.format(geom.get_geometry_type()) 43 | elif isinstance(geom, np.ndarray): 44 | geom_pcd = o3d.geometry.PointCloud() 45 | geom_pcd.points = o3d.utility.Vector3dVector(geom[:, :3]) 46 | elif isinstance(geom, torch.Tensor): 47 | geom = geom.detach().cpu().numpy() 48 | geom_pcd = o3d.geometry.PointCloud() 49 | geom_pcd.points = o3d.utility.Vector3dVector(geom[:, :3]) 50 | else: 51 | assert False, '{} type not supported'.format(type(geom)) 52 | 53 | return geom_pcd 54 | 55 | @staticmethod 56 | def is_empty(length): 57 | empty = True 58 | if length: 59 | empty = False 60 | return empty 61 | 62 | input() 63 | 64 | class RMSE(): 65 | def __init__(self): 66 | self.dists = [] 67 | 68 | return 69 | 70 | def update(self, gt_pcd, pt_pcd): 71 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd)) 72 | 73 | self.dists.append(np.mean(dist_pt_2_gt)) 74 | 75 | def reset(self): 76 | self.dists = [] 77 | 78 | def compute(self): 79 | dist = np.array(self.dists) 80 | return dist.mean(), dist.std() 81 | 82 | class EMD(): 83 | def __init__(self, voxel_size=0.5, epsilon=0.001, max_iter=3000, tol=1e-4): 84 | self.emds = [] 85 | self.voxel_size = voxel_size 86 | self.epsilon = epsilon 87 | self.max_iter = max_iter 88 | self.tol = tol 89 | 90 | return 91 | 92 | def update(self, gt_pcd, pt_pcd): 93 | emd = calc_EMD_with_sinkhorn_knopp(np.array(pt_pcd.points), np.array(gt_pcd.points), voxel_size=self.voxel_size, epsilon=self.epsilon, max_iter=self.max_iter, tol=self.tol) 94 | 95 | if emd: 96 | self.emds.append(emd) 97 | 98 | def reset(self): 99 | self.emds = [] 100 | 101 | def compute(self): 102 | emds_array = np.array(self.emds) 103 | return emds_array.mean(), emds_array.std() 104 | 105 | class CompletionIoU(): 106 | def __init__(self, voxel_sizes=[0.5, 0.2, 0.1]): 107 | self.voxel_sizes = voxel_sizes 108 | # num_thresholds, tp, fp, fn 109 | self.conf_matrix = np.zeros((len(self.voxel_sizes), 3)).astype(np.uint64) 110 | 111 | def update(self, gt, pred): 112 | max_range = 50. 113 | for i, vsize in enumerate(self.voxel_sizes): 114 | bins = int(2 * max_range / vsize) 115 | vox_coords_gt = np.round(np.array(gt.points) / vsize).astype(int) 116 | hist_gt = np.histogramdd( 117 | vox_coords_gt, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range]) 118 | )[0].astype(bool).astype(int) 119 | 120 | vox_coords_pred = np.round(np.array(pred.points) / vsize).astype(int) 121 | hist_pred = np.histogramdd( 122 | vox_coords_pred, bins=bins, range=([-max_range,max_range],[-max_range,max_range],[-max_range,max_range]) 123 | )[0].astype(bool).astype(int) 124 | 125 | self.conf_matrix[i][0] += ((hist_gt == 1) & (hist_pred == 1)).sum() # tp 126 | self.conf_matrix[i][1] += ((hist_gt == 1) & (hist_pred == 0)).sum() # fn 127 | self.conf_matrix[i][2] += ((hist_gt == 0) & (hist_pred == 1)).sum() # fp 128 | 129 | def compute(self): 130 | res_vsizes = {} 131 | for i, vsize in enumerate(self.voxel_sizes): 132 | tp = self.conf_matrix[i][0] 133 | fn = self.conf_matrix[i][1] 134 | fp = self.conf_matrix[i][2] 135 | 136 | intersection = tp 137 | union = tp + fn + fp + 1e-15 138 | 139 | res_vsizes[vsize] = intersection / union 140 | 141 | return res_vsizes 142 | 143 | def reset(self): 144 | self.conf_matrix = np.zeros((len(self.voxel_sizes), 3)).astype(np.uint) 145 | 146 | class ChamferDistance(): 147 | def __init__(self): 148 | self.dists = [] 149 | 150 | return 151 | 152 | def update(self, gt_pcd, pt_pcd): 153 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd)) 154 | dist_gt_2_pt = np.asarray(gt_pcd.compute_point_cloud_distance(pt_pcd)) 155 | 156 | self.dists.append((np.mean(dist_gt_2_pt) + np.mean(dist_pt_2_gt)) / 2) 157 | 158 | def reset(self): 159 | self.dists = [] 160 | 161 | def compute(self): 162 | cdist = np.array(self.dists) 163 | return cdist.mean(), cdist.std() 164 | 165 | class PrecisionRecall(Metrics3D): 166 | 167 | def __init__(self, min_t, max_t, num): 168 | self.thresholds = np.linspace(min_t, max_t, num) 169 | self.pr_dict = {t: [] for t in self.thresholds} 170 | self.re_dict = {t: [] for t in self.thresholds} 171 | self.f1_dict = {t: [] for t in self.thresholds} 172 | 173 | def update(self, gt_pcd, pt_pcd): 174 | # precision: predicted --> ground truth 175 | dist_pt_2_gt = np.asarray(pt_pcd.compute_point_cloud_distance(gt_pcd)) 176 | 177 | # recall: ground truth --> predicted 178 | dist_gt_2_pt = np.asarray(gt_pcd.compute_point_cloud_distance(pt_pcd)) 179 | 180 | for t in self.thresholds: 181 | p = np.where(dist_pt_2_gt < t)[0] 182 | p = 100 / len(dist_pt_2_gt) * len(p) 183 | self.pr_dict[t].append(p) 184 | 185 | r = np.where(dist_gt_2_pt < t)[0] 186 | r = 100 / len(dist_gt_2_pt) * len(r) 187 | self.re_dict[t].append(r) 188 | 189 | # fscore 190 | if p == 0 or r == 0: 191 | f = 0 192 | else: 193 | f = 2 * p * r / (p + r) 194 | self.f1_dict[t].append(f) 195 | 196 | def reset(self): 197 | self.pr_dict = {t: [] for t in self.thresholds} 198 | self.re_dict = {t: [] for t in self.thresholds} 199 | self.f1_dict = {t: [] for t in self.thresholds} 200 | 201 | def compute_at_threshold(self, threshold): 202 | t = self.find_nearest_threshold(threshold) 203 | # print('computing metrics at threshold:', t) 204 | pr = sum(self.pr_dict[t]) / len(self.pr_dict[t]) 205 | re = sum(self.re_dict[t]) / len(self.re_dict[t]) 206 | f1 = sum(self.f1_dict[t]) / len(self.f1_dict[t]) 207 | # print('precision: {}'.format(pr)) 208 | # print('recall: {}'.format(re)) 209 | # print('fscore: {}'.format(f1)) 210 | return pr, re, f1, t 211 | 212 | def compute_auc(self): 213 | dx = self.thresholds[1] - self.thresholds[0] 214 | perfect_predictor = scipy.integrate.simpson(np.ones_like(self.thresholds), dx=dx) 215 | 216 | pr, re, f1 = self.compute_at_all_thresholds() 217 | 218 | pr_area = scipy.integrate.simpson(pr, dx=dx) 219 | norm_pr_area = pr_area / perfect_predictor 220 | 221 | re_area = scipy.integrate.simpson(re, dx=dx) 222 | norm_re_area = re_area / perfect_predictor 223 | 224 | f1_area = scipy.integrate.simpson(f1, dx=dx) 225 | norm_f1_area = f1_area / perfect_predictor 226 | 227 | # print('computing area under curve') 228 | # print('precision: {}'.format(norm_pr_area)) 229 | # print('recall: {}'.format(norm_re_area)) 230 | # print('fscore: {}'.format(norm_f1_area)) 231 | 232 | return norm_pr_area, norm_re_area, norm_f1_area 233 | 234 | def compute_at_all_thresholds(self): 235 | pr = [sum(self.pr_dict[t]) / len(self.pr_dict[t]) for t in self.thresholds] 236 | re = [sum(self.re_dict[t]) / len(self.re_dict[t]) for t in self.thresholds] 237 | f1 = [sum(self.f1_dict[t]) / len(self.f1_dict[t]) for t in self.thresholds] 238 | return pr, re, f1 239 | 240 | def find_nearest_threshold(self, value): 241 | idx = (np.abs(self.thresholds - value)).argmin() 242 | return self.thresholds[idx] 243 | 244 | -------------------------------------------------------------------------------- /utils/diff_completion_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import MinkowskiEngine as ME 3 | import torch 4 | from models.minkunet import MinkRewardModel,MinkGlobalEnc,MinkUNetDiff, MinkUNet 5 | import open3d as o3d 6 | from diffusers import DPMSolverMultistepScheduler 7 | from pytorch_lightning.core.lightning import LightningModule 8 | import yaml 9 | import os 10 | import tqdm 11 | from natsort import natsorted 12 | import click 13 | import time 14 | from os.path import join, dirname, abspath 15 | from os import environ, makedirs 16 | 17 | class DiffCompletion_DistillationDPO(LightningModule): 18 | def __init__(self, diff_path, refine_path, denoising_steps, uncond_w): 19 | super().__init__() 20 | 21 | # load diff net 22 | ckpt_diff = torch.load(diff_path) 23 | 24 | dm_weights = {k.replace('generator.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('generator.')} 25 | encoder_weights = {k.replace('partial_enc.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('partial_enc.')} 26 | 27 | # load encoder and model 28 | self.partial_enc = MinkGlobalEnc().cuda() 29 | self.partial_enc.load_state_dict(encoder_weights, strict=True) 30 | self.partial_enc.eval() 31 | 32 | self.model = MinkUNetDiff().cuda() 33 | self.model.load_state_dict(dm_weights, strict=True) 34 | self.model.eval() 35 | 36 | # load refiner 37 | ckpt_refine = torch.load(refine_path) 38 | refiner_weights = {k.replace('model_refine.', ''): v for k, v in ckpt_refine["state_dict"].items() if k.startswith('model_refine.')} 39 | self.model_refine = MinkUNet(in_channels=3, out_channels=3*6).cuda() 40 | self.model_refine.load_state_dict(refiner_weights, strict=True) 41 | self.model_refine.eval() 42 | 43 | self.cuda() 44 | 45 | # for fast sampling 46 | self.dpm_scheduler = DPMSolverMultistepScheduler( 47 | num_train_timesteps=1000, 48 | beta_start=3.5e-5, 49 | beta_end=0.007, 50 | beta_schedule='linear', 51 | algorithm_type='sde-dpmsolver++', 52 | solver_order=2, 53 | ) 54 | self.dpm_scheduler.set_timesteps(num_inference_steps=denoising_steps) 55 | self.scheduler_to_cuda() 56 | 57 | self.w_uncond = uncond_w 58 | 59 | def scheduler_to_cuda(self): 60 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.cuda() 61 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.cuda() 62 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.cuda() 63 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.cuda() 64 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.cuda() 65 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.cuda() 66 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.cuda() 67 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.cuda() 68 | 69 | def feats_to_coord(self, p_feats, resolution, mean=None, std=None): 70 | p_feats = p_feats.reshape(mean.shape[0],-1,3) 71 | p_coord = torch.round(p_feats / resolution) 72 | 73 | return p_coord.reshape(-1,3) 74 | 75 | def points_to_tensor(self, points): 76 | x_feats = ME.utils.batched_coordinates(list(points[:]), dtype=torch.float32, device=self.device) 77 | 78 | x_coord = x_feats.clone() 79 | x_coord = torch.round(x_coord / 0.05) 80 | 81 | x_t = ME.TensorField( 82 | features=x_feats[:,1:], 83 | coordinates=x_coord, 84 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, 85 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, 86 | device=self.device, 87 | ) 88 | 89 | torch.cuda.empty_cache() 90 | 91 | return x_t 92 | 93 | def reset_partial_pcd(self, x_part): 94 | x_part = self.points_to_tensor(x_part.F.reshape(1,-1,3).detach()) 95 | 96 | return x_part 97 | 98 | def preprocess_scan(self, scan): 99 | dist = np.sqrt(np.sum((scan)**2, -1)) 100 | scan = scan[(dist < 50.0) & (dist > 3.5)][:,:3] 101 | 102 | # use farthest point sampling 103 | pcd_scan = o3d.geometry.PointCloud() 104 | pcd_scan.points = o3d.utility.Vector3dVector(scan) 105 | pcd_scan = pcd_scan.farthest_point_down_sample(int(180000 / 10)) 106 | scan = torch.tensor(np.array(pcd_scan.points)).cuda() 107 | 108 | scan = scan.repeat(10,1) 109 | scan = scan[None,:,:] 110 | 111 | return scan 112 | 113 | 114 | def postprocess_scan(self, completed_scan, input_scan): 115 | dist = np.sqrt(np.sum((completed_scan)**2, -1)) 116 | post_scan = completed_scan[dist < 50.0] 117 | max_z = input_scan[...,2].max().item() 118 | min_z = (input_scan[...,2].mean() - 2 * input_scan[...,2].std()).item() 119 | 120 | post_scan = post_scan[(post_scan[:,2] < max_z) & (post_scan[:,2] > min_z)] 121 | 122 | return post_scan 123 | 124 | def complete_scan(self, pcd_part): 125 | pcd_part_rep = self.preprocess_scan(pcd_part).view(1,-1,3) 126 | # pcd_part = torch.tensor(pcd_part, device=self.device).view(1,-1,3) 127 | # print(f'pcd_part_rep.shape = {pcd_part_rep.shape}') 128 | # print(f'pcd_part.shape = {pcd_part.shape}') 129 | 130 | x_feats = pcd_part_rep + torch.randn(pcd_part_rep.shape, device=self.device) 131 | x_full = self.points_to_tensor(x_feats) # x_T 132 | x_cond = self.points_to_tensor(pcd_part_rep) # x_0 133 | x_uncond = self.points_to_tensor(torch.zeros_like(pcd_part_rep)) 134 | 135 | completed_scan = self.completion_loop(pcd_part_rep, x_full, x_cond, x_uncond) 136 | post_scan = self.postprocess_scan(completed_scan, pcd_part_rep) 137 | 138 | refine_in = self.points_to_tensor(post_scan[None,:,:]) 139 | offset = self.refine_forward(refine_in).reshape(-1,6,3) 140 | 141 | refine_complete_scan = post_scan[:,None,:] + offset.cpu().numpy() 142 | 143 | return refine_complete_scan.reshape(-1,3), post_scan 144 | 145 | 146 | def refine_forward(self, x_in): 147 | with torch.no_grad(): 148 | offset = self.model_refine(x_in) 149 | 150 | return offset 151 | 152 | def forward(self, x_full, x_full_sparse, x_part, t): 153 | with torch.no_grad(): 154 | part_feat = self.partial_enc(x_part) 155 | out = self.model(x_full, x_full_sparse, part_feat, t) 156 | 157 | torch.cuda.empty_cache() 158 | return out.reshape(t.shape[0],-1,3) 159 | 160 | def classfree_forward(self, x_t, x_cond, x_uncond, t): 161 | x_t_sparse = x_t.sparse() 162 | x_cond = self.forward(x_t, x_t_sparse, x_cond, t) 163 | x_uncond = self.forward(x_t, x_t_sparse, x_uncond, t) 164 | 165 | return x_uncond + self.w_uncond * (x_cond - x_uncond) 166 | 167 | def completion_loop(self, x_init, x_t, x_cond, x_uncond): 168 | self.scheduler_to_cuda() 169 | 170 | # for t in tqdm.tqdm(range(len(self.dpm_scheduler.timesteps))): 171 | for t in range(len(self.dpm_scheduler.timesteps)): 172 | t = self.dpm_scheduler.timesteps[t].cuda()[None] 173 | 174 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t) 175 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 176 | x_t = x_init + self.dpm_scheduler.step(noise_t, t, input_noise)['prev_sample'] 177 | x_t = self.points_to_tensor(x_t) 178 | 179 | x_cond = self.reset_partial_pcd(x_cond) 180 | torch.cuda.empty_cache() 181 | 182 | return x_t.F.cpu().detach().numpy() 183 | 184 | def load_pcd(pcd_file): 185 | if pcd_file.endswith('.bin'): 186 | return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3] 187 | elif pcd_file.endswith('.ply'): 188 | return np.array(o3d.io.read_point_cloud(pcd_file).points) 189 | else: 190 | print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)") 191 | 192 | class DiffCompletion_lidiff(LightningModule): 193 | def __init__(self, diff_path, refine_path, denoising_steps, uncond_w): 194 | super().__init__() 195 | 196 | # load diff net 197 | ckpt_diff = torch.load(diff_path) 198 | 199 | dm_weights = {k.replace('model.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('model.')} 200 | encoder_weights = {k.replace('partial_enc.', ''): v for k, v in ckpt_diff["state_dict"].items() if k.startswith('partial_enc.')} 201 | 202 | # load encoder and model 203 | self.partial_enc = MinkGlobalEnc().cuda() 204 | self.partial_enc.load_state_dict(encoder_weights, strict=True) 205 | self.partial_enc.eval() 206 | 207 | self.model = MinkUNetDiff().cuda() 208 | self.model.load_state_dict(dm_weights, strict=True) 209 | self.model.eval() 210 | 211 | # load refiner 212 | ckpt_refine = torch.load(refine_path) 213 | refiner_weights = {k.replace('model_refine.', ''): v for k, v in ckpt_refine["state_dict"].items() if k.startswith('model_refine.')} 214 | self.model_refine = MinkUNet(in_channels=3, out_channels=3*6).cuda() 215 | self.model_refine.load_state_dict(refiner_weights, strict=True) 216 | self.model_refine.eval() 217 | 218 | self.cuda() 219 | 220 | # for fast sampling 221 | self.dpm_scheduler = DPMSolverMultistepScheduler( 222 | num_train_timesteps=1000, 223 | beta_start=3.5e-5, 224 | beta_end=0.007, 225 | beta_schedule='linear', 226 | algorithm_type='sde-dpmsolver++', 227 | solver_order=2, 228 | ) 229 | self.dpm_scheduler.set_timesteps(num_inference_steps=denoising_steps) 230 | self.scheduler_to_cuda() 231 | 232 | self.w_uncond = uncond_w 233 | 234 | def scheduler_to_cuda(self): 235 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.cuda() 236 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.cuda() 237 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.cuda() 238 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.cuda() 239 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.cuda() 240 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.cuda() 241 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.cuda() 242 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.cuda() 243 | 244 | def feats_to_coord(self, p_feats, resolution, mean=None, std=None): 245 | p_feats = p_feats.reshape(mean.shape[0],-1,3) 246 | p_coord = torch.round(p_feats / resolution) 247 | 248 | return p_coord.reshape(-1,3) 249 | 250 | def points_to_tensor(self, points): 251 | x_feats = ME.utils.batched_coordinates(list(points[:]), dtype=torch.float32, device=self.device) 252 | 253 | x_coord = x_feats.clone() 254 | x_coord = torch.round(x_coord / 0.05) 255 | 256 | x_t = ME.TensorField( 257 | features=x_feats[:,1:], 258 | coordinates=x_coord, 259 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, 260 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, 261 | device=self.device, 262 | ) 263 | 264 | torch.cuda.empty_cache() 265 | 266 | return x_t 267 | 268 | def reset_partial_pcd(self, x_part): 269 | x_part = self.points_to_tensor(x_part.F.reshape(1,-1,3).detach()) 270 | 271 | return x_part 272 | 273 | def preprocess_scan(self, scan): 274 | dist = np.sqrt(np.sum((scan)**2, -1)) 275 | scan = scan[(dist < 50.0) & (dist > 3.5)][:,:3] 276 | 277 | # use farthest point sampling 278 | pcd_scan = o3d.geometry.PointCloud() 279 | pcd_scan.points = o3d.utility.Vector3dVector(scan) 280 | pcd_scan = pcd_scan.farthest_point_down_sample(int(180000 / 10)) 281 | scan = torch.tensor(np.array(pcd_scan.points)).cuda() 282 | 283 | scan = scan.repeat(10,1) 284 | scan = scan[None,:,:] 285 | 286 | return scan 287 | 288 | 289 | def postprocess_scan(self, completed_scan, input_scan): 290 | dist = np.sqrt(np.sum((completed_scan)**2, -1)) 291 | post_scan = completed_scan[dist < 50.0] 292 | max_z = input_scan[...,2].max().item() 293 | min_z = (input_scan[...,2].mean() - 2 * input_scan[...,2].std()).item() 294 | 295 | post_scan = post_scan[(post_scan[:,2] < max_z) & (post_scan[:,2] > min_z)] 296 | 297 | return post_scan 298 | 299 | def complete_scan(self, pcd_part): 300 | pcd_part_rep = self.preprocess_scan(pcd_part).view(1,-1,3) 301 | # pcd_part = torch.tensor(pcd_part, device=self.device).view(1,-1,3) 302 | # print(f'pcd_part_rep.shape = {pcd_part_rep.shape}') 303 | # print(f'pcd_part.shape = {pcd_part.shape}') 304 | 305 | x_feats = pcd_part_rep + torch.randn(pcd_part_rep.shape, device=self.device) 306 | x_full = self.points_to_tensor(x_feats) # x_T 307 | x_cond = self.points_to_tensor(pcd_part_rep) # x_0 308 | x_uncond = self.points_to_tensor(torch.zeros_like(pcd_part_rep)) 309 | 310 | completed_scan = self.completion_loop(pcd_part_rep, x_full, x_cond, x_uncond) 311 | post_scan = self.postprocess_scan(completed_scan, pcd_part_rep) 312 | 313 | refine_in = self.points_to_tensor(post_scan[None,:,:]) 314 | offset = self.refine_forward(refine_in).reshape(-1,6,3) 315 | 316 | refine_complete_scan = post_scan[:,None,:] + offset.cpu().numpy() 317 | 318 | return refine_complete_scan.reshape(-1,3), post_scan 319 | 320 | 321 | def refine_forward(self, x_in): 322 | with torch.no_grad(): 323 | offset = self.model_refine(x_in) 324 | 325 | return offset 326 | 327 | def forward(self, x_full, x_full_sparse, x_part, t): 328 | with torch.no_grad(): 329 | part_feat = self.partial_enc(x_part) 330 | out = self.model(x_full, x_full_sparse, part_feat, t) 331 | 332 | torch.cuda.empty_cache() 333 | return out.reshape(t.shape[0],-1,3) 334 | 335 | def classfree_forward(self, x_t, x_cond, x_uncond, t): 336 | x_t_sparse = x_t.sparse() 337 | x_cond = self.forward(x_t, x_t_sparse, x_cond, t) 338 | x_uncond = self.forward(x_t, x_t_sparse, x_uncond, t) 339 | 340 | return x_uncond + self.w_uncond * (x_cond - x_uncond) 341 | 342 | def completion_loop(self, x_init, x_t, x_cond, x_uncond): 343 | self.scheduler_to_cuda() 344 | 345 | # for t in tqdm.tqdm(range(len(self.dpm_scheduler.timesteps))): 346 | for t in range(len(self.dpm_scheduler.timesteps)): 347 | t = self.dpm_scheduler.timesteps[t].cuda()[None] 348 | 349 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t) 350 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 351 | x_t = x_init + self.dpm_scheduler.step(noise_t, t, input_noise)['prev_sample'] 352 | x_t = self.points_to_tensor(x_t) 353 | 354 | x_cond = self.reset_partial_pcd(x_cond) 355 | torch.cuda.empty_cache() 356 | 357 | return x_t.F.cpu().detach().numpy() 358 | 359 | def load_pcd(pcd_file): 360 | if pcd_file.endswith('.bin'): 361 | return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3] 362 | elif pcd_file.endswith('.ply'): 363 | return np.array(o3d.io.read_point_cloud(pcd_file).points) 364 | else: 365 | print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)") 366 | -------------------------------------------------------------------------------- /models/minkunet.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import MinkowskiEngine as ME 7 | import numpy as np 8 | from pykeops.torch import LazyTensor 9 | 10 | __all__ = ['MinkUNetDiff'] 11 | 12 | 13 | class BasicConvolutionBlock(nn.Module): 14 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 15 | super().__init__() 16 | self.net = nn.Sequential( 17 | ME.MinkowskiConvolution(inc, 18 | outc, 19 | kernel_size=ks, 20 | dilation=dilation, 21 | stride=stride, 22 | dimension=D), 23 | ME.MinkowskiBatchNorm(outc), 24 | ME.MinkowskiReLU(inplace=True) 25 | ) 26 | 27 | def forward(self, x): 28 | out = self.net(x) 29 | return out 30 | 31 | 32 | class BasicDeconvolutionBlock(nn.Module): 33 | def __init__(self, inc, outc, ks=3, stride=1, D=3): 34 | super().__init__() 35 | self.net = nn.Sequential( 36 | ME.MinkowskiConvolutionTranspose(inc, 37 | outc, 38 | kernel_size=ks, 39 | stride=stride, 40 | dimension=D), 41 | ME.MinkowskiBatchNorm(outc), 42 | ME.MinkowskiReLU(inplace=True) 43 | ) 44 | 45 | def forward(self, x): 46 | return self.net(x) 47 | 48 | 49 | class ResidualBlock(nn.Module): 50 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 51 | super().__init__() 52 | self.net = nn.Sequential( 53 | ME.MinkowskiConvolution(inc, 54 | outc, 55 | kernel_size=ks, 56 | dilation=dilation, 57 | stride=stride, 58 | dimension=D), 59 | ME.MinkowskiBatchNorm(outc), 60 | ME.MinkowskiReLU(inplace=True), 61 | ME.MinkowskiConvolution(outc, 62 | outc, 63 | kernel_size=ks, 64 | dilation=dilation, 65 | stride=1, 66 | dimension=D), 67 | ME.MinkowskiBatchNorm(outc) 68 | ) 69 | 70 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ 71 | nn.Sequential( 72 | ME.MinkowskiConvolution(inc, outc, kernel_size=1, dilation=1, stride=stride, dimension=D), 73 | ME.MinkowskiBatchNorm(outc) 74 | ) 75 | 76 | self.relu = ME.MinkowskiReLU(inplace=True) 77 | 78 | def forward(self, x): 79 | out = self.relu(self.net(x) + self.downsample(x)) 80 | return out 81 | 82 | 83 | class MinkGlobalEnc(nn.Module): 84 | def __init__(self, **kwargs): 85 | super().__init__() 86 | cr = kwargs.get('cr', 1.0) 87 | in_channels = kwargs.get('in_channels', 3) 88 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 89 | cs = [int(cr * x) for x in cs] 90 | self.embed_dim = cs[-1] 91 | self.run_up = kwargs.get('run_up', True) 92 | self.D = kwargs.get('D', 3) 93 | self.stem = nn.Sequential( 94 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D), 95 | ME.MinkowskiBatchNorm(cs[0]), 96 | ME.MinkowskiReLU(True), 97 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D), 98 | ME.MinkowskiBatchNorm(cs[0]), 99 | ME.MinkowskiReLU(inplace=True) 100 | ) 101 | 102 | self.stage1 = nn.Sequential( 103 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D), 104 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D), 105 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D), 106 | ) 107 | 108 | self.stage2 = nn.Sequential( 109 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D), 110 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D), 111 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D), 112 | ) 113 | 114 | self.stage3 = nn.Sequential( 115 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D), 116 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D), 117 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D), 118 | ) 119 | 120 | self.stage4 = nn.Sequential( 121 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D), 122 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D), 123 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D), 124 | ) 125 | 126 | self.weight_initialization() 127 | 128 | def weight_initialization(self): 129 | for m in self.modules(): 130 | if isinstance(m, nn.BatchNorm1d): 131 | nn.init.constant_(m.weight, 1) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | def forward(self, x): 135 | x0 = self.stem(x.sparse()) 136 | x1 = self.stage1(x0) 137 | x2 = self.stage2(x1) 138 | x3 = self.stage3(x2) 139 | x4 = self.stage4(x3) 140 | 141 | return x4 142 | 143 | 144 | class MinkUNetDiff(nn.Module): 145 | def __init__(self, **kwargs): 146 | super().__init__() 147 | 148 | cr = kwargs.get('cr', 1.0) 149 | in_channels = kwargs.get('in_channels', 3) 150 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 151 | cs = [int(cr * x) for x in cs] 152 | self.embed_dim = cs[-1] 153 | self.run_up = kwargs.get('run_up', True) 154 | self.D = kwargs.get('D', 3) 155 | self.stem = nn.Sequential( 156 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D), 157 | ME.MinkowskiBatchNorm(cs[0]), 158 | ME.MinkowskiReLU(True), 159 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D), 160 | ME.MinkowskiBatchNorm(cs[0]), 161 | ME.MinkowskiReLU(inplace=True) 162 | ) 163 | 164 | # Stage1 temp embed proj and conv 165 | self.latent_stage1 = nn.Sequential( 166 | nn.Linear(cs[4], cs[4]), 167 | nn.LeakyReLU(0.1, inplace=True), 168 | nn.Linear(cs[4], cs[4]), 169 | ) 170 | 171 | self.latemp_stage1 = nn.Sequential( 172 | nn.Linear(cs[4]+cs[4], cs[4]), 173 | nn.LeakyReLU(0.1, inplace=True), 174 | nn.Linear(cs[4], cs[0]), 175 | ) 176 | 177 | self.stage1_temp = nn.Sequential( 178 | nn.Linear(self.embed_dim, self.embed_dim), 179 | nn.LeakyReLU(0.1, inplace=True), 180 | nn.Linear(self.embed_dim, cs[4]), 181 | ) 182 | 183 | self.stage1 = nn.Sequential( 184 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D), 185 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D), 186 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D), 187 | ) 188 | 189 | # Stage2 temp embed proj and conv 190 | self.latent_stage2 = nn.Sequential( 191 | nn.Linear(cs[4], cs[4]), 192 | nn.LeakyReLU(0.1, inplace=True), 193 | nn.Linear(cs[4], cs[4]), 194 | ) 195 | 196 | self.latemp_stage2 = nn.Sequential( 197 | nn.Linear(cs[4]+cs[4], cs[4]), 198 | nn.LeakyReLU(0.1, inplace=True), 199 | nn.Linear(cs[4], cs[1]), 200 | ) 201 | 202 | self.stage2_temp = nn.Sequential( 203 | nn.Linear(self.embed_dim, self.embed_dim), 204 | nn.LeakyReLU(0.1, inplace=True), 205 | nn.Linear(self.embed_dim, cs[4]), 206 | ) 207 | 208 | self.stage2 = nn.Sequential( 209 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D), 210 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D), 211 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D) 212 | ) 213 | 214 | # Stage3 temp embed proj and conv 215 | self.latent_stage3 = nn.Sequential( 216 | nn.Linear(cs[4], cs[4]), 217 | nn.LeakyReLU(0.1, inplace=True), 218 | nn.Linear(cs[4], cs[4]), 219 | ) 220 | 221 | self.latemp_stage3 = nn.Sequential( 222 | nn.Linear(cs[4]+cs[4], cs[4]), 223 | nn.LeakyReLU(0.1, inplace=True), 224 | nn.Linear(cs[4], cs[2]), 225 | ) 226 | 227 | self.stage3_temp = nn.Sequential( 228 | nn.Linear(self.embed_dim, self.embed_dim), 229 | nn.LeakyReLU(0.1, inplace=True), 230 | nn.Linear(self.embed_dim, cs[4]), 231 | ) 232 | 233 | self.stage3 = nn.Sequential( 234 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D), 235 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D), 236 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D), 237 | ) 238 | 239 | # Stage4 temp embed proj and conv 240 | self.latent_stage4 = nn.Sequential( 241 | nn.Linear(cs[4], cs[4]), 242 | nn.LeakyReLU(0.1, inplace=True), 243 | nn.Linear(cs[4], cs[4]), 244 | ) 245 | 246 | self.latemp_stage4 = nn.Sequential( 247 | nn.Linear(cs[4]+cs[4], cs[4]), 248 | nn.LeakyReLU(0.1, inplace=True), 249 | nn.Linear(cs[4], cs[3]), 250 | ) 251 | 252 | self.stage4_temp = nn.Sequential( 253 | nn.Linear(self.embed_dim, self.embed_dim), 254 | nn.LeakyReLU(0.1, inplace=True), 255 | nn.Linear(self.embed_dim, cs[4]), 256 | ) 257 | 258 | self.stage4 = nn.Sequential( 259 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D), 260 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D), 261 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D), 262 | ) 263 | 264 | # Up1 temp embed proj and conv 265 | self.latent_up1 = nn.Sequential( 266 | nn.Linear(cs[4], cs[4]), 267 | nn.LeakyReLU(0.1, inplace=True), 268 | nn.Linear(cs[4], cs[4]), 269 | ) 270 | 271 | self.latemp_up1 = nn.Sequential( 272 | nn.Linear(cs[4]+cs[4], cs[4]), 273 | nn.LeakyReLU(0.1, inplace=True), 274 | nn.Linear(cs[4], cs[4]), 275 | ) 276 | 277 | self.up1_temp = nn.Sequential( 278 | nn.Linear(self.embed_dim, self.embed_dim), 279 | nn.LeakyReLU(0.1, inplace=True), 280 | nn.Linear(self.embed_dim, cs[4]), 281 | ) 282 | 283 | self.up1 = nn.ModuleList([ 284 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D), 285 | nn.Sequential( 286 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, 287 | dilation=1, D=self.D), 288 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D), 289 | ) 290 | ]) 291 | 292 | # Up2 temp embed proj and conv 293 | self.latent_up2 = nn.Sequential( 294 | nn.Linear(cs[4], cs[4]), 295 | nn.LeakyReLU(0.1, inplace=True), 296 | nn.Linear(cs[4], cs[4]), 297 | ) 298 | 299 | self.latemp_up2 = nn.Sequential( 300 | nn.Linear(cs[4]+cs[4], cs[5]), 301 | nn.LeakyReLU(0.1, inplace=True), 302 | nn.Linear(cs[5], cs[5]), 303 | ) 304 | 305 | self.up2_temp = nn.Sequential( 306 | nn.Linear(self.embed_dim, self.embed_dim), 307 | nn.LeakyReLU(0.1, inplace=True), 308 | nn.Linear(self.embed_dim, cs[4]), 309 | ) 310 | 311 | self.up2 = nn.ModuleList([ 312 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D), 313 | nn.Sequential( 314 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, 315 | dilation=1, D=self.D), 316 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D), 317 | ) 318 | ]) 319 | 320 | # Up3 temp embed proj and conv 321 | self.latent_up3 = nn.Sequential( 322 | nn.Linear(cs[4], cs[4]), 323 | nn.LeakyReLU(0.1, inplace=True), 324 | nn.Linear(cs[4], cs[4]), 325 | ) 326 | 327 | self.latemp_up3 = nn.Sequential( 328 | nn.Linear(cs[4]+cs[4], cs[6]), 329 | nn.LeakyReLU(0.1, inplace=True), 330 | nn.Linear(cs[6], cs[6]), 331 | ) 332 | 333 | self.up3_temp = nn.Sequential( 334 | nn.Linear(self.embed_dim, self.embed_dim), 335 | nn.LeakyReLU(0.1, inplace=True), 336 | nn.Linear(self.embed_dim, cs[4]), 337 | ) 338 | 339 | self.up3 = nn.ModuleList([ 340 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D), 341 | nn.Sequential( 342 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, 343 | dilation=1, D=self.D), 344 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D), 345 | ) 346 | ]) 347 | 348 | # Up4 temp embed proj and conv 349 | self.latent_up4 = nn.Sequential( 350 | nn.Linear(cs[4], cs[4]), 351 | nn.LeakyReLU(0.1, inplace=True), 352 | nn.Linear(cs[4], cs[4]), 353 | ) 354 | 355 | self.latemp_up4 = nn.Sequential( 356 | nn.Linear(cs[4]+cs[4], cs[7]), 357 | nn.LeakyReLU(0.1, inplace=True), 358 | nn.Linear(cs[7], cs[7]), 359 | ) 360 | 361 | self.up4_temp = nn.Sequential( 362 | nn.Linear(self.embed_dim, self.embed_dim), 363 | nn.LeakyReLU(0.1, inplace=True), 364 | nn.Linear(self.embed_dim, cs[4]), 365 | ) 366 | 367 | self.up4 = nn.ModuleList([ 368 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D), 369 | nn.Sequential( 370 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, 371 | dilation=1, D=self.D), 372 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D), 373 | ) 374 | ]) 375 | 376 | self.last = nn.Sequential( 377 | nn.Linear(cs[8], 20), 378 | nn.LeakyReLU(0.1, inplace=True), 379 | nn.Linear(20, 3), 380 | ) 381 | 382 | self.weight_initialization() 383 | 384 | def weight_initialization(self): 385 | for m in self.modules(): 386 | if isinstance(m, nn.BatchNorm1d): 387 | nn.init.constant_(m.weight, 1) 388 | nn.init.constant_(m.bias, 0) 389 | 390 | def get_timestep_embedding(self, timesteps): 391 | assert len(timesteps.shape) == 1 392 | 393 | half_dim = self.embed_dim // 2 394 | emb = np.log(10000) / (half_dim - 1) 395 | emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(torch.device('cuda')) 396 | emb = timesteps[:, None] * emb[None, :] 397 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 398 | if self.embed_dim % 2 == 1: # zero pad 399 | emb = nn.functional.pad(emb, (0, 1), "constant", 0) 400 | assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim]) 401 | return emb 402 | 403 | def match_part_to_full(self, x_full, x_part): 404 | full_c = x_full.C.clone().float() 405 | part_c = x_part.C.clone().float() 406 | 407 | # hash batch coord 408 | max_coord = full_c.max() 409 | full_c[:,0] *= max_coord * 2. 410 | part_c[:,0] *= max_coord * 2. 411 | 412 | f_coord = LazyTensor(full_c[:,None,:]) 413 | p_coord = LazyTensor(part_c[None,:,:]) 414 | 415 | dist_fp = ((f_coord - p_coord)**2).sum(-1) 416 | match_feats = dist_fp.argKmin(1,dim=1)[:,0] 417 | 418 | return x_part.F[match_feats] 419 | 420 | def forward(self, x, x_sparse, part_feats, t): 421 | temp_emb = self.get_timestep_embedding(t) 422 | 423 | x0 = self.stem(x_sparse) 424 | match0 = self.match_part_to_full(x0, part_feats) 425 | p0 = self.latent_stage1(match0) 426 | t0 = self.stage1_temp(temp_emb) 427 | batch_temp = torch.unique(x0.C[:,0], return_counts=True)[1] 428 | t0 = torch.repeat_interleave(t0, batch_temp, dim=0) 429 | w0 = self.latemp_stage1(torch.cat((p0,t0),-1)) 430 | 431 | x1 = self.stage1(x0*w0) 432 | match1 = self.match_part_to_full(x1, part_feats) 433 | p1 = self.latent_stage2(match1) 434 | t1 = self.stage2_temp(temp_emb) 435 | batch_temp = torch.unique(x1.C[:,0], return_counts=True)[1] 436 | t1 = torch.repeat_interleave(t1, batch_temp, dim=0) 437 | w1 = self.latemp_stage2(torch.cat((p1,t1),-1)) 438 | 439 | x2 = self.stage2(x1*w1) 440 | match2 = self.match_part_to_full(x2, part_feats) 441 | p2 = self.latent_stage3(match2) 442 | t2 = self.stage3_temp(temp_emb) 443 | batch_temp = torch.unique(x2.C[:,0], return_counts=True)[1] 444 | t2 = torch.repeat_interleave(t2, batch_temp, dim=0) 445 | w2 = self.latemp_stage3(torch.cat((p2,t2),-1)) 446 | 447 | x3 = self.stage3(x2*w2) 448 | match3 = self.match_part_to_full(x3, part_feats) 449 | p3 = self.latent_stage4(match3) 450 | t3 = self.stage4_temp(temp_emb) 451 | batch_temp = torch.unique(x3.C[:,0], return_counts=True)[1] 452 | t3 = torch.repeat_interleave(t3, batch_temp, dim=0) 453 | w3 = self.latemp_stage4(torch.cat((p3,t3),-1)) 454 | 455 | x4 = self.stage4(x3*w3) 456 | match4 = self.match_part_to_full(x4, part_feats) 457 | p4 = self.latent_up1(match4) 458 | t4 = self.up1_temp(temp_emb) 459 | batch_temp = torch.unique(x4.C[:,0], return_counts=True)[1] 460 | t4 = torch.repeat_interleave(t4, batch_temp, dim=0) 461 | w4 = self.latemp_up1(torch.cat((t4,p4),-1)) 462 | 463 | y1 = self.up1[0](x4*w4) 464 | y1 = ME.cat(y1, x3) 465 | y1 = self.up1[1](y1) 466 | match5 = self.match_part_to_full(y1, part_feats) 467 | p5 = self.latent_up2(match5) 468 | t5 = self.up2_temp(temp_emb) 469 | batch_temp = torch.unique(y1.C[:,0], return_counts=True)[1] 470 | t5 = torch.repeat_interleave(t5, batch_temp, dim=0) 471 | w5 = self.latemp_up2(torch.cat((p5,t5),-1)) 472 | 473 | y2 = self.up2[0](y1*w5) 474 | y2 = ME.cat(y2, x2) 475 | y2 = self.up2[1](y2) 476 | match6 = self.match_part_to_full(y2, part_feats) 477 | p6 = self.latent_up3(match6) 478 | t6 = self.up3_temp(temp_emb) 479 | batch_temp = torch.unique(y2.C[:,0], return_counts=True)[1] 480 | t6 = torch.repeat_interleave(t6, batch_temp, dim=0) 481 | w6 = self.latemp_up3(torch.cat((p6,t6),-1)) 482 | 483 | y3 = self.up3[0](y2*w6) 484 | y3 = ME.cat(y3, x1) 485 | y3 = self.up3[1](y3) 486 | match7 = self.match_part_to_full(y3, part_feats) 487 | p7 = self.latent_up4(match7) 488 | t7 = self.up4_temp(temp_emb) 489 | batch_temp = torch.unique(y3.C[:,0], return_counts=True)[1] 490 | t7 = torch.repeat_interleave(t7, batch_temp, dim=0) 491 | w7 = self.latemp_up4(torch.cat((p7,t7),-1)) 492 | 493 | y4 = self.up4[0](y3*w7) 494 | y4 = ME.cat(y4, x0) 495 | y4 = self.up4[1](y4) 496 | 497 | return self.last(y4.slice(x).F) 498 | 499 | class MinkUNet(nn.Module): 500 | def __init__(self, **kwargs): 501 | super().__init__() 502 | 503 | cr = kwargs.get('cr', 1.0) 504 | in_channels = kwargs.get('in_channels', 3) 505 | out_channels = kwargs.get('out_channels', 3) 506 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 507 | cs = [int(cr * x) for x in cs] 508 | self.run_up = kwargs.get('run_up', True) 509 | self.D = kwargs.get('D', 3) 510 | self.stem = nn.Sequential( 511 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D), 512 | ME.MinkowskiBatchNorm(cs[0]), 513 | ME.MinkowskiReLU(True), 514 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D), 515 | ME.MinkowskiBatchNorm(cs[0]), 516 | ME.MinkowskiReLU(inplace=True) 517 | ) 518 | 519 | self.stage1 = nn.Sequential( 520 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D), 521 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D), 522 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D), 523 | ) 524 | 525 | self.stage2 = nn.Sequential( 526 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D), 527 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D), 528 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D) 529 | ) 530 | 531 | self.stage3 = nn.Sequential( 532 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D), 533 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D), 534 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D), 535 | ) 536 | 537 | self.stage4 = nn.Sequential( 538 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D), 539 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D), 540 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D), 541 | ) 542 | 543 | self.up1 = nn.ModuleList([ 544 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D), 545 | nn.Sequential( 546 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, 547 | dilation=1, D=self.D), 548 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D), 549 | ) 550 | ]) 551 | 552 | self.up2 = nn.ModuleList([ 553 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D), 554 | nn.Sequential( 555 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, 556 | dilation=1, D=self.D), 557 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D), 558 | ) 559 | ]) 560 | 561 | self.up3 = nn.ModuleList([ 562 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D), 563 | nn.Sequential( 564 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, 565 | dilation=1, D=self.D), 566 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D), 567 | ) 568 | ]) 569 | 570 | self.up4 = nn.ModuleList([ 571 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D), 572 | nn.Sequential( 573 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, 574 | dilation=1, D=self.D), 575 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D), 576 | ) 577 | ]) 578 | 579 | self.last = nn.Sequential( 580 | nn.Linear(cs[8], 20), 581 | nn.LeakyReLU(0.1, inplace=True), 582 | nn.Linear(20, out_channels), 583 | nn.Tanh(), 584 | ) 585 | 586 | self.weight_initialization() 587 | self.dropout = nn.Dropout(0.3, True) 588 | 589 | def weight_initialization(self): 590 | for m in self.modules(): 591 | if isinstance(m, nn.BatchNorm1d): 592 | nn.init.constant_(m.weight, 1) 593 | nn.init.constant_(m.bias, 0) 594 | 595 | def forward(self, x): 596 | x0 = self.stem(x.sparse()) 597 | x1 = self.stage1(x0) 598 | x2 = self.stage2(x1) 599 | x3 = self.stage3(x2) 600 | x4 = self.stage4(x3) 601 | 602 | y1 = self.up1[0](x4) 603 | y1 = ME.cat(y1, x3) 604 | y1 = self.up1[1](y1) 605 | 606 | y2 = self.up2[0](y1) 607 | y2 = ME.cat(y2, x2) 608 | y2 = self.up2[1](y2) 609 | 610 | y3 = self.up3[0](y2) 611 | y3 = ME.cat(y3, x1) 612 | y3 = self.up3[1](y3) 613 | 614 | y4 = self.up4[0](y3) 615 | y4 = ME.cat(y4, x0) 616 | y4 = self.up4[1](y4) 617 | 618 | return self.last(y4.slice(x).F) 619 | 620 | 621 | 622 | -------------------------------------------------------------------------------- /trains/DistillationDPO.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import MinkowskiEngine as ME 6 | import open3d as o3d 7 | import datetime 8 | from tqdm import tqdm 9 | from os import makedirs, path 10 | import os 11 | import copy 12 | 13 | from pytorch_lightning.core.lightning import LightningModule 14 | from pytorch_lightning import Trainer 15 | from pytorch_lightning import loggers as pl_loggers 16 | from pytorch_lightning.callbacks import ModelCheckpoint 17 | from diffusers import DPMSolverMultistepScheduler 18 | 19 | from utils.collations import * 20 | from utils.metrics import ChamferDistance, PrecisionRecall 21 | from utils.scheduling import beta_func 22 | from utils.metrics import ChamferDistance, PrecisionRecall, CompletionIoU, RMSE, EMD 23 | from utils.histogram_metrics import compute_hist_metrics 24 | from models.minkunet import MinkRewardModel,MinkGlobalEnc,MinkUNetDiff 25 | import datasets.SemanticKITTI_dataset as SemanticKITTI_dataset 26 | 27 | class DistillationDPO(LightningModule): 28 | def __init__(self, args): 29 | super().__init__() 30 | 31 | # configs 32 | self.lr = args.lr 33 | self.timestamp = args.timestamp 34 | self.args = args 35 | self.w_uncond = 3.5 36 | 37 | # Load pre-trained DM weights, init reference model 38 | dm_ckpt = torch.load(args.pre_trained_diff_path) 39 | # dm_weights = {k.replace('model.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('model.')} 40 | dm_weights = {k.replace('DiffModel.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('DiffModel.')} 41 | generator_weights = copy.deepcopy(dm_weights) 42 | auxDiffBetter_weights = copy.deepcopy(dm_weights) 43 | auxDiffWorse_weights = copy.deepcopy(dm_weights) 44 | teacher_weights = dm_weights 45 | # encoder_weights = {k.replace('partial_enc.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('partial_enc.')} 46 | encoder_weights = {k.replace('DM_encoder.', ''): v for k, v in dm_ckpt["state_dict"].items() if k.startswith('DM_encoder.')} 47 | 48 | self.partial_enc = MinkGlobalEnc() 49 | self.generator = MinkUNetDiff() 50 | self.teacher = MinkUNetDiff() 51 | self.auxDiffBetter = MinkUNetDiff() 52 | self.auxDiffWorse = MinkUNetDiff() 53 | self.partial_enc.load_state_dict(encoder_weights, strict=True) 54 | self.generator.load_state_dict(generator_weights, strict=True) 55 | self.teacher.load_state_dict(teacher_weights, strict=True) 56 | self.auxDiffBetter.load_state_dict(auxDiffBetter_weights, strict=True) 57 | self.auxDiffWorse.load_state_dict(auxDiffWorse_weights, strict=True) 58 | 59 | for param in self.partial_enc.parameters(): 60 | param.requires_grad = False 61 | for param in self.generator.parameters(): 62 | param.requires_grad = True 63 | for param in self.teacher.parameters(): 64 | param.requires_grad = False 65 | for param in self.auxDiffBetter.parameters(): 66 | param.requires_grad = True 67 | for param in self.auxDiffWorse.parameters(): 68 | param.requires_grad = True 69 | 70 | self.partial_enc.eval() 71 | 72 | # init scheduler for DM 73 | self.betas = beta_func['linear'](1000, 3.5e-5, 0.007) 74 | 75 | self.t_steps = 1000 76 | self.s_steps = 1 77 | self.s_steps_val = 8 78 | 79 | self.alphas = 1. - self.betas 80 | self.alphas_cumprod = torch.tensor( 81 | np.cumprod(self.alphas, axis=0), dtype=torch.float32, device=self.device 82 | ) 83 | 84 | self.alphas_cumprod_prev = torch.tensor( 85 | np.append(1., self.alphas_cumprod[:-1].cpu().numpy()), dtype=torch.float32, device=self.device 86 | ) 87 | 88 | self.betas = torch.tensor(self.betas, device=self.device) 89 | self.alphas = torch.tensor(self.alphas, device=self.device) 90 | 91 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 92 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 93 | self.log_one_minus_alphas_cumprod = torch.log(1. - self.alphas_cumprod) 94 | self.sqrt_recip_alphas = torch.sqrt(1. / self.alphas) 95 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod) 96 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1.) 97 | 98 | self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 99 | self.sqrt_posterior_variance = torch.sqrt(self.posterior_variance) 100 | self.posterior_log_var = torch.log( 101 | torch.max(self.posterior_variance, 1e-20 * torch.ones_like(self.posterior_variance)) 102 | ) 103 | 104 | self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) 105 | self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod) 106 | 107 | self.dpm_scheduler = DPMSolverMultistepScheduler( 108 | num_train_timesteps=self.t_steps, 109 | beta_start=3.5e-5, 110 | beta_end=0.007, 111 | beta_schedule='linear', 112 | algorithm_type='sde-dpmsolver++', 113 | solver_order=2, 114 | ) 115 | self.dpm_scheduler.set_timesteps(num_inference_steps=self.s_steps) 116 | self.dpm_scheduler_val = DPMSolverMultistepScheduler( 117 | num_train_timesteps=self.t_steps, 118 | beta_start=3.5e-5, 119 | beta_end=0.007, 120 | beta_schedule='linear', 121 | algorithm_type='sde-dpmsolver++', 122 | solver_order=2, 123 | ) 124 | self.dpm_scheduler_val.set_timesteps(num_inference_steps=self.s_steps_val) 125 | self.scheduler_to_cuda() 126 | 127 | # metrcis for validation 128 | self.chamfer_distance = ChamferDistance() 129 | self.precision_recall = PrecisionRecall(0.05 ,2*0.05, 100) 130 | self.completion_iou = CompletionIoU() 131 | 132 | def scheduler_to_cuda(self): 133 | self.dpm_scheduler.timesteps = self.dpm_scheduler.timesteps.to(self.device) 134 | self.dpm_scheduler.betas = self.dpm_scheduler.betas.to(self.device) 135 | self.dpm_scheduler.alphas = self.dpm_scheduler.alphas.to(self.device) 136 | self.dpm_scheduler.alphas_cumprod = self.dpm_scheduler.alphas_cumprod.to(self.device) 137 | self.dpm_scheduler.alpha_t = self.dpm_scheduler.alpha_t.to(self.device) 138 | self.dpm_scheduler.sigma_t = self.dpm_scheduler.sigma_t.to(self.device) 139 | self.dpm_scheduler.lambda_t = self.dpm_scheduler.lambda_t.to(self.device) 140 | self.dpm_scheduler.sigmas = self.dpm_scheduler.sigmas.to(self.device) 141 | 142 | self.dpm_scheduler_val.timesteps = self.dpm_scheduler_val.timesteps.to(self.device) 143 | self.dpm_scheduler_val.betas = self.dpm_scheduler_val.betas.to(self.device) 144 | self.dpm_scheduler_val.alphas = self.dpm_scheduler_val.alphas.to(self.device) 145 | self.dpm_scheduler_val.alphas_cumprod = self.dpm_scheduler_val.alphas_cumprod.to(self.device) 146 | self.dpm_scheduler_val.alpha_t = self.dpm_scheduler_val.alpha_t.to(self.device) 147 | self.dpm_scheduler_val.sigma_t = self.dpm_scheduler_val.sigma_t.to(self.device) 148 | self.dpm_scheduler_val.lambda_t = self.dpm_scheduler_val.lambda_t.to(self.device) 149 | self.dpm_scheduler_val.sigmas = self.dpm_scheduler_val.sigmas.to(self.device) 150 | 151 | def q_sample(self, x, t, noise): 152 | return self.sqrt_alphas_cumprod[t][:,None,None].to(self.device) * x + \ 153 | self.sqrt_one_minus_alphas_cumprod[t][:,None,None].to(self.device) * noise 154 | 155 | def reset_partial_pcd(self, x_part, x_uncond, x_mean, x_std): 156 | x_part = self.points_to_tensor(x_part.F.reshape(x_mean.shape[0],-1,3).detach(), x_mean, x_std) 157 | x_uncond = self.points_to_tensor( 158 | torch.zeros_like(x_part.F.reshape(x_mean.shape[0],-1,3)), torch.zeros_like(x_mean), torch.zeros_like(x_std) 159 | ) 160 | 161 | return x_part, x_uncond 162 | 163 | def reset_partial_pcd_part(self, x_part, x_mean, x_std): 164 | x_part = self.points_to_tensor(x_part.F.reshape(x_mean.shape[0],-1,3).detach(), x_mean, x_std) 165 | 166 | return x_part 167 | 168 | def p_sample_loop(self, x_init, x_t, x_cond, x_uncond, gt_pts, x_mean, x_std): 169 | pcd = o3d.geometry.PointCloud() 170 | self.scheduler_to_cuda() 171 | 172 | for t in tqdm(range(len(self.dpm_scheduler.timesteps))): 173 | t = torch.ones(gt_pts.shape[0]).to(self.device).long() * self.dpm_scheduler.timesteps[t].to(self.device) 174 | 175 | noise_t = self.classfree_forward(x_t, x_cond, x_uncond, t) 176 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 177 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample'] 178 | x_t = self.points_to_tensor(x_t, x_mean, x_std) 179 | 180 | # this is needed otherwise minkEngine will keep "stacking" coords maps over the x_part and x_uncond 181 | # i.e. memory leak 182 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std) 183 | torch.cuda.empty_cache() 184 | 185 | makedirs(f'{self.logger.log_dir}/generated_pcd/', exist_ok=True) 186 | 187 | return x_t 188 | 189 | def p_sample_with_final_step_grad(self, x_init, x_t, x_cond, x_mean, x_std): 190 | 191 | assert len(self.dpm_scheduler.timesteps) == self.s_steps 192 | 193 | with torch.no_grad(): 194 | for t in range(len(self.dpm_scheduler.timesteps)-1): 195 | t = torch.full((x_init.shape[0],), self.dpm_scheduler.timesteps[t]).to(self.device) 196 | 197 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t) 198 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 199 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample'] 200 | x_t = self.points_to_tensor(x_t, x_mean, x_std) 201 | 202 | t = torch.full((x_init.shape[0],), self.dpm_scheduler.timesteps[-1]).to(self.device) 203 | 204 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t) 205 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 206 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample'] 207 | x_t = self.points_to_tensor(x_t, x_mean, x_std) 208 | 209 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std) 210 | torch.cuda.empty_cache() 211 | 212 | return x_t 213 | 214 | def p_sample_one_step(self, x_init, x_t, x_cond, t, x_mean, x_std): 215 | 216 | noise_t = self.forward_generator(x_t, x_t.sparse(), x_cond, t) 217 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 218 | x_t = x_init + self.dpm_scheduler.step(noise_t, t[0], input_noise)['prev_sample'] 219 | x_t = self.points_to_tensor(x_t, x_mean, x_std) 220 | 221 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std) 222 | torch.cuda.empty_cache() 223 | 224 | return x_t 225 | 226 | def pred_noise(self, model, sparse_TF, x_t_TF, t:torch.Tensor): 227 | 228 | condition = self.DM_encoder(sparse_TF) 229 | pred_noise = model(x_t_TF, x_t_TF.sparse(), condition, t) 230 | 231 | torch.cuda.empty_cache() 232 | return pred_noise 233 | 234 | 235 | def sample_val(self, x_init, x_t, x_cond, x_uncond, x_mean, x_std): 236 | 237 | assert len(self.dpm_scheduler_val.timesteps) == self.s_steps_val 238 | 239 | for t in range(len(self.dpm_scheduler_val.timesteps)): 240 | t = torch.full((x_init.shape[0],), self.dpm_scheduler_val.timesteps[t]).to(self.device) 241 | 242 | noise_t = self.classfree_forward_generator(x_t, x_cond, x_uncond, t) 243 | input_noise = x_t.F.reshape(t.shape[0],-1,3) - x_init 244 | x_t = x_init + self.dpm_scheduler_val.step(noise_t, t[0], input_noise)['prev_sample'] 245 | x_t = self.points_to_tensor(x_t, x_mean, x_std) 246 | 247 | x_cond = self.reset_partial_pcd_part(x_cond, x_mean, x_std) 248 | torch.cuda.empty_cache() 249 | 250 | return x_t 251 | 252 | def do_forward(self, model, x_full, x_full_sparse, x_part, t): 253 | part_feat = self.partial_enc(x_part) 254 | out = model(x_full, x_full_sparse, part_feat, t) 255 | torch.cuda.empty_cache() 256 | return out.reshape(t.shape[0],-1,3) 257 | 258 | def forward_generator(self, x_full, x_full_sparse, x_part, t): 259 | part_feat = self.partial_enc(x_part) 260 | out = self.generator(x_full, x_full_sparse, part_feat, t) 261 | torch.cuda.empty_cache() 262 | return out.reshape(t.shape[0],-1,3) 263 | 264 | def classfree_forward_generator(self, x_t, x_cond, x_uncond, t): 265 | x_t_sparse = x_t.sparse() 266 | x_cond = self.forward_generator(x_t, x_t_sparse, x_cond, t) 267 | x_uncond = self.forward_generator(x_t, x_t_sparse, x_uncond, t) 268 | 269 | return x_uncond + self.w_uncond * (x_cond - x_uncond) 270 | 271 | def feats_to_coord(self, p_feats, resolution, batch_size): 272 | p_feats = p_feats.reshape(batch_size,-1,3) 273 | p_coord = torch.round(p_feats / resolution) 274 | 275 | return p_coord.reshape(-1,3) 276 | 277 | def points_to_tensor(self, x_feats, mean=None, std=None): 278 | if mean is None: 279 | batch_size = x_feats.shape[0] 280 | x_feats = ME.utils.batched_coordinates(list(x_feats[:]), dtype=torch.float32, device=self.device) 281 | 282 | x_coord = x_feats.clone() 283 | x_coord[:,1:] = self.feats_to_coord(x_feats[:,1:], 0.05, batch_size) 284 | 285 | x_t = ME.TensorField( 286 | features=x_feats[:,1:], 287 | coordinates=x_coord, 288 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, 289 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, 290 | device=self.device, 291 | ) 292 | 293 | torch.cuda.empty_cache() 294 | 295 | return x_t 296 | 297 | else: 298 | x_feats = ME.utils.batched_coordinates(list(x_feats[:]), dtype=torch.float32, device=self.device) 299 | 300 | x_coord = x_feats.clone() 301 | x_coord[:,1:] = feats_to_coord(x_feats[:,1:], 0.05, mean.shape[0]) 302 | 303 | x_t = ME.TensorField( 304 | features=x_feats[:,1:], 305 | coordinates=x_coord, 306 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, 307 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, 308 | device=self.device, 309 | ) 310 | 311 | torch.cuda.empty_cache() 312 | 313 | return x_t 314 | 315 | def point_set_to_sparse(self, p_full, n_part): 316 | p_full = p_full[0].cpu().detach().numpy() 317 | 318 | dist_part = np.sum(p_full**2, -1)**.5 319 | p_full = p_full[(dist_part < 50.) & (dist_part > 3.5)] 320 | p_full = p_full[p_full[:,2] > -4.] 321 | 322 | pcd_part = o3d.geometry.PointCloud() 323 | pcd_part.points = o3d.utility.Vector3dVector(p_full) 324 | 325 | pcd_part = pcd_part.farthest_point_down_sample(n_part) 326 | p_part = torch.tensor(np.array(pcd_part.points)) 327 | 328 | return p_part 329 | 330 | def calc_cd(self, pred:torch.Tensor, gt:torch.Tensor): 331 | 332 | assert pred.shape[0] == 1 333 | assert gt.shape[0] == 1 334 | 335 | pcd_pred = o3d.geometry.PointCloud() 336 | pcd_pred.points = o3d.utility.Vector3dVector(pred[0].cpu().detach().numpy()) 337 | pcd_gt = o3d.geometry.PointCloud() 338 | pcd_gt.points = o3d.utility.Vector3dVector(gt[0].cpu().detach().numpy()) 339 | 340 | dist_pt_2_gt = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt)) 341 | dist_gt_2_pt = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred)) 342 | 343 | cd = (np.mean(dist_gt_2_pt) + np.mean(dist_pt_2_gt)) / 2 344 | 345 | return cd 346 | 347 | def calc_jsd(self, pred:torch.Tensor, gt:torch.Tensor): 348 | 349 | assert pred.shape[0] == 1 350 | assert gt.shape[0] == 1 351 | 352 | pcd_pred = o3d.geometry.PointCloud() 353 | pcd_pred.points = o3d.utility.Vector3dVector(pred[0].cpu().detach().numpy()) 354 | pcd_gt = o3d.geometry.PointCloud() 355 | pcd_gt.points = o3d.utility.Vector3dVector(gt[0].cpu().detach().numpy()) 356 | 357 | jsd = compute_hist_metrics(pcd_gt, pcd_pred, bev=False) 358 | 359 | return jsd 360 | 361 | def training_step(self, batch:dict, batch_idx, optimizer_idx): 362 | 363 | # vars to be used 364 | partial_tensor = batch['pcd_part'] 365 | partial_x10_tensor = partial_tensor.repeat(1,10,1) 366 | gt_tensor = batch['pcd_full'] 367 | B = batch['pcd_part'].shape[0] 368 | partial_TF = self.points_to_tensor(partial_tensor) 369 | 370 | if optimizer_idx == 0: # train auxDiffs 371 | 372 | # generate a batch of better & worse samples from Generator 373 | with torch.no_grad(): 374 | self.generator.eval() 375 | t_gen = torch.randint(0, self.t_steps, size=(B,)).to(self.device) 376 | noise = torch.randn(partial_x10_tensor.shape, device=self.device) 377 | 378 | # better 379 | partial_x10_better_noised_tensor = partial_x10_tensor + noise 380 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor) 381 | generated_sample1 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3) 382 | 383 | 384 | # worse 385 | partial_x10_better_noised_tensor = partial_x10_tensor + 1.1*noise 386 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor) 387 | generated_sample2 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3) 388 | 389 | # calc cd 390 | cd1 = self.calc_cd(generated_sample1, gt_tensor).item() 391 | cd2 = self.calc_cd(generated_sample2, gt_tensor).item() 392 | 393 | # compare 394 | generated_better_sample = generated_sample1 if cd1 < cd2 else generated_sample2 395 | generated_worse_sample = generated_sample2 if cd1 < cd2 else generated_sample1 396 | cd_better = cd1 if cd1 < cd2 else cd2 397 | cd_worse = cd2 if cd1 < cd2 else cd1 398 | switch = cd1 > cd2 399 | self.log('cd/better', cd_better, prog_bar=False) 400 | self.log('cd/worse', cd_worse, prog_bar=False) 401 | self.log('cd/switch', switch, prog_bar=False) 402 | 403 | # add noise to generated samples 404 | with torch.no_grad(): 405 | # better 406 | t_better = torch.randint(0, self.t_steps, size=(B,)).to(self.device) 407 | noise_better = torch.randn(generated_better_sample.shape, device=self.device) 408 | generated_better_noised_sample_tensor = generated_better_sample + self.q_sample(torch.zeros_like(generated_better_sample), 409 | t_better, noise_better) 410 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor) 411 | 412 | # worse 413 | t_worse = torch.randint(0, self.t_steps, size=(B,)).to(self.device) 414 | noise_worse = torch.randn(generated_worse_sample.shape, device=self.device) 415 | generated_worse_noised_sample_tensor = generated_worse_sample + self.q_sample(torch.zeros_like(generated_worse_sample), 416 | t_worse, noise_worse) 417 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor) 418 | 419 | # denoise generated samples with auxDiffs 420 | self.auxDiffBetter.train() 421 | pred_noise_auxB = self.do_forward(self.auxDiffBetter, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(), 422 | partial_TF, t_better) 423 | self.auxDiffWorse.train() 424 | pred_noise_auxW = self.do_forward(self.auxDiffWorse, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(), 425 | partial_TF, t_worse) 426 | 427 | # calculate loss 428 | auxDiffBetter_loss = F.mse_loss(noise_better, pred_noise_auxB) 429 | auxDiffWorse_loss = F.mse_loss(noise_worse, pred_noise_auxW) 430 | loss_aux = auxDiffBetter_loss + auxDiffWorse_loss 431 | 432 | # log info on progress bar 433 | self.log('loss_auxB', auxDiffBetter_loss, prog_bar=True) 434 | self.log('loss_auxW', auxDiffWorse_loss, prog_bar=True) 435 | 436 | torch.cuda.empty_cache() 437 | 438 | return loss_aux 439 | 440 | if optimizer_idx == 1: # train generator 441 | 442 | # get a batch of sample from Generator 443 | self.generator.train() 444 | t_gen = torch.randint(0, self.t_steps, size=(B,)).to(self.device) 445 | noise = torch.randn(partial_x10_tensor.shape, device=self.device) 446 | 447 | partial_x10_better_noised_tensor = partial_x10_tensor + noise 448 | partial_x10_better_noised_TF = self.points_to_tensor(partial_x10_better_noised_tensor) 449 | generated_sample1 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_better_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3) 450 | 451 | partial_x10_worse_noised_tensor = partial_x10_tensor + 1.1*noise 452 | partial_x10_worse_noised_TF = self.points_to_tensor(partial_x10_worse_noised_tensor) 453 | generated_sample2 = self.p_sample_with_final_step_grad(partial_x10_tensor, partial_x10_worse_noised_TF, partial_TF, batch['mean'], batch['std']).F.reshape(B,-1,3) 454 | 455 | cd1 = self.calc_cd(generated_sample1, gt_tensor).item() 456 | cd2 = self.calc_cd(generated_sample2, gt_tensor).item() 457 | 458 | generated_better_sample = generated_sample1 if cd1 < cd2 else generated_sample2 459 | generated_worse_sample = generated_sample2 if cd1 < cd2 else generated_sample1 460 | 461 | # add noise to generated samples 462 | t = torch.randint(0, self.t_steps, size=(B,)).to(self.device) 463 | noise = torch.randn(generated_better_sample.shape, device=self.device) 464 | generated_better_noised_sample_tensor = generated_better_sample + self.q_sample(torch.zeros_like(generated_better_sample), t, noise) 465 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor) 466 | 467 | generated_worse_noised_sample_tensor = generated_worse_sample + self.q_sample(torch.zeros_like(generated_worse_sample), t, noise) 468 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor) 469 | 470 | # denoise generated samples with axudiffs and teacher, respectively 471 | with torch.no_grad(): 472 | self.auxDiffBetter.eval() 473 | noise_auxB = self.do_forward(self.auxDiffBetter, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(), partial_TF, t) 474 | 475 | generated_better_noised_sample_TF = self.points_to_tensor(generated_better_noised_sample_tensor) 476 | self.auxDiffWorse.eval() 477 | noise_auxW = self.do_forward(self.auxDiffWorse, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(), partial_TF, t) 478 | 479 | generated_worse_noised_sample_TF = self.points_to_tensor(generated_worse_noised_sample_tensor) 480 | self.teacher.eval() 481 | noise_better_teacher = self.do_forward(self.teacher, generated_better_noised_sample_TF, generated_better_noised_sample_TF.sparse(), partial_TF, t) 482 | noise_worse_teacher = self.do_forward(self.teacher, generated_worse_noised_sample_TF, generated_worse_noised_sample_TF.sparse(), partial_TF, t) 483 | 484 | 485 | distil_loss = ((noise_worse_teacher - noise_auxW) * (generated_worse_noised_sample_tensor - generated_worse_sample)).mean() \ 486 | - ((noise_better_teacher - noise_auxB) * (generated_better_noised_sample_tensor - generated_better_sample)).mean() 487 | 488 | generator_loss = distil_loss 489 | 490 | # log info on progress bar 491 | self.log('loss_g', generator_loss, prog_bar=True) 492 | 493 | torch.cuda.empty_cache() 494 | 495 | return generator_loss 496 | 497 | def configure_optimizers(self): 498 | 499 | optimizer_aux = torch.optim.SGD(list(self.auxDiffBetter.parameters())+list(self.auxDiffWorse.parameters()), lr=self.args.lr) 500 | optimizer_g = torch.optim.SGD(self.generator.parameters(), lr=self.args.lr) 501 | 502 | from torch.optim.lr_scheduler import StepLR 503 | scheduler_aux = StepLR(optimizer_aux, step_size=1, gamma=0.999) 504 | scheduler_g = StepLR(optimizer_g, step_size=1, gamma=0.999) 505 | 506 | return [optimizer_aux, optimizer_g], [scheduler_aux, scheduler_g] 507 | 508 | def validation_step(self, batch:dict, batch_idx): 509 | 510 | with torch.no_grad(): 511 | gt_pts = batch['pcd_full'].detach().cpu().numpy() 512 | 513 | # for inference we get the partial pcd and sample the noise around the partial 514 | x_init = batch['pcd_part'].repeat(1,10,1) 515 | x_feats = x_init + torch.randn(x_init.shape, device=self.device) 516 | x_full = self.points_to_tensor(x_feats, batch['mean'], batch['std']) 517 | x_part = self.points_to_tensor(batch['pcd_part'], batch['mean'], batch['std']) 518 | x_uncond = self.points_to_tensor( 519 | torch.zeros_like(batch['pcd_part']), torch.zeros_like(batch['mean']), torch.zeros_like(batch['std']) 520 | ) 521 | 522 | x_gen_eval = self.sample_val(x_init, x_full, x_part, x_uncond, batch['mean'], batch['std']) 523 | x_gen_eval = x_gen_eval.F.reshape((gt_pts.shape[0],-1,3)) 524 | 525 | for i in range(len(batch['pcd_full'])): 526 | 527 | 528 | pcd_pred = o3d.geometry.PointCloud() 529 | # pcd_pred_all = o3d.geometry.PointCloud() 530 | c_pred = x_gen_eval[i].cpu().detach().numpy() 531 | 532 | # pcd_pred_all.points = o3d.utility.Vector3dVector(c_pred) 533 | dist_pts = np.sqrt(np.sum((c_pred)**2, axis=-1)) 534 | dist_idx = dist_pts < 50.0 535 | points = c_pred[dist_idx] 536 | max_z = x_init[i][...,2].max().item() 537 | min_z = (x_init[i][...,2].mean() - 2 * x_init[i][...,2].std()).item() 538 | pcd_pred.points = o3d.utility.Vector3dVector(points[(points[:,2] < max_z) & (points[:,2] > min_z)]) 539 | pcd_pred.paint_uniform_color([1.0, 0.,0.]) 540 | 541 | file_path = f'exp/distill/sdpo_{self.args.timestamp}/samples/{self.args.batch_size*batch_idx+i}.pcd' 542 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 543 | o3d.io.write_point_cloud(file_path, pcd_pred) 544 | 545 | pcd_gt = o3d.geometry.PointCloud() 546 | # pcd_gt_all = o3d.geometry.PointCloud() 547 | g_pred = batch['pcd_full'][i].cpu().detach().numpy() 548 | # pcd_gt_all.points = o3d.utility.Vector3dVector(g_pred) 549 | pcd_gt.points = o3d.utility.Vector3dVector(g_pred) 550 | pcd_gt.paint_uniform_color([0., 1.,0.]) 551 | 552 | pcd_part = o3d.geometry.PointCloud() 553 | pcd_part.points = o3d.utility.Vector3dVector(batch['pcd_part'][i].cpu().detach().numpy()) 554 | pcd_part.paint_uniform_color([0., 1.,0.]) 555 | 556 | self.chamfer_distance.update(pcd_gt, pcd_pred) 557 | self.precision_recall.update(pcd_gt, pcd_pred) 558 | self.completion_iou.update(pcd_gt, pcd_pred) 559 | 560 | torch.cuda.empty_cache() 561 | 562 | cd_mean, cd_std = self.chamfer_distance.compute() 563 | pr, re, f1 = self.precision_recall.compute_auc() 564 | thr_ious = self.completion_iou.compute() 565 | 566 | return {'val_cd_mean': cd_mean, 'val_cd_std': cd_std, 'val_precision': pr, 'val_recall': re, 'val_fscore': f1, 'val_iou0.5': thr_ious[0.5], 'val_iou0.2': thr_ious[0.2], 'val_iou0.1': thr_ious[0.1]} 567 | 568 | def validation_epoch_end(self, outputs): 569 | 570 | cd_mean = np.mean(np.stack([x["val_cd_mean"] for x in outputs])) 571 | cd_std = np.mean(np.stack([x["val_cd_std"] for x in outputs])) 572 | pr = np.mean(np.stack([x["val_precision"] for x in outputs])) 573 | re = np.mean(np.stack([x["val_recall"] for x in outputs])) 574 | f1 = np.mean(np.stack([x["val_fscore"] for x in outputs])) 575 | iou0_5 = np.mean(np.stack([x["val_iou0.5"] for x in outputs])) 576 | iou0_2 = np.mean(np.stack([x["val_iou0.2"] for x in outputs])) 577 | iou0_1 = np.mean(np.stack([x["val_iou0.1"] for x in outputs])) 578 | 579 | 580 | self.log('val_cd_mean', cd_mean, prog_bar=False) 581 | self.log('val_cd_std', cd_std) 582 | self.log('val_precision', pr) 583 | self.log('val_recall', re) 584 | self.log('val_fscore', f1, prog_bar=False) 585 | self.log('val_iou05', iou0_5, prog_bar=True) 586 | self.log('val_iou02', iou0_2, prog_bar=True) 587 | self.log('val_iou01', iou0_1, prog_bar=True) 588 | 589 | self.chamfer_distance.reset() 590 | self.precision_recall.reset() 591 | self.completion_iou.reset() 592 | 593 | def valid_paths(self, filenames): 594 | output_paths = [] 595 | skip = [] 596 | 597 | for fname in filenames: 598 | seq_dir = f'{self.logger.log_dir}/generated_pcd/{fname.split("/")[-3]}' 599 | ply_name = f'{fname.split("/")[-1].split(".")[0]}.ply' 600 | 601 | skip.append(path.isfile(f'{seq_dir}/{ply_name}')) 602 | makedirs(seq_dir, exist_ok=True) 603 | output_paths.append(f'{seq_dir}/{ply_name}') 604 | 605 | return np.all(skip), output_paths 606 | 607 | 608 | 609 | def main(args): 610 | # metadata 611 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 612 | args.timestamp = timestamp 613 | 614 | # model 615 | model = DistillationDPO(args) 616 | 617 | # dataset 618 | dataloader = SemanticKITTI_dataset.dataloaders['KITTI'](args) 619 | 620 | # ckpt saving config 621 | checkpoint_callback = ModelCheckpoint( 622 | dirpath=f"exp/distill/sdpo_{timestamp}/checkpoints", 623 | filename="{epoch}-{step}-{val_iou05:.3f}-{val_iou02:.3f}-{val_iou01:.3f}", 624 | save_top_k=-1, # save every ckpt 625 | every_n_epochs=1, 626 | ) 627 | 628 | # logger 629 | tb_logger = pl_loggers.TensorBoardLogger(f"exp/distill/sdpo_{timestamp}", default_hp_metric=False) 630 | 631 | # setup trainer 632 | trainer = Trainer( 633 | gpus=2, strategy='ddp', accelerator='gpu', 634 | max_epochs=10, 635 | logger=tb_logger, 636 | log_every_n_steps=1, 637 | callbacks=[checkpoint_callback], 638 | gradient_clip_val=0.05, 639 | val_check_interval=200, 640 | limit_val_batches=1, 641 | detect_anomaly=True, 642 | ) 643 | trainer.fit(model, dataloader) 644 | 645 | if __name__ == "__main__": 646 | parser = argparse.ArgumentParser() 647 | 648 | parser.add_argument("--batch_size", default=1, type=int, required=False, help="batch size") 649 | parser.add_argument("--lr", default=1e-5, type=float, required=False, help="learning rate") 650 | parser.add_argument( 651 | "--SemanticKITTI_path", 652 | default='datasets/SemanticKITTI', 653 | type=str, 654 | required=False, 655 | help="path to SementicKITTI dataset" 656 | ) 657 | parser.add_argument( 658 | "--pre_trained_diff_path", 659 | default='checkpoints/lidiff_ddpo_refined.ckpt', 660 | type=str, 661 | required=False, 662 | help="path to pre-trained diffusion model weights" 663 | ) 664 | 665 | args = parser.parse_args() 666 | main(args) --------------------------------------------------------------------------------