├── LICENSE ├── README.md ├── config └── IntrA │ └── IntrA_pointtransformer_seg_repro.yaml ├── dataset ├── IntrADataset.py ├── __init__.py └── data_utils.py ├── models ├── __init__.py ├── base.py ├── point_transformer_cls.py └── point_transformer_seg.py ├── point_transformer_lib ├── MANIFEST.in ├── point_transformer_ops.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── point_transformer_ops │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── point_transformer_modules.cpython-37.pyc │ │ └── point_transformer_utils.cpython-37.pyc │ ├── _ext-src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── point_transformer_modules.py │ └── point_transformer_utils.py └── setup.py ├── tool ├── ept.sh ├── ept_1024.sh ├── ept_2048.sh ├── test.py └── train.py └── utils ├── config.py ├── logger.py ├── timer.py └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yifan LIU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge-oriented Point-cloud Transformer for 3D Intracranial Aneurysm Segmentation 2 | by [Yifan Liu](https://github.com/yifliu3) 3 | 4 | 5 | ## 1.Introduction 6 | This repository is for our MICCAI 2022 paper "Edge-oriented Point cloud Transformer for 3D Intracranial Aneurysm Segmentation" 7 | 8 | ## 2.Data Preparation 9 | Download `fileSplit`, `geo.zip` and `IntrA.zip` from [IntrA repository](https://github.com/intra3d2019/IntrA) 10 | 11 | Unzip `geo.zip` and `IntrA.zip` into `geo` and `IntrA` foler 12 | 13 | Move the unzipped `geo` folder into `IntrA/annoated/geo` 14 | 15 | Move the `fileSplit` into `IntrA/split` 16 | 17 | Create one foler data in the code respository and add one symbolic link 18 | 19 | `mkdir data && ln -s Yourpath/IntrA data/IntrA` 20 | 21 | ## 3. Installation 22 | ### Requirements 23 | - python 3.7 24 | - pytorch 1.7 25 | - h5py 26 | - pyyaml 27 | - tensorboardx 28 | 29 | ### Step-by-step installation 30 | ```bash 31 | # create python environment 32 | conda create -n ept python=3.7 33 | conda activate ept 34 | 35 | # install dependencies 36 | conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch 37 | conda install -c anaconda h5py pyyaml -y 38 | pip install tensorboardx 39 | 40 | # clone this repository in your own workspace 41 | git clone https://github.com/CityU-AIM-Group/EPT.git 42 | cd EPT 43 | mkdir data && ln -s Yourpath/IntrA data/IntrA 44 | 45 | # compile cuda operations 46 | cd point_transformer_lib 47 | python3 setup.py build_exit install 48 | 49 | ``` 50 | 51 | ## 4. Train/test the Model 52 | To separately train and test you can use the commands below (take 512 sampling as an example): 53 | Train: 54 | `python -m tool.train --config config/IntrA/IntrA_pointtransformer_seg_repro sample_points 512` 55 | Test: 56 | `python -m tool.test --config config/IntrA/IntrA_pointtransformer_seg_repro sample_points 512` 57 | 58 | 59 | Or you can use the bash scipt to run train.py and test.py sequentially: 60 | `sh tool/ept.sh IntrA pointtransformer_seg_repro` 61 | 62 | The trained models are provided in [Google Drive](https://drive.google.com/drive/folders/1wThn1dBmQk36-suSJOq5T8UJq3GPQ6QF?usp=sharing) 63 | 64 | ## 5. Citation 65 | If you find this work useful for your research, please cite our paper: 66 | ``` 67 | @inproceedings{liu2022, 68 | title={Edge-oriented Point-cloud Transformer for 3D Intracranial Aneurysm Segmentation}, 69 | author={Yifan Liu, Jie Liu and Yixuan Yuan}, 70 | booktitle= {MICCAI}, 71 | year={2022} 72 | } 73 | ``` 74 | 75 | ## 6. Acknowledgement 76 | This work is based on [point-transformer](https://github.com/POSTECH-CVLab/point-transformer). 77 | 78 | -------------------------------------------------------------------------------- /config/IntrA/IntrA_pointtransformer_seg_repro.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: IntrA 3 | data_root: data/IntrA/ 4 | sample_points: 512 5 | use_uniform_sample: True 6 | use_normals: True 7 | classes: 2 8 | fea_dim: 6 9 | loop: 1 10 | 11 | MODEL: 12 | npoints: 16 13 | downsampling_ratio: 0.5 14 | 15 | TRAIN: 16 | arch: IntrA_pointtransformer_seg_repro 17 | use_xyz: True 18 | sync_bn: False 19 | folds: [0, 1, 2, 3, 4] 20 | train_gpu: [0] 21 | num_workers: 8 # data loader workers 22 | batch_size_train: 8 # batch size for training 23 | batch_size_val: 8 # batch size for validation during training, memory and speed tradeoff 24 | base_lr: 0.001 25 | epochs: 400 26 | start_epoch: 0 27 | momentum: 0.9 28 | weight_decay: 0.0 29 | drop_rate: 0.5 30 | manual_seed: 666 31 | print_freq: 1 32 | save_freq: 1 33 | save_path: 34 | weight: # path to initial weight (default: none) 35 | weight_edge: 1 36 | weight_contra: 0.2 37 | weight_refine: 1 38 | resume: # path to latest checkpoint (default: none) 39 | evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend 40 | eval_freq: 1 41 | num_votes: 3 42 | num_edge_neighbor: 4 43 | temp: 1.0 44 | n_neighbors: 8 45 | 46 | TEST: 47 | test_points: 512 48 | test_times: 10 -------------------------------------------------------------------------------- /dataset/IntrADataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import warnings 4 | import torch 5 | import copy 6 | from pathlib import Path 7 | 8 | 9 | from torch.utils.data import Dataset 10 | from point_transformer_lib.point_transformer_ops.point_transformer_utils import FPS 11 | 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | 16 | def pc_normalize(pc): 17 | centroid = np.mean(pc, axis=0) 18 | pc = pc - centroid 19 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 20 | pc = pc / m 21 | return pc 22 | 23 | 24 | def load_txtfile(path): 25 | file_list = [] 26 | with open(path, 'r') as f: 27 | for line in f.readlines(): 28 | line = line.strip('\n') 29 | file_list.append(line) 30 | return file_list 31 | 32 | 33 | def load_adfile(path): 34 | points = [] 35 | normals = [] 36 | labels = [] 37 | with open(path, 'r') as f: 38 | for line in f.readlines(): 39 | s_line = line.split() 40 | points.append([float(s_line[0]), float(s_line[1]), float(s_line[2])]) 41 | normals.append([float(s_line[3]), float(s_line[4]), float(s_line[5])]) 42 | labels.append(int(s_line[6])) 43 | return points, normals, labels 44 | 45 | 46 | def load_gdfile(path): 47 | matrix = [] 48 | with open(path, 'r') as f: 49 | for line in f.readlines(): 50 | s_line = line.split() 51 | numbers_float = list(map(float, s_line)) 52 | matrix.append(numbers_float) 53 | return np.array(matrix).astype(np.float16) 54 | 55 | 56 | class IntrADataset(Dataset): 57 | def __init__(self, data_root, num_points, use_uniform_sample, use_normals, test_fold, \ 58 | num_edge_neighbor, mode='train', transform=None): 59 | self.data_root = Path(data_root) 60 | self.npoints = num_points 61 | self.uniform = use_uniform_sample 62 | self.use_normals = use_normals 63 | self.test_fold = test_fold 64 | self.split_path = self.data_root / "split/seg/" 65 | self.mode = mode 66 | self.num_edge_neighbor = num_edge_neighbor 67 | self.transform =transform 68 | self.data_path = [] 69 | self.geo_data_path = [] 70 | self.whole_data_path = [] 71 | self.preloaded_gmatrix = [] 72 | 73 | for i in range(5): 74 | file_path = "annSplit_%d.txt" % i 75 | geo_file_path = "geoSplit_%d.txt" % i 76 | self.whole_data_path.extend(load_txtfile(self.split_path/file_path)) 77 | if mode == 'train' and i != test_fold: 78 | self.data_path.extend(load_txtfile(self.split_path/file_path)) 79 | self.geo_data_path.extend(load_txtfile(self.split_path/geo_file_path)) 80 | if mode == 'test' and i == test_fold: 81 | self.data_path.extend(load_txtfile(self.split_path/file_path)) 82 | self.geo_data_path.extend(load_txtfile(self.split_path/geo_file_path)) 83 | else: 84 | continue 85 | 86 | segweights = np.zeros(2) 87 | label_list = [] 88 | for path in self.whole_data_path: 89 | _, _, labels = load_adfile(self.data_root/path) 90 | label_list.extend(labels) 91 | label_list[label_list==2] == 1 92 | tmp, _ = np.histogram(label_list, range(3)) 93 | segweights += tmp 94 | segweights = segweights.astype(np.float32) 95 | segweights = segweights / np.sum(segweights) 96 | self.segweights = torch.from_numpy(np.power(np.amax(segweights)/segweights, 1 / 3.0)) 97 | 98 | # preload geodesic matrix 99 | for path in self.geo_data_path: 100 | gmatrix = load_gdfile(self.data_root/path) 101 | self.preloaded_gmatrix.append(gmatrix) 102 | 103 | 104 | def __len__(self): 105 | return len(self.data_path) 106 | 107 | def __getitem__(self, index): 108 | points, normals, labels = load_adfile(self.data_root/self.data_path[index]) 109 | gmatrix = self.preloaded_gmatrix[index].astype(np.float32) 110 | 111 | num_avail_points = len(points) 112 | points, normals, labels = np.array(points), np.array(normals), np.array(labels) 113 | labels[labels==2] = 1 114 | point_idxs = range(num_avail_points) 115 | npoints = self.npoints 116 | 117 | if num_avail_points >= npoints: 118 | if self.uniform: 119 | points_cuda = torch.from_numpy(points).float().unsqueeze(0) 120 | selected_points_idxs = FPS(points_cuda, npoints).squeeze().numpy().astype(np.int64) 121 | else: 122 | if self.uniform: 123 | points_cuda = torch.from_numpy(points).float().unsqueeze(0) 124 | scale = npoints // num_avail_points 125 | extra = npoints % num_avail_points 126 | extra_idxs = FPS(points_cuda, extra).squeeze().numpy().astype(np.int64) 127 | selected_points_idxs = np.concatenate((np.array(list(point_idxs)*scale).astype(np.int64), extra_idxs)) 128 | 129 | selected_points = points[selected_points_idxs] 130 | selected_normals = normals[selected_points_idxs] 131 | selected_gmatrix = gmatrix[selected_points_idxs, :][:, selected_points_idxs] 132 | selected_labels = labels[selected_points_idxs] 133 | 134 | selected_points = pc_normalize(selected_points) 135 | if self.use_normals: 136 | selected_points = np.concatenate((selected_points, selected_normals), axis=1) 137 | if self.transform is not None: 138 | selected_points = self.transform(selected_points).float() 139 | else: 140 | selected_points = torch.from_numpy(selected_points).float() 141 | selected_labels = torch.from_numpy(selected_labels).long() 142 | selected_gmatrix = torch.from_numpy(selected_gmatrix).float() 143 | 144 | selected_edge_labels, edgeweights = self.get_edge_label(selected_points_idxs, selected_labels, selected_gmatrix, self.num_edge_neighbor) 145 | 146 | return selected_points, selected_labels, selected_edge_labels, edgeweights, selected_gmatrix, selected_points_idxs 147 | 148 | 149 | def get_edge_label(self, idxs, labels, gmatrix, k): 150 | _, indices, reverse_indices = np.unique(idxs, return_index=True, return_inverse=True) 151 | unique_labels = labels[indices] 152 | unique_gmatrix = gmatrix[indices, :][:, indices] 153 | edge_labels = torch.zeros(unique_labels.shape[0]) 154 | idxs_neighbor = unique_gmatrix.argsort(dim=-1)[:, :k] # (N, K) 155 | gts_neighbor = torch.gather(unique_labels[None, :].repeat(idxs_neighbor.shape[0], 1), 1, idxs_neighbor) # (N, K) 156 | gts_neighbor_sum = gts_neighbor.sum(dim=-1) 157 | edge_mask = torch.logical_and(gts_neighbor_sum!=0, gts_neighbor_sum!=k) 158 | edge_labels[edge_mask] = 1 159 | edge_labels = edge_labels[reverse_indices] 160 | edgeweights = torch.histc(edge_labels, bins=2, min=0, max=1) 161 | edgeweights = edgeweights / torch.sum(edgeweights) 162 | edgeweights = (edgeweights.max() / edgeweights) ** (1/3) 163 | edge_labels = edge_labels.long() 164 | return edge_labels, edgeweights 165 | 166 | 167 | if __name__ == '__main__': 168 | import torch 169 | np.random.seed(666) 170 | 171 | data = IntrADataset(data_root='/home/yifliu3/data/IntrA/', num_points=512, use_normals=True,\ 172 | test_fold=0, num_edge_neighbor=4, mode='train', use_uniform_sample=True) 173 | 174 | data_val = IntrADataset(data_root='/home/yifliu3/data/IntrA/', num_points=512, use_normals=True,\ 175 | test_fold=0, num_edge_neighbor=4, mode='test', use_uniform_sample=True) 176 | 177 | DataLoader = torch.utils.data.DataLoader(data, batch_size=8, shuffle=False) 178 | DataLoader_val = torch.utils.data.DataLoader(data_val, batch_size=8, shuffle=False) 179 | for point, label in DataLoader_val: 180 | print(torch.sum(label==2, dim=-1)) 181 | 182 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .IntrADataset import IntrADataset 2 | -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def angle_axis(angle, axis): 6 | # type: (float, np.ndarray) -> float 7 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 8 | 9 | Parameters 10 | ---------- 11 | angle : float 12 | Angle to rotate by 13 | axis: np.ndarray 14 | Axis to rotate about 15 | 16 | Returns 17 | ------- 18 | torch.Tensor 19 | 3x3 rotation matrix 20 | """ 21 | u = axis / np.linalg.norm(axis) 22 | cosval, sinval = np.cos(angle), np.sin(angle) 23 | 24 | # yapf: disable 25 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 26 | [u[2], 0.0, -u[0]], 27 | [-u[1], u[0], 0.0]]) 28 | 29 | R = torch.from_numpy( 30 | cosval * np.eye(3) 31 | + sinval * cross_prod_mat 32 | + (1.0 - cosval) * np.outer(u, u) 33 | ) 34 | # yapf: enable 35 | return R.float() 36 | 37 | 38 | class PointcloudScale(object): 39 | def __init__(self, lo=0.8, hi=1.25): 40 | self.lo, self.hi = lo, hi 41 | 42 | def __call__(self, points): 43 | scaler = np.random.uniform(self.lo, self.hi) 44 | points[:, 0:3] *= scaler 45 | return points 46 | 47 | 48 | class PointcloudRotate(object): 49 | def __init__(self, axis=np.array([0.0, 1.0, 0.0])): 50 | self.axis = axis 51 | 52 | def __call__(self, points): 53 | rotation_angle = np.random.uniform() * 2 * np.pi 54 | rotation_matrix = angle_axis(rotation_angle, self.axis) 55 | 56 | normals = points.size(1) > 3 57 | if not normals: 58 | return torch.matmul(points, rotation_matrix.t()) 59 | else: 60 | pc_xyz = points[:, 0:3] 61 | pc_normals = points[:, 3:] 62 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 63 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 64 | 65 | return points 66 | 67 | 68 | class PointcloudRotatePerturbation(object): 69 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 70 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip 71 | 72 | def _get_angles(self): 73 | angles = np.clip( 74 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip 75 | ) 76 | 77 | return angles 78 | 79 | def __call__(self, points): 80 | angles = self._get_angles() 81 | Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) 82 | Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) 83 | Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) 84 | 85 | rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx) 86 | 87 | normals = points.size(1) > 3 88 | if not normals: 89 | return torch.matmul(points, rotation_matrix.t()) 90 | else: 91 | pc_xyz = points[:, 0:3] 92 | pc_normals = points[:, 3:] 93 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 94 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 95 | 96 | return points 97 | 98 | 99 | class PointcloudJitter(object): 100 | def __init__(self, std=0.01, clip=0.05): 101 | self.std, self.clip = std, clip 102 | 103 | def __call__(self, points): 104 | jittered_data = ( 105 | points.new(points.size(0), 3) 106 | .normal_(mean=0.0, std=self.std) 107 | .clamp_(-self.clip, self.clip) 108 | ) 109 | points[:, 0:3] += jittered_data 110 | return points 111 | 112 | 113 | class PointcloudTranslate(object): 114 | def __init__(self, translate_range=0.1): 115 | self.translate_range = translate_range 116 | 117 | def __call__(self, points): 118 | translation = np.random.uniform(-self.translate_range, self.translate_range) 119 | points[:, 0:3] += translation 120 | return points 121 | 122 | 123 | class PointcloudToTensor(object): 124 | def __call__(self, points): 125 | return torch.from_numpy(points).float() 126 | 127 | 128 | class PointcloudRandomInputDropout(object): 129 | def __init__(self, max_dropout_ratio=0.875): 130 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 131 | self.max_dropout_ratio = max_dropout_ratio 132 | 133 | def __call__(self, points): 134 | pc = points.numpy() 135 | 136 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 137 | drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0] 138 | if len(drop_idx) > 0: 139 | pc[drop_idx] = pc[0] # set to the first point 140 | 141 | return torch.from_numpy(pc).float() 142 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_transformer_seg import PointTransformerSemSegmentation 2 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim.lr_scheduler as lr_sched 7 | from torch.utils.data import DataLoader, DistributedSampler 8 | from torchvision import transforms 9 | 10 | import point_transformer.data.data_utils as d_utils 11 | from point_transformer.data.ModelNet40Loader import ModelNet40Cls 12 | 13 | 14 | def set_bn_momentum_default(bn_momentum): 15 | def fn(m): 16 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 17 | m.momentum = bn_momentum 18 | 19 | return fn 20 | 21 | 22 | class BNMomentumScheduler(lr_sched.LambdaLR): 23 | def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default): 24 | if not isinstance(model, nn.Module): 25 | raise RuntimeError( 26 | "Class '{}' is not a PyTorch nn Module".format(type(model)._name_) 27 | ) 28 | 29 | self.model = model 30 | self.setter = setter 31 | self.lmbd = bn_lambda 32 | 33 | self.step(last_epoch + 1) 34 | self.last_epoch = last_epoch 35 | 36 | def step(self, epoch=None): 37 | if epoch is None: 38 | epoch = self.last_epoch + 1 39 | 40 | self.last_epoch = epoch 41 | self.model.apply(self.setter(self.lmbd(epoch))) 42 | 43 | def state_dict(self): 44 | return dict(last_epoch=self.last_epoch) 45 | 46 | def load_state_dict(self, state): 47 | self.last_epoch = state["last_epoch"] 48 | self.step(self.last_epoch) 49 | 50 | 51 | lr_clip = 1e-5 52 | bnm_clip = 1e-2 53 | 54 | class BaseClassification(pl.LightningModule): 55 | def __init__(self, hparams): 56 | super().__init__() 57 | 58 | self.hparams = hparams 59 | self._build_model() 60 | 61 | def _build_model(self): 62 | raise NotImplementedError 63 | 64 | def _break_up_pc(self, pc): 65 | xyz = pc[..., 0:3].contiguous() 66 | features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None 67 | 68 | return xyz, features 69 | 70 | def forward(self, pointcloud): 71 | raise NotImplementedError 72 | 73 | def training_step(self, batch, batch_idx): 74 | pc, labels = batch 75 | logits = self.forward(pc) 76 | loss = F.cross_entropy(logits, labels) 77 | with torch.no_grad(): 78 | acc = (torch.argmax(logits, dim=1) == labels).float().mean() 79 | log = dict(train_loss=loss, train_acc=acc) 80 | return dict(loss=loss, log=log, progress_bar=dict(train_acc=acc)) 81 | 82 | def validation_step(self, batch, batch_idx): 83 | pc, labels = batch 84 | logits = self.forward(pc) 85 | loss = F.cross_entropy(logits, labels) 86 | acc = (torch.argmax(logits, dim=1) == labels).float().mean() 87 | return dict(val_loss=loss, val_acc=acc) 88 | 89 | def validation_epoch_end(self, outputs): 90 | reduced_outputs = {} 91 | for k in outputs[0]: 92 | for o in outputs: 93 | reduced_outputs[k] = reduced_outputs.get(k, []) + [o[k]] 94 | for k in reduced_outputs: 95 | reduced_outputs[k] = torch.stack(reduced_outputs[k]).mean() 96 | reduced_outputs.update(dict(log=reduced_outputs.copy(), progress_bar=reduced_outputs.copy())) 97 | return reduced_outputs 98 | 99 | def configure_optimizers(self): 100 | lr_lbmd = lambda _: max( 101 | self.hparams["optimizer.lr_decay"] 102 | ** ( 103 | int( 104 | self.global_step 105 | * self.hparams["batch_size"] 106 | / self.hparams["optimizer.decay_step"] 107 | ) 108 | ), 109 | lr_clip / self.hparams["optimizer.lr"], 110 | ) 111 | bn_lbmd = lambda _: max( 112 | self.hparams["optimizer.bn_momentum"] 113 | * self.hparams["optimizer.bnm_decay"] 114 | ** ( 115 | int( 116 | self.global_step 117 | * self.hparams["batch_size"] 118 | / self.hparams["optimizer.decay_step"] 119 | ) 120 | ), 121 | bnm_clip, 122 | ) 123 | 124 | optimizer = torch.optim.Adam( 125 | self.parameters(), 126 | lr=self.hparams["optimizer.lr"], 127 | weight_decay=self.hparams["optimizer.weight_decay"], 128 | ) 129 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd) 130 | bnm_scheduler = BNMomentumScheduler(self, bn_lambda=bn_lbmd) 131 | 132 | return [optimizer], [lr_scheduler, bnm_scheduler] 133 | 134 | def prepare_data(self): 135 | train_transforms = transforms.Compose( 136 | [ 137 | d_utils.PointcloudToTensor(), 138 | d_utils.PointcloudScale(), 139 | d_utils.PointcloudRotate(), 140 | d_utils.PointcloudRotatePerturbation(), 141 | d_utils.PointcloudTranslate(), 142 | d_utils.PointcloudJitter(), 143 | d_utils.PointcloudRandomInputDropout(), 144 | ] 145 | ) 146 | 147 | self.train_dset = ModelNet40Cls( 148 | self.hparams["num_points"], transforms=train_transforms, train=True 149 | ) 150 | self.val_dset = ModelNet40Cls( 151 | self.hparams["num_points"], transforms=None, train=False 152 | ) 153 | 154 | def _build_dataloader(self, dset, mode): 155 | return DataLoader( 156 | dset, 157 | batch_size=self.hparams["batch_size"], 158 | shuffle=mode == "train", 159 | num_workers=4, 160 | pin_memory=True, 161 | drop_last=mode == "train", 162 | ) 163 | 164 | def train_dataloader(self): 165 | return self._build_dataloader(self.train_dset, mode="train") 166 | 167 | def val_dataloader(self): 168 | return self._build_dataloader(self.val_dset, mode="val") -------------------------------------------------------------------------------- /models/point_transformer_cls.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim.lr_scheduler as lr_sched 6 | from point_transformer.utils.timer import Timer 7 | 8 | from point_transformer.models.base import BaseClassification 9 | from point_transformer_ops.point_transformer_modules import PointTransformerBlock, TransitionDown 10 | 11 | 12 | class PointTransformerClassification(BaseClassification): 13 | def _build_model(self): 14 | channels, k, sampling_ratio, num_points = ( 15 | self.hparams["model.channels"], 16 | self.hparams["model.k"], 17 | self.hparams["model.sampling_ratio"], 18 | self.hparams["num_points"], 19 | ) 20 | channels = list(map(int, channels.split("."))) 21 | assert len(channels) > 3 22 | 23 | self.prev_block = nn.Sequential( 24 | nn.Linear(3, channels[0]), 25 | nn.ReLU(True), 26 | nn.Linear(channels[0], channels[0]), 27 | ) 28 | self.prev_transformer = PointTransformerBlock(channels[0], k) 29 | 30 | self.trans_downs = nn.ModuleList() 31 | self.transformers = nn.ModuleList() 32 | 33 | for i in range(1, len(channels) - 2): 34 | self.trans_downs.append( 35 | TransitionDown( 36 | in_channels=channels[i - 1], 37 | out_channels=channels[i], 38 | k=k, 39 | sampling_ratio=sampling_ratio, 40 | ) 41 | ) 42 | self.transformers.append(PointTransformerBlock(channels[i], k)) 43 | 44 | self.final_block = nn.Sequential( 45 | nn.Linear(channels[-3], channels[-2]), 46 | nn.ReLU(True), 47 | nn.Linear(channels[-2], channels[-1]), 48 | ) 49 | 50 | def _break_up_pc(self, pc): 51 | xyz = pc[..., 0:3].contiguous() 52 | features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None 53 | 54 | return xyz, features 55 | 56 | def forward(self, pointcloud): 57 | r""" 58 | Forward pass of the network 59 | 60 | Parameters 61 | ---------- 62 | pointcloud: Variable(torch.cuda.FloatTensor) 63 | (B, N, 3 + input_channels) tensor 64 | Point cloud to run predicts on 65 | Each point in the point-cloud MUST 66 | be formated as (x, y, z, features...) 67 | """ 68 | xyz, features = self._break_up_pc(pointcloud) 69 | 70 | # Timers 71 | t_prev = Timer("prev_block") 72 | t_prev.tic() 73 | features = self.prev_block(xyz) 74 | t_prev.toc() 75 | 76 | t_prev_trs = Timer("prev_transformer") 77 | t_prev_trs.tic() 78 | features = self.prev_transformer(features, xyz) 79 | t_prev_trs.toc() 80 | 81 | t_td = Timer("transition_down") 82 | t_trs = Timer("transformer") 83 | for trans_down_layer, transformer_layer in zip( 84 | self.trans_downs, self.transformers 85 | ): 86 | t_td.tic() 87 | features, xyz = trans_down_layer(features, xyz) 88 | t_td.toc() 89 | 90 | t_trs.tic() 91 | features = transformer_layer(features, xyz) 92 | t_trs.toc() 93 | 94 | t_final = Timer("final_block") 95 | t_final.tic() 96 | out = self.final_block(features.mean(1)) 97 | t_final.toc() 98 | return out 99 | 100 | def configure_optimizers(self): 101 | """ 102 | SGD: momentum=0.9, weight_decay = 0.0001 103 | Max epoch: 200 104 | Initial learning rate = 0.05 105 | Drop 10x at epoch 120, 160 106 | """ 107 | optimizer = torch.optim.SGD( 108 | self.parameters(), 109 | lr=self.hparams["optimizer.lr"], 110 | weight_decay=self.hparams["optimizer.weight_decay"], 111 | momentum=self.hparams["optimizer.momentum"], 112 | ) 113 | milestones = list(map(int, self.hparams["optimizer.milestones"].split("."))) 114 | lr_scheduler = lr_sched.MultiStepLR( 115 | optimizer, 116 | milestones=milestones, 117 | gamma=self.hparams["optimizer.gamma"], 118 | ) 119 | 120 | return [optimizer], [lr_scheduler] -------------------------------------------------------------------------------- /models/point_transformer_seg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.timer import Timer 3 | import torch.nn.functional as F 4 | from point_transformer_lib.point_transformer_ops.point_transformer_modules import PointTransformerBlock, TransitionDown, TransitionUp, BFM_torch 5 | 6 | 7 | class PointTransformerSemSegmentation(nn.Module): 8 | def __init__(self, args): 9 | super().__init__() 10 | npoints = args.npoints 11 | dim = [args.fea_dim, 32, 64, 128, 256] 12 | sampling_ratio = args.downsampling_ratio 13 | output_dim = args.classes 14 | 15 | self.Encoder = nn.ModuleList() 16 | for i in range(len(dim)-1): 17 | if i == 0: 18 | self.Encoder.append(nn.Linear(dim[i], dim[i+1], bias=False)) 19 | else: 20 | self.Encoder.append(TransitionDown(dim[i], dim[i+1], npoints, sampling_ratio, fast=True)) 21 | self.Encoder.append(PointTransformerBlock(dim[i+1], npoints)) 22 | 23 | self.SegDecoder = nn.ModuleList() 24 | for i in range(len(dim)-1,0,-1): 25 | if i == len(dim)-1: 26 | self.SegDecoder.append(nn.Linear(dim[i], dim[i], bias=False)) 27 | else: 28 | self.SegDecoder.append(TransitionUp(dim[i+1], dim[i])) 29 | self.SegDecoder.append(PointTransformerBlock(dim[i], npoints)) 30 | 31 | self.EdgeDecoder = nn.ModuleList() 32 | for i in range(len(dim)-1,0,-1): 33 | if i == len(dim)-1: 34 | self.EdgeDecoder.append(nn.Linear(dim[i], dim[i], bias=False)) 35 | else: 36 | self.EdgeDecoder.append(TransitionUp(dim[i+1], dim[i])) 37 | self.EdgeDecoder.append(PointTransformerBlock(dim[i], npoints)) 38 | 39 | self.seg_fc_layer = nn.Sequential( 40 | nn.Conv1d(dim[1], dim[1], kernel_size=1, bias=False), 41 | nn.BatchNorm1d(dim[1]), 42 | nn.ReLU(inplace=True), 43 | nn.Dropout(0.5), 44 | nn.Conv1d(dim[1], output_dim, kernel_size=1), 45 | ) 46 | 47 | self.edge_fc_layer = nn.Sequential( 48 | nn.Conv1d(dim[1], dim[1], kernel_size=1, bias=False), 49 | nn.BatchNorm1d(dim[1]), 50 | nn.ReLU(inplace=True), 51 | nn.Dropout(0.5), 52 | nn.Conv1d(dim[1], output_dim, kernel_size=1), 53 | ) 54 | 55 | self.proj_layer = nn.Sequential( 56 | nn.Conv1d(dim[1], dim[1], kernel_size=1, bias=False), 57 | nn.BatchNorm1d(dim[1]), 58 | nn.ReLU(inplace=True), 59 | nn.Conv1d(dim[1], dim[1], kernel_size=1), 60 | ) 61 | 62 | self.BFM = BFM_torch(dim[1], dim[1], args.n_neighbors) 63 | 64 | self.seg_refine_fc_layer = nn.Sequential( 65 | nn.Conv1d(dim[1], dim[1], kernel_size=1, bias=False), 66 | nn.BatchNorm1d(dim[1]), 67 | nn.ReLU(inplace=True), 68 | nn.Dropout(0.5), 69 | nn.Conv1d(dim[1], output_dim, kernel_size=1), 70 | ) 71 | 72 | def forward(self, features, gmatrix, idxs): 73 | 74 | xyz = features[...,0:3].contiguous() 75 | 76 | # Encoding period 77 | l_xyz, l_features = [xyz], [features] 78 | for i in range(int(len(self.Encoder)/2)): 79 | if i == 0: 80 | li_features = self.Encoder[2*i](l_features[i]) 81 | li_xyz = l_xyz[i] 82 | else: 83 | li_features, li_xyz = self.Encoder[2*i](l_features[i], l_xyz[i]) 84 | li_features = self.Encoder[2*i+1](li_features, li_xyz) 85 | 86 | l_features.append(li_features) 87 | l_xyz.append(li_xyz) 88 | del li_features, li_xyz 89 | 90 | e_features = [feature.clone() for feature in l_features] 91 | 92 | # Decoding period 93 | D_n = int(len(self.SegDecoder)/2) 94 | for i in range(D_n): 95 | if i == 0: 96 | l_features[D_n-i] = self.SegDecoder[2*i](l_features[D_n-i]) 97 | l_features[D_n-i] = self.SegDecoder[2*i+1](l_features[D_n-i], l_xyz[D_n-i]) 98 | e_features[D_n-i] = self.EdgeDecoder[2*i](e_features[D_n-i]) 99 | e_features[D_n-i] = self.EdgeDecoder[2*i+1](e_features[D_n-i], l_xyz[D_n-i]) 100 | else: 101 | l_features[D_n-i] = self.SegDecoder[2*i](l_features[D_n-i+1], l_xyz[D_n-i+1], l_features[D_n-i], l_xyz[D_n-i]) 102 | l_features[D_n-i] = self.SegDecoder[2*i+1](l_features[D_n-i], l_xyz[D_n-i]) 103 | e_features[D_n-i] = self.EdgeDecoder[2*i](e_features[D_n-i+1], l_xyz[D_n-i+1], e_features[D_n-i], l_xyz[D_n-i]) 104 | e_features[D_n-i] = self.EdgeDecoder[2*i+1](e_features[D_n-i], l_xyz[D_n-i]) 105 | 106 | del l_features[0], l_features[1:], e_features[0], e_features[1:], l_xyz 107 | 108 | # Final output 109 | seg_features = l_features[0].transpose(1, 2).contiguous() 110 | edge_features = e_features[0].transpose(1, 2).contiguous() 111 | seg_preds = self.seg_fc_layer(seg_features) 112 | edge_preds = self.edge_fc_layer(edge_features) 113 | 114 | seg_refine_features = self.BFM(seg_features, edge_preds, gmatrix, idxs) 115 | seg_refine_preds = self.seg_refine_fc_layer(seg_refine_features.transpose(1, 2).contiguous()) 116 | 117 | seg_embed = F.normalize(self.proj_layer(seg_features), p=2, dim=1) 118 | 119 | return seg_preds, seg_refine_preds, seg_embed, edge_preds 120 | -------------------------------------------------------------------------------- /point_transformer_lib/MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft point_transformer_ops/_ext-src 2 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: point-transformer-ops 3 | Version: 0.1.0 4 | Summary: UNKNOWN 5 | Author: Erik Wijmans 6 | License: UNKNOWN 7 | Platform: UNKNOWN 8 | 9 | UNKNOWN 10 | 11 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | MANIFEST.in 2 | setup.py 3 | point_transformer_ops/__init__.py 4 | point_transformer_ops/_version.py 5 | point_transformer_ops/point_transformer_modules.py 6 | point_transformer_ops/point_transformer_utils.py 7 | point_transformer_ops.egg-info/PKG-INFO 8 | point_transformer_ops.egg-info/SOURCES.txt 9 | point_transformer_ops.egg-info/dependency_links.txt 10 | point_transformer_ops.egg-info/requires.txt 11 | point_transformer_ops.egg-info/top_level.txt 12 | point_transformer_ops/_ext-src/include/ball_query.h 13 | point_transformer_ops/_ext-src/include/cuda_utils.h 14 | point_transformer_ops/_ext-src/include/group_points.h 15 | point_transformer_ops/_ext-src/include/interpolate.h 16 | point_transformer_ops/_ext-src/include/sampling.h 17 | point_transformer_ops/_ext-src/include/utils.h 18 | point_transformer_ops/_ext-src/src/ball_query.cpp 19 | point_transformer_ops/_ext-src/src/ball_query_gpu.cu 20 | point_transformer_ops/_ext-src/src/bindings.cpp 21 | point_transformer_ops/_ext-src/src/group_points.cpp 22 | point_transformer_ops/_ext-src/src/group_points_gpu.cu 23 | point_transformer_ops/_ext-src/src/interpolate.cpp 24 | point_transformer_ops/_ext-src/src/interpolate_gpu.cu 25 | point_transformer_ops/_ext-src/src/sampling.cpp 26 | point_transformer_ops/_ext-src/src/sampling_gpu.cu -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4 2 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | point_transformer_ops 2 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import point_transformer_ops.point_transformer_modules 2 | import point_transformer_ops.point_transformer_utils 3 | from point_transformer_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifliu3/EPT-Net/080c3040fceb02cda960f4143585c576712327fa/point_transformer_lib/point_transformer_ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/__pycache__/point_transformer_modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifliu3/EPT-Net/080c3040fceb02cda960f4143585c576712327fa/point_transformer_lib/point_transformer_ops/__pycache__/point_transformer_modules.cpython-37.pyc -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/__pycache__/point_transformer_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifliu3/EPT-Net/080c3040fceb02cda960f4143585c576712327fa/point_transformer_lib/point_transformer_ops/__pycache__/point_transformer_utils.cpython-37.pyc -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor farthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("farthest_point_sampling", &farthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void farthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor farthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | farthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void farthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void farthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | farthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | farthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | farthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | farthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | farthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | farthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | farthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | farthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | farthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | farthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | farthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/point_transformer_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import einsum 7 | import numpy as np 8 | import point_transformer_ops.point_transformer_utils as pt_utils 9 | 10 | 11 | 12 | class PointTransformerBlock(nn.Module): 13 | def __init__(self, dim, k): 14 | super().__init__() 15 | self.prev_linear = nn.Linear(dim, dim) 16 | self.k = k 17 | self.to_q = nn.Linear(dim, dim, bias=False) 18 | self.to_k = nn.Linear(dim, dim, bias=False) 19 | self.to_v = nn.Linear(dim, dim, bias=False) 20 | 21 | self.pos_mlp = nn.Sequential( 22 | nn.Conv2d(3, dim, kernel_size=1, bias=False), 23 | nn.BatchNorm2d(dim), 24 | nn.ReLU(True), 25 | nn.Conv2d(dim, dim, kernel_size=1, bias=False) 26 | ) 27 | 28 | self.attn_mlp = nn.Sequential( 29 | nn.Conv2d(dim, dim, kernel_size=1, bias=False), 30 | nn.BatchNorm2d(dim), 31 | nn.ReLU(True), 32 | nn.Conv2d(dim, dim, kernel_size=1, bias=False) 33 | ) 34 | 35 | self.final_linear = nn.Linear(dim, dim) 36 | 37 | def forward(self, x, pos): 38 | # queries, keys, values 39 | 40 | x_pre = x 41 | knn_idx = pt_utils.kNN_torch(pos, pos, self.k) 42 | knn_xyz = pt_utils.index_points(pos, knn_idx) 43 | 44 | q = self.to_q(x) 45 | k = pt_utils.index_points(self.to_k(x), knn_idx) 46 | v = pt_utils.index_points(self.to_v(x), knn_idx) 47 | 48 | pos_enc = (pos[:, :, None] - knn_xyz).permute(0, 3, 1, 2).contiguous() 49 | for i, layer in enumerate(self.pos_mlp): pos_enc = layer(pos_enc) 50 | pos_enc = pos_enc.permute(0, 2, 3, 1).contiguous() 51 | 52 | attn = (q[:, :, None] - k + pos_enc).permute(0, 3, 1, 2).contiguous() 53 | for i, layer in enumerate(self.attn_mlp): attn = layer(attn) 54 | attn = attn.permute(0, 2, 3, 1).contiguous() 55 | 56 | attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) 57 | 58 | agg = einsum('b i j d, b i j d -> b i d', attn, v+pos_enc) 59 | agg = self.final_linear(agg) + x_pre 60 | 61 | return agg 62 | 63 | 64 | class TransitionDown(nn.Module): 65 | def __init__(self, in_channels, out_channels, k, sampling_ratio, fast=True): 66 | super().__init__() 67 | 68 | self.in_channels = in_channels 69 | self.out_channels = out_channels 70 | self.k = k 71 | self.sampling_ratio = sampling_ratio 72 | self.fast = fast 73 | self.mlp = nn.Sequential( 74 | nn.Conv1d(self.in_channels, self.out_channels, kernel_size=1, bias=False), 75 | nn.BatchNorm1d(self.out_channels), 76 | nn.ReLU(True), 77 | ) 78 | 79 | def forward(self, x, p1): 80 | """ 81 | inputs 82 | x: (B, N, in_channels) shaped torch Tensor (A set of feature vectors) 83 | p1: (B, N, 3) shaped torch Tensor (3D coordinates) 84 | outputs 85 | y: (B, M, out_channels) shaped torch Tensor 86 | p2: (B, M, 3) shaped torch Tensor 87 | M = N * sampling ratio 88 | """ 89 | B, N, _ = x.shape 90 | M = int(N * self.sampling_ratio) 91 | 92 | # 1: Farthest Point Sampling 93 | p1_flipped = p1.transpose(1, 2).contiguous() 94 | p2 = ( 95 | pt_utils.gather_operation( 96 | p1_flipped, pt_utils.farthest_point_sample(p1, M) 97 | ) 98 | .transpose(1, 2) 99 | .contiguous() 100 | ) # p2: (B, M, 3) 101 | 102 | # 2: kNN & MLP 103 | knn_fn = pt_utils.kNN_torch if self.fast else pt_utils.kNN 104 | neighbors = knn_fn(p2, p1, self.k) # neighbors: (B, M, k) 105 | 106 | # 2-1: Apply MLP onto each feature 107 | x_flipped = x.transpose(1, 2).contiguous() 108 | mlp_x = self.mlp(x_flipped).transpose(1, 2).contiguous() # mlp_x: (B, N, out_channels) 109 | 110 | # 2-2: Extract features based on neighbors 111 | features = pt_utils.index_points(mlp_x, neighbors) # features: (B, M, k, out_channels) 112 | 113 | # 3: Local Max Pooling 114 | y = torch.max(features, dim=2)[0] # y: (B, M, out_channels) 115 | 116 | return y, p2 117 | 118 | 119 | class TransitionUp(nn.Module): 120 | def __init__(self, in_channels, out_channels): 121 | super().__init__() 122 | 123 | self.up_mlp = nn.Sequential( 124 | nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False), 125 | nn.BatchNorm1d(out_channels), 126 | nn.ReLU(True) 127 | ) 128 | self.lateral_mlp = nn.Sequential( 129 | nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False), 130 | nn.BatchNorm1d(out_channels), 131 | nn.ReLU(True) 132 | ) 133 | 134 | def forward(self, x1, p1, x2, p2): 135 | """ 136 | x1: (B, N, in_channels) torch.Tensor 137 | p1: (B, N, 3) torch.Tensor 138 | x2: (B, M, out_channels) torch.Tensor 139 | p2: (B, M, 3) torch.Tensor 140 | Note that N is smaller than M because this module upsamples features. 141 | """ 142 | x1 = self.up_mlp(x1.transpose(1, 2).contiguous()) 143 | dist, idx= pt_utils.three_nn(p2, p1) 144 | dist_recip = 1.0 / (dist + 1e-8) 145 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 146 | weight = dist_recip / norm 147 | interpolated_feats = pt_utils.three_interpolate( 148 | x1, idx, weight 149 | ) 150 | x2 = self.lateral_mlp(x2.transpose(1, 2).contiguous()) 151 | y = interpolated_feats + x2 152 | return y.transpose(1, 2).contiguous() 153 | 154 | 155 | class BFM_torch(nn.Module): 156 | def __init__(self, in_channels, out_channels, n_neighbors): 157 | super().__init__() 158 | self.n_neighbors = n_neighbors 159 | self.GCN_1 = GCN(in_channels, out_channels) 160 | self.GCN_2 = GCN(out_channels, out_channels) 161 | 162 | 163 | def forward(self, seg_features, edge_preds, gmatrix, idxs): 164 | # import pdb; pdb.set_trace() 165 | B = idxs.shape[0] 166 | seg_features = seg_features.transpose(1, 2).contiguous() 167 | refined_features_list = [] 168 | 169 | # construct separated topology graph 170 | for i in range(B): 171 | edge_preds_this = edge_preds[i, ...].argmax(0) # N 172 | gmatrix_this = gmatrix[i, ...] # N, N 173 | idxs_this = idxs[i, ...] 174 | features_this = seg_features[i, ...] 175 | 176 | # get unique attributes 177 | unique_idxs, indices, inverse_indices = np.unique(idxs_this, return_index=True, return_inverse=True) 178 | gmatrix_unique = gmatrix_this[indices, :][:, indices] 179 | features_unique = features_this[indices, :] 180 | edge_preds_unique = edge_preds_this[indices] 181 | 182 | # find neighbors based on the geodesic distance 183 | adjacency_matrix = torch.zeros(gmatrix_unique.shape).cuda() 184 | neighbor_idxs_matrix = torch.argsort(gmatrix_unique, axis=-1)[:, 0:self.n_neighbors] 185 | seq_idxs_matrix = torch.arange(neighbor_idxs_matrix.shape[0])[:, None] 186 | adjacency_matrix[seq_idxs_matrix, neighbor_idxs_matrix] = 1 187 | adjacency_matrix[neighbor_idxs_matrix, seq_idxs_matrix] = 1 188 | 189 | # cutoff connections and check 190 | # edge points are disconnected with other points 191 | edge_index = torch.nonzero(edge_preds_unique==1, as_tuple=True)[0] 192 | non_edge_index = torch.nonzero(edge_preds_unique==0, as_tuple=True)[0] 193 | adjacency_matrix[edge_index, :] = torch.zeros(edge_preds_unique.shape).cuda() 194 | # nonedge points are disconnected with edge points and other rules 195 | if non_edge_index.shape[0] == 0: 196 | adjacency_matrix = torch.diag(torch.ones(adjacency_matrix.shape[0])).cuda() 197 | else: 198 | nonedge_adj = adjacency_matrix[non_edge_index, :] 199 | nonedge_geo = gmatrix_unique[non_edge_index, :] 200 | edge_mask = torch.zeros(nonedge_adj.shape[-1]).cuda() 201 | edge_mask[edge_index] = 1 202 | edge_mask = edge_mask[None, :].repeat(nonedge_adj.shape[0], 1) 203 | tmp_fill = (torch.ones_like(edge_mask)*1000).cuda() 204 | maxgeolimit, _ = torch.where(torch.logical_and(nonedge_adj==1, edge_mask==1), nonedge_geo, tmp_fill).min(dim=-1) 205 | zero_mask = torch.zeros_like(nonedge_adj).cuda() 206 | nonedge_adj = torch.where(nonedge_geo > maxgeolimit[:, None], zero_mask, nonedge_adj) 207 | nonedge_adj[:, edge_index] = 0 208 | adjacency_matrix[non_edge_index, :] = nonedge_adj 209 | adjacency_matrix_trans = adjacency_matrix.transpose(0, 1) 210 | adjacency_matrix = torch.logical_or(adjacency_matrix, adjacency_matrix_trans) 211 | adjacency_matrix = adjacency_matrix.type(torch.cuda.FloatTensor) 212 | 213 | # GCN layer 214 | refined_features_unique = self.GCN_1(features_unique, adjacency_matrix) 215 | refined_features_unique = self.GCN_2(refined_features_unique, adjacency_matrix) 216 | refined_features = refined_features_unique[inverse_indices, :] 217 | refined_features_list.append(refined_features) 218 | 219 | refined_features = torch.stack(refined_features_list, dim=0) 220 | refined_features = refined_features + seg_features 221 | 222 | return refined_features 223 | 224 | 225 | class GCN(nn.Module): 226 | def __init__(self, in_channels, out_channels, dropout=0.): 227 | super().__init__() 228 | self.trans_msg = nn.Linear(in_channels, out_channels) 229 | self.nonlinear = nn.ReLU() 230 | 231 | def forward(self, features, adj_matrix): 232 | N, C = features.shape 233 | identity = features 234 | features = self.nonlinear(self.trans_msg(features)) 235 | row_degree = torch.sum(adj_matrix, dim=-1, keepdim=True) # (N, 1) 236 | col_degree = torch.sum(adj_matrix, dim=-2, keepdim=True) # (1, N) 237 | degree = torch.mm(torch.sqrt(row_degree), torch.sqrt(col_degree)) # (N, N) 238 | if degree[degree==0].shape[0] != 0: 239 | return identity 240 | else: 241 | refined_features = torch.sparse.mm(adj_matrix / degree, features) # (B, N, C) 242 | refined_features = refined_features + identity 243 | return refined_features -------------------------------------------------------------------------------- /point_transformer_lib/point_transformer_ops/point_transformer_utils.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | import open3d as o3d 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.dlpack 6 | import warnings 7 | from torch.autograd import Function 8 | from typing import * 9 | 10 | try: 11 | import point_transformer_ops._ext as _ext 12 | except ImportError: 13 | from torch.utils.cpp_extension import load 14 | import glob 15 | import os.path as osp 16 | import os 17 | 18 | warnings.warn("Unable to load point_transformer_ops cpp extension. JIT Compiling.") 19 | 20 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 21 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 22 | osp.join(_ext_src_root, "src", "*.cu") 23 | ) 24 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 25 | 26 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 27 | _ext = load( 28 | "_ext", 29 | sources=_ext_sources, 30 | extra_include_paths=[osp.join(_ext_src_root, "include")], 31 | extra_cflags=["-O3"], 32 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 33 | with_cuda=True, 34 | ) 35 | 36 | 37 | def FPS(xyz, npoint): 38 | """ 39 | Input: 40 | xyz: pointcloud data, [B, N, 3] 41 | npoint: number of samples 42 | Return: 43 | centroids: sampled pointcloud index, [B, npoint] 44 | """ 45 | device = xyz.device 46 | B, N, C = xyz.shape 47 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 48 | distance = torch.ones(B, N).to(device) * 1e10 49 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 50 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 51 | for i in range(npoint): 52 | centroids[:, i] = farthest 53 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 54 | dist = torch.sum((xyz - centroid) ** 2, -1) 55 | mask = dist < distance 56 | distance[mask] = dist[mask] 57 | farthest = torch.max(distance, -1)[1] 58 | return centroids 59 | 60 | 61 | class FarthestPointSampling(Function): 62 | @staticmethod 63 | def forward(ctx, xyz, npoint): 64 | # type: (Any, torch.Tensor, int) -> torch.Tensor 65 | r""" 66 | Uses iterative farthest point sampling to select a set of npoint features that have the largest 67 | minimum distance 68 | 69 | Parameters 70 | ---------- 71 | xyz : torch.Tensor 72 | (B, N, 3) tensor where N > npoint 73 | npoint : int32 74 | number of features in the sampled set 75 | 76 | Returns 77 | ------- 78 | torch.Tensor 79 | (B, npoint) tensor containing the set 80 | """ 81 | out = _ext.farthest_point_sampling(xyz, npoint) 82 | 83 | ctx.mark_non_differentiable(out) 84 | 85 | return out 86 | 87 | @staticmethod 88 | def backward(ctx, grad_out): 89 | return () 90 | 91 | 92 | farthest_point_sample = FarthestPointSampling.apply 93 | 94 | 95 | class GatherOperation(Function): 96 | @staticmethod 97 | def forward(ctx, features, idx): 98 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 99 | r""" 100 | 101 | Parameters 102 | ---------- 103 | features : torch.Tensor 104 | (B, C, N) tensor 105 | 106 | idx : torch.Tensor 107 | (B, npoint) tensor of the features to gather 108 | 109 | Returns 110 | ------- 111 | torch.Tensor 112 | (B, C, npoint) tensor 113 | """ 114 | 115 | ctx.save_for_backward(idx, features) 116 | 117 | return _ext.gather_points(features, idx) 118 | 119 | @staticmethod 120 | def backward(ctx, grad_out): 121 | idx, features = ctx.saved_tensors 122 | N = features.size(2) 123 | 124 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 125 | return grad_features, None 126 | 127 | 128 | gather_operation = GatherOperation.apply 129 | 130 | 131 | class ThreeNN(Function): 132 | @staticmethod 133 | def forward(ctx, unknown, known): 134 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 135 | r""" 136 | Find the three nearest neighbors of unknown in known 137 | Parameters 138 | ---------- 139 | unknown : torch.Tensor 140 | (B, n, 3) tensor of known features 141 | known : torch.Tensor 142 | (B, m, 3) tensor of unknown features 143 | 144 | Returns 145 | ------- 146 | dist : torch.Tensor 147 | (B, n, 3) l2 distance to the three nearest neighbors 148 | idx : torch.Tensor 149 | (B, n, 3) index of 3 nearest neighbors 150 | """ 151 | dist2, idx = _ext.three_nn(unknown, known) 152 | dist = torch.sqrt(dist2) 153 | 154 | ctx.mark_non_differentiable(dist, idx) 155 | 156 | return dist, idx 157 | 158 | @staticmethod 159 | def backward(ctx, grad_dist, grad_idx): 160 | return () 161 | 162 | 163 | three_nn = ThreeNN.apply 164 | 165 | 166 | class ThreeInterpolate(Function): 167 | @staticmethod 168 | def forward(ctx, features, idx, weight): 169 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 170 | r""" 171 | Performs weight linear interpolation on 3 features 172 | Parameters 173 | ---------- 174 | features : torch.Tensor 175 | (B, c, m) Features descriptors to be interpolated from 176 | idx : torch.Tensor 177 | (B, n, 3) three nearest neighbors of the target features in features 178 | weight : torch.Tensor 179 | (B, n, 3) weights 180 | 181 | Returns 182 | ------- 183 | torch.Tensor 184 | (B, c, n) tensor of the interpolated features 185 | """ 186 | ctx.save_for_backward(idx, weight, features) 187 | 188 | return _ext.three_interpolate(features, idx, weight) 189 | 190 | @staticmethod 191 | def backward(ctx, grad_out): 192 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 193 | r""" 194 | Parameters 195 | ---------- 196 | grad_out : torch.Tensor 197 | (B, c, n) tensor with gradients of ouputs 198 | 199 | Returns 200 | ------- 201 | grad_features : torch.Tensor 202 | (B, c, m) tensor with gradients of features 203 | 204 | None 205 | 206 | None 207 | """ 208 | idx, weight, features = ctx.saved_tensors 209 | m = features.size(2) 210 | 211 | grad_features = _ext.three_interpolate_grad( 212 | grad_out.contiguous(), idx, weight, m 213 | ) 214 | 215 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 216 | 217 | 218 | three_interpolate = ThreeInterpolate.apply 219 | 220 | 221 | class GroupingOperation(Function): 222 | @staticmethod 223 | def forward(ctx, features, idx): 224 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 225 | r""" 226 | 227 | Parameters 228 | ---------- 229 | features : torch.Tensor 230 | (B, C, N) tensor of features to group 231 | idx : torch.Tensor 232 | (B, npoint, nsample) tensor containing the indicies of features to group with 233 | 234 | Returns 235 | ------- 236 | torch.Tensor 237 | (B, C, npoint, nsample) tensor 238 | """ 239 | ctx.save_for_backward(idx, features) 240 | 241 | return _ext.group_points(features, idx) 242 | 243 | @staticmethod 244 | def backward(ctx, grad_out): 245 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 246 | r""" 247 | 248 | Parameters 249 | ---------- 250 | grad_out : torch.Tensor 251 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 252 | 253 | Returns 254 | ------- 255 | torch.Tensor 256 | (B, C, N) gradient of the features 257 | None 258 | """ 259 | idx, features = ctx.saved_tensors 260 | N = features.size(2) 261 | 262 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 263 | 264 | return grad_features, torch.zeros_like(idx) 265 | 266 | 267 | grouping_operation = GroupingOperation.apply 268 | 269 | 270 | class BallQuery(Function): 271 | @staticmethod 272 | def forward(ctx, radius, nsample, xyz, new_xyz): 273 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 274 | r""" 275 | 276 | Parameters 277 | ---------- 278 | radius : float 279 | radius of the balls 280 | nsample : int 281 | maximum number of features in the balls 282 | xyz : torch.Tensor 283 | (B, N, 3) xyz coordinates of the features 284 | new_xyz : torch.Tensor 285 | (B, npoint, 3) centers of the ball query 286 | 287 | Returns 288 | ------- 289 | torch.Tensor 290 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 291 | """ 292 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 293 | 294 | ctx.mark_non_differentiable(output) 295 | 296 | return output 297 | 298 | @staticmethod 299 | def backward(ctx, grad_out): 300 | return () 301 | 302 | 303 | ball_query = BallQuery.apply 304 | 305 | 306 | class QueryAndGroup(nn.Module): 307 | r""" 308 | Groups with a ball query of radius 309 | 310 | Parameters 311 | --------- 312 | radius : float32 313 | Radius of ball 314 | nsample : int32 315 | Maximum number of features to gather in the ball 316 | """ 317 | 318 | def __init__(self, radius, nsample, use_xyz=True): 319 | # type: (QueryAndGroup, float, int, bool) -> None 320 | super(QueryAndGroup, self).__init__() 321 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 322 | 323 | def forward(self, xyz, new_xyz, features=None): 324 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 325 | r""" 326 | Parameters 327 | ---------- 328 | xyz : torch.Tensor 329 | xyz coordinates of the features (B, N, 3) 330 | new_xyz : torch.Tensor 331 | centriods (B, npoint, 3) 332 | features : torch.Tensor 333 | Descriptors of the features (B, C, N) 334 | 335 | Returns 336 | ------- 337 | new_features : torch.Tensor 338 | (B, 3 + C, npoint, nsample) tensor 339 | """ 340 | 341 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 342 | xyz_trans = xyz.transpose(1, 2).contiguous() 343 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 344 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 345 | 346 | if features is not None: 347 | grouped_features = grouping_operation(features, idx) 348 | if self.use_xyz: 349 | new_features = torch.cat( 350 | [grouped_xyz, grouped_features], dim=1 351 | ) # (B, C + 3, npoint, nsample) 352 | else: 353 | new_features = grouped_features 354 | else: 355 | assert ( 356 | self.use_xyz 357 | ), "Cannot have not features and not use xyz as a feature!" 358 | new_features = grouped_xyz 359 | 360 | return new_features 361 | 362 | 363 | class GroupAll(nn.Module): 364 | r""" 365 | Groups all features 366 | 367 | Parameters 368 | --------- 369 | """ 370 | 371 | def __init__(self, use_xyz=True): 372 | # type: (GroupAll, bool) -> None 373 | super(GroupAll, self).__init__() 374 | self.use_xyz = use_xyz 375 | 376 | def forward(self, xyz, new_xyz, features=None): 377 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 378 | r""" 379 | Parameters 380 | ---------- 381 | xyz : torch.Tensor 382 | xyz coordinates of the features (B, N, 3) 383 | new_xyz : torch.Tensor 384 | Ignored 385 | features : torch.Tensor 386 | Descriptors of the features (B, C, N) 387 | 388 | Returns 389 | ------- 390 | new_features : torch.Tensor 391 | (B, C + 3, 1, N) tensor 392 | """ 393 | 394 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 395 | if features is not None: 396 | grouped_features = features.unsqueeze(2) 397 | if self.use_xyz: 398 | new_features = torch.cat( 399 | [grouped_xyz, grouped_features], dim=1 400 | ) # (B, 3 + C, 1, N) 401 | else: 402 | new_features = grouped_features 403 | else: 404 | new_features = grouped_xyz 405 | 406 | return new_features 407 | 408 | 409 | def square_distance(src, dst): 410 | """ 411 | Calculate Euclid distance between each two points. 412 | src^T * dst = xn * xm + yn * ym + zn * zm; 413 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 414 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 415 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 416 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 417 | Input: 418 | src: source points, [B, N, C] 419 | dst: target points, [B, M, C] 420 | Output: 421 | dist: per-point square distance, [B, N, M] 422 | """ 423 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) 424 | 425 | 426 | def kNN(query, dataset, k): 427 | """ 428 | inputs 429 | query: (B, N0, D) shaped torch gpu Tensor. 430 | dataset: (B, N1, D) shaped torch gpu Tensor. 431 | k: int 432 | outputs 433 | neighbors: (B * N0, k) shaped torch Tensor. 434 | Each row is the indices of a neighboring points. 435 | It is flattened along batch dimension. 436 | """ 437 | assert query.is_cuda and dataset.is_cuda, "Input tensors should be gpu tensors." 438 | assert query.dim() == 3 and dataset.dim() == 3, "Input tensors should be 3D." 439 | assert ( 440 | query.shape[0] == dataset.shape[0] 441 | ), "Input tensors should have same batch size." 442 | assert ( 443 | query.shape[2] == dataset.shape[2] 444 | ), "Input tensors should have same dimension." 445 | 446 | B, N1, _ = dataset.shape 447 | 448 | query_o3d = o3d.core.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(query)) 449 | dataset_o3d = o3d.core.Tensor.from_dlpack(torch.utils.dlpack.to_dlpack(dataset)) 450 | 451 | indices = [] 452 | for i in range(query_o3d.shape[0]): 453 | _query = query_o3d[i] 454 | _dataset = dataset_o3d[i] 455 | nns = o3d.core.nns.NearestNeighborSearch(_dataset) 456 | status = nns.knn_index() 457 | if not status: 458 | raise Exception("Index failed.") 459 | neighbors, _ = nns.knn_search(_query, k) 460 | # calculate prefix sum of indices 461 | # neighbors += N1 * i 462 | indices.append(torch.utils.dlpack.from_dlpack(neighbors.to_dlpack())) 463 | 464 | # flatten indices 465 | indices = torch.stack(indices) 466 | return indices 467 | 468 | 469 | def kNN_torch(query, dataset, k): 470 | """ 471 | inputs 472 | query: (B, N0, D) shaped torch gpu Tensor. 473 | dataset: (B, N1, D) shaped torch gpu Tensor. 474 | k: int 475 | outputs 476 | neighbors: (B * N0, k) shaped torch Tensor. 477 | Each row is the indices of a neighboring points. 478 | It is flattened along batch dimension. 479 | """ 480 | assert query.is_cuda and dataset.is_cuda, "Input tensors should be gpu tensors." 481 | assert query.dim() == 3 and dataset.dim() == 3, "Input tensors should be 3D." 482 | assert ( 483 | query.shape[0] == dataset.shape[0] 484 | ), "Input tensors should have same batch size." 485 | assert ( 486 | query.shape[2] == dataset.shape[2] 487 | ), "Input tensors should have same dimension." 488 | 489 | dists = square_distance(query, dataset) # dists: [B, N0, N1] 490 | neighbors = dists.argsort()[:, :, :k] # neighbors: [B, N0, k] 491 | torch.cuda.empty_cache() 492 | return neighbors 493 | 494 | 495 | def index_points(points, idx): 496 | """ 497 | Input: 498 | points: input points data, [B, N, C] 499 | idx: sample index data, [B, S, [K]] 500 | Return: 501 | new_points:, indexed points data, [B, S, [K], C] 502 | """ 503 | raw_size = idx.size() 504 | idx = idx.reshape(raw_size[0], -1) 505 | res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) 506 | return res.reshape(*raw_size, -1) -------------------------------------------------------------------------------- /point_transformer_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("point_transformer_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("point_transformer_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="point_transformer_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="point_transformer_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /tool/ept.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH -J pt4pc 4 | #SBATCH --partition=gpu1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -w node6 7 | #SBATCH -c 4 8 | #SBATCH --nodes=1 9 | 10 | source /home/yifliu3/.bashrc 11 | nvidia-smi 12 | nvcc -V 13 | 14 | export PYTHONPATH=./ 15 | 16 | conda activate pt4pc 17 | 18 | TRAIN_CODE=train.py 19 | TEST_CODE=test.py 20 | 21 | dataset=$1 22 | exp_name=$2 23 | exp_dir=exp/${dataset}/${exp_name} 24 | config=config/${dataset}/${dataset}_${exp_name}.yaml 25 | 26 | mkdir -p ${exp_dir} 27 | cp tool/train.sh tool/${TRAIN_CODE} tool/${TEST_CODE} ${config} ${exp_dir} 28 | 29 | python ${exp_dir}/${TRAIN_CODE} \ 30 | --config=${config} \ 31 | save_path ${exp_dir} \ 32 | num_edge_neighbor 4 \ 33 | 2>&1 | tee ${exp_dir}/train-$now.log 34 | 35 | python ${exp_dir}/${TEST_CODE} \ 36 | --config=${config} \ 37 | save_path ${exp_dir} \ 38 | num_edge_neighbor 4 \ 39 | 2>&1 | tee ${exp_dir}/test-$now.log -------------------------------------------------------------------------------- /tool/ept_1024.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH -J pt4pc 4 | #SBATCH --partition=gpu1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -w node6 7 | #SBATCH -c 4 8 | #SBATCH --nodes=1 9 | 10 | source /home/yifliu3/.bashrc 11 | nvidia-smi 12 | nvcc -V 13 | 14 | export PYTHONPATH=./ 15 | 16 | conda activate pt4pc 17 | 18 | TRAIN_CODE=train.py 19 | TEST_CODE=test.py 20 | 21 | dataset=$1 22 | exp_name=$2 23 | exp_dir=exp_1024/${dataset}/${exp_name} 24 | config=config/${dataset}/${dataset}_${exp_name}.yaml 25 | 26 | mkdir -p ${exp_dir} 27 | cp tool/train.sh tool/${TRAIN_CODE} tool/${TEST_CODE} ${config} ${exp_dir} 28 | 29 | python ${exp_dir}/${TRAIN_CODE} \ 30 | --config=${config} \ 31 | sample_points 1024 \ 32 | save_path ${exp_dir} \ 33 | num_edge_neighbor 6 \ 34 | 2>&1 | tee ${exp_dir}/train-$now.log 35 | 36 | python ${exp_dir}/${TEST_CODE} \ 37 | --config=${config} \ 38 | save_path ${exp_dir} \ 39 | num_edge_neighbor 6 \ 40 | test_points 1024 \ 41 | 2>&1 | tee ${exp_dir}/test-$now.log 42 | -------------------------------------------------------------------------------- /tool/ept_2048.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH -J pt4pc 4 | #SBATCH --partition=gpu1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH -w node6 7 | #SBATCH -c 4 8 | #SBATCH --nodes=1 9 | 10 | source /home/yifliu3/.bashrc 11 | nvidia-smi 12 | nvcc -V 13 | 14 | export PYTHONPATH=./ 15 | 16 | conda activate pt4pc 17 | 18 | TRAIN_CODE=train.py 19 | TEST_CODE=test.py 20 | 21 | dataset=$1 22 | exp_name=$2 23 | exp_dir=exp_2048/${dataset}/${exp_name} 24 | config=config/${dataset}/${dataset}_${exp_name}.yaml 25 | 26 | mkdir -p ${exp_dir} 27 | cp tool/train.sh tool/${TRAIN_CODE} tool/${TEST_CODE} ${config} ${exp_dir} 28 | 29 | python ${exp_dir}/${TRAIN_CODE} \ 30 | --config=${config} \ 31 | sample_points 2048 \ 32 | num_edge_neighbor 8 \ 33 | save_path ${exp_dir} \ 34 | 2>&1 | tee ${exp_dir}/train-$now.log 35 | 36 | python ${exp_dir}/${TEST_CODE} \ 37 | --config=${config} \ 38 | num_edge_neighbor 8 \ 39 | test_points 2048 40 | save_path ${exp_dir} \ 41 | 2>&1 | tee ${exp_dir}/test-$now.log 42 | -------------------------------------------------------------------------------- /tool/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import logging 6 | import argparse 7 | import shutil 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | from dataset import IntrADataset 14 | import dataset.data_utils as d_utils 15 | 16 | from utils import config 17 | from utils.tools import record_statistics, cal_IoU_Acc_batch, get_contra_loss 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser(description='PyTorch Point Cloud Semantic Segmentation') 22 | parser.add_argument('--config', type=str, default='config/IntrA/IntrA_pointtransformer_seg_repro.yaml', help='config file') 23 | parser.add_argument('opts', help='see config/IntrA/IntrA_segmentation.yaml for all options', default=None, nargs=argparse.REMAINDER) 24 | args = parser.parse_args() 25 | assert args.config is not None 26 | cfg = config.load_cfg_from_cfg_file(args.config) 27 | if args.opts is not None: 28 | cfg = config.merge_cfg_from_list(cfg, args.opts) 29 | return cfg 30 | 31 | 32 | def get_logger(): 33 | logger_name = "main-logger" 34 | logger = logging.getLogger(logger_name) 35 | logger.setLevel(logging.INFO) 36 | handler = logging.StreamHandler() 37 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 38 | handler.setFormatter(logging.Formatter(fmt)) 39 | logger.addHandler(handler) 40 | return logger 41 | 42 | 43 | def main(): 44 | global args, logger 45 | args = get_parser() 46 | if args.manual_seed is not None: 47 | random.seed(args.manual_seed) 48 | np.random.seed(args.manual_seed) 49 | torch.manual_seed(args.manual_seed) 50 | torch.cuda.manual_seed(args.manual_seed) 51 | torch.cuda.manual_seed_all(args.manual_seed) 52 | torch.backends.cudnn.benchmark = False 53 | torch.backends.cudnn.deterministic = True 54 | args.ngpus_per_node = len(args.train_gpu) 55 | 56 | logger = get_logger() 57 | logger.info(args) 58 | 59 | loss_list, loss_seg_list, loss_edge_list, loss_seg_refine_list = [], [], [], [] 60 | iou_list, inner_iou_list, outer_iou_list = [], [], [] 61 | iou_refine_list, inner_iou_refine_list, outer_iou_refine_list = [], [], [] 62 | 63 | for fold in args.folds: 64 | record = main_worker(args.train_gpu, args.ngpus_per_node, test_fold=fold, test_times=args.test_times) 65 | loss_list.append(record['loss_avg']) 66 | loss_seg_list.append(record['loss_seg']) 67 | loss_edge_list.append(record['loss_edge']) 68 | loss_seg_refine_list.append(record['loss_seg_refine']) 69 | iou_list.append(record['iou_list'].cpu().numpy()) 70 | inner_iou_list.append(record['inner_iou_list'].cpu().numpy()) 71 | outer_iou_list.append(record['outer_iou_list'].cpu().numpy()) 72 | iou_refine_list.append(record['iou_refine_list'].cpu().numpy()) 73 | inner_iou_refine_list.append(record['inner_iou_refine_list'].cpu().numpy()) 74 | outer_iou_refine_list.append(record['outer_iou_refine_list'].cpu().numpy()) 75 | 76 | loss, loss_seg, loss_edge, loss_seg_refine = np.mean(loss_list), np.mean(loss_seg_list), np.mean(loss_edge_list), np.mean(loss_seg_refine_list) 77 | iou = np.mean(np.stack(iou_list, axis=0), axis=0) 78 | miou = np.mean(iou) 79 | inner_miou = np.mean(np.stack(inner_iou_list, axis=0)) 80 | outer_miou = np.mean(np.stack(outer_iou_list, axis=0)) 81 | iou_refine = np.mean(np.stack(iou_refine_list, axis=0), axis=0) 82 | miou_refine = np.mean(iou_refine) 83 | inner_miou_refine = np.mean(np.stack(inner_iou_refine_list, axis=0)) 84 | outer_miou_refine = np.mean(np.stack(outer_iou_refine_list, axis=0)) 85 | logger.info("=> Final mIoU is {:.4f}, vIoU is {:.4f}, aIoU is {:.4f}, inner mIoU is {:.4f}, outer mIoU is {:.4f}"\ 86 | .format(miou, iou[0], iou[1], inner_miou, outer_miou)) 87 | logger.info("=> Final mIoU_refine is {:.4f}, vIoU_refine is {:.4f}, aIoU_refine is {:.4f}, inner mIoU refine is {:.4f}, outer mIoU refine is {:.4f}"\ 88 | .format(miou_refine, iou_refine[0], iou_refine[1], inner_miou_refine, outer_miou_refine)) 89 | logger.info("=> Final loss is {:.4f}, loss_seg is {:.4f}, loss_edge is {:.4f}, loss_seg_refine is {:.4f}".format(loss, loss_seg, loss_edge, loss_seg_refine)) 90 | 91 | 92 | def main_worker(gpu, ngpus_per_node, test_fold, test_times=3): 93 | global args, logger 94 | logger.info("===============Test fold {}===============".format(test_fold)) 95 | 96 | 97 | logger.info("=> Creating model ...") 98 | if args.arch == 'IntrA_pointtransformer_seg_repro': 99 | from models.point_transformer_seg import PointTransformerSemSegmentation as Model 100 | else: 101 | raise Exception('architecture {} not supported yet'.format(args.arch)) 102 | model = Model(args=args).cuda() 103 | default_ckpt_path = 'exp/IntrA/pointtransformer_seg_repro/' 104 | if args.test_points == 512: 105 | ckpt_path = default_ckpt_path 106 | args.num_edge_neighbor = 4 107 | elif args.test_points == 1024: 108 | ckpt_path = default_ckpt_path.replace('exp', 'exp_'+str(args.test_points)) 109 | args.num_edge_neighbor = 6 110 | else: 111 | ckpt_path = default_ckpt_path.replace('exp', 'exp_'+str(args.test_points)) 112 | args.num_edge_neighbor = 8 113 | 114 | ckpt_path = os.path.join(ckpt_path, 'fold'+str(test_fold), 'model_best.pth') 115 | ckpt = torch.load(ckpt_path)['state_dict'] 116 | model.load_state_dict(ckpt) 117 | 118 | 119 | logger.info("=> Loading data ...") 120 | if args.data_name == "IntrA": 121 | val_data = IntrADataset(args.data_root, args.test_points, args.use_uniform_sample, args.use_normals, 122 | test_fold=test_fold, num_edge_neighbor=args.num_edge_neighbor, mode='test', transform=None, test_all=False) 123 | val_loader = DataLoader(val_data, batch_size=args.batch_size_val, 124 | shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) 125 | 126 | 127 | logger.info("=> Testing ...") 128 | record_val = validation(val_loader, model) 129 | print(record_val) 130 | 131 | return record_val 132 | 133 | 134 | def validation(val_loader, model): 135 | global args 136 | model.eval() 137 | 138 | loss_avg_list, loss_seg_list, loss_seg_refine_list, loss_edge_list, loss_contra_list = [], [], [], [], [] 139 | iou_avg_list, inner_iou_avg_list, outer_iou_avg_list = [], [], [] 140 | iou_refine_avg_list, inner_iou_refine_avg_list, outer_iou_refine_avg_list = [], [], [] 141 | 142 | for i in range(args.num_votes): 143 | loss_avg, loss_seg_avg, loss_seg_refine_avg, loss_edge_avg, loss_contra_avg = 0.0, 0.0, 0.0, 0.0, 0.0 144 | iou_avg, inner_iou_avg, outer_iou_avg = [], [], [] 145 | iou_refine_avg, inner_iou_refine_avg, outer_iou_refine_avg = [], [], [] 146 | with torch.no_grad(): 147 | for batch_idx, (pts, gts, egts, eweights, gmatrix, idxs) in enumerate(val_loader): 148 | pts, gts, egts, eweights, gmatrix = pts.cuda(), gts.cuda(), egts.cuda(), eweights.mean(dim=0).cuda(), gmatrix.cuda() 149 | seg_preds, seg_refine_preds, seg_embed, edge_preds = model(pts, gmatrix, idxs) 150 | loss_seg = F.cross_entropy(seg_preds, gts, weight=val_loader.dataset.segweights.cuda()) 151 | loss_seg_refine = F.cross_entropy(seg_refine_preds, gts, weight=val_loader.dataset.segweights.cuda()) 152 | loss_edge = F.cross_entropy(edge_preds, egts, weight=eweights) 153 | loss_contra = get_contra_loss(egts, gts, seg_embed, gmatrix, num_class=args.classes, temp=args.temp) 154 | loss = loss_seg + args.weight_edge * loss_edge + args.weight_contra * loss_contra + args.weight_refine * loss_seg_refine 155 | 156 | loss_avg += loss.item() 157 | loss_seg_avg += loss_seg.item() 158 | loss_seg_refine_avg += loss_seg_refine.item() 159 | loss_edge_avg += loss_edge.item() 160 | loss_contra_avg += loss_contra.item() 161 | 162 | iou, inner_iou, outer_iou = cal_IoU_Acc_batch(seg_preds, gts, egts) 163 | iou_avg.append(iou) 164 | inner_iou_avg.append(inner_iou) 165 | outer_iou_avg.append(outer_iou) 166 | 167 | iou_refine, inner_iou_refine, outer_iou_refine = cal_IoU_Acc_batch(seg_refine_preds, gts, egts) 168 | iou_refine_avg.append(iou_refine) 169 | inner_iou_refine_avg.append(inner_iou_refine) 170 | outer_iou_refine_avg.append(outer_iou_refine) 171 | 172 | dataset_len = len(val_loader.dataset) 173 | loss_seg_list.append(loss_seg_avg/dataset_len) 174 | loss_seg_refine_list.append(loss_seg_refine_avg/dataset_len) 175 | loss_edge_list.append(loss_edge_avg/dataset_len) 176 | loss_contra_list.append(loss_contra_avg/dataset_len) 177 | loss_avg_list.append(loss_avg/dataset_len) 178 | 179 | iou_avg_list.append(torch.cat(iou_avg, dim=0).mean(dim=0)) 180 | inner_iou_avg_list.append(torch.cat(inner_iou_avg, dim=0).mean(dim=0)) 181 | outer_iou_avg_list.append(torch.cat(outer_iou_avg, dim=0).mean(dim=0)) 182 | 183 | iou_refine_avg_list.append(torch.cat(iou_refine_avg, dim=0).mean(dim=0)) 184 | inner_iou_refine_avg_list.append(torch.cat(inner_iou_refine_avg, dim=0).mean(dim=0)) 185 | outer_iou_refine_avg_list.append(torch.cat(outer_iou_refine_avg, dim=0).mean(dim=0)) 186 | 187 | record = {} 188 | record['loss_avg'] = np.mean(loss_avg_list) 189 | record['loss_seg'] = np.mean(loss_seg_list) 190 | record['loss_seg_refine'] = np.mean(loss_seg_refine_list) 191 | record['loss_edge'] = np.mean(loss_edge_list) 192 | record['loss_contra'] = np.mean(loss_contra_list) 193 | record['iou_list'] = torch.stack(iou_avg_list, dim=0).mean(dim=0) 194 | record['inner_iou_list'] = torch.stack(inner_iou_avg_list, dim=0).mean(dim=0) 195 | record['outer_iou_list'] = torch.stack(outer_iou_avg_list, dim=0).mean(dim=0) 196 | record['iou_refine_list'] = torch.stack(iou_refine_avg_list, dim=0).mean(dim=0) 197 | record['inner_iou_refine_list'] = torch.stack(inner_iou_refine_avg_list, dim=0).mean(dim=0) 198 | record['outer_iou_refine_list'] = torch.stack(outer_iou_refine_avg_list, dim=0).mean(dim=0) 199 | return record 200 | 201 | 202 | if __name__ == "__main__": 203 | main() 204 | -------------------------------------------------------------------------------- /tool/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import logging 6 | import argparse 7 | import shutil 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | import torch.nn as nn 14 | from torchvision import transforms 15 | from tensorboardX import SummaryWriter 16 | 17 | from dataset import IntrADataset 18 | import dataset.data_utils as d_utils 19 | 20 | from utils import config 21 | from utils.tools import cal_IoU_Acc_batch, get_contra_loss, record_statistics 22 | 23 | 24 | def get_parser(): 25 | parser = argparse.ArgumentParser(description='PyTorch Point Cloud Semantic Segmentation') 26 | parser.add_argument('--config', type=str, default='config/IntrA/IntrA_pointtransformer_seg_repro.yaml', help='config file') 27 | parser.add_argument('opts', help='see config/IntrA/IntrA_segmentation.yaml for all options', default=None, nargs=argparse.REMAINDER) 28 | args = parser.parse_args() 29 | assert args.config is not None 30 | cfg = config.load_cfg_from_cfg_file(args.config) 31 | if args.opts is not None: 32 | cfg = config.merge_cfg_from_list(cfg, args.opts) 33 | return cfg 34 | 35 | 36 | def get_logger(): 37 | logger_name = "main-logger" 38 | logger = logging.getLogger(logger_name) 39 | logger.setLevel(logging.INFO) 40 | handler = logging.StreamHandler() 41 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 42 | handler.setFormatter(logging.Formatter(fmt)) 43 | logger.addHandler(handler) 44 | return logger 45 | 46 | 47 | def main(): 48 | global args, logger 49 | args = get_parser() 50 | if args.manual_seed is not None: 51 | random.seed(args.manual_seed) 52 | np.random.seed(args.manual_seed) 53 | torch.manual_seed(args.manual_seed) 54 | torch.cuda.manual_seed(args.manual_seed) 55 | torch.cuda.manual_seed_all(args.manual_seed) 56 | torch.backends.cudnn.benchmark = False 57 | torch.backends.cudnn.deterministic = True 58 | args.ngpus_per_node = len(args.train_gpu) 59 | 60 | logger = get_logger() 61 | logger.info(args) 62 | 63 | iou_list = [] 64 | for fold in args.folds: 65 | best_iou_list = main_worker(args.train_gpu, args.ngpus_per_node, test_fold=fold) 66 | iou_list.append(best_iou_list) 67 | iou = torch.stack(iou_list, dim=0).mean(dim=0) 68 | miou = torch.mean(iou) 69 | logger.info("=> Final mIoU is {:.4f}, vIoU is {:.4f}, aIoU is {:.4f}".format(miou, iou[0], iou[1])) 70 | 71 | 72 | def main_worker(gpu, ngpus_per_node, test_fold): 73 | global args, logger, writer 74 | 75 | fold_path = os.path.join(args.save_path, "fold{}".format(test_fold)) 76 | if not os.path.exists(fold_path): 77 | os.makedirs(fold_path) 78 | writer = SummaryWriter(fold_path) 79 | logger.info("===============test fold {}===============".format(test_fold)) 80 | 81 | logger.info("=====================> Creating model ...") 82 | if args.arch == 'IntrA_pointtransformer_seg_repro': 83 | from models.point_transformer_seg import PointTransformerSemSegmentation as Model 84 | else: 85 | raise Exception('architecture {} not supported yet'.format(args.arch)) 86 | model = Model(args=args).cuda() 87 | optimizer = torch.optim.Adam( 88 | model.parameters(), 89 | lr=args.base_lr, 90 | weight_decay=args.weight_decay 91 | ) 92 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=args.base_lr*0.01) 93 | logger.info("=> Features:{}, Classes: {}".format(args.fea_dim, args.classes)) 94 | 95 | logger.info("=====================> Loading data ...") 96 | if args.data_name == "IntrA": 97 | train_transforms = transforms.Compose( 98 | [ 99 | d_utils.PointcloudToTensor(), 100 | d_utils.PointcloudScale(), 101 | d_utils.PointcloudRotate(), 102 | d_utils.PointcloudRotatePerturbation(), 103 | d_utils.PointcloudTranslate(), 104 | d_utils.PointcloudJitter(), 105 | d_utils.PointcloudRandomInputDropout(), 106 | ] 107 | ) 108 | train_data = IntrADataset(args.data_root, args.sample_points, args.use_uniform_sample, args.use_normals, 109 | test_fold=test_fold, num_edge_neighbor=args.num_edge_neighbor, mode='train', transform=train_transforms) 110 | train_loader = DataLoader(train_data, batch_size=args.batch_size_train, 111 | shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) 112 | val_data = IntrADataset(args.data_root, args.sample_points, args.use_uniform_sample, args.use_normals, 113 | test_fold=test_fold, num_edge_neighbor=args.num_edge_neighbor, mode='test', transform=None) 114 | val_loader = DataLoader(val_data, batch_size=args.batch_size_val, 115 | shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) 116 | logger.info("=> Loaded {} training samples, {} testing samples".format(len(train_data), len(val_data))) 117 | 118 | logger.info("=====================> Training loop...") 119 | best_miou = 0 120 | best_iou_list = [] 121 | for epoch in range(args.start_epoch, args.epochs): 122 | record_train = train_one_epoch(train_loader, model, optimizer) 123 | writer = record_statistics(writer, record_train, mode='train', epoch=epoch) 124 | 125 | scheduler.step() 126 | 127 | if args.evaluate and (epoch % args.eval_freq == 0): 128 | is_best = False 129 | record_val = val_one_epoch(val_loader, model) 130 | writer = record_statistics(writer, record_val, mode='val', epoch=epoch) 131 | 132 | iou_list = record_val['iou_list'] 133 | miou_val = torch.mean(iou_list) 134 | is_best = miou_val > best_miou 135 | best_miou = miou_val if is_best else best_miou 136 | best_iou_list = iou_list if is_best else best_iou_list 137 | filename = os.path.join(fold_path, 'model_last.pth') 138 | torch.save({'epoch': epoch, 'state_dict' : model.state_dict(), 'optimizer': optimizer.state_dict(), 139 | 'scheduler': scheduler.state_dict(), 'best_miou': best_miou, 140 | 'best_viou': best_iou_list[0], 'best_aiou': best_iou_list[1]}, filename) 141 | if is_best: 142 | logger.info('Epoch{}: best validation mIoU updated to {:.4f}, vIoU is {:.4f} and aIoU is {:.4f}'.format( 143 | epoch, best_miou, best_iou_list[0], best_iou_list[1])) 144 | shutil.copyfile(filename, os.path.join(fold_path, 'model_best.pth')) 145 | 146 | writer.close() 147 | logger.info("===============test fold {} training done===============\n\ 148 | Best mIoU is {:.4f}, vIoU is {:.4f}, aIoU is {:.4f}".format(test_fold, best_miou, best_iou_list[0], best_iou_list[1])) 149 | return best_iou_list 150 | 151 | 152 | def train_one_epoch(train_loader, model, optimizer): 153 | global args 154 | model.train() 155 | 156 | loss_avg, loss_seg_avg, loss_seg_refine_avg, loss_edge_avg, loss_contra_avg = 0.0, 0.0, 0.0, 0.0, 0.0 157 | iou_list, iou_refine_list = [], [] 158 | for batch_i, (pts, gts, egts, eweights, gmatrix, idxs) in enumerate(train_loader): 159 | pts, gts, egts, eweights, gmatrix = pts.cuda(), gts.cuda(), egts.cuda(), eweights.mean(dim=0).cuda(), gmatrix.cuda() 160 | seg_preds, seg_refine_preds, seg_embed, edge_preds = model(pts, gmatrix, idxs) 161 | loss_seg = F.cross_entropy(seg_preds, gts, weight=train_loader.dataset.segweights.cuda()) 162 | loss_seg_refine = F.cross_entropy(seg_refine_preds, gts, weight=train_loader.dataset.segweights.cuda()) 163 | loss_edge = F.cross_entropy(edge_preds, egts, weight=eweights) 164 | loss_contra = get_contra_loss(egts, gts, seg_embed, gmatrix, num_class=args.classes, temp=args.temp) 165 | loss = loss_seg + args.weight_refine * loss_seg_refine + args.weight_edge * loss_edge + args.weight_contra * loss_contra 166 | 167 | loss_avg += loss.item() 168 | loss_seg_avg += loss_seg.item() 169 | loss_seg_refine_avg += loss_seg_refine.item() 170 | loss_edge_avg += loss_edge.item() 171 | loss_contra_avg += loss_contra.item() 172 | iou_list.append(cal_IoU_Acc_batch(seg_preds, gts)) 173 | iou_refine_list.append(cal_IoU_Acc_batch(seg_refine_preds, gts)) 174 | 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | 179 | record = {} 180 | dataset_len = len(train_loader.dataset) 181 | record['loss_all'] = loss_avg / dataset_len 182 | record['loss_seg'] = loss_seg_avg / dataset_len 183 | record['loss_seg_refine'] = loss_seg_refine_avg / dataset_len 184 | record['loss_edge'] = loss_edge_avg / dataset_len 185 | record['loss_contra'] = loss_contra_avg / dataset_len 186 | record['iou_list'] = torch.cat(iou_list, dim=0).mean(dim=0) 187 | record['iou_refine_list'] = torch.cat(iou_refine_list, dim=0).mean(dim=0) 188 | return record 189 | 190 | 191 | def val_one_epoch(val_loader, model): 192 | global args 193 | model.eval() 194 | 195 | loss_avg_list, loss_seg_avg_list, loss_seg_refine_avg_list, loss_edge_avg_list, loss_contra_avg_list = [], [], [], [], [] 196 | iou_avg_list, iou_refine_avg_list = [], [] 197 | for i in range(args.num_votes): 198 | loss_avg, loss_seg_avg, loss_seg_refine_avg, loss_edge_avg, loss_contra_avg = 0.0, 0.0, 0.0, 0.0, 0.0 199 | iou_avg, iou_refine_avg = [], [] 200 | with torch.no_grad(): 201 | for batch_idx, (pts, gts, egts, eweights, gmatrix, idxs) in enumerate(val_loader): 202 | pts, gts, egts, eweights, gmatrix = pts.cuda(), gts.cuda(), egts.cuda(), eweights.mean(dim=0).cuda(), gmatrix.cuda() 203 | seg_preds, seg_refine_preds, seg_embed, edge_preds = model(pts, gmatrix, idxs) 204 | loss_seg = F.cross_entropy(seg_preds, gts, weight=val_loader.dataset.segweights.cuda()) 205 | loss_seg_refine = F.cross_entropy(seg_refine_preds, gts, weight=val_loader.dataset.segweights.cuda()) 206 | loss_edge = F.cross_entropy(edge_preds, egts, weight=eweights) 207 | loss_contra = get_contra_loss(egts, gts, seg_embed, gmatrix, num_class=args.classes, temp=args.temp) 208 | loss = loss_seg + args.weight_refine * loss_seg_refine + args.weight_edge * loss_edge + args.weight_contra * loss_contra 209 | 210 | loss_avg += loss.item() 211 | loss_seg_avg += loss_seg.item() 212 | loss_seg_refine_avg += loss_seg_refine.item() 213 | loss_edge_avg += loss_edge.item() 214 | loss_contra_avg += loss_contra.item() 215 | iou_avg.append(cal_IoU_Acc_batch(seg_preds, gts)) 216 | iou_refine_avg.append(cal_IoU_Acc_batch(seg_refine_preds, gts)) 217 | 218 | dataset_len = len(val_loader.dataset) 219 | loss_avg_list.append(loss_avg / dataset_len) 220 | loss_seg_avg_list.append(loss_seg_avg / dataset_len) 221 | loss_seg_refine_avg_list.append(loss_seg_refine_avg / dataset_len) 222 | loss_edge_avg_list.append(loss_edge_avg / dataset_len) 223 | loss_contra_avg_list.append(loss_contra_avg / dataset_len) 224 | iou_avg_list.append(torch.cat(iou_avg, dim=0).mean(dim=0)) 225 | iou_refine_avg_list.append(torch.cat(iou_refine_avg, dim=0).mean(dim=0)) 226 | 227 | record = {} 228 | record['loss_all'] = np.mean(loss_avg_list) 229 | record['loss_seg'] = np.mean(loss_seg_avg_list) 230 | record['loss_seg_refine'] = np.mean(loss_seg_refine_avg_list) 231 | record['loss_edge'] = np.mean(loss_edge_avg_list) 232 | record['loss_contra'] = np.mean(loss_contra_avg_list) 233 | record['iou_list'] = torch.stack(iou_avg_list, dim=0).mean(dim=0) 234 | record['iou_refine_list'] = torch.stack(iou_refine_avg_list, dim=0).mean(dim=0) 235 | return record 236 | 237 | 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | for key in cfg_from_file: 69 | for k, v in cfg_from_file[key].items(): 70 | cfg[k] = v 71 | 72 | cfg = CfgNode(cfg) 73 | return cfg 74 | 75 | 76 | def merge_cfg_from_list(cfg, cfg_list): 77 | new_cfg = copy.deepcopy(cfg) 78 | assert len(cfg_list) % 2 == 0 79 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 80 | subkey = full_key.split('.')[-1] 81 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 82 | value = _decode_cfg_value(v) 83 | value = _check_and_coerce_cfg_value_type( 84 | value, cfg[subkey], subkey, full_key 85 | ) 86 | setattr(new_cfg, subkey, value) 87 | 88 | return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | """Checks that `replacement`, which is intended to replace `original` is of 122 | the right type. The type is correct if it matches exactly or is one of a few 123 | cases in which the type can be easily coerced. 124 | """ 125 | original_type = type(original) 126 | replacement_type = type(replacement) 127 | 128 | # The types must match (with some exceptions) 129 | if replacement_type == original_type or original is None: 130 | return replacement 131 | 132 | # Cast replacement from from_type to to_type if the replacement and original 133 | # types match from_type and to_type 134 | def conditional_cast(from_type, to_type): 135 | if replacement_type == from_type and original_type == to_type: 136 | return True, to_type(replacement) 137 | else: 138 | return False, None 139 | 140 | # Conditionally casts 141 | # list <-> tuple 142 | casts = [(tuple, list), (list, tuple)] 143 | # For py2: allow converting from str (bytes) to a unicode string 144 | try: 145 | casts.append((str, unicode)) # noqa: F821 146 | except Exception: 147 | pass 148 | 149 | for (from_type, to_type) in casts: 150 | converted, converted_value = conditional_cast(from_type, to_type) 151 | if converted: 152 | return converted_value 153 | 154 | raise ValueError( 155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | "key: {}".format( 157 | original_type, replacement_type, original, replacement, full_key 158 | ) 159 | ) 160 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | 5 | class logger(object): 6 | def __init__(self, root_path) -> None: 7 | timenow = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") 8 | self.logger_path = os.path.join(root_path, timenow) 9 | if not os.path.exists(self.logger_path): 10 | os.makedirs(self.logger_path) 11 | self.record_path = os.path.join(self.logger_path, "record.log") 12 | self.config_path = os.path.join(self.logger_path, "config.log") 13 | 14 | def printdir(self, dir): 15 | with open(self.config_path, 'a') as f: 16 | for k, v in dir.items(): 17 | item = str(k) + ':' + str(v) 18 | f.write(item) 19 | f.write('\n') 20 | 21 | def print(self, str): 22 | with open(self.record_path, 'a') as f: 23 | f.write(str) 24 | f.write('\n') 25 | f.close() 26 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | class Timer: 4 | def __init__(self, tag, print=False): 5 | self.tag = tag 6 | self.ts = None 7 | self.print = print 8 | 9 | def tic(self): 10 | self.ts = time() 11 | 12 | def toc(self): 13 | if self.print: 14 | print("{}: {}s".format(self.tag, time() - self.ts)) 15 | return time() 16 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cal_IoU_Acc_batch(preds, labels): 5 | B, C, _ = preds.shape 6 | preds = torch.argmax(preds, dim=1) 7 | IoU = torch.zeros(B, C).cuda() 8 | for j in range(C): 9 | tmp_and_num = torch.sum(torch.bitwise_and(preds==j, labels==j), dim=1, keepdim=False) 10 | tmp_or_num = torch.sum(torch.bitwise_or(preds==j, labels==j), dim=1, keepdim=False) 11 | IoU[..., j] = tmp_and_num / tmp_or_num 12 | 13 | return IoU 14 | 15 | 16 | def record_statistics(writer, record, mode, epoch): 17 | for k, v in record.items(): 18 | if k == 'iou_list': 19 | writer.add_scalar('miou_{}'.format(mode), v.mean(), epoch) 20 | writer.add_scalar('viou_{}'.format(mode), v[0], epoch) 21 | writer.add_scalar('aiou_{}'.format(mode), v[1], epoch) 22 | elif k == 'iou_refine_list': 23 | writer.add_scalar('miou_refine_{}'.format(mode), v.mean(), epoch) 24 | writer.add_scalar('viou_refine_{}'.format(mode), v[0], epoch) 25 | writer.add_scalar('aiou_refine_{}'.format(mode), v[1], epoch) 26 | else: 27 | writer.add_scalar(k+'_{}'.format(mode), v, epoch) 28 | return writer 29 | 30 | 31 | def get_contra_loss(egts, gts, seg_emb, gmatrix, num_class, temp): 32 | B, D, N = seg_emb.shape 33 | seg_emb = seg_emb.transpose(1, 2).contiguous() 34 | detach_emb = seg_emb.clone().detach() 35 | loss_contra = 0.0 36 | # import pdb; pdb.set_trace() 37 | for i in range(B): 38 | egts_this, gts_this, seg_emb_this, gmatrix_this = egts[i, :], gts[i, :], seg_emb[i, ...], gmatrix[i, ...] 39 | detach_emb_this = detach_emb[i, ...] 40 | nonedge_idxs = torch.nonzero(egts_this==0, as_tuple=True)[0] 41 | edge_idxs = torch.nonzero(egts_this==1, as_tuple=True)[0] 42 | 43 | edge_gts = gts_this[edge_idxs] 44 | edge_emb = seg_emb_this[edge_idxs, :] 45 | nonedge_gts = gts_this[nonedge_idxs] 46 | nonedge_detach_emb = detach_emb_this[nonedge_idxs, :] 47 | nonedge_gmatrix = gmatrix_this[nonedge_idxs, :] 48 | 49 | keys = [] 50 | for j in range(num_class): 51 | jclass_nonedge_emb = nonedge_detach_emb[nonedge_gts==j, :] 52 | jclass_nonedge_gmatrix = nonedge_gmatrix[nonedge_gts==j, ][:, edge_idxs].mean(dim=-1) 53 | jclass_nonedge_neighbor_idxs = jclass_nonedge_gmatrix.argsort(dim=-1)[ :16] 54 | keys.append(jclass_nonedge_emb[jclass_nonedge_neighbor_idxs, :]) 55 | 56 | for j in range(num_class): 57 | positive_key = keys[j].view(-1, D) # (B, D) 58 | negative_key = torch.stack(keys[:j] + keys[j+1:], dim=0).view(-1, D) #(C-1*B, D) 59 | jclass_emb = edge_emb[edge_gts==j, :] # (M, D) 60 | jclass_pos_logits = - torch.mm(jclass_emb, positive_key.transpose(0, 1)) / temp # (M, B) 61 | jclass_neg_logits = torch.log(torch.sum(torch.exp(torch.mm(jclass_emb, negative_key.transpose(0, 1)) / temp), dim=-1)) # M 62 | loss_contra_tmp = torch.mean(jclass_pos_logits + jclass_neg_logits[:, None]) 63 | loss_contra += loss_contra_tmp 64 | 65 | loss_contra = loss_contra / (B * num_class) 66 | return loss_contra 67 | --------------------------------------------------------------------------------