├── .gitignore ├── docs └── framework.png ├── configs ├── desc_net.yaml ├── desc_net_self.yaml ├── cape.yaml ├── cape128.yaml └── human4d.yaml ├── dataset ├── __init__.py ├── single.py ├── cape.py ├── load.py ├── human4d.py └── common.py ├── README.md ├── tools ├── point.py └── exp.py ├── test.py ├── visualize.py ├── metric.py ├── models ├── base_model.py ├── spconv.py ├── desc_net.py └── desc_net_self.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | scripts/ 3 | wandb/ 4 | weights/ 5 | preprocess/ 6 | temp* -------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyifanthu/HumanReg/HEAD/docs/framework.png -------------------------------------------------------------------------------- /configs/desc_net.yaml: -------------------------------------------------------------------------------- 1 | model: "desc_net" 2 | 3 | # td parameter in Equ.(2) 4 | td_init: 1.0 5 | td_min: 0.02 6 | 7 | # Descriptor network configuration 8 | backbone_args: 9 | channels: [-1, 32, 96, 64, 192] 10 | tr_channels: [-1, 32, 32, 64, 96] 11 | feat_channels: 64 12 | n_classes: 14 13 | 14 | optimizer: "Adam" 15 | learning_rate: 16 | init: 1.0e-3 17 | decay_mult: 0.7 18 | decay_step: 500000 19 | clip: 1.0e-6 20 | weight_decay: 0.0 21 | grad_clip: 0.5 22 | 23 | # Supervised loss 24 | sup_loss: 25 | lmda: 1.0 -------------------------------------------------------------------------------- /configs/desc_net_self.yaml: -------------------------------------------------------------------------------- 1 | model: "desc_net_self" 2 | 3 | # td parameter in Equ.(2) 4 | td_init: 1.0 5 | td_min: 0.02 6 | 7 | # Descriptor network configuration 8 | backbone_args: 9 | channels: [-1, 32, 96, 64, 192] 10 | tr_channels: [-1, 32, 32, 64, 96] 11 | feat_channels: 64 12 | n_classes: 14 13 | 14 | optimizer: "Adam" 15 | learning_rate: 16 | init: 1.0e-4 17 | decay_mult: 0.7 18 | decay_step: 500000 19 | clip: 1.0e-6 20 | weight_decay: 0.0 21 | grad_clip: 0.5 22 | 23 | # Self-supervised loss 24 | self_sup_loss: 25 | k_neigh: 3 26 | chamfer_weight: 1.0 27 | smooth_weight: 1.0 28 | class_weight: 0.1 29 | translate_weight: 10.0 -------------------------------------------------------------------------------- /configs/cape.yaml: -------------------------------------------------------------------------------- 1 | name: 'CAPEDataset' 2 | batch_size: 64 3 | 4 | data_dir: &data_dir "/disk1/chenyifan/mpc/mpc-cape/" 5 | voxel_size: &voxel_size 0.01 6 | 7 | train: 8 | data_dir: *data_dir 9 | split: 'train' 10 | augmentation: 11 | centralize: true 12 | together: 13 | scale_low: 1.0 14 | scale_high: 1.0 15 | degree_range: 180 16 | pc2: 17 | degree_range: 180.0 18 | jitter_sigma: 0.0 19 | jitter_clip: 0.0 20 | dof: 'z' 21 | 22 | valid: 23 | data_dir: *data_dir 24 | split: 'val' 25 | augmentation: 26 | centralize: true 27 | 28 | test: 29 | data_dir: *data_dir 30 | split: 'test' 31 | augmentation: 32 | centralize: true -------------------------------------------------------------------------------- /configs/cape128.yaml: -------------------------------------------------------------------------------- 1 | name: 'CAPEDataset' 2 | batch_size: 64 3 | 4 | data_dir: &data_dir "/disk1/chenyifan/mpc/mpc-cape128/" 5 | voxel_size: &voxel_size 0.01 6 | 7 | train: 8 | data_dir: *data_dir 9 | split: 'train' 10 | augmentation: 11 | centralize: true 12 | together: 13 | scale_low: 1.0 14 | scale_high: 1.0 15 | degree_range: 180 16 | pc2: 17 | degree_range: 10.0 18 | jitter_sigma: 0.0 19 | jitter_clip: 0.0 20 | dof: 'z' 21 | 22 | valid: 23 | data_dir: *data_dir 24 | split: 'val' 25 | augmentation: 26 | centralize: true 27 | 28 | test: 29 | data_dir: *data_dir 30 | split: 'test' 31 | augmentation: 32 | centralize: true -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | from torch.utils.data import DataLoader 4 | from .common import list_collate 5 | from .human4d import Human4dDataset 6 | from .cape import CAPEDataset 7 | from .single import SingleDataset 8 | 9 | def get_dataloader(cfg, mode): 10 | assert mode in ['train', 'valid', 'test'] 11 | dataset = eval(cfg.name)(**cfg[mode]) 12 | loader = DataLoader(dataset, 13 | batch_size = cfg.batch_size, 14 | shuffle = mode == 'train', 15 | num_workers = 10, 16 | collate_fn = list_collate) 17 | return loader 18 | 19 | def get_dataset(cfg, mode): 20 | assert mode in ['train', 'valid', 'test'] 21 | dataset = eval(cfg.name)(**cfg[mode]) 22 | return dataset -------------------------------------------------------------------------------- /configs/human4d.yaml: -------------------------------------------------------------------------------- 1 | name: 'Human4dDataset' 2 | data_dir: &data_dir "/disk1/chenyifan/human4d" 3 | voxel_size: &voxel_size 0.01 4 | batch_size: 64 5 | 6 | train: 7 | data_dir: *data_dir 8 | ped_ids: [0,1,2,3,4,5,6] 9 | intervals: [] 10 | voxel_size: *voxel_size 11 | augmentation: 12 | centralize: true 13 | together: 14 | scale_low: 0.9 15 | scale_high: 1.1 16 | degree_range: 180 17 | pc2: 18 | degree_range: 180.0 19 | jitter_sigma: 0.03 20 | jitter_clip: 0.01 21 | dof: 'z' 22 | 23 | valid: 24 | data_dir: *data_dir 25 | ped_ids: [7,8] 26 | intervals: [10,20,40] 27 | voxel_size: *voxel_size 28 | augmentation: 29 | centralize: true 30 | 31 | test: 32 | data_dir: *data_dir 33 | ped_ids: [9] 34 | intervals: [] 35 | voxel_size: *voxel_size 36 | augmentation: 37 | centralize: true -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HumanReg 2 | Implementation of "[HumanReg: Self-supervised Non-rigid Registration of Sparse Human Point Cloud](https://arxiv.org/abs/2312.05462)" 3 | 4 | Authors: [Yifan Chen](https://github.com/chenyifanthu/), Zhiyu Pan, Zhicheng Zhong, Wenxuan Guo, [Jianjiang Feng](https://scholar.google.cz/citations?hl=zh-CN&user=qlcjuzcAAAAJ), [Jie Zhou](https://scholar.google.cz/citations?user=6a79aPwAAAAJ&hl=zh-CN&oi=ao) 5 | 6 | ![method](docs/framework.png "model arch") 7 | 8 | ## Installation 9 | 10 | ## Dataset 11 | | Dataset | Download Link | 12 | |:---:|:---:| 13 | |HumanSyn4D|[[Google Drive](https://drive.google.com/file/d/1JOeVJ8PsI48SPKGfPOQVux3AOQ55tq3H/view?usp=drive_link)]| 14 | |CAPE-512|[[Google Drive](https://drive.google.com/file/d/1R0_5qK-CNKfW8wScZgFvY3MdFKJ9njd7/view?usp=drive_link)]| 15 | |BasketballPlayer|[[Google Drive](https://drive.google.com/file/d/1cxQHXPDmy-I0mA0Ue0DSUHxYb4LgZuvI/view?usp=drive_link)] 16 | 17 | ## Model Zoo 18 | We have put our model checkpoints in Google Drive. 19 | 20 | | Dataset | Download Link | Remark | 21 | |:---:|:---:|:---:| 22 | |HumanSyn4D|[[Google Drive](https://drive.google.com/file/d/1s466b7WNV-C5P9xKjFrfUkgYDive7coQ/view?usp=drive_link)]| Pretrain model| 23 | |CAPE-512|[[Google Drive](https://drive.google.com/file/d/1IbN9_y8a8Dt2_XfxWxcajQZI1Azmmdj9/view?usp=drive_link)]| | 24 | 25 | ## Citation 26 | ``` 27 | @inproceedings{chen2024humanreg, 28 | title={HumanReg: Self-supervised Non-rigid Registration of Human Point Cloud}, 29 | author={Chen, Yifan and Pan, Zhiyu and Zhong, Zhicheng and Guo, Wenxuan and Feng, Jianjiang and Zhou, Jie}, 30 | booktitle={2024 International Conference on 3D Vision (3DV)}, 31 | pages={954--964}, 32 | year={2024}, 33 | organization={IEEE} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /tools/point.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def index_points_group(points, knn_idx, t=False): 5 | """ 6 | Input: 7 | points: input points data, [B, N', C], or [B, C, N'](transposed) 8 | knn_idx: sample index data, [B, N, K] 9 | Return: 10 | new_points:, indexed points data, [B, N, K, C] or [B, C, N, K](transposed) 11 | """ 12 | B, Np, C = points.size() 13 | if t: 14 | Np, C = C, Np 15 | 16 | _, N, K = knn_idx.size() 17 | knn_idx = knn_idx.reshape(B, -1) 18 | if not t: 19 | new_points = torch.gather(points, dim=1, index=knn_idx.unsqueeze(-1).expand(-1, -1, points.size(-1))) 20 | new_points = new_points.reshape(B, N, K, C) 21 | else: 22 | new_points = torch.gather(points, dim=-1, index=knn_idx.unsqueeze(1).expand(-1, points.size(1), -1)) 23 | new_points = new_points.reshape(B, C, N, K) 24 | 25 | return new_points 26 | 27 | 28 | def propagate_features(source_pc: torch.Tensor, target_pc: torch.Tensor, 29 | source_feat: torch.Tensor, nk: int = 3, batched: bool = True): 30 | """ 31 | Propagate features from the domain of source to the domain of target. 32 | :param source_pc: (B, N, 3) point coordinates 33 | :param target_pc: (B, M, 3) point coordinates 34 | :param source_feat: (B, N, F) source features 35 | :param nk: propagate k number 36 | :param batched: whether dimension B is present or not. 37 | :return: (B, M, F) target feature 38 | """ 39 | if not batched: 40 | source_pc = source_pc.unsqueeze(0) 41 | target_pc = target_pc.unsqueeze(0) 42 | source_feat = source_feat.unsqueeze(0) 43 | 44 | dist = torch.cdist(target_pc, source_pc) # (B, N, M) 45 | dist, group_idx = torch.topk(dist, nk, dim=-1, largest=False, sorted=False) # (B, N, K) 46 | 47 | # Shifted reciprocal function. 48 | w_func = 1 / (dist + 1.0e-6) 49 | weight = (w_func / torch.sum(w_func, dim=-1, keepdim=True)).unsqueeze(-1) # (B, N, k, 1) 50 | sparse_feature = index_points_group(source_feat, group_idx) 51 | full_flow = (sparse_feature * weight).sum(-2) # (B, N, C) 52 | 53 | if not batched: 54 | full_flow = full_flow[0] 55 | 56 | return full_flow 57 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os; os.environ["CUDA_VISIBLE_DEVICES"] = "4" 2 | import torch 3 | import importlib 4 | import numpy as np 5 | from omegaconf import OmegaConf 6 | from tqdm import tqdm 7 | 8 | from tools.exp import to_target_device 9 | from dataset import get_dataloader, get_dataset 10 | 11 | ckpt_path = './weights/finetune-cape512.pth' 12 | model_cfg_path = './configs/desc_net_self.yaml' 13 | data_cfg_path = './configs/cape.yaml' 14 | 15 | model_args = OmegaConf.load(model_cfg_path) 16 | data_args = OmegaConf.load(data_cfg_path) 17 | model_args.batch_size = data_args.batch_size = 1 18 | 19 | net_module = importlib.import_module("models." + model_args.model).Model 20 | net_model = net_module(model_args) 21 | net_model.load_state_dict(torch.load(ckpt_path)['state_dict']) 22 | print(f"Checkpoint loaded from {ckpt_path}.") 23 | import pdb; pdb.set_trace() 24 | 25 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 26 | net_model = to_target_device(net_model, device) 27 | net_model.device = device 28 | net_model.eval() 29 | net_model.hparams.is_training = False 30 | 31 | # test_loader = get_dataloader(data_args, 'test') 32 | from torch.utils.data import DataLoader 33 | from dataset.common import list_collate 34 | test_set = get_dataset(data_args, 'test') 35 | test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=10, collate_fn=list_collate) 36 | 37 | pbar = tqdm(test_loader, desc=f'Testing', ncols=120) 38 | with torch.no_grad(): 39 | for batch_idx, batch in enumerate(pbar): 40 | batch = to_target_device(batch, device) 41 | net_model.step(batch, 'test') 42 | # if batch_idx == 20: break 43 | 44 | d = net_model.log_cache.loss_dict 45 | epe3d = d['test/epe3d'] 46 | AccS = d['test/acc3d_strict'] 47 | AccR = d['test/acc3d_relax'] 48 | outlier = d['test/outlier'] 49 | # print(d['test/acc3d_strict']) 50 | # print(len(d['test/acc3d_strict'])) 51 | print("Test metrics:") 52 | print(" + EPE3D: \t %.2f\t+/-\t%.2f" % (np.mean(epe3d) * 100, np.std(epe3d) * 100)) 53 | print(" + AccS (%%): \t %.1f\t+/-\t%.1f" % (np.mean(AccS) * 100, np.std(AccS) * 100)) 54 | print(" + AccR (%%): \t %.1f\t+/-\t%.1f" % (np.mean(AccR) * 100, np.std(AccR) * 100)) 55 | print(" + Outlier: \t %.2f\t+/-\t%.2f" % (np.mean(outlier) * 100, np.std(outlier) * 100)) -------------------------------------------------------------------------------- /dataset/single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import MinkowskiEngine as ME 5 | 6 | from torch.utils.data import Dataset 7 | from .common import DataAugmentor 8 | 9 | class SingleDataset(Dataset): 10 | def __init__(self, 11 | data_dir: str, 12 | split: str, 13 | random_seed: int = 0, 14 | voxel_size: float = 0.01, 15 | augmentation: dict = None) -> None: 16 | super().__init__() 17 | assert split in ["train", "val", "test"] 18 | self.data_dir = data_dir 19 | self.split = split 20 | self.voxel_size = voxel_size 21 | if augmentation is None: 22 | self.augmentor = None 23 | else: 24 | self.augmentor = DataAugmentor(**augmentation) 25 | self.rng = np.random.RandomState(random_seed) 26 | self.generate_cases() 27 | 28 | def generate_cases(self) -> None: 29 | self.filelist = [] 30 | meta = json.load(open(os.path.join(self.data_dir, "meta.json")))[self.split] 31 | for file, n_frames in meta: 32 | self.filelist.append(os.path.join(self.data_dir, 'data', file)) 33 | self.idxs = [] 34 | for i in range(4): 35 | for j in range(4): 36 | if i < j: 37 | self.idxs.append((i, j)) 38 | 39 | def __len__(self) -> int: 40 | return len(self.idxs) * len(self.filelist) 41 | 42 | def __getitem__(self, index) -> dict: 43 | filepath = self.filelist[index // len(self.idxs)] 44 | i, j = self.idxs[index % len(self.idxs)] 45 | data = np.load(filepath, allow_pickle=True) 46 | 47 | ret = {} 48 | ret = {"file": filepath, "idxs": [i, j]} 49 | 50 | # Load point clouds 51 | pc1 = data['pcs'][i].astype(np.float32) 52 | pc2 = data['pcs'][j].astype(np.float32) 53 | ret["pcs"] = [pc1, pc2] 54 | 55 | if self.augmentor is not None: 56 | self.augmentor.process(ret, self.rng) 57 | 58 | # Quantize point clouds 59 | quan_coords1 = self.quantize(ret["pcs"][0]) 60 | quan_coords2 = self.quantize(ret["pcs"][1]) 61 | ret["coords"] = [quan_coords1, quan_coords2] 62 | 63 | return ret 64 | 65 | def quantize(self, points): 66 | coords = np.floor(points / self.voxel_size) 67 | inds = ME.utils.sparse_quantize(coords, return_index=True, return_maps_only=True) 68 | return coords[inds], inds -------------------------------------------------------------------------------- /dataset/cape.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import MinkowskiEngine as ME 5 | 6 | from torch.utils.data import Dataset 7 | from .common import DataAugmentor 8 | 9 | class CAPEDataset(Dataset): 10 | def __init__(self, 11 | data_dir: str, 12 | split: str, 13 | random_seed: int = 0, 14 | voxel_size: float = 0.01, 15 | augmentation: dict = None) -> None: 16 | super().__init__() 17 | assert split in ["train", "val", "test"] 18 | self.data_dir = data_dir 19 | self.split = split 20 | self.voxel_size = voxel_size 21 | if augmentation is None: 22 | self.augmentor = None 23 | else: 24 | self.augmentor = DataAugmentor(**augmentation) 25 | self.rng = np.random.RandomState(random_seed) 26 | self.generate_cases() 27 | 28 | def generate_cases(self) -> None: 29 | self.filelist = [] 30 | meta = json.load(open(os.path.join(self.data_dir, "meta.json")))[self.split] 31 | for file, n_frames in meta: 32 | self.filelist.append(os.path.join(self.data_dir, 'data', file)) 33 | self.idxs = [] 34 | for i in range(4): 35 | for j in range(4): 36 | if i < j: 37 | self.idxs.append((i, j)) 38 | 39 | def __len__(self) -> int: 40 | return len(self.idxs) * len(self.filelist) 41 | 42 | def __getitem__(self, index) -> dict: 43 | filepath = self.filelist[index // len(self.idxs)] 44 | i, j = self.idxs[index % len(self.idxs)] 45 | data = np.load(filepath) 46 | reaxis = [2, 0, 1] 47 | 48 | ret = {} 49 | ret = {"file": filepath, "idxs": [i, j]} 50 | 51 | # Load point clouds 52 | pc1 = data['pcs'][i].astype(np.float32) 53 | pc1 = np.ascontiguousarray(pc1[:, reaxis]) 54 | pc2 = data['pcs'][j].astype(np.float32) 55 | pc2 = np.ascontiguousarray(pc2[:, reaxis]) 56 | ret["pcs"] = [pc1, pc2] 57 | 58 | # Load labels 59 | ret["labels"] = [None, None] 60 | 61 | # Load flows 62 | flow12 = data['flows'][i, j].astype(np.float32) 63 | flow21 = data['flows'][j, i].astype(np.float32) 64 | ret["flows"] = [flow12[:, reaxis], flow21[:, reaxis]] 65 | 66 | if self.augmentor is not None: 67 | self.augmentor.process(ret, self.rng) 68 | 69 | # Quantize point clouds 70 | quan_coords1 = self.quantize(ret["pcs"][0]) 71 | quan_coords2 = self.quantize(ret["pcs"][1]) 72 | ret["coords"] = [quan_coords1, quan_coords2] 73 | 74 | return ret 75 | 76 | def quantize(self, points): 77 | coords = np.floor(points / self.voxel_size) 78 | inds = ME.utils.sparse_quantize(coords, return_index=True, return_maps_only=True) 79 | return coords[inds], inds -------------------------------------------------------------------------------- /dataset/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | from scipy.spatial import cKDTree 5 | 6 | # DATA_DIR = "/disk1/panzhiyu/sync_cam4_dense/" 7 | BONE_LINK = np.array([[0, 1], [0, 2], [0, 3], [3, 4], [4, 5], [0, 9], [9, 10], [10, 11], [2, 6], [2, 12], [6, 7], [7, 8], [12, 13], [13, 14]]) 8 | 9 | def innerp_mat(a, b=None): 10 | if b is None: b = a 11 | assert a.shape == b.shape 12 | return np.sum(a * b, axis=-1) 13 | 14 | def calculate_labels(points, joints): 15 | # Calculate the distance between each point to each bone 16 | # and assign the label to the point with the minimum distance 17 | N, M = points.shape[0], len(BONE_LINK) 18 | distpb = np.zeros((N, M)) 19 | a, b = joints[BONE_LINK[:, 0]], joints[BONE_LINK[:, 1]] 20 | a = np.repeat(a[None, :], N, axis=0) 21 | b = np.repeat(b[None, :], N, axis=0) 22 | points = np.repeat(points[:, None, :], M, axis=1) 23 | ap, bp, ab = points - a, points - b, b - a 24 | t = innerp_mat(points - a, b - a) / innerp_mat(b - a, b - a) 25 | idx1 = np.where(t < 0) 26 | distpb[idx1] = innerp_mat(ap)[idx1] 27 | idx2 = np.where(t > 1) 28 | distpb[idx2] = innerp_mat(bp)[idx2] 29 | idx3 = np.where((t>=0) & (t<=1)) 30 | distpb[idx3] = innerp_mat(ap - t[:, :, None] * ab)[idx3] 31 | labels = np.argmin(distpb, axis=1) 32 | return labels 33 | 34 | def calculate_flow(pc1, ind1, mesh2): 35 | flow12 = mesh2[ind1] - pc1 36 | return flow12 37 | 38 | 39 | class SyncDataLoader: 40 | def __init__(self, data_dir): 41 | self.data_dir = data_dir 42 | self.data_len = [] 43 | for i in range(10): 44 | self.data_len.append(self._get_data_length(i)) 45 | 46 | def load_mesh_points(self, ped_id, frame_id): 47 | filepath = os.path.join(self.data_dir, "ped_mesh", str(ped_id), f"{frame_id}.txt") 48 | mesh = np.loadtxt(filepath) 49 | mesh = mesh[:, [0, 2, 1]] 50 | return mesh 51 | 52 | def load_joints(self, ped_id, frame_id): 53 | filepath = os.path.join(self.data_dir, "joints", str(ped_id), f"joints_{frame_id}.txt") 54 | joints = np.loadtxt(filepath) 55 | joints = joints[:, [0, 2, 1]] 56 | return joints 57 | 58 | def load_points(self, ped_id, frame_id): 59 | filepath = os.path.join(self.data_dir, "points_ped", str(ped_id), f"{frame_id:05d}.ply") 60 | pcd = o3d.io.read_point_cloud(filepath) 61 | points = np.asarray(pcd.points) 62 | return points 63 | 64 | def load_pcd(self, ped_id, frame_id): 65 | filepath = os.path.join(self.data_dir, "points_ped", str(ped_id), f"{frame_id:05d}.ply") 66 | return o3d.io.read_point_cloud(filepath) 67 | 68 | def _get_data_length(self, ped_id): 69 | filepath = os.path.join(self.data_dir, "points_ped", str(ped_id)) 70 | return len(os.listdir(filepath)) 71 | 72 | def get_data_length(self, ped_id): 73 | return self.data_len[ped_id] 74 | 75 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os; os.environ["CUDA_VISIBLE_DEVICES"] = "4" 2 | import torch 3 | import numpy as np 4 | import open3d as o3d 5 | import MinkowskiEngine as ME 6 | import matplotlib.pyplot as plt 7 | 8 | from tqdm import tqdm 9 | from omegaconf import OmegaConf 10 | from tools.exp import to_target_device 11 | from models.desc_net import Model 12 | from dataset.common import list_collate 13 | 14 | ckpt_path = './weights/best-cls-flow.pth' 15 | model_cfg_path = './configs/desc_net.yaml' 16 | ped_id = 9 17 | start_idx = 100 18 | n_frames = 10 19 | n_skip = 2 20 | 21 | def get_loader(voxel_size=0.01): 22 | def quantize(points): 23 | coords = np.floor(points / voxel_size) 24 | inds = ME.utils.sparse_quantize(coords, return_index=True, return_maps_only=True) 25 | return coords[inds], inds 26 | 27 | datai = np.load(f'./human4d/{ped_id}/{start_idx:05d}.npz') 28 | meani = np.mean(datai['pc'], axis=0) 29 | for k in range(1, n_frames): 30 | j = start_idx + k * n_skip 31 | dataj = np.load(f'./human4d/{ped_id}/{j:05d}.npz') 32 | meanj = np.mean(dataj['pc'], axis=0) 33 | pci = datai['pc'] - meani 34 | pcj = dataj['pc'] - meanj 35 | ret = { 36 | 'pcs': [pci, pcj], 37 | # 'pcs': [datai['pc'], dataj['pc']], 38 | 'coords': [quantize(pci), quantize(pcj)], 39 | } 40 | yield list_collate([ret]) 41 | 42 | 43 | model_args = OmegaConf.load(model_cfg_path) 44 | model_args.batch_size = 1 45 | 46 | net_model = Model(model_args) 47 | net_model.load_state_dict(torch.load(ckpt_path)['state_dict']) 48 | print(f"Checkpoint loaded from {ckpt_path}.") 49 | 50 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 51 | net_model = to_target_device(net_model, device) 52 | net_model.device = device 53 | net_model.eval() 54 | net_model.hparams.is_training = False 55 | 56 | res_flow_vis = o3d.geometry.PointCloud() 57 | res_cls_vis = o3d.geometry.PointCloud() 58 | cmap = plt.get_cmap('jet') 59 | with torch.no_grad(): 60 | for i, batch in tqdm(enumerate(get_loader()), total=n_frames-1): 61 | batch = to_target_device(batch, device) 62 | pred = net_model.step(batch, 'test')[1] 63 | if i == 0: 64 | pts0 = batch['pcs'][0][0].cpu().numpy() 65 | pcd = o3d.geometry.PointCloud() 66 | pcd.points = o3d.utility.Vector3dVector(pts0) 67 | pcd.paint_uniform_color(cmap(0.0)[:3]) 68 | res_flow_vis += pcd 69 | cls0 = pred[0]['cls0'].cpu().numpy() 70 | colors = cmap(cls0 / 13)[:,:3] 71 | pcd.colors = o3d.utility.Vector3dVector(colors) 72 | res_cls_vis += pcd 73 | 74 | pts1 = batch['pcs'][1][0].cpu().numpy() 75 | cls1 = pred[0]['cls1'].cpu().numpy() 76 | flow10 = pred[0]['flow10'].cpu().numpy() 77 | 78 | pcd = o3d.geometry.PointCloud() 79 | pcd.points = o3d.utility.Vector3dVector(pts1+flow10) 80 | pcd.paint_uniform_color(cmap((i+1)/(n_frames-1))[:3]) 81 | res_flow_vis += pcd 82 | colors = cmap(cls1 / 13)[:,:3] 83 | pcd.colors = o3d.utility.Vector3dVector(colors) 84 | res_cls_vis += pcd 85 | 86 | o3d.io.write_point_cloud('flow_vis.ply', res_flow_vis) 87 | o3d.io.write_point_cloud('cls_vis.ply', res_cls_vis) -------------------------------------------------------------------------------- /dataset/human4d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import MinkowskiEngine as ME 5 | 6 | from .common import DataAugmentor 7 | from .load import calculate_flow, calculate_labels 8 | from torch.utils.data import Dataset 9 | 10 | class Human4dDataset(Dataset): 11 | def __init__(self, 12 | data_dir: str, 13 | ped_ids: list, 14 | intervals: list = [], 15 | voxel_size: float = 0.01, 16 | random_seed: int = 0, 17 | augmentation: dict = None): 18 | self.data_dir = data_dir 19 | self.ped_ids = ped_ids 20 | self.intervals = intervals 21 | self.voxel_size = voxel_size 22 | self.rng = np.random.RandomState(random_seed) 23 | if augmentation is None: 24 | self.augmentor = None 25 | else: 26 | self.augmentor = DataAugmentor(**augmentation) 27 | self.generate_cases() 28 | 29 | def generate_cases(self): 30 | self.cases = [] 31 | self.lengths = [0] * 10 32 | for ped_id in self.ped_ids: 33 | length = len(os.listdir(os.path.join(self.data_dir, str(ped_id)))) 34 | self.lengths[ped_id] = length 35 | for i in range(length): 36 | if not self.intervals: 37 | self.cases.append((ped_id, i)) 38 | else: 39 | for j in self.intervals: 40 | if i + j < length: 41 | self.cases.append((ped_id, i, i + j)) 42 | 43 | def __len__(self): 44 | return len(self.cases) 45 | 46 | def __getitem__(self, idx): 47 | data = self.cases[idx] 48 | if len(data) == 3: 49 | ped_id, i, j = data 50 | elif len(data) == 2: 51 | ped_id, i = data 52 | j = self.rng.randint(0, self.lengths[ped_id] - 1) 53 | else: 54 | raise ValueError("Invalid data") 55 | 56 | datai = np.load(os.path.join(self.data_dir, str(ped_id), f"{i:05d}.npz")) 57 | dataj = np.load(os.path.join(self.data_dir, str(ped_id), f"{j:05d}.npz")) 58 | ret = {} 59 | 60 | # Load point clouds 61 | ret["pcs"] = [datai['pc'], dataj['pc']] 62 | 63 | # Load labels 64 | labels1 = calculate_labels(datai['pc'], datai['joints']) 65 | labels2 = calculate_labels(dataj['pc'], dataj['joints']) 66 | ret["labels"] = [labels1, labels2] 67 | 68 | # Load flows 69 | flow12 = calculate_flow(datai['pc'], datai['meshid'], dataj['mesh']) 70 | flow21 = calculate_flow(dataj['pc'], dataj['meshid'], datai['mesh']) 71 | ret["flows"] = [flow12, flow21] 72 | 73 | if self.augmentor is not None: 74 | self.augmentor.process(ret, self.rng) 75 | 76 | # Quantize point clouds 77 | quan_coords1 = self.quantize(ret["pcs"][0]) 78 | quan_coords2 = self.quantize(ret["pcs"][1]) 79 | ret["coords"] = [quan_coords1, quan_coords2] 80 | 81 | return ret 82 | 83 | def quantize(self, points): 84 | coords = np.floor(points / self.voxel_size) 85 | inds = ME.utils.sparse_quantize(coords, return_index=True, return_maps_only=True) 86 | return coords[inds], inds 87 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class PairwiseMetric: 4 | def __init__(self, batch_mean: bool = False, compute_epe3d: bool = True, compute_acc3d_outlier: bool = False, scene_level: bool = False): 5 | """ 6 | :param batch_mean: Whether to return an array with size (B, ) or a single scalar (mean) 7 | :param compute_epe3d: compute EPE3D metric 8 | :param compute_acc3d_outlier: compute Acc3d-strict, Acc3d-relax and outlier metric 9 | :param scene_level: whether use the scene threshold as proposed in FlowNet3D. 10 | """ 11 | self.batch_mean = batch_mean 12 | self.compute_epe3d = compute_epe3d 13 | self.compute_acc3d_outlier = compute_acc3d_outlier 14 | self.scene_level = scene_level 15 | 16 | def evaluate(self, gt_flow: torch.Tensor, pd_flow: torch.Tensor, valid_mask: torch.Tensor = None): 17 | """ 18 | Compute the pairwise flow metric; batch dimension will not be reduced. (Unit will be the same as input) 19 | :param gt_flow: (..., N, 3) 20 | :param pd_flow: (..., N, 3) 21 | :param valid_mask: (..., N) 22 | :return: metrics dict. 23 | """ 24 | result_dict = {} 25 | assert gt_flow.size(-1) == pd_flow.size(-1) == 3 26 | assert gt_flow.size(-2) == pd_flow.size(-2) 27 | 28 | n_point = gt_flow.size(-2) 29 | gt_flow = gt_flow.reshape(-1, n_point, 3) 30 | pd_flow = pd_flow.reshape(-1, n_point, 3) 31 | if valid_mask is None: 32 | valid_mask = torch.ones((gt_flow.size(0), n_point), dtype=bool, device=gt_flow.device) 33 | else: 34 | valid_mask = valid_mask.reshape(-1, n_point) 35 | 36 | l2_norm = torch.norm(pd_flow - gt_flow, dim=-1) # (B, N) 37 | 38 | if self.compute_epe3d: 39 | result_dict['epe3d'] = (l2_norm * valid_mask).sum(-1) / (valid_mask.sum(-1) + 1e-6) 40 | 41 | if self.compute_acc3d_outlier: 42 | sf_norm = torch.norm(gt_flow, dim=-1) # (B, N) 43 | rel_err = l2_norm / (sf_norm + 1e-4) # (B, N) 44 | 45 | if self.scene_level: 46 | acc3d_strict_mask = torch.logical_or(l2_norm < 0.05, rel_err < 0.05).float() 47 | acc3d_relax_mask = torch.logical_or(l2_norm < 0.1, rel_err < 0.1).float() 48 | outlier_mask = torch.logical_or(l2_norm > 0.3, rel_err > 0.1).float() 49 | else: 50 | # acc3d_strict_mask = torch.logical_or(l2_norm < 0.02, rel_err < 0.05).float() 51 | # acc3d_relax_mask = torch.logical_or(l2_norm < 0.05, rel_err < 0.1).float() 52 | # outlier_mask = (rel_err > 0.5).float() 53 | acc3d_strict_mask = (l2_norm < 0.05).float() 54 | acc3d_relax_mask = (l2_norm < 0.1).float() 55 | outlier_mask = (l2_norm > 0.2).float() 56 | 57 | result_dict['acc3d_strict'] = (acc3d_strict_mask * valid_mask).sum(-1) / (valid_mask.sum(-1) + 1e-6) 58 | result_dict['acc3d_relax'] = (acc3d_relax_mask * valid_mask).sum(-1) / (valid_mask.sum(-1) + 1e-6) 59 | result_dict['outlier'] = (outlier_mask * valid_mask).sum(-1) / (valid_mask.sum(-1) + 1e-6) 60 | 61 | if self.batch_mean: 62 | for ckey in list(result_dict.keys()): 63 | result_dict[ckey] = torch.mean(result_dict[ckey]) 64 | 65 | return result_dict 66 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | from pathlib import Path 4 | from typing import Mapping, Any 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from tools.exp import AverageMeter, parse_config_yaml 11 | 12 | 13 | def lambda_lr_wrapper(it, lr_config, batch_size): 14 | return max( 15 | lr_config['decay_mult'] ** (int(it * batch_size / lr_config['decay_step'])), 16 | lr_config['clip'] / lr_config['init']) 17 | 18 | 19 | class BaseModel(nn.Module): 20 | def __init__(self, hparams): 21 | super().__init__() 22 | self.hparams = hparams 23 | self.log_cache = AverageMeter() 24 | 25 | @staticmethod 26 | def load_module(spec_path): 27 | """ 28 | Load a module given spec_path 29 | :param spec_path: Path to a model ckpt. 30 | :return: the module class, possibly with weight loaded. 31 | """ 32 | spec_path = Path(spec_path) 33 | config_args = parse_config_yaml(spec_path.parent / "config.yaml") 34 | net_module = importlib.import_module("models." + config_args.model).Model 35 | net_model = net_module(config_args) 36 | if "none.pth" not in spec_path.name: 37 | ckpt_data = torch.load(spec_path) 38 | net_model.load_state_dict(ckpt_data['state_dict']) 39 | print(f"Checkpoint loaded from {spec_path}.") 40 | return net_model 41 | 42 | def configure_optimizers(self): 43 | lr_config = self.hparams.learning_rate 44 | if self.hparams.optimizer == 'SGD': 45 | optimizer = torch.optim.SGD(self.parameters(), lr=lr_config['init'], momentum=0.9, 46 | weight_decay=self.hparams.weight_decay) 47 | elif self.hparams.optimizer == 'Adam': 48 | # The learning rate here is the maximum rate we can reach for each parameter. 49 | optimizer = torch.optim.AdamW(self.parameters(), lr=lr_config['init'], 50 | weight_decay=self.hparams.weight_decay, amsgrad=True) 51 | else: 52 | raise NotImplementedError 53 | scheduler = LambdaLR(optimizer, 54 | lr_lambda=functools.partial( 55 | lambda_lr_wrapper, lr_config=lr_config, batch_size=self.hparams.batch_size)) 56 | return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] 57 | 58 | def on_after_backward(self): 59 | grad_clip_val = self.hparams.get('grad_clip', 1000.) 60 | torch.nn.utils.clip_grad_value_(self.parameters(), clip_value=grad_clip_val) 61 | 62 | # Also remove nan values if any. 63 | has_nan_value = False 64 | for p in filter(lambda p: p.grad is not None, self.parameters()): 65 | pdata = p.grad.data 66 | grad_is_nan = pdata != pdata 67 | if torch.any(grad_is_nan): 68 | has_nan_value = True 69 | pdata[grad_is_nan] = 0. 70 | if has_nan_value: 71 | print(f"Warning: Gets a nan-gradient but set to 0.") 72 | 73 | def log(self, key, value): 74 | # if self.hparams.is_training: 75 | # assert key not in self.log_cache.loss_dict 76 | self.log_cache.append_loss({ 77 | key: value.item() if isinstance(value, torch.Tensor) else value 78 | }) 79 | 80 | def log_dict(self, dictionary: Mapping[str, Any]): 81 | for k, v in dictionary.items(): 82 | self.log(str(k), v) 83 | 84 | # def write_log(self, writer, it): 85 | # logs_written = {} 86 | # if not self.hparams.is_training or it % 10 == 0: 87 | # for k, v in self.log_cache.get_mean_loss_dict().items(): 88 | # writer.add_scalar(k, v, it) 89 | # logs_written[k] = v 90 | # self.log_cache.clear() 91 | # return logs_written 92 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | import requests, json 4 | import wandb 5 | import torch 6 | import importlib 7 | 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | from tools import exp 11 | from omegaconf import OmegaConf 12 | from dataset import get_dataloader 13 | 14 | def train_epoch(): 15 | global epoch_idx 16 | 17 | net_model.train() 18 | net_model.hparams.is_training = True 19 | 20 | pbar = tqdm(train_loader, desc=f'[Ep:{epoch_idx}] Training', ncols=120) 21 | for batch_idx, data in enumerate(pbar): 22 | data = exp.to_target_device(data, net_model.device) 23 | optimizer.zero_grad() 24 | loss, _ = net_model.step(data, 'train') 25 | loss.backward() 26 | net_model.on_after_backward() 27 | optimizer.step() 28 | scheduler.step() 29 | pbar.set_postfix_str("Loss=%.4f" % loss.item()) 30 | 31 | def validate_epoch(): 32 | global metric_val_best, epoch_idx 33 | 34 | net_model.eval() 35 | net_model.hparams.is_training = False 36 | 37 | pbar = tqdm(val_loader, desc=f'[Ep:{epoch_idx}] Validation', ncols=120) 38 | for batch_idx, data in enumerate(pbar): 39 | data = exp.to_target_device(data, net_model.device) 40 | with torch.no_grad(): 41 | loss, _ = net_model.step(data, 'valid') 42 | pbar.set_postfix_str(f"Loss = {loss.item():.4f}") 43 | 44 | metrics = net_model.log_cache.get_mean_loss_dict() 45 | net_model.log_cache.print_format_loss() 46 | net_model.log_cache.clear() 47 | 48 | wandb.log({ 49 | "learning_rate": scheduler.get_last_lr()[0], 50 | **metrics 51 | }) 52 | 53 | model_state = { 54 | 'state_dict': net_model.state_dict(), 55 | 'epoch': epoch_idx, 'metrics': metrics 56 | } 57 | 58 | if metrics["valid/total_loss"] < metric_val_best: 59 | print("* Best Loss: %.4f" % metrics["valid/total_loss"]) 60 | metric_val_best = metrics["valid/total_loss"] 61 | torch.save(model_state, Path(wandb.run.dir) / "best.pth") 62 | torch.save(model_state, Path(wandb.run.dir) / "latest.pth") 63 | 64 | 65 | if __name__ == '__main__': 66 | NET_CFG_PATH = 'configs/desc_net_self.yaml' 67 | DATA_CFG_PATH = 'configs/cape128.yaml' 68 | CKPT_PATH = 'weights/pretrain-model.pth' 69 | EPOCHS = 200 70 | 71 | model_args = OmegaConf.load(NET_CFG_PATH) 72 | data_args = OmegaConf.load(DATA_CFG_PATH) 73 | model_args.batch_size = data_args.batch_size 74 | wandb.init(project='HumanReg', name='cape128', config=model_args, 75 | notes='Self-supervised fine-tuning on cape128 dataset.') 76 | 77 | net_module = importlib.import_module("models." + model_args.model).Model 78 | net_model = net_module(model_args) 79 | if CKPT_PATH: 80 | ckpt_data = torch.load(CKPT_PATH) 81 | net_model.load_state_dict(ckpt_data['state_dict']) 82 | print(f"Pretrained model loaded from {CKPT_PATH}.") 83 | 84 | # Load dataset 85 | train_loader = get_dataloader(data_args, 'train') 86 | val_loader = get_dataloader(data_args, 'valid') 87 | 88 | # Load training specs 89 | optimizers, schedulers = net_model.configure_optimizers() 90 | assert len(optimizers) == 1 and len(schedulers) == 1 91 | optimizer, scheduler = optimizers[0], schedulers[0] 92 | assert scheduler['interval'] == 'step' 93 | scheduler = scheduler['scheduler'] 94 | 95 | # Move to target device 96 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 97 | net_model = exp.to_target_device(net_model, device) 98 | net_model.device = device 99 | 100 | # Train and validate within a protected loop. 101 | metric_val_best = 1e6 102 | # for epoch_idx in range(EPOCHS): 103 | # train_epoch() 104 | # validate_epoch() 105 | try: 106 | for epoch_idx in range(EPOCHS): 107 | train_epoch() 108 | validate_epoch() 109 | # send_message("Training Finished", "Best Flow Loss: %.4f" % metric_val_best) 110 | except Exception as ex: 111 | print(ex) 112 | # send_message("Training Error", str(ex)) 113 | 114 | wandb.finish() 115 | -------------------------------------------------------------------------------- /dataset/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | import numpy as np 4 | from pyquaternion.quaternion import Quaternion 5 | 6 | def list_collate(batch): 7 | """ 8 | This collation does not stack batch dimension, but instead output only lists. 9 | """ 10 | elem = None 11 | for e in batch: 12 | if e is not None: 13 | elem = e 14 | break 15 | elem_type = type(elem) 16 | if isinstance(elem, torch.Tensor): 17 | return batch 18 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 19 | and elem_type.__name__ != 'string_': 20 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 21 | return list_collate([torch.as_tensor(b) if b is not None else None for b in batch]) 22 | elif elem.shape == (): # scalars 23 | return torch.as_tensor(batch) 24 | elif isinstance(elem, float): 25 | return torch.tensor(batch, dtype=torch.float64) 26 | elif isinstance(elem, int): 27 | return torch.tensor(batch) 28 | elif isinstance(elem, str): 29 | return batch 30 | elif isinstance(elem, collections.abc.Mapping): 31 | return {key: list_collate([d[key] for d in batch]) for key in elem} 32 | elif isinstance(elem, collections.abc.Sequence): 33 | # check to make sure that the elements in batch have consistent size 34 | it = iter(batch) 35 | elem_size = len(next(it)) 36 | if not all(len(elem) == elem_size for elem in it): 37 | raise RuntimeError('each element in list of batch should be of equal size') 38 | transposed = zip(*batch) 39 | return [list_collate(samples) for samples in transposed] 40 | elif elem is None: 41 | return batch 42 | 43 | raise NotImplementedError 44 | 45 | 46 | class DataAugmentor: 47 | """ 48 | Will apply data augmentation to pairwise point clouds, by applying random transformations 49 | to the point clouds (or one of them), or adding noise. 50 | """ 51 | def __init__(self, 52 | centralize: dict = None, 53 | together: dict = None, 54 | pc2: dict = None): 55 | self.centralize_args = centralize 56 | self.together_args = together 57 | self.pc2_args = pc2 58 | 59 | def process(self, data_dict: dict, rng: np.random.RandomState): 60 | pcs = data_dict["pcs"] 61 | assert len(pcs) == 2 62 | 63 | pc1, pc2 = pcs[0], pcs[1] 64 | if 'flows' in data_dict: 65 | pc1_virtual = pc2 + data_dict["flows"][1] 66 | pc2_virtual = pc1 + data_dict["flows"][0] 67 | 68 | if self.centralize_args is not None: 69 | pc1_center = np.mean(pc1, axis=0) 70 | pc2_center = np.mean(pc2, axis=0) 71 | pc1 -= pc1_center 72 | pc2 -= pc2_center 73 | if 'flows' in data_dict: 74 | pc1_virtual -= pc1_center 75 | pc2_virtual -= pc2_center 76 | 77 | if self.together_args is not None: 78 | scale = np.diag(rng.uniform(self.together_args.scale_low, 79 | self.together_args.scale_high, 3).astype(np.float32)) 80 | angle = rng.uniform(-self.together_args.degree_range, self.together_args.degree_range) / 180.0 * np.pi 81 | rot_matrix = np.array([ 82 | [np.cos(angle), np.sin(angle), 0.], 83 | [-np.sin(angle), np.cos(angle), 0.], 84 | [0., 0., 1.] 85 | ], dtype=np.float32) 86 | matrix = scale.dot(rot_matrix.T) 87 | 88 | pc1 = pc1.dot(matrix) 89 | pc2 = pc2.dot(matrix) 90 | if 'flows' in data_dict: 91 | pc1_virtual = pc1_virtual.dot(matrix) 92 | pc2_virtual = pc2_virtual.dot(matrix) 93 | 94 | if self.pc2_args is not None: 95 | angle2 = rng.uniform(-self.pc2_args.degree_range, self.pc2_args.degree_range) / 180.0 * np.pi 96 | rot_axis = np.array([0.0, 0.0, 1.0]) 97 | matrix2 = Quaternion(axis=rot_axis, radians=angle2).rotation_matrix.astype(np.float32) 98 | 99 | jitter1 = np.clip(self.pc2_args.jitter_sigma * rng.randn(pc1.shape[0], 3), 100 | -self.pc2_args.jitter_clip, self.pc2_args.jitter_clip).astype(np.float32) 101 | jitter2 = np.clip(self.pc2_args.jitter_sigma * rng.randn(pc2.shape[0], 3), 102 | -self.pc2_args.jitter_clip, self.pc2_args.jitter_clip).astype(np.float32) 103 | 104 | pc1 = pc1 + jitter1 105 | pc2 = pc2.dot(matrix2) + jitter2 106 | if 'flows' in data_dict: 107 | pc2_virtual = pc2_virtual.dot(matrix2) 108 | 109 | 110 | data_dict["pcs"] = [pc1, pc2] 111 | if 'flows' in data_dict: 112 | data_dict["flows"] = [pc2_virtual - pc1, pc1_virtual - pc2] 113 | -------------------------------------------------------------------------------- /models/spconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | import MinkowskiEngine.MinkowskiFunctional as MEF 5 | 6 | 7 | class BasicBlockBase(nn.Module): 8 | """ 9 | A double-conv ResBlock with relu activation, with residual connection. 10 | """ 11 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, D=3): 12 | super(BasicBlockBase, self).__init__() 13 | self.conv1 = ME.MinkowskiConvolution( 14 | inplanes, planes, kernel_size=3, stride=stride, dimension=D) 15 | self.norm1 = ME.MinkowskiInstanceNorm(planes) 16 | self.conv2 = ME.MinkowskiConvolution( 17 | planes, planes, kernel_size=3, stride=1, dilation=dilation, bias=False, dimension=D) 18 | self.norm2 = ME.MinkowskiInstanceNorm(planes) 19 | self.downsample = downsample 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.norm1(out) 26 | out = MEF.relu(out) 27 | 28 | out = self.conv2(out) 29 | out = self.norm2(out) 30 | 31 | if self.downsample is not None: 32 | residual = self.downsample(x) 33 | 34 | out += residual 35 | out = MEF.relu(out) 36 | 37 | return out 38 | 39 | 40 | class ResUNet(ME.MinkowskiNetwork): 41 | """ 42 | Our main network structure - a U-Net with residual double-conv blocks. 43 | Please refer to the appendix of our paper for illustration of the model. 44 | """ 45 | def __init__(self, network_config, 46 | in_channels=3, out_channels=32, n_classes=14, 47 | normalize_feature=None, conv1_kernel_size=None, D=3): 48 | super().__init__(D) 49 | channels = network_config.channels 50 | tr_channels = list(network_config.tr_channels) 51 | 52 | assert len(channels) == len(tr_channels) 53 | channels[0] = in_channels 54 | tr_channels.append(0) 55 | 56 | self.normalize_feature = normalize_feature 57 | 58 | self.in_convs, self.in_norms, self.in_blocks = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() 59 | self.out_convs, self.out_norms, self.out_blocks = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() 60 | 61 | for layer_id in range(len(channels) - 1): 62 | self.in_convs.append(ME.MinkowskiConvolution( 63 | in_channels=channels[layer_id], 64 | out_channels=channels[layer_id + 1], 65 | kernel_size=conv1_kernel_size if layer_id == 0 else 3, 66 | stride=1 if layer_id == 0 else 2, 67 | dilation=1, bias=False, dimension=D)) 68 | self.in_norms.append(ME.MinkowskiInstanceNorm(channels[layer_id + 1])) 69 | self.in_blocks.append(BasicBlockBase( 70 | channels[layer_id + 1], channels[layer_id + 1], D=D)) 71 | self.out_convs.append(ME.MinkowskiConvolutionTranspose( 72 | in_channels=channels[layer_id + 1] + tr_channels[layer_id + 2], 73 | out_channels=tr_channels[layer_id + 1], 74 | kernel_size=1 if layer_id == 0 else 3, 75 | stride=1 if layer_id == 0 else 2, 76 | dilation=1, 77 | bias=False, 78 | dimension=D)) 79 | if layer_id > 0: 80 | self.out_norms.append(ME.MinkowskiInstanceNorm(tr_channels[layer_id + 1])) 81 | self.out_blocks.append(BasicBlockBase( 82 | tr_channels[layer_id + 1], tr_channels[layer_id + 1], D=D)) 83 | 84 | self.final = ME.MinkowskiConvolution( 85 | in_channels=tr_channels[1], out_channels=out_channels, 86 | kernel_size=1, stride=1, dilation=1, bias=True, dimension=D) 87 | 88 | self.classifier = ME.MinkowskiConvolution( 89 | in_channels=tr_channels[1], out_channels=n_classes, 90 | kernel_size=1, stride=1, dilation=1, bias=True, dimension=D) 91 | 92 | 93 | def forward(self, x): 94 | skip_outputs = [] 95 | for layer_id in range(len(self.in_convs)): 96 | out_skip = self.in_convs[layer_id](x) 97 | out_skip = self.in_norms[layer_id](out_skip) 98 | out_skip = self.in_blocks[layer_id](out_skip) 99 | x = MEF.relu(out_skip) 100 | skip_outputs.append(out_skip) 101 | 102 | for layer_id in range(len(self.in_convs) - 1, -1, -1): 103 | x = self.out_convs[layer_id](x) 104 | if layer_id > 0: 105 | x = self.out_norms[layer_id - 1](x) 106 | x = self.out_blocks[layer_id - 1](x) 107 | x_tr = MEF.relu(x) 108 | if layer_id > 0: 109 | x = ME.cat(x_tr, skip_outputs[layer_id - 1]) 110 | 111 | feat = self.final(x) 112 | cls = self.classifier(x) 113 | 114 | if self.normalize_feature: 115 | return ME.SparseTensor( 116 | feat.F / torch.norm(feat.F, p=2, dim=1, keepdim=True), 117 | coordinate_map_key=feat.coordinate_map_key, 118 | coordinate_manager=feat.coordinate_manager), cls 119 | else: 120 | return feat, cls 121 | 122 | -------------------------------------------------------------------------------- /tools/exp.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | from collections import OrderedDict 4 | 5 | import sys 6 | import numpy as np 7 | import torch 8 | import functools 9 | from pathlib import Path 10 | from omegaconf import OmegaConf 11 | 12 | 13 | def seed_everything(seed: int): 14 | """ 15 | Setup global seed to ensure reproducibility. 16 | :param seed: integer value 17 | """ 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def parse_config_yaml(yaml_path: Path, args: OmegaConf = None, override: bool = True) -> OmegaConf: 25 | """ 26 | Load yaml file, and optionally merge it with existing ones. 27 | This supports a light-weight (recursive) inclusion scheme. 28 | :param yaml_path: path to the yaml file 29 | :param args: previous config 30 | :param override: if option clashes, whether or not to overwrite previous ones. 31 | :return: new config. 32 | """ 33 | if args is None: 34 | args = OmegaConf.create() 35 | 36 | configs = OmegaConf.load(yaml_path) 37 | if "include_configs" in configs: 38 | base_configs = configs["include_configs"] 39 | del configs["include_configs"] 40 | if isinstance(base_configs, str): 41 | base_configs = [base_configs] 42 | # Update the config from top to down. 43 | for base_config in base_configs: 44 | base_config_path = yaml_path.parent / Path(base_config) 45 | configs = parse_config_yaml(base_config_path, configs, override=False) 46 | 47 | if "assign" in configs: 48 | overlays = configs["assign"] 49 | del configs["assign"] 50 | assign_config = OmegaConf.from_dotlist([f"{k}={v}" for k, v in overlays.items()]) 51 | configs = OmegaConf.merge(configs, assign_config) 52 | 53 | if override: 54 | return OmegaConf.merge(args, configs) 55 | else: 56 | return OmegaConf.merge(configs, args) 57 | 58 | 59 | def to_target_device(obj, device): 60 | if isinstance(obj, tuple): 61 | return tuple(map(functools.partial(to_target_device, device=device), obj)) 62 | if isinstance(obj, list): 63 | return list(map(functools.partial(to_target_device, device=device), obj)) 64 | if isinstance(obj, dict): 65 | return dict(map(functools.partial(to_target_device, device=device), obj.items())) 66 | if isinstance(obj, torch.Tensor): 67 | return obj.to(device) 68 | if isinstance(obj, torch.nn.Module): 69 | return obj.to(device) 70 | return obj 71 | 72 | 73 | class AverageMeter: 74 | """ 75 | Maintain named lists of numbers. Compute their average to evaluate dataset statistics. 76 | This can not only used for loss, but also for progressive training logging, supporting import/export data. 77 | """ 78 | def __init__(self): 79 | self.loss_dict = OrderedDict() 80 | 81 | def clear(self): 82 | self.loss_dict.clear() 83 | 84 | def export(self, f): 85 | if isinstance(f, str): 86 | f = open(f, 'wb') 87 | pickle.dump(self.loss_dict, f) 88 | 89 | def load(self, f): 90 | if isinstance(f, str): 91 | f = open(f, 'rb') 92 | self.loss_dict = pickle.load(f) 93 | return self 94 | 95 | def append_loss(self, losses): 96 | for loss_name, loss_val in losses.items(): 97 | if loss_val is None: 98 | continue 99 | if loss_name not in self.loss_dict.keys(): 100 | self.loss_dict.update({loss_name: [loss_val]}) 101 | else: 102 | self.loss_dict[loss_name].append(loss_val) 103 | 104 | def get_mean_loss_dict(self): 105 | loss_dict = {} 106 | for loss_name, loss_arr in self.loss_dict.items(): 107 | loss_dict[loss_name] = sum(loss_arr) / len(loss_arr) 108 | return loss_dict 109 | 110 | def get_mean_loss(self): 111 | mean_loss_dict = self.get_mean_loss_dict() 112 | if len(mean_loss_dict) == 0: 113 | return 0.0 114 | else: 115 | return sum(mean_loss_dict.values()) / len(mean_loss_dict) 116 | 117 | def get_printable_mean(self): 118 | text = "" 119 | loss_dict = self.get_mean_loss_dict() 120 | for i, loss_name in enumerate(sorted(loss_dict.keys())): 121 | loss_mean = loss_dict[loss_name] 122 | text += "%s: %.4f | " % (loss_name, loss_mean) 123 | if i == 2: text += "\n" 124 | return text 125 | 126 | def get_newest_loss_dict(self, return_count=False): 127 | loss_dict = {} 128 | loss_count_dict = {} 129 | for loss_name, loss_arr in self.loss_dict.items(): 130 | if len(loss_arr) > 0: 131 | loss_dict[loss_name] = loss_arr[-1] 132 | loss_count_dict[loss_name] = len(loss_arr) 133 | if return_count: 134 | return loss_dict, loss_count_dict 135 | else: 136 | return loss_dict 137 | 138 | def get_printable_newest(self): 139 | nloss_val, nloss_count = self.get_newest_loss_dict(return_count=True) 140 | return ", ".join([f"{loss_name}[{nloss_count[loss_name] - 1}]: {nloss_val[loss_name]}" 141 | for loss_name in nloss_val.keys()]) 142 | 143 | def print_format_loss(self, color=None): 144 | if hasattr(sys.stdout, "terminal"): 145 | color_device = sys.stdout.terminal 146 | else: 147 | color_device = sys.stdout 148 | if color == "y": 149 | color_device.write('\033[93m') 150 | elif color == "g": 151 | color_device.write('\033[92m') 152 | elif color == "b": 153 | color_device.write('\033[94m') 154 | print(self.get_printable_mean(), flush=True) 155 | if color is not None: 156 | color_device.write('\033[0m') 157 | 158 | -------------------------------------------------------------------------------- /models/desc_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | from torch.nn import Parameter 4 | from torch.utils.data import DataLoader 5 | 6 | from dataset.human4d import * 7 | from metric import PairwiseMetric 8 | 9 | from typing import * 10 | from collections import defaultdict 11 | from models.spconv import ResUNet 12 | from models.base_model import BaseModel 13 | import numpy as np 14 | from tools.point import propagate_features 15 | from sklearn.metrics import confusion_matrix 16 | 17 | 18 | class Model(BaseModel): 19 | """ 20 | This model trains the descriptor network. This is the 1st stage of training. 21 | """ 22 | def __init__(self, hparams): 23 | super().__init__(hparams) 24 | self.backbone_args = self.hparams.backbone_args 25 | self.backbone = ResUNet(self.backbone_args, 26 | in_channels=3, 27 | out_channels=self.backbone_args.feat_channels, 28 | n_classes=self.backbone_args.n_classes, 29 | normalize_feature=True, 30 | conv1_kernel_size=3) 31 | self.td = Parameter(torch.tensor(np.float32(self.hparams.td_init)), requires_grad=True) 32 | 33 | def forward(self, batch): 34 | """ 35 | Forward descriptor network. 36 | As the backbone quantized point cloud into voxels (by selecting one point for each voxel), 37 | we also return the selected point indices. 38 | """ 39 | num_batches = len(batch["pcs"][0]) 40 | num_views = len(batch["pcs"]) 41 | all_coords, all_feats, all_sels = [], [], [] 42 | for batch_idx in range(num_batches): 43 | for view_idx in range(num_views): 44 | all_coords.append(batch["coords"][view_idx][0][batch_idx]) 45 | cur_sel = batch["coords"][view_idx][1][batch_idx] 46 | all_sels.append(cur_sel) 47 | all_feats.append(batch["pcs"][view_idx][batch_idx][cur_sel]) 48 | coords_batch, feats_batch = ME.utils.sparse_collate(all_coords, all_feats, device=self.device) 49 | sinput = ME.SparseTensor(feats_batch, coordinates=coords_batch) 50 | desc_output, cls_output = self.backbone(sinput) 51 | 52 | # Compute loss and metrics 53 | num_batches = len(batch["pcs"][0]) 54 | losses, metrics, pd = [], [], [] 55 | has_labels = "labels" in batch 56 | has_flows = "flows" in batch 57 | for batch_idx in range(num_batches): 58 | cur_pc0, cur_pc1 = batch["pcs"][0][batch_idx], batch["pcs"][1][batch_idx] 59 | cur_sel0, cur_sel1 = all_sels[batch_idx * 2 + 0], all_sels[batch_idx * 2 + 1] 60 | cur_gt0 = cur_gt1 = cur_labels0 = cur_labels1 = None 61 | if has_flows: 62 | cur_gt0, cur_gt1 = batch["flows"][0][batch_idx], batch["flows"][1][batch_idx] 63 | if has_labels: 64 | cur_labels0, cur_labels1 = batch["labels"][0][batch_idx], batch["labels"][1][batch_idx] 65 | 66 | cur_cls0 = cls_output.features_at(batch_idx * 2 + 0) 67 | cur_cls1 = cls_output.features_at(batch_idx * 2 + 1) 68 | cur_feat0 = desc_output.features_at(batch_idx * 2 + 0) 69 | cur_feat1 = desc_output.features_at(batch_idx * 2 + 1) 70 | 71 | dist_mat = torch.cdist(cur_feat0, cur_feat1) / torch.maximum( 72 | torch.tensor(np.float32(self.hparams.td_min), device=self.device), self.td) 73 | cur_pd0 = torch.softmax(-dist_mat, dim=1) @ cur_pc1[cur_sel1] - cur_pc0[cur_sel0] 74 | cur_pd1 = torch.softmax(-dist_mat, dim=0).transpose(-1, -2) @ cur_pc0[cur_sel0] - cur_pc1[cur_sel1] 75 | 76 | # Compute Metrics 77 | loss_dict = self.compute_sup_loss(self.hparams.sup_loss, 78 | cur_cls0, cur_cls1, cur_pd0, cur_pd1, cur_sel0, cur_sel1, 79 | cur_labels0, cur_labels1, cur_gt0, cur_gt1) 80 | losses.append(loss_dict) 81 | 82 | pd_full = {} 83 | if not self.hparams.is_training: 84 | pd_full_cls0 = propagate_features(cur_pc0[cur_sel0], cur_pc0, cur_cls0, batched=False) 85 | pd_full_cls1 = propagate_features(cur_pc1[cur_sel1], cur_pc1, cur_cls1, batched=False) 86 | pd_full_flow01 = propagate_features(cur_pc0[cur_sel0], cur_pc0, cur_pd0, batched=False) 87 | pd_full_flow10 = propagate_features(cur_pc1[cur_sel1], cur_pc1, cur_pd1, batched=False) 88 | metric = self.compute_metric(pd_full_cls0, pd_full_cls1, pd_full_flow01, pd_full_flow10, 89 | cur_labels0, cur_labels1, cur_gt0, cur_gt1) 90 | metrics.append(metric) 91 | pd_full = {'cls0': torch.max(pd_full_cls0, 1)[1], 92 | 'cls1': torch.max(pd_full_cls1, 1)[1], 93 | 'flow01': pd_full_flow01.cpu().numpy(), 94 | 'flow10': pd_full_flow10.cpu().numpy(), 95 | } 96 | pd.append(pd_full) 97 | 98 | return losses, metrics, pd 99 | 100 | @staticmethod 101 | def compute_sup_loss(loss_config, cls0, cls1, fpd0, fpd1, sel0, sel1, 102 | labels0=None, labels1=None, 103 | fgt0=None, fgt1=None) -> Tuple[torch.Tensor, torch.Tensor]: 104 | res = defaultdict(list) 105 | cls_criterion = torch.nn.CrossEntropyLoss() 106 | if labels0 is not None: 107 | loss = (1 - loss_config.lmda) * cls_criterion(cls0, labels0[sel0]) 108 | res['cls_loss'].append(loss) 109 | if labels1 is not None: 110 | loss = (1 - loss_config.lmda) * cls_criterion(cls1, labels1[sel1]) 111 | res['cls_loss'].append(loss) 112 | if fgt0 is not None: 113 | loss = loss_config.lmda * torch.linalg.norm(fpd0 - fgt0[sel0], dim=-1).mean() 114 | res['flow_loss'].append(loss) 115 | if fgt1 is not None: 116 | loss = loss_config.lmda * torch.linalg.norm(fpd1 - fgt1[sel1], dim=-1).mean() 117 | res['flow_loss'].append(loss) 118 | for k, v in res.items(): 119 | res[k] = torch.stack(v).mean() 120 | return res 121 | 122 | @staticmethod 123 | def compute_metric(cls0, cls1, fpd01, fpd10, 124 | labels0=None, labels1=None, 125 | fgt01=None, fgt10=None): 126 | metric = PairwiseMetric(compute_epe3d=True, compute_acc3d_outlier=True) 127 | res = defaultdict(list) 128 | if labels0 is not None: 129 | _, pred0 = torch.max(cls0, 1) 130 | correct = (pred0 == labels0).sum().item() 131 | res['accuracy'].append(correct / pred0.shape[0]) 132 | if labels1 is not None: 133 | _, pred1 = torch.max(cls1, 1) 134 | correct = (pred1 == labels1).sum().item() 135 | res['accuracy'].append(correct / pred1.shape[0]) 136 | if fgt01 is not None: 137 | m = metric.evaluate(fgt01, fpd01) 138 | for k, v in m.items(): 139 | res[k].append(v.item()) 140 | if fgt10 is not None: 141 | m = metric.evaluate(fgt10, fpd10) 142 | for k, v in m.items(): 143 | res[k].append(v.item()) 144 | for k, v in res.items(): 145 | res[k] = np.mean(v) 146 | return res 147 | 148 | def step(self, batch, mode): 149 | assert mode in ['train', 'valid', 'test'] 150 | losses, metrics, pd = self(batch) 151 | loss_sum = 0.0 152 | for loss_dict in losses: 153 | for name, val in loss_dict.items(): 154 | self.log(f'{mode}/{name}', val) 155 | loss_sum += val 156 | for metric_dict in metrics: 157 | for name, val in metric_dict.items(): 158 | self.log(f'{mode}/{name}', val) 159 | total_loss = loss_sum / len(losses) 160 | self.log(f'{mode}/total_loss', total_loss) 161 | return total_loss, pd -------------------------------------------------------------------------------- /models/desc_net_self.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | import torch_scatter 3 | import MinkowskiEngine as ME 4 | from torch.nn import Parameter 5 | 6 | from dataset.human4d import * 7 | from metric import PairwiseMetric 8 | 9 | from typing import * 10 | from collections import defaultdict 11 | from models.spconv import ResUNet 12 | from models.base_model import BaseModel 13 | import numpy as np 14 | from functools import reduce 15 | from tools.point import propagate_features 16 | 17 | class Model(BaseModel): 18 | """ 19 | This model trains the descriptor network. This is the 1st stage of training. 20 | """ 21 | def __init__(self, hparams): 22 | super().__init__(hparams) 23 | self.backbone_args = self.hparams.backbone_args 24 | self.backbone = ResUNet(self.backbone_args, 25 | in_channels=3, 26 | out_channels=self.backbone_args.feat_channels, 27 | n_classes=self.backbone_args.n_classes, 28 | normalize_feature=True, 29 | conv1_kernel_size=3) 30 | self.td = Parameter(torch.tensor(np.float32(self.hparams.td_init)), requires_grad=True) 31 | 32 | def forward(self, batch): 33 | """ 34 | Forward descriptor network. 35 | As the backbone quantized point cloud into voxels (by selecting one point for each voxel), 36 | we also return the selected point indices. 37 | """ 38 | num_batches = len(batch["pcs"][0]) 39 | num_views = len(batch["pcs"]) 40 | all_coords, all_feats, all_sels = [], [], [] 41 | for batch_idx in range(num_batches): 42 | for view_idx in range(num_views): 43 | all_coords.append(batch["coords"][view_idx][0][batch_idx]) 44 | cur_sel = batch["coords"][view_idx][1][batch_idx] 45 | all_sels.append(cur_sel) 46 | all_feats.append(batch["pcs"][view_idx][batch_idx][cur_sel]) 47 | coords_batch, feats_batch = ME.utils.sparse_collate(all_coords, all_feats, device=self.device) 48 | sinput = ME.SparseTensor(feats_batch, coordinates=coords_batch) 49 | desc_output, cls_output = self.backbone(sinput) 50 | 51 | # Compute loss and metrics 52 | num_batches = len(batch["pcs"][0]) 53 | losses, metrics, pd = [], [], [] 54 | has_labels = "labels" in batch 55 | has_flows = "flows" in batch 56 | for batch_idx in range(num_batches): 57 | cur_pc0, cur_pc1 = batch["pcs"][0][batch_idx], batch["pcs"][1][batch_idx] 58 | cur_sel0, cur_sel1 = all_sels[batch_idx * 2 + 0], all_sels[batch_idx * 2 + 1] 59 | cur_gt0 = cur_gt1 = cur_labels0 = cur_labels1 = None 60 | if has_flows: 61 | cur_gt0, cur_gt1 = batch["flows"][0][batch_idx], batch["flows"][1][batch_idx] 62 | if has_labels: 63 | cur_labels0, cur_labels1 = batch["labels"][0][batch_idx], batch["labels"][1][batch_idx] 64 | 65 | cur_cls0 = cls_output.features_at(batch_idx * 2 + 0) 66 | cur_cls1 = cls_output.features_at(batch_idx * 2 + 1) 67 | cur_feat0 = desc_output.features_at(batch_idx * 2 + 0) 68 | cur_feat1 = desc_output.features_at(batch_idx * 2 + 1) 69 | 70 | dist_mat = torch.cdist(cur_feat0, cur_feat1) / torch.maximum( 71 | torch.tensor(np.float32(self.hparams.td_min), device=self.device), self.td) 72 | cur_pd0 = torch.softmax(-dist_mat, dim=1) @ cur_pc1[cur_sel1] - cur_pc0[cur_sel0] 73 | cur_pd1 = torch.softmax(-dist_mat, dim=0).transpose(-1, -2) @ cur_pc0[cur_sel0] - cur_pc1[cur_sel1] 74 | 75 | # Compute Loss 76 | loss_dict = self.compute_self_sup_loss(cur_pc0[cur_sel0], cur_pc1[cur_sel1], cur_pd0, 77 | cur_cls0, cur_cls1, self.hparams.self_sup_loss) 78 | losses.append(loss_dict) 79 | loss_dict = self.compute_self_sup_loss(cur_pc1[cur_sel1], cur_pc0[cur_sel0], cur_pd1, 80 | cur_cls1, cur_cls0, self.hparams.self_sup_loss) 81 | losses.append(loss_dict) 82 | 83 | # For testing only 84 | if not self.hparams.is_training: 85 | pd_full_cls0 = propagate_features(cur_pc0[cur_sel0], cur_pc0, cur_cls0, batched=False) 86 | pd_full_cls1 = propagate_features(cur_pc1[cur_sel1], cur_pc1, cur_cls1, batched=False) 87 | pd_full_flow01 = propagate_features(cur_pc0[cur_sel0], cur_pc0, cur_pd0, batched=False) 88 | pd_full_flow10 = propagate_features(cur_pc1[cur_sel1], cur_pc1, cur_pd1, batched=False) 89 | metric = self.compute_metric(pd_full_cls0, pd_full_cls1, pd_full_flow01, pd_full_flow10, 90 | cur_labels0, cur_labels1, cur_gt0, cur_gt1) 91 | metrics.append(metric) 92 | 93 | cls0_pd = torch.max(pd_full_cls0, 1)[1] 94 | cls1_pd = torch.max(pd_full_cls1, 1)[1] 95 | pd_full_flow01 = self.refine_flow(cur_pc0, cls0_pd, pd_full_flow01) 96 | pd_full_flow10 = self.refine_flow(cur_pc1, cls1_pd, pd_full_flow10) 97 | 98 | pd_full = {'cls0': cls0_pd, 99 | 'cls1': cls1_pd, 100 | 'flow01': pd_full_flow01.cpu().numpy(), 101 | 'flow10': pd_full_flow10.cpu().numpy() 102 | } 103 | pd.append(pd_full) 104 | 105 | return losses, metrics, pd 106 | 107 | def refine_flow(self, pc0, cls0, flow01): 108 | pc0_warpped = pc0 + flow01 109 | for i in range(14): 110 | idx = cls0 == i 111 | if idx.sum() < 2: continue 112 | src = pc0[idx] 113 | dst = pc0_warpped[idx] 114 | src_mean = src.mean(0) 115 | dst_mean = dst.mean(0) 116 | # Compute covariance 117 | H = (src - src_mean).T @ (dst - dst_mean) 118 | u, _, vt = torch.linalg.svd(H) 119 | rot_pd_T = u @ vt 120 | t_pd = - src_mean @ rot_pd_T + dst_mean 121 | flow01[idx] = src @ rot_pd_T + t_pd - src 122 | return flow01 123 | 124 | 125 | @staticmethod 126 | def compute_self_sup_loss(pc0, pc1, pd_flow01, lbl_out0, lbl_out1, loss_config): 127 | pc0_warpped = pc0 + pd_flow01 128 | dist01 = torch.cdist(pc0_warpped, pc1) 129 | loss_dict = {} 130 | k_neigh = loss_config.k_neigh 131 | 132 | if loss_config.chamfer_weight > 0.0: 133 | chamfer01 = torch.min(dist01, dim=-1).values 134 | chamfer10 = torch.min(dist01, dim=-2).values 135 | loss_dict['chamfer'] = loss_config.chamfer_weight * (chamfer01.mean() + chamfer10.mean()) 136 | 137 | if loss_config.smooth_weight > 0.0: 138 | dist00 = torch.cdist(pc0, pc0) 139 | _, kidx0 = torch.topk(dist00, k_neigh, dim=-1, largest=False, sorted=False) 140 | 141 | grouped_flow = pd_flow01[kidx0] # (N, K, 3) 142 | loss_dict['smooth'] = loss_config.smooth_weight * \ 143 | (((grouped_flow - pd_flow01.unsqueeze(1)) ** 2).sum(-1).sum(-1) / (k_neigh - 1.0)).mean() 144 | 145 | if loss_config.class_weight > 0.0: 146 | lbl0 = torch.argmax(lbl_out0, dim=-1) # (N) 147 | grouped_lbl_gt = lbl0.repeat(k_neigh, 1).T.flatten() # (N * K) 148 | dist, knn01 = torch.topk(dist01, k_neigh, dim=-1, largest=False, sorted=False) # (N, K) 149 | grouped_lbl_pd = lbl_out1[knn01].reshape(-1, 14) # (N * K, 14) 150 | criterion = torch.nn.CrossEntropyLoss(reduction='mean') 151 | loss_dict['class'] = loss_config.class_weight * \ 152 | criterion(grouped_lbl_pd, grouped_lbl_gt) / k_neigh 153 | else: 154 | lbl0 = torch.argmax(lbl_out0, dim=-1) # (N) 155 | 156 | if loss_config.translate_weight > 0.0: 157 | cnts = torch_scatter.scatter_add(torch.ones_like(lbl0), lbl0, dim=0) 158 | idx = torch.where(cnts[lbl0] > 1)[0] 159 | lbl0 = lbl0[idx] 160 | pc0_mean = torch_scatter.scatter_mean(pc0[idx], lbl0, dim=0) 161 | pc1_mean = torch_scatter.scatter_mean(pc0_warpped[idx], lbl0, dim=0) 162 | pts0 = pc0[idx] - pc0_mean[lbl0] 163 | pts1 = pc0_warpped[idx] - pc1_mean[lbl0] 164 | 165 | n_points = pts0.shape[0] 166 | n_class = 14 167 | pts0_scatter = torch.zeros(n_class, n_points, 3).cuda().scatter_(0, lbl0[None, :, None].expand(-1, -1, 3), pts0.unsqueeze(0)) 168 | pts1_scatter = torch.zeros(n_class, n_points, 3).cuda().scatter_(0, lbl0[None, :, None].expand(-1, -1, 3), pts1.unsqueeze(0)) 169 | H = torch.bmm(pts0_scatter.transpose(1, 2), pts1_scatter) 170 | U, _, Vt = torch.linalg.svd(H) 171 | rot_pd_T = torch.bmm(U, Vt) # (n_class, 3, 3) 172 | error = pts1 - torch.bmm(pts0.unsqueeze(1), rot_pd_T[lbl0]).squeeze(1) 173 | tl_loss_sum = torch.sum(error ** 2) 174 | loss_dict['translate'] = loss_config.translate_weight * tl_loss_sum / n_points 175 | 176 | return loss_dict 177 | 178 | @staticmethod 179 | def compute_metric(cls0, cls1, fpd01, fpd10, 180 | labels0=None, labels1=None, 181 | fgt01=None, fgt10=None): 182 | metric = PairwiseMetric(compute_epe3d=True, compute_acc3d_outlier=True) 183 | res = defaultdict(list) 184 | if labels0 is not None: 185 | _, pred0 = torch.max(cls0, 1) 186 | correct = (pred0 == labels0).sum().item() 187 | res['accuracy'].append(correct / pred0.shape[0]) 188 | if labels1 is not None: 189 | _, pred1 = torch.max(cls1, 1) 190 | correct = (pred1 == labels1).sum().item() 191 | res['accuracy'].append(correct / pred1.shape[0]) 192 | if fgt01 is not None: 193 | m = metric.evaluate(fgt01, fpd01) 194 | for k, v in m.items(): 195 | res[k].append(v.item()) 196 | if fgt10 is not None: 197 | m = metric.evaluate(fgt10, fpd10) 198 | for k, v in m.items(): 199 | res[k].append(v.item()) 200 | for k, v in res.items(): 201 | res[k] = np.mean(v) 202 | return res 203 | 204 | def step(self, batch, mode): 205 | assert mode in ['train', 'valid', 'test'] 206 | losses, metrics, pd = self(batch) 207 | loss_sum = 0.0 208 | for loss_dict in losses: 209 | for name, val in loss_dict.items(): 210 | self.log(f'{mode}/{name}-loss', val) 211 | loss_sum += val 212 | for metric_dict in metrics: 213 | for name, val in metric_dict.items(): 214 | self.log(f'{mode}/{name}', val) 215 | total_loss = loss_sum / len(losses) 216 | self.log(f'{mode}/total_loss', total_loss) 217 | return total_loss, pd --------------------------------------------------------------------------------