├── LICENSE ├── README.md ├── cfgs ├── config_partseg_gpus.yaml └── config_partseg_test.yaml ├── data ├── ShapeNetPartLoader.py ├── __init__.py ├── __pycache__ │ ├── Indoor3DSemSegLoader.cpython-36.pyc │ ├── ModelNet40Loader.cpython-36.pyc │ ├── ShapeNetPartLoader.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ └── data_utils.cpython-36.pyc └── data_utils.py ├── figures ├── cls.png ├── compare.png ├── intro.png ├── modelnet.png ├── new_network.png ├── partseg.png ├── scanobjectnn.png ├── shapenet.png └── visual.png ├── models ├── __init__.py └── drnet.py ├── pointnet2 ├── _ext-src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu ├── _ext.cpython-36m-x86_64-linux-gnu.so └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── pointnet2_modules.cpython-36.pyc │ └── pointnet2_utils.cpython-36.pyc │ ├── linalg_utils.py │ ├── pointnet2_modules.py │ └── pointnet2_utils.py ├── train_partseg_gpus.py ├── train_partseg_gpus.sh ├── voting_test.py └── voting_test.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ShiQiu0419 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense-Resolution Network for Point Cloud Classification and Segmentation 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dense-resolution-network-for-point-cloud/3d-point-cloud-classification-on-scanobjectnn)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=dense-resolution-network-for-point-cloud) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dense-resolution-network-for-point-cloud/3d-part-segmentation-on-shapenet-part)](https://paperswithcode.com/sota/3d-part-segmentation-on-shapenet-part?p=dense-resolution-network-for-point-cloud) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/dense-resolution-network-for-point-cloud/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=dense-resolution-network-for-point-cloud) 5 | 6 | This repository is for Dense-Resolution Networ (DRNet) introduced in the following paper 7 | 8 | [Shi Qiu](https://shiqiu0419.github.io/) [Saeed Anwar](https://saeed-anwar.github.io/), [Nick Barnes](http://users.cecs.anu.edu.au/~nmb/) 9 | "Dense-Resolution Network for Point Cloud Classification and Segmentation" 10 | IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2021) 11 | 12 | ## Paper 13 | The paper can be downloaded from [here (arXiv)](https://arxiv.org/abs/2005.06734) or [here (CVF)](https://openaccess.thecvf.com/content/WACV2021/papers/Qiu_Dense-Resolution_Network_for_Point_Cloud_Classification_and_Segmentation_WACV_2021_paper.pdf), together with the [supplementary material](https://openaccess.thecvf.com/content/WACV2021/supplemental/Qiu_Dense-Resolution_Network_for_WACV_2021_supplemental.pdf). 14 | 15 | ## Motivation 16 |

17 | 18 |

19 | 20 | ## Implementation 21 | * Python 3.6 22 | * Pytorch 1.3.0 23 | * Cuda 10.0 24 | 25 | ## Dataset 26 | Download the [ShapeNet Part Dataset](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and upzip it to your rootpath. Alternatively, you can modify the path of your dataset in `cfgs/config_partseg_gpus.yaml` and `cfgs/config_partseg_test.yaml`. 27 | 28 | ## CUDA Kernel Building 29 | For PyTorch version <= 0.4.0, please refer to [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN). 30 | For PyTorch version >= 1.0.0, please refer to [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). 31 | 32 | **Note:** 33 | In our DRNet, we use Farthest Point Sampling (e.g., `pointnet2_utils.furthest_point_sample`) to down-sample the point cloud. Also, we adpot Feature Propagation (e.g., `pointnet2_utils.three_nn` and `pointnet2_utils.three_interpolate`) to up-sample the feature maps. 34 | 35 | ## Training 36 | 37 | sh train_partseg_gpus.sh 38 | 39 | Due to the complexity of DRNet, we support Multi-GPU via `nn.DataParallel`. You can also adjust other parameters such as batch size or the number of input points in `cfgs/config_partseg_gpus.yaml`, in order to fit the memory limit of your device. 40 | 41 | ## Voting Evaluation 42 | You can set the path of your pre-trained model in `cfgs/config_partseg_test.yaml`, then run: 43 | 44 | sh voting_test.sh 45 | 46 | ## Citation 47 | 48 | If you find our paper is useful, please cite: 49 | 50 | @inproceedings{qiu2021dense, 51 | title={Dense-Resolution Network for Point Cloud Classification and Segmentation}, 52 | author={Qiu, Shi and Anwar, Saeed and Barnes, Nick}, 53 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 54 | month={January}, 55 | year={2021}, 56 | pages={3813-3822} 57 | } 58 | 59 | ## Acknowledgement 60 | The code is built on [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch), [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN), [DGCNN](https://github.com/WangYueFt/dgcnn/tree/master/pytorch). We thank the authors for sharing their codes. 61 | -------------------------------------------------------------------------------- /cfgs/config_partseg_gpus.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | workers: 8 3 | 4 | num_points: 2048 5 | num_classes: 50 6 | batch_size: 32 7 | 8 | base_lr: 0.001 9 | lr_clip: 0.00001 10 | lr_decay: 0.5 11 | decay_step: 21 12 | epochs: 200 13 | 14 | weight_decay: 0 15 | bn_momentum: 0.9 16 | bnm_clip: 0.01 17 | bn_decay: 0.5 18 | 19 | evaluate: 1 # validation in training process 20 | val_freq_epoch: 0.8 # frequency in epoch for validation, can be decimal 21 | print_freq_iter: 20 # frequency in iteration for printing infomation 22 | 23 | input_channels: 0 # feature channels except (x, y, z) 24 | 25 | checkpoint: '' # the model to start from 26 | save_path: seg_own 27 | data_root: shapenetcore_partanno_segmentation_benchmark_v0_normal 28 | -------------------------------------------------------------------------------- /cfgs/config_partseg_test.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | workers: 8 3 | 4 | num_points: 2048 5 | num_classes: 50 6 | batch_size: 4 7 | 8 | base_lr: 0.001 9 | lr_clip: 0.00001 10 | lr_decay: 0.5 11 | decay_step: 21 12 | epochs: 200 13 | 14 | weight_decay: 0 15 | bn_momentum: 0.9 16 | bnm_clip: 0.01 17 | bn_decay: 0.5 18 | 19 | evaluate: 1 # validation in training process 20 | val_freq_epoch: 0.8 # frequency in epoch for validation, can be decimal 21 | print_freq_iter: 20 # frequency in iteration for printing infomation 22 | 23 | input_channels: 0 # feature channels except (x, y, z) 24 | 25 | checkpoint: 'seg/your_trained_model.pth' # the model to start from 26 | save_path: seg_own 27 | data_root: shapenetcore_partanno_segmentation_benchmark_v0_normal 28 | -------------------------------------------------------------------------------- /data/ShapeNetPartLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | import json 5 | import numpy as np 6 | import sys 7 | import torchvision.transforms as transforms 8 | 9 | def pc_normalize(pc): 10 | l = pc.shape[0] 11 | centroid = np.mean(pc, axis=0) 12 | pc = pc - centroid 13 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 14 | pc = pc / m 15 | return pc 16 | 17 | class ShapeNetPart(): 18 | def __init__(self, root, num_points = 2048, split='train', normalize=True, transforms = None): 19 | self.transforms = transforms 20 | self.num_points = num_points 21 | self.root = root 22 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 23 | self.normalize = normalize 24 | 25 | self.cat = {} 26 | with open(self.catfile, 'r') as f: 27 | for line in f: 28 | ls = line.strip().split() 29 | self.cat[ls[0]] = ls[1] 30 | self.cat = {k:v for k,v in self.cat.items()} 31 | 32 | self.meta = {} 33 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 34 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 35 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 36 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 38 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 39 | for item in self.cat: 40 | self.meta[item] = [] 41 | dir_point = os.path.join(self.root, self.cat[item]) 42 | fns = sorted(os.listdir(dir_point)) 43 | if split=='trainval': 44 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 45 | elif split=='train': 46 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 47 | elif split=='val': 48 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 49 | elif split=='test': 50 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 51 | else: 52 | print('Unknown split: %s. Exiting..'%(split)) 53 | exit(-1) 54 | 55 | for fn in fns: 56 | token = (os.path.splitext(os.path.basename(fn))[0]) 57 | self.meta[item].append(os.path.join(dir_point, token + '.txt')) 58 | 59 | self.datapath = [] 60 | for item in self.cat: 61 | for fn in self.meta[item]: 62 | self.datapath.append((item, fn)) 63 | 64 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 65 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels 66 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 67 | 68 | self.cache = {} 69 | self.cache_size = 20000 70 | 71 | def __getitem__(self, index): 72 | if index in self.cache: 73 | point_set, seg, cls = self.cache[index] 74 | else: 75 | fn = self.datapath[index] 76 | cat = self.datapath[index][0] 77 | cls = self.classes[cat] 78 | cls = np.array([cls]).astype(np.int64) 79 | data = np.loadtxt(fn[1]).astype(np.float32) 80 | point_set = data[:,0:3] 81 | if self.normalize: 82 | point_set = pc_normalize(point_set) 83 | seg = data[:,-1].astype(np.int64) 84 | if len(self.cache) < self.cache_size: 85 | self.cache[index] = (point_set, seg, cls) 86 | 87 | choice = np.random.choice(len(seg), self.num_points, replace=True) 88 | #resample 89 | point_set = point_set[choice, :] 90 | seg = seg[choice] 91 | if self.transforms is not None: 92 | point_set = self.transforms(point_set) 93 | 94 | return point_set, torch.from_numpy(seg), torch.from_numpy(cls) 95 | 96 | def __len__(self): 97 | return len(self.datapath) 98 | 99 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # from .ModelNet40Loader import ModelNet40Cls 2 | from .ShapeNetPartLoader import ShapeNetPart 3 | # from .Indoor3DSemSegLoader import Indoor3DSemSeg 4 | -------------------------------------------------------------------------------- /data/__pycache__/Indoor3DSemSegLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/Indoor3DSemSegLoader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/ModelNet40Loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/ModelNet40Loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/ShapeNetPartLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/ShapeNetPartLoader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class PointcloudToTensor(object): 5 | def __call__(self, points): 6 | return torch.from_numpy(points).float() 7 | 8 | def angle_axis(angle: float, axis: np.ndarray): 9 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 10 | 11 | Parameters 12 | ---------- 13 | angle : float 14 | Angle to rotate by 15 | axis: np.ndarray 16 | Axis to rotate about 17 | 18 | Returns 19 | ------- 20 | torch.Tensor 21 | 3x3 rotation matrix 22 | """ 23 | u = axis / np.linalg.norm(axis) 24 | cosval, sinval = np.cos(angle), np.sin(angle) 25 | 26 | # yapf: disable 27 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 28 | [u[2], 0.0, -u[0]], 29 | [-u[1], u[0], 0.0]]) 30 | 31 | R = torch.from_numpy( 32 | cosval * np.eye(3) 33 | + sinval * cross_prod_mat 34 | + (1.0 - cosval) * np.outer(u, u) 35 | ) 36 | # yapf: enable 37 | return R.float() 38 | 39 | class PointcloudRotatebyAngle(object): 40 | def __init__(self, rotation_angle = 0.0): 41 | self.rotation_angle = rotation_angle 42 | 43 | def __call__(self, pc): 44 | normals = pc.size(2) > 3 45 | bsize = pc.size()[0] 46 | for i in range(bsize): 47 | cosval = np.cos(self.rotation_angle) 48 | sinval = np.sin(self.rotation_angle) 49 | rotation_matrix = np.array([[cosval, 0, sinval], 50 | [0, 1, 0], 51 | [-sinval, 0, cosval]]) 52 | rotation_matrix = torch.from_numpy(rotation_matrix).float().cuda() 53 | 54 | cur_pc = pc[i, :, :] 55 | if not normals: 56 | cur_pc = cur_pc @ rotation_matrix 57 | else: 58 | pc_xyz = cur_pc[:, 0:3] 59 | pc_normals = cur_pc[:, 3:] 60 | cur_pc[:, 0:3] = pc_xyz @ rotation_matrix 61 | cur_pc[:, 3:] = pc_normals @ rotation_matrix 62 | 63 | pc[i, :, :] = cur_pc 64 | 65 | return pc 66 | 67 | class PointcloudJitter(object): 68 | def __init__(self, std=0.01, clip=0.05): 69 | self.std, self.clip = std, clip 70 | 71 | def __call__(self, pc): 72 | bsize = pc.size()[0] 73 | for i in range(bsize): 74 | jittered_data = pc.new(pc.size(1), 3).normal_( 75 | mean=0.0, std=self.std 76 | ).clamp_(-self.clip, self.clip) 77 | pc[i, :, 0:3] += jittered_data 78 | 79 | return pc 80 | 81 | class PointcloudScaleAndTranslate(object): 82 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2): 83 | self.scale_low = scale_low 84 | self.scale_high = scale_high 85 | self.translate_range = translate_range 86 | 87 | def __call__(self, pc): 88 | bsize = pc.size()[0] 89 | for i in range(bsize): 90 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 91 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 92 | 93 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda() 94 | 95 | return pc 96 | 97 | class PointcloudScale(object): 98 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): 99 | self.scale_low = scale_low 100 | self.scale_high = scale_high 101 | 102 | def __call__(self, pc): 103 | bsize = pc.size()[0] 104 | for i in range(bsize): 105 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 106 | 107 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) 108 | 109 | return pc 110 | 111 | class PointcloudTranslate(object): 112 | def __init__(self, translate_range=0.2): 113 | self.translate_range = translate_range 114 | 115 | def __call__(self, pc): 116 | bsize = pc.size()[0] 117 | for i in range(bsize): 118 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 119 | 120 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda() 121 | 122 | return pc 123 | 124 | class PointcloudRandomInputDropout(object): 125 | def __init__(self, max_dropout_ratio=0.875): 126 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 127 | self.max_dropout_ratio = max_dropout_ratio 128 | 129 | def __call__(self, pc): 130 | bsize = pc.size()[0] 131 | for i in range(bsize): 132 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 133 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0] 134 | if len(drop_idx) > 0: 135 | cur_pc = pc[i, :, :] 136 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point 137 | pc[i, :, :] = cur_pc 138 | 139 | return pc 140 | -------------------------------------------------------------------------------- /figures/cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/cls.png -------------------------------------------------------------------------------- /figures/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/compare.png -------------------------------------------------------------------------------- /figures/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/intro.png -------------------------------------------------------------------------------- /figures/modelnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/modelnet.png -------------------------------------------------------------------------------- /figures/new_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/new_network.png -------------------------------------------------------------------------------- /figures/partseg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/partseg.png -------------------------------------------------------------------------------- /figures/scanobjectnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/scanobjectnn.png -------------------------------------------------------------------------------- /figures/shapenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/shapenet.png -------------------------------------------------------------------------------- /figures/visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/visual.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .drnet import DRNET as DRNET_Seg -------------------------------------------------------------------------------- /models/drnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import copy 7 | import math 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from pointnet2.utils import pointnet2_utils 13 | 14 | 15 | def knn(x, k): 16 | inner = -2*torch.matmul(x.transpose(2, 1), x) 17 | xx = torch.sum(x**2, dim=1, keepdim=True) 18 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 19 | 20 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 21 | return idx 22 | 23 | 24 | def pw_dist(x): 25 | inner = -2*torch.matmul(x.transpose(2, 1), x) 26 | xx = torch.sum(x**2, dim=1, keepdim=True) 27 | pairwise_distance = -xx - inner - xx.transpose(2, 1) # (batch_size, num_points, n) 28 | 29 | return -pairwise_distance 30 | 31 | 32 | def knn_metric(x, d, conv_op1, conv_op2, conv_op11, k): 33 | batch_size = x.size(0) 34 | num_points = x.size(2) 35 | inner = -2*torch.matmul(x.transpose(2, 1), x) 36 | xx = torch.sum(x**2, dim=1, keepdim=True) 37 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 38 | 39 | metric = (-pairwise_distance).topk(k=d*k, dim=-1, largest=False)[0] # B,N,100 40 | metric_idx = (-pairwise_distance).topk(k=d*k, dim=-1, largest=False)[1] # B,N,100 41 | metric_trans = metric.permute(0, 2, 1) # B,100,N 42 | metric = conv_op1(metric_trans) # B,50,N 43 | metric = torch.squeeze(conv_op11(metric).permute(0, 2, 1), -1) # B,N 44 | # normalize function 45 | metric = torch.sigmoid(-metric) 46 | # projection function 47 | metric = 5 * metric + 0.5 48 | # scaling function 49 | 50 | value1 = torch.where((metric>=0.5)&(metric<1.5), torch.full_like(metric, 1), torch.full_like(metric, 0)) 51 | value2 = torch.where((metric>=1.5)&(metric<2.5), torch.full_like(metric, 2), torch.full_like(metric, 0)) 52 | value3 = torch.where((metric>=2.5)&(metric<3.5), torch.full_like(metric, 3), torch.full_like(metric, 0)) 53 | value4 = torch.where((metric>=3.5)&(metric<4.5), torch.full_like(metric, 4), torch.full_like(metric, 0)) 54 | value5 = torch.where((metric>=4.5)&(metric<=5.5), torch.full_like(metric, 5), torch.full_like(metric, 0)) 55 | 56 | value = value1 + value2 + value3 + value4 + value5 # B,N 57 | 58 | select_idx = torch.cuda.LongTensor(np.arange(k)) # k 59 | select_idx = torch.unsqueeze(select_idx, 0).repeat(num_points, 1) # N,k 60 | select_idx = torch.unsqueeze(select_idx, 0).repeat(batch_size, 1, 1) # B,N,k 61 | value = torch.unsqueeze(value, -1).repeat(1, 1, k) # B,N,k 62 | select_idx = select_idx * value 63 | select_idx = select_idx.long() 64 | idx = pairwise_distance.topk(k=k*d, dim=-1)[1] # (batch_size, num_points, k*d) 65 | # dilatedly selecting k from k*d idx 66 | idx = torch.gather(idx, dim=-1, index=select_idx) # B,N,k 67 | return idx 68 | 69 | 70 | def get_adptive_dilated_graph_feature(x, conv_op1, conv_op2, conv_op11, d=5, k=20, idx=None): 71 | batch_size = x.size(0) 72 | num_points = x.size(2) 73 | x = x.view(batch_size, -1, num_points) 74 | if idx is None: 75 | idx = knn_metric(x, d, conv_op1, conv_op2, conv_op11, k=k) # (batch_size, num_points, k) 76 | device = torch.device('cuda') 77 | idx_base = torch.arange(0, batch_size, device=device) 78 | idx_base = idx_base.view(-1, 1, 1)*num_points 79 | idx_base = idx_base.type(torch.cuda.LongTensor) 80 | idx = idx.type(torch.cuda.LongTensor) 81 | idx = idx + idx_base 82 | idx = idx.view(-1) 83 | _, num_dims, _ = x.size() 84 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 85 | feature = x.view(batch_size*num_points, -1)[idx, :] 86 | feature = feature.view(batch_size, num_points, k, num_dims) 87 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 88 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2) 89 | 90 | return feature 91 | 92 | 93 | def get_graph_feature(x, k=20, idx=None): 94 | batch_size = x.size(0) 95 | num_points = x.size(2) 96 | x = x.view(batch_size, -1, num_points) 97 | if idx is None: 98 | idx = knn(x, k=k) # (batch_size, num_points, k) 99 | device = torch.device('cuda') 100 | 101 | idx_base = torch.arange(0, batch_size, device=device) 102 | idx_base = idx_base.view(-1, 1, 1)*num_points 103 | idx_base = idx_base.type(torch.cuda.LongTensor) 104 | idx = idx.type(torch.cuda.LongTensor) 105 | idx = idx + idx_base 106 | idx = idx.view(-1) 107 | _, num_dims, _ = x.size() 108 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 109 | feature = x.view(batch_size*num_points, -1)[idx, :] 110 | feature = feature.view(batch_size, num_points, k, num_dims) 111 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 112 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2) 113 | 114 | return feature 115 | 116 | 117 | class Channel_fusion(nn.Module): 118 | def __init__(self, in_dim): 119 | super(Channel_fusion, self).__init__() 120 | self.channel_in = in_dim 121 | 122 | self.bn1 = nn.BatchNorm1d(in_dim//8) 123 | self.bn2 = nn.BatchNorm1d(in_dim) 124 | 125 | self.squeeze_conv = nn.Sequential(nn.Conv1d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1, bias=False), 126 | self.bn1) 127 | self.excite_conv = nn.Sequential(nn.Conv1d(in_channels=in_dim//8, out_channels=in_dim, kernel_size=1, bias=False), 128 | self.bn2) 129 | 130 | def forward(self, x): 131 | """ 132 | inputs : 133 | x : B, C, N 134 | returns : 135 | out : B, C, N 136 | """ 137 | batch_size, channel_num, num_points= x.size() 138 | fusion_score = F.adaptive_avg_pool1d(x, 1) # B, C, 1 139 | fusion_score = self.squeeze_conv(fusion_score) # B, C', 1 140 | fusion_score = F.relu(fusion_score) 141 | fusion_score = self.excite_conv(fusion_score) # B, C, 1 142 | fusion_score = torch.sigmoid(fusion_score).expand_as(x) # B, C, N 143 | 144 | return fusion_score 145 | 146 | 147 | class DRNET(nn.Module): 148 | def __init__(self, num_classes): 149 | super().__init__() 150 | self.k = 20 151 | 152 | self.bn1 = nn.BatchNorm2d(64) 153 | self.bn11 = nn.BatchNorm2d(3) 154 | self.bn12 = nn.BatchNorm1d(512) 155 | self.bn13 = nn.BatchNorm1d(1024) 156 | self.bn14 = nn.BatchNorm2d(64) 157 | self.bn2 = nn.BatchNorm2d(64) 158 | self.bn21 = nn.BatchNorm1d(256) 159 | self.bn22 = nn.BatchNorm1d(128) 160 | self.bn23 = nn.BatchNorm1d(128) 161 | self.bn24 = nn.BatchNorm2d(128) 162 | self.bn3 = nn.BatchNorm2d(256) 163 | self.bn32 = nn.BatchNorm1d(256) 164 | self.bn33 = nn.BatchNorm1d(256) 165 | self.bn34 = nn.BatchNorm1d(128) 166 | self.bn35 = nn.BatchNorm1d(128) 167 | self.bn5 = nn.BatchNorm1d(64) 168 | self.bn53 = nn.BatchNorm1d(512) 169 | self.bn54 = nn.BatchNorm1d(1024) 170 | self.bn63 = nn.BatchNorm1d(128) 171 | self.bn64 = nn.BatchNorm1d(50) 172 | self.bn7 = nn.BatchNorm1d(1024) 173 | 174 | self.bn8 = nn.BatchNorm1d(1024) 175 | self.bn9 = nn.BatchNorm1d(1024) 176 | 177 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 178 | self.bn1, 179 | nn.LeakyReLU(negative_slope=0.2)) 180 | self.conv11 = nn.Sequential(nn.Conv2d(64, 3, kernel_size=[1,20], bias=False), 181 | self.bn11, 182 | nn.LeakyReLU(negative_slope=0.2)) 183 | self.conv12 = nn.Sequential(nn.Conv1d(448, 512, kernel_size=1, bias=False), 184 | self.bn12, 185 | nn.LeakyReLU(negative_slope=0.2)) 186 | self.conv13 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 187 | self.bn13, 188 | nn.LeakyReLU(negative_slope=0.2)) 189 | self.conv14 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 190 | self.bn14, 191 | nn.LeakyReLU(negative_slope=0.2)) 192 | self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False), 193 | self.bn2, 194 | nn.LeakyReLU(negative_slope=0.2)) 195 | self.conv21 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False), 196 | self.bn21, 197 | nn.LeakyReLU(negative_slope=0.2)) 198 | self.conv22 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False), 199 | self.bn22, 200 | nn.LeakyReLU(negative_slope=0.2)) 201 | self.conv23 = nn.Sequential(nn.Conv1d(192, 128, kernel_size=1, bias=False), 202 | self.bn23, 203 | nn.LeakyReLU(negative_slope=0.2)) 204 | self.conv24 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), 205 | self.bn24, 206 | nn.LeakyReLU(negative_slope=0.2)) 207 | self.conv3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False), 208 | self.bn3, 209 | nn.LeakyReLU(negative_slope=0.2)) 210 | self.conv32 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False), 211 | self.bn32, 212 | nn.LeakyReLU(negative_slope=0.2)) 213 | self.conv33 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False), 214 | self.bn33, 215 | nn.LeakyReLU(negative_slope=0.2)) 216 | self.conv34 = nn.Sequential(nn.Conv1d(384, 128, kernel_size=1, bias=False), 217 | self.bn34, 218 | nn.LeakyReLU(negative_slope=0.2)) 219 | self.conv35 = nn.Sequential(nn.Conv1d(320, 128, kernel_size=1, bias=False), 220 | self.bn35, 221 | nn.LeakyReLU(negative_slope=0.2)) 222 | self.conv5 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False), 223 | self.bn5, 224 | nn.LeakyReLU(negative_slope=0.2)) 225 | self.conv53 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, bias=False), 226 | self.bn53, 227 | nn.LeakyReLU(negative_slope=0.2)) 228 | self.conv54 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 229 | self.bn54, 230 | nn.LeakyReLU(negative_slope=0.2)) 231 | self.conv63 = nn.Sequential(nn.Conv1d(1024, 128, kernel_size=1, bias=False), 232 | self.bn63, 233 | nn.LeakyReLU(negative_slope=0.2)) 234 | self.conv64 = nn.Sequential(nn.Conv1d(128, 50, kernel_size=1)) 235 | 236 | self.conv7 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 237 | self.bn7, 238 | nn.LeakyReLU(negative_slope=0.2)) 239 | self.conv8 = nn.Sequential(nn.Conv1d(2112, 1024, kernel_size=1, bias=False), 240 | self.bn8, 241 | nn.LeakyReLU(negative_slope=0.2)) 242 | self.conv9 = nn.Sequential(nn.Conv1d(1024, 1024, kernel_size=1, bias=False), 243 | self.bn9, 244 | nn.LeakyReLU(negative_slope=0.2)) 245 | 246 | self.bnfc2 = nn.BatchNorm2d(64) 247 | self.bnfc21 = nn.BatchNorm2d(64) 248 | self.bnfc24 = nn.BatchNorm2d(64) 249 | self.bnfc3 = nn.BatchNorm2d(128) 250 | self.bnfc31 = nn.BatchNorm2d(64) 251 | self.bnfc4 = nn.BatchNorm2d(256) 252 | self.bnfc41 = nn.BatchNorm2d(128) 253 | self.convfc2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False), 254 | self.bnfc2, 255 | nn.LeakyReLU(negative_slope=0.2)) 256 | self.convfc21 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=[1,20], bias=False), 257 | self.bnfc21, 258 | nn.LeakyReLU(negative_slope=0.2)) 259 | self.convfc24 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 260 | self.bnfc24, 261 | nn.LeakyReLU(negative_slope=0.2)) 262 | self.convfc3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False), 263 | self.bnfc3, 264 | nn.LeakyReLU(negative_slope=0.2)) 265 | self.convfc31 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=[1,20], bias=False), 266 | self.bnfc31, 267 | nn.LeakyReLU(negative_slope=0.2)) 268 | self.convfc4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False), 269 | self.bnfc4, 270 | nn.LeakyReLU(negative_slope=0.2)) 271 | self.convfc41 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=[1,20], bias=False), 272 | self.bnfc41, 273 | nn.LeakyReLU(negative_slope=0.2)) 274 | 275 | self.bnop1 = nn.BatchNorm1d(50) 276 | self.bnop2 = nn.BatchNorm1d(50) 277 | self.bnop3 = nn.BatchNorm1d(50) 278 | self.bnop4 = nn.BatchNorm1d(50) 279 | self.bnop11 = nn.BatchNorm1d(1) 280 | self.bnop21 = nn.BatchNorm1d(1) 281 | self.bnop31 = nn.BatchNorm1d(1) 282 | self.bnop41 = nn.BatchNorm1d(1) 283 | self.bnop12 = nn.BatchNorm1d(1) 284 | self.bnop22 = nn.BatchNorm1d(1) 285 | self.bnop32 = nn.BatchNorm1d(1) 286 | self.bnop42 = nn.BatchNorm1d(1) 287 | self.conv_op1 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop1, nn.LeakyReLU(negative_slope=0.2)) 288 | self.conv_op2 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop2, nn.LeakyReLU(negative_slope=0.2)) 289 | self.conv_op3 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop3, nn.LeakyReLU(negative_slope=0.2)) 290 | self.conv_op4 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop4, nn.LeakyReLU(negative_slope=0.2)) 291 | self.conv_op11 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop11) 292 | self.conv_op21 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop21) 293 | self.conv_op31 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop31) 294 | self.conv_op41 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop41) 295 | self.conv_op12 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop12) 296 | self.conv_op22 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop22) 297 | self.conv_op32 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop32) 298 | self.conv_op42 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop42) 299 | 300 | self.fuse = Channel_fusion(1024) 301 | 302 | self.dp = nn.Dropout(p=0.5) 303 | 304 | def _break_up_pc(self, pc): 305 | xyz = pc[..., 0:3].contiguous() 306 | features = ( 307 | pc[..., 3:].transpose(1, 2).contiguous() 308 | if pc.size(-1) > 3 else None 309 | ) 310 | 311 | return xyz, features 312 | 313 | def forward(self, pointcloud: torch.cuda.FloatTensor, cls): 314 | # x: B,3,N 315 | 316 | xyz, features = self._break_up_pc(pointcloud) 317 | num_pts = xyz.size(1) 318 | batch_size = xyz.size(0) 319 | # FPS to find different point subsets and their relations 320 | subset1_idx = pointnet2_utils.furthest_point_sample(xyz, num_pts//4).long() # B,N/2 321 | subset1_xyz = torch.unsqueeze(subset1_idx, -1).repeat(1, 1, 3) # B,N/2,3 322 | subset1_xyz = torch.take(xyz, subset1_xyz) # B,N/2,3 323 | 324 | dist, idx1 = pointnet2_utils.three_nn(xyz, subset1_xyz) 325 | dist_recip = 1.0 / (dist + 1e-8) 326 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 327 | weight1 = dist_recip / norm 328 | 329 | subset12_idx = pointnet2_utils.furthest_point_sample(subset1_xyz, num_pts//16).long() # B,N/4 330 | subset12_xyz = torch.unsqueeze(subset12_idx, -1).repeat(1, 1, 3) # B,N/4,3 331 | subset12_xyz = torch.take(subset1_xyz, subset12_xyz) # B,N/4,3 332 | 333 | dist, idx12 = pointnet2_utils.three_nn(subset1_xyz, subset12_xyz) 334 | dist_recip = 1.0 / (dist + 1e-8) 335 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 336 | weight12 = dist_recip / norm 337 | 338 | device = torch.device('cuda') 339 | centroid = torch.zeros([batch_size, 1, 3], device=device) 340 | dist, idx0 = pointnet2_utils.three_nn(subset12_xyz, centroid) 341 | dist_recip = 1.0 / (dist + 1e-8) 342 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 343 | weight0 = dist_recip / norm 344 | ####################################### 345 | # Error-minimizing module 1: 346 | # Encoding 347 | x = xyz.transpose(2, 1) # x: B,3,N 348 | x1_1 = x 349 | x = get_adptive_dilated_graph_feature(x, self.conv_op1, self.conv_op11, self.conv_op12, d=5, k=20) 350 | x = self.conv1(x) # B,64,N,k 351 | x = self.conv14(x) # B,64,N,k 352 | x1_2 = x 353 | # Back-projection 354 | x = self.conv11(x) # B,3,N,1 355 | x = torch.squeeze(x, -1) # B,3,N 356 | x1_3 = x 357 | # Calculating Error 358 | delta_1 = x1_3 - x1_1 # B,3,N 359 | # Output 360 | x = x1_2 # B,64,N,k 361 | x1 = x.max(dim=-1, keepdim=False)[0] # B,64,N 362 | ####################################### 363 | 364 | ####################################### 365 | # Multi-resolution (MR) Branch 366 | # Down-scaling 1 367 | subset1_feat = torch.unsqueeze(subset1_idx, -1).repeat(1, 1, 64) # B,N/2,64 368 | x1_subset1 = torch.take(x1.transpose(1, 2).contiguous(), subset1_feat).transpose(1, 2).contiguous() # B,64,N/2 369 | 370 | x2_1 = x1_subset1 # B,64,N/2 371 | x = get_graph_feature(x1_subset1, k=self.k//2) 372 | x = self.conv2(x) # B,64,N/2,k 373 | x = self.conv24(x) # B,128,N/2,k 374 | x2 = x.max(dim=-1, keepdim=False)[0] # B,128,N/2 375 | 376 | # Dense-connection 377 | x12 = pointnet2_utils.three_interpolate(x2, idx1, weight1) # B,128,N 378 | x12 = torch.cat((x12, x1), dim=1) # B,192,N 379 | x12 = self.conv23(x12) # B,128,N 380 | 381 | # Down-scaling 2 382 | subset12_feat = torch.unsqueeze(subset12_idx, -1).repeat(1, 1, 128) # B,N/4,128 383 | x2_subset12 = torch.take(x2.transpose(1, 2).contiguous(), subset12_feat).transpose(1, 2).contiguous() # B,128,N/4 384 | 385 | x3_1 = x2_subset12 # B,128,N/4 386 | x = get_graph_feature(x2_subset12, k=self.k//4) 387 | x = self.conv3(x) # B,256,N/4,k 388 | x3 = x.max(dim=-1, keepdim=False)[0] # B,256,N/4 389 | 390 | # Dense-connection 391 | x23 = pointnet2_utils.three_interpolate(x3, idx12, weight12) # B,256,N/2 392 | x23 = torch.cat((x23, x2), dim=1) # B,384,N/2 393 | x23 = self.conv34(x23) # B,128,N/2 394 | x123 = pointnet2_utils.three_interpolate(x23, idx1, weight1) # B,128,N 395 | x123 = torch.cat((x123, x12, x1), dim=1) # B,320,N 396 | x123 = self.conv35(x123) # B,128,N 397 | 398 | # Down-scaling 3 399 | x_bot = self.conv53(x3) 400 | x_bot = self.conv54(x_bot) # B,1024,N/128 401 | x_bot = F.adaptive_max_pool1d(x_bot, 1) # B,1024,1 402 | 403 | # Upsampling 3: 404 | interpolated_feats1 = pointnet2_utils.three_interpolate(x_bot, idx0, weight0) # B,1024,N/4 405 | interpolated_feats2 = x3 # B,256,N/4 406 | x3_up = torch.cat((interpolated_feats1, interpolated_feats2), dim=1) # B,1280,N/4 407 | x3_up = self.conv32(x3_up) # B,256,N/4 408 | x3_up = self.conv33(x3_up) # B,256,N/4 409 | 410 | # Upsampling 2: 411 | interpolated_feats1 = pointnet2_utils.three_interpolate(x3_up, idx12, weight12) # B,256,N/2 412 | interpolated_feats2 = x2 # B,128,N/2 413 | interpolated_feats3 = x23 # B,128,N/2 414 | x2_up = torch.cat((interpolated_feats1, interpolated_feats3, interpolated_feats2), dim=1) # B,512,N/2 415 | x2_up = self.conv21(x2_up) # B,256,N/2 416 | x2_up = self.conv22(x2_up) # B,128,N/2 417 | 418 | # Upsampling 1: 419 | interpolated_feats1 = pointnet2_utils.three_interpolate(x2_up, idx1, weight1) # B,128,N 420 | interpolated_feats2 = x1 # B,64,N 421 | interpolated_feats3 = x12 # B,128,N 422 | interpolated_feats4 = x123 # B,128,N 423 | x1_up = torch.cat((interpolated_feats1, interpolated_feats4, interpolated_feats3, interpolated_feats2), dim=1) # B,448,N 424 | x1_up = self.conv12(x1_up) # B,512,N 425 | x1_up = self.conv13(x1_up) # B,1024,N 426 | 427 | x_mr = x1_up 428 | ############################################################################# 429 | 430 | ############################################################################# 431 | # Full-resolution Branch 432 | # Error-minimizing module 2: 433 | # Encoding 434 | x2_1 = x1 # B,64,N 435 | x = get_adptive_dilated_graph_feature(x1, self.conv_op2, self.conv_op21, self.conv_op22, d=5, k=20) 436 | x = self.convfc2(x) # B,64,N,k 437 | x = self.convfc24(x) # B,64,N,k 438 | x2_2 = x 439 | # Back-projection 440 | x = self.convfc21(x) # B,64,N,1 441 | x = torch.squeeze(x, -1) # B,64,N 442 | x2_3 = x 443 | # Calculating Error 444 | delta_2 = x2_3 - x2_1 # B,64,N 445 | # Output 446 | x = x2_2 # B,64,N,k 447 | x2 = x.max(dim=-1, keepdim=False)[0] # B,64,N 448 | ####################################### 449 | # Error-minimizing module 3: 450 | # Encoding 451 | x3_1 = x2 # B,64,N 452 | x = get_adptive_dilated_graph_feature(x2, self.conv_op3, self.conv_op31, self.conv_op32, d=5, k=20) 453 | x = self.convfc3(x) # B,128,N,k 454 | x3_2 = x 455 | # Back-projection 456 | x = self.convfc31(x) # B,64,N,1 457 | x = torch.squeeze(x, -1) # B,64,N 458 | x3_3 = x 459 | # Calculating Error 460 | delta_3 = x3_3 - x3_1 # B,64,N 461 | # Output 462 | x = x3_2 # B,128,N,k 463 | x3 = x.max(dim=-1, keepdim=False)[0] # B,128,N 464 | ####################################### 465 | # Error-minimizing module 4: 466 | # Encoding 467 | x4_1 = x3 # B,128,N 468 | x = get_adptive_dilated_graph_feature(x3, self.conv_op4, self.conv_op41, self.conv_op42, d=5, k=20) 469 | x = self.convfc4(x) # B,256,N,k 470 | x4_2 = x 471 | # Back-projection 472 | x = self.convfc41(x) # B,128,N,1 473 | x = torch.squeeze(x, -1) # B,128,N 474 | x4_3 = x 475 | # Calculating Error 476 | delta_4 = x4_3 - x4_1 # B,128,N 477 | # Output 478 | x = x4_2 # B,256,N,k 479 | x4 = x.max(dim=-1, keepdim=False)[0] # B,256,N 480 | 481 | x = torch.cat((x1, x2, x3, x4), dim=1) # B,512,N 482 | x_fr = self.conv7(x) # B,1024,N 483 | 484 | # Fusing FR and MR outputs 485 | fusion_score = self.fuse(x_mr) 486 | x = x_fr + x_fr * fusion_score 487 | x_all = self.conv9(x) # B,1024,N 488 | 489 | # Collecting global feature 490 | one_hot_label = cls.view(-1, 16, 1) # B,16,1 491 | one_hot_label = self.conv5(one_hot_label) # B,64,1 492 | x_max = F.adaptive_max_pool1d(x_all, 1) # B,1024,1 493 | x_global = torch.cat((x_max, one_hot_label), dim=1) # B,1088,1 494 | 495 | x_global = x_global.repeat(1, 1, num_pts) # B,1088,N 496 | x = torch.cat((x_all, x_global), dim=1) # B,2112,N 497 | 498 | x = self.conv8(x) # B,1024,N 499 | 500 | x = self.conv63(x) # B,128,N 501 | x = self.dp(x) 502 | x = self.conv64(x) # B,50,N 503 | 504 | return (x.transpose(2, 1).contiguous(), 505 | delta_1.transpose(2, 1).contiguous(), 506 | delta_2.transpose(2, 1).contiguous(), 507 | delta_3.transpose(2, 1).contiguous(), 508 | delta_4.transpose(2, 1).contiguous()) 509 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.type().is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.type().is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data(), 26 | xyz.data(), idx.data()); 27 | } else { 28 | AT_CHECK(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.type().is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.type().is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), points.data(), 29 | idx.data(), output.data()); 30 | } else { 31 | AT_CHECK(false, "CPU not supported"); 32 | } 33 | 34 | return output; 35 | } 36 | 37 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 38 | CHECK_CONTIGUOUS(grad_out); 39 | CHECK_CONTIGUOUS(idx); 40 | CHECK_IS_FLOAT(grad_out); 41 | CHECK_IS_INT(idx); 42 | 43 | if (grad_out.type().is_cuda()) { 44 | CHECK_CUDA(idx); 45 | } 46 | 47 | at::Tensor output = 48 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 49 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 50 | 51 | if (grad_out.type().is_cuda()) { 52 | group_points_grad_kernel_wrapper( 53 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 54 | grad_out.data(), idx.data(), output.data()); 55 | } else { 56 | AT_CHECK(false, "CPU not supported"); 57 | } 58 | 59 | return output; 60 | } 61 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.type().is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.type().is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data(), knows.data(), 34 | dist2.data(), idx.data()); 35 | } else { 36 | AT_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.type().is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.type().is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data(), idx.data(), weight.data(), 64 | output.data()); 65 | } else { 66 | AT_CHECK(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.type().is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.type().is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data(), idx.data(), weight.data(), 93 | output.data()); 94 | } else { 95 | AT_CHECK(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.type().is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.type().is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data(), 32 | idx.data(), output.data()); 33 | } else { 34 | AT_CHECK(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.type().is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.type().is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data(), 58 | idx.data(), output.data()); 59 | } else { 60 | AT_CHECK(false, "CPU not supported"); 61 | } 62 | 63 | return output; 64 | } 65 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 66 | CHECK_CONTIGUOUS(points); 67 | CHECK_IS_FLOAT(points); 68 | 69 | at::Tensor output = 70 | torch::zeros({points.size(0), nsamples}, 71 | at::device(points.device()).dtype(at::ScalarType::Int)); 72 | 73 | at::Tensor tmp = 74 | torch::full({points.size(0), points.size(1)}, 1e10, 75 | at::device(points.device()).dtype(at::ScalarType::Float)); 76 | 77 | if (points.type().is_cuda()) { 78 | furthest_point_sampling_kernel_wrapper( 79 | points.size(0), points.size(1), nsamples, points.data(), 80 | tmp.data(), output.data()); 81 | } else { 82 | AT_CHECK(false, "CPU not supported"); 83 | } 84 | 85 | return output; 86 | } 87 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2/_ext.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/_ext.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pointnet2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | from . import pointnet2_utils 9 | from . import pointnet2_modules 10 | -------------------------------------------------------------------------------- /pointnet2/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pointnet2/utils/__pycache__/pointnet2_modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/pointnet2_modules.cpython-36.pyc -------------------------------------------------------------------------------- /pointnet2/utils/__pycache__/pointnet2_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/pointnet2_utils.cpython-36.pyc -------------------------------------------------------------------------------- /pointnet2/utils/linalg_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | from enum import Enum 10 | import numpy as np 11 | 12 | PDist2Order = Enum("PDist2Order", "d_first d_second") 13 | 14 | 15 | def pdist2(X, Z=None, order=PDist2Order.d_second): 16 | # type: (torch.Tensor, torch.Tensor, PDist2Order) -> torch.Tensor 17 | r""" Calculates the pairwise distance between X and Z 18 | 19 | D[b, i, j] = l2 distance X[b, i] and Z[b, j] 20 | 21 | Parameters 22 | --------- 23 | X : torch.Tensor 24 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d 25 | Z: torch.Tensor 26 | Z is a (B, M, d) tensor. If Z is None, then Z = X 27 | 28 | Returns 29 | ------- 30 | torch.Tensor 31 | Distance matrix is size (B, N, M) 32 | """ 33 | 34 | if order == PDist2Order.d_second: 35 | if X.dim() == 2: 36 | X = X.unsqueeze(0) 37 | if Z is None: 38 | Z = X 39 | G = np.matmul(X, Z.transpose(-2, -1)) 40 | S = (X * X).sum(-1, keepdim=True) 41 | R = S.transpose(-2, -1) 42 | else: 43 | if Z.dim() == 2: 44 | Z = Z.unsqueeze(0) 45 | G = np.matmul(X, Z.transpose(-2, -1)) 46 | S = (X * X).sum(-1, keepdim=True) 47 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1) 48 | else: 49 | if X.dim() == 2: 50 | X = X.unsqueeze(0) 51 | if Z is None: 52 | Z = X 53 | G = np.matmul(X.transpose(-2, -1), Z) 54 | R = (X * X).sum(-2, keepdim=True) 55 | S = R.transpose(-2, -1) 56 | else: 57 | if Z.dim() == 2: 58 | Z = Z.unsqueeze(0) 59 | G = np.matmul(X.transpose(-2, -1), Z) 60 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1) 61 | R = (Z * Z).sum(-2, keepdim=True) 62 | 63 | return torch.abs(R + S - 2 * G).squeeze(0) 64 | 65 | 66 | def pdist2_slow(X, Z=None): 67 | if Z is None: 68 | Z = X 69 | D = torch.zeros(X.size(0), X.size(2), Z.size(2)) 70 | 71 | for b in range(D.size(0)): 72 | for i in range(D.size(1)): 73 | for j in range(D.size(2)): 74 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j]) 75 | return D 76 | 77 | 78 | if __name__ == "__main__": 79 | X = torch.randn(2, 3, 5) 80 | Z = torch.randn(2, 3, 3) 81 | 82 | print(pdist2(X, order=PDist2Order.d_first)) 83 | print(pdist2_slow(X)) 84 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X))) 85 | -------------------------------------------------------------------------------- /pointnet2/utils/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import etw_pytorch_utils as pt_utils 12 | 13 | from pointnet2.utils import pointnet2_utils 14 | 15 | if False: 16 | # Workaround for type hints without depending on the `typing` module 17 | from typing import * 18 | 19 | 20 | class _PointnetSAModuleBase(nn.Module): 21 | def __init__(self): 22 | super(_PointnetSAModuleBase, self).__init__() 23 | self.npoint = None 24 | self.groupers = None 25 | self.mlps = None 26 | 27 | def forward(self, xyz, features=None): 28 | # type: (_PointnetSAModuleBase, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 29 | r""" 30 | Parameters 31 | ---------- 32 | xyz : torch.Tensor 33 | (B, N, 3) tensor of the xyz coordinates of the features 34 | features : torch.Tensor 35 | (B, N, C) tensor of the descriptors of the the features 36 | 37 | Returns 38 | ------- 39 | new_xyz : torch.Tensor 40 | (B, npoint, 3) tensor of the new features' xyz 41 | new_features : torch.Tensor 42 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 43 | """ 44 | 45 | new_features_list = [] 46 | 47 | xyz_flipped = xyz.transpose(1, 2).contiguous() 48 | new_xyz = ( 49 | pointnet2_utils.gather_operation( 50 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 51 | ) 52 | .transpose(1, 2) 53 | .contiguous() 54 | if self.npoint is not None 55 | else None 56 | ) 57 | 58 | for i in range(len(self.groupers)): 59 | new_features = self.groupers[i]( 60 | xyz, new_xyz, features 61 | ) # (B, C, npoint, nsample) 62 | 63 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 64 | new_features = F.max_pool2d( 65 | new_features, kernel_size=[1, new_features.size(3)] 66 | ) # (B, mlp[-1], npoint, 1) 67 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 68 | 69 | new_features_list.append(new_features) 70 | 71 | return new_xyz, torch.cat(new_features_list, dim=1) 72 | 73 | 74 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 75 | r"""Pointnet set abstrction layer with multiscale grouping 76 | 77 | Parameters 78 | ---------- 79 | npoint : int 80 | Number of features 81 | radii : list of float32 82 | list of radii to group with 83 | nsamples : list of int32 84 | Number of samples in each ball query 85 | mlps : list of list of int32 86 | Spec of the pointnet before the global max_pool for each scale 87 | bn : bool 88 | Use batchnorm 89 | """ 90 | 91 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 92 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 93 | super(PointnetSAModuleMSG, self).__init__() 94 | 95 | assert len(radii) == len(nsamples) == len(mlps) 96 | 97 | self.npoint = npoint 98 | self.groupers = nn.ModuleList() 99 | self.mlps = nn.ModuleList() 100 | for i in range(len(radii)): 101 | radius = radii[i] 102 | nsample = nsamples[i] 103 | self.groupers.append( 104 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 105 | if npoint is not None 106 | else pointnet2_utils.GroupAll(use_xyz) 107 | ) 108 | mlp_spec = mlps[i] 109 | if use_xyz: 110 | mlp_spec[0] += 3 111 | 112 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 113 | 114 | 115 | class PointnetSAModule(PointnetSAModuleMSG): 116 | r"""Pointnet set abstrction layer 117 | 118 | Parameters 119 | ---------- 120 | npoint : int 121 | Number of features 122 | radius : float 123 | Radius of ball 124 | nsample : int 125 | Number of samples in the ball query 126 | mlp : list 127 | Spec of the pointnet before the global max_pool 128 | bn : bool 129 | Use batchnorm 130 | """ 131 | 132 | def __init__( 133 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 134 | ): 135 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 136 | super(PointnetSAModule, self).__init__( 137 | mlps=[mlp], 138 | npoint=npoint, 139 | radii=[radius], 140 | nsamples=[nsample], 141 | bn=bn, 142 | use_xyz=use_xyz, 143 | ) 144 | 145 | 146 | class PointnetFPModule(nn.Module): 147 | r"""Propigates the features of one set to another 148 | 149 | Parameters 150 | ---------- 151 | mlp : list 152 | Pointnet module parameters 153 | bn : bool 154 | Use batchnorm 155 | """ 156 | 157 | def __init__(self, mlp, bn=True): 158 | # type: (PointnetFPModule, List[int], bool) -> None 159 | super(PointnetFPModule, self).__init__() 160 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 161 | 162 | def forward(self, unknown, known, unknow_feats, known_feats): 163 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 164 | r""" 165 | Parameters 166 | ---------- 167 | unknown : torch.Tensor 168 | (B, n, 3) tensor of the xyz positions of the unknown features 169 | known : torch.Tensor 170 | (B, m, 3) tensor of the xyz positions of the known features 171 | unknow_feats : torch.Tensor 172 | (B, C1, n) tensor of the features to be propigated to 173 | known_feats : torch.Tensor 174 | (B, C2, m) tensor of features to be propigated 175 | 176 | Returns 177 | ------- 178 | new_features : torch.Tensor 179 | (B, mlp[-1], n) tensor of the features of the unknown features 180 | """ 181 | 182 | if known is not None: 183 | dist, idx = pointnet2_utils.three_nn(unknown, known) 184 | dist_recip = 1.0 / (dist + 1e-8) 185 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 186 | weight = dist_recip / norm 187 | 188 | interpolated_feats = pointnet2_utils.three_interpolate( 189 | known_feats, idx, weight 190 | ) 191 | else: 192 | interpolated_feats = known_feats.expand( 193 | *(known_feats.size()[0:2] + [unknown.size(1)]) 194 | ) 195 | 196 | if unknow_feats is not None: 197 | new_features = torch.cat( 198 | [interpolated_feats, unknow_feats], dim=1 199 | ) # (B, C2 + C1, n) 200 | else: 201 | new_features = interpolated_feats 202 | 203 | new_features = new_features.unsqueeze(-1) 204 | new_features = self.mlp(new_features) 205 | 206 | return new_features.squeeze(-1) 207 | 208 | 209 | if __name__ == "__main__": 210 | from torch.autograd import Variable 211 | 212 | torch.manual_seed(1) 213 | torch.cuda.manual_seed_all(1) 214 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) 215 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) 216 | 217 | test_module = PointnetSAModuleMSG( 218 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]] 219 | ) 220 | test_module.cuda() 221 | print(test_module(xyz, xyz_feats)) 222 | 223 | # test_module = PointnetFPModule(mlp=[6, 6]) 224 | # test_module.cuda() 225 | # from torch.autograd import gradcheck 226 | # inputs = (xyz, xyz, None, xyz_feats) 227 | # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4) 228 | # print(test) 229 | 230 | for _ in range(1): 231 | _, new_features = test_module(xyz, xyz_feats) 232 | new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1)) 233 | print(new_features) 234 | print(xyz.grad) 235 | -------------------------------------------------------------------------------- /pointnet2/utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn as nn 11 | import etw_pytorch_utils as pt_utils 12 | import sys 13 | 14 | try: 15 | import builtins 16 | except: 17 | import __builtin__ as builtins 18 | 19 | try: 20 | import pointnet2._ext as _ext 21 | except ImportError: 22 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 23 | raise ImportError( 24 | "Could not import _ext module.\n" 25 | "Please see the setup instructions in the README: " 26 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 27 | ) 28 | 29 | if False: 30 | # Workaround for type hints without depending on the `typing` module 31 | from typing import * 32 | 33 | 34 | class RandomDropout(nn.Module): 35 | def __init__(self, p=0.5, inplace=False): 36 | super(RandomDropout, self).__init__() 37 | self.p = p 38 | self.inplace = inplace 39 | 40 | def forward(self, X): 41 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 42 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 43 | 44 | 45 | class FurthestPointSampling(Function): 46 | @staticmethod 47 | def forward(ctx, xyz, npoint): 48 | # type: (Any, torch.Tensor, int) -> torch.Tensor 49 | r""" 50 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 51 | minimum distance 52 | 53 | Parameters 54 | ---------- 55 | xyz : torch.Tensor 56 | (B, N, 3) tensor where N > npoint 57 | npoint : int32 58 | number of features in the sampled set 59 | 60 | Returns 61 | ------- 62 | torch.Tensor 63 | (B, npoint) tensor containing the set 64 | """ 65 | return _ext.furthest_point_sampling(xyz, npoint) 66 | 67 | @staticmethod 68 | def backward(xyz, a=None): 69 | return None, None 70 | 71 | 72 | furthest_point_sample = FurthestPointSampling.apply 73 | 74 | 75 | class GatherOperation(Function): 76 | @staticmethod 77 | def forward(ctx, features, idx): 78 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 79 | r""" 80 | 81 | Parameters 82 | ---------- 83 | features : torch.Tensor 84 | (B, C, N) tensor 85 | 86 | idx : torch.Tensor 87 | (B, npoint) tensor of the features to gather 88 | 89 | Returns 90 | ------- 91 | torch.Tensor 92 | (B, C, npoint) tensor 93 | """ 94 | 95 | _, C, N = features.size() 96 | 97 | ctx.for_backwards = (idx, C, N) 98 | 99 | return _ext.gather_points(features, idx) 100 | 101 | @staticmethod 102 | def backward(ctx, grad_out): 103 | idx, C, N = ctx.for_backwards 104 | 105 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 106 | return grad_features, None 107 | 108 | 109 | gather_operation = GatherOperation.apply 110 | 111 | 112 | class ThreeNN(Function): 113 | @staticmethod 114 | def forward(ctx, unknown, known): 115 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 116 | r""" 117 | Find the three nearest neighbors of unknown in known 118 | Parameters 119 | ---------- 120 | unknown : torch.Tensor 121 | (B, n, 3) tensor of known features 122 | known : torch.Tensor 123 | (B, m, 3) tensor of unknown features 124 | 125 | Returns 126 | ------- 127 | dist : torch.Tensor 128 | (B, n, 3) l2 distance to the three nearest neighbors 129 | idx : torch.Tensor 130 | (B, n, 3) index of 3 nearest neighbors 131 | """ 132 | dist2, idx = _ext.three_nn(unknown, known) 133 | 134 | return torch.sqrt(dist2), idx 135 | 136 | @staticmethod 137 | def backward(ctx, a=None, b=None): 138 | return None, None 139 | 140 | 141 | three_nn = ThreeNN.apply 142 | 143 | 144 | class ThreeInterpolate(Function): 145 | @staticmethod 146 | def forward(ctx, features, idx, weight): 147 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 148 | r""" 149 | Performs weight linear interpolation on 3 features 150 | Parameters 151 | ---------- 152 | features : torch.Tensor 153 | (B, c, m) Features descriptors to be interpolated from 154 | idx : torch.Tensor 155 | (B, n, 3) three nearest neighbors of the target features in features 156 | weight : torch.Tensor 157 | (B, n, 3) weights 158 | 159 | Returns 160 | ------- 161 | torch.Tensor 162 | (B, c, n) tensor of the interpolated features 163 | """ 164 | B, c, m = features.size() 165 | n = idx.size(1) 166 | 167 | ctx.three_interpolate_for_backward = (idx, weight, m) 168 | 169 | return _ext.three_interpolate(features, idx, weight) 170 | 171 | @staticmethod 172 | def backward(ctx, grad_out): 173 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 174 | r""" 175 | Parameters 176 | ---------- 177 | grad_out : torch.Tensor 178 | (B, c, n) tensor with gradients of ouputs 179 | 180 | Returns 181 | ------- 182 | grad_features : torch.Tensor 183 | (B, c, m) tensor with gradients of features 184 | 185 | None 186 | 187 | None 188 | """ 189 | idx, weight, m = ctx.three_interpolate_for_backward 190 | 191 | grad_features = _ext.three_interpolate_grad( 192 | grad_out.contiguous(), idx, weight, m 193 | ) 194 | 195 | return grad_features, None, None 196 | 197 | 198 | three_interpolate = ThreeInterpolate.apply 199 | 200 | 201 | class GroupingOperation(Function): 202 | @staticmethod 203 | def forward(ctx, features, idx): 204 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 205 | r""" 206 | 207 | Parameters 208 | ---------- 209 | features : torch.Tensor 210 | (B, C, N) tensor of features to group 211 | idx : torch.Tensor 212 | (B, npoint, nsample) tensor containing the indicies of features to group with 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | (B, C, npoint, nsample) tensor 218 | """ 219 | B, nfeatures, nsample = idx.size() 220 | _, C, N = features.size() 221 | 222 | ctx.for_backwards = (idx, N) 223 | 224 | return _ext.group_points(features, idx) 225 | 226 | @staticmethod 227 | def backward(ctx, grad_out): 228 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 229 | r""" 230 | 231 | Parameters 232 | ---------- 233 | grad_out : torch.Tensor 234 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 235 | 236 | Returns 237 | ------- 238 | torch.Tensor 239 | (B, C, N) gradient of the features 240 | None 241 | """ 242 | idx, N = ctx.for_backwards 243 | 244 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 245 | 246 | return grad_features, None 247 | 248 | 249 | grouping_operation = GroupingOperation.apply 250 | 251 | 252 | class BallQuery(Function): 253 | @staticmethod 254 | def forward(ctx, radius, nsample, xyz, new_xyz): 255 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 256 | r""" 257 | 258 | Parameters 259 | ---------- 260 | radius : float 261 | radius of the balls 262 | nsample : int 263 | maximum number of features in the balls 264 | xyz : torch.Tensor 265 | (B, N, 3) xyz coordinates of the features 266 | new_xyz : torch.Tensor 267 | (B, npoint, 3) centers of the ball query 268 | 269 | Returns 270 | ------- 271 | torch.Tensor 272 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 273 | """ 274 | return _ext.ball_query(new_xyz, xyz, radius, nsample) 275 | 276 | @staticmethod 277 | def backward(ctx, a=None): 278 | return None, None, None, None 279 | 280 | 281 | ball_query = BallQuery.apply 282 | 283 | 284 | class QueryAndGroup(nn.Module): 285 | r""" 286 | Groups with a ball query of radius 287 | 288 | Parameters 289 | --------- 290 | radius : float32 291 | Radius of ball 292 | nsample : int32 293 | Maximum number of features to gather in the ball 294 | """ 295 | 296 | def __init__(self, radius, nsample, use_xyz=True): 297 | # type: (QueryAndGroup, float, int, bool) -> None 298 | super(QueryAndGroup, self).__init__() 299 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 300 | 301 | def forward(self, xyz, new_xyz, features=None): 302 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 303 | r""" 304 | Parameters 305 | ---------- 306 | xyz : torch.Tensor 307 | xyz coordinates of the features (B, N, 3) 308 | new_xyz : torch.Tensor 309 | centriods (B, npoint, 3) 310 | features : torch.Tensor 311 | Descriptors of the features (B, C, N) 312 | 313 | Returns 314 | ------- 315 | new_features : torch.Tensor 316 | (B, 3 + C, npoint, nsample) tensor 317 | """ 318 | 319 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 320 | xyz_trans = xyz.transpose(1, 2).contiguous() 321 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 322 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 323 | 324 | if features is not None: 325 | grouped_features = grouping_operation(features, idx) 326 | if self.use_xyz: 327 | new_features = torch.cat( 328 | [grouped_xyz, grouped_features], dim=1 329 | ) # (B, C + 3, npoint, nsample) 330 | else: 331 | new_features = grouped_features 332 | else: 333 | assert ( 334 | self.use_xyz 335 | ), "Cannot have not features and not use xyz as a feature!" 336 | new_features = grouped_xyz 337 | 338 | return new_features 339 | 340 | 341 | class GroupAll(nn.Module): 342 | r""" 343 | Groups all features 344 | 345 | Parameters 346 | --------- 347 | """ 348 | 349 | def __init__(self, use_xyz=True): 350 | # type: (GroupAll, bool) -> None 351 | super(GroupAll, self).__init__() 352 | self.use_xyz = use_xyz 353 | 354 | def forward(self, xyz, new_xyz, features=None): 355 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 356 | r""" 357 | Parameters 358 | ---------- 359 | xyz : torch.Tensor 360 | xyz coordinates of the features (B, N, 3) 361 | new_xyz : torch.Tensor 362 | Ignored 363 | features : torch.Tensor 364 | Descriptors of the features (B, C, N) 365 | 366 | Returns 367 | ------- 368 | new_features : torch.Tensor 369 | (B, C + 3, 1, N) tensor 370 | """ 371 | 372 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 373 | if features is not None: 374 | grouped_features = features.unsqueeze(2) 375 | if self.use_xyz: 376 | new_features = torch.cat( 377 | [grouped_xyz, grouped_features], dim=1 378 | ) # (B, 3 + C, 1, N) 379 | else: 380 | new_features = grouped_features 381 | else: 382 | new_features = grouped_xyz 383 | 384 | return new_features 385 | -------------------------------------------------------------------------------- /train_partseg_gpus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lr_sched 4 | from torch.optim.lr_scheduler import CosineAnnealingLR 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import os 11 | from torchvision import transforms 12 | from models import DRNET_Seg as DRNET 13 | from data import ShapeNetPart 14 | import data.data_utils as d_utils 15 | import argparse 16 | import random 17 | import yaml 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = True 22 | 23 | seed = 123 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | parser = argparse.ArgumentParser(description='DRNET Shape Part Segmentation Training') 31 | parser.add_argument('--config', default='cfgs/config_partseg_gpus.yaml', type=str) 32 | 33 | 34 | def set_bn_momentum_default(bn_momentum): 35 | 36 | def fn(m): 37 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 38 | m.momentum = bn_momentum 39 | 40 | return fn 41 | 42 | 43 | class BNMomentumScheduler(object): 44 | 45 | def __init__( 46 | self, model, bn_lambda, last_epoch=-1, 47 | setter=set_bn_momentum_default 48 | ): 49 | if not isinstance(model, nn.Module): 50 | raise RuntimeError( 51 | "Class '{}' is not a PyTorch nn Module".format( 52 | type(model).__name__ 53 | ) 54 | ) 55 | 56 | self.model = model 57 | self.setter = setter 58 | self.lmbd = bn_lambda 59 | 60 | self.step(last_epoch + 1) 61 | self.last_epoch = last_epoch 62 | 63 | def step(self, epoch=None): 64 | if epoch is None: 65 | epoch = self.last_epoch + 1 66 | 67 | self.last_epoch = epoch 68 | self.model.apply(self.setter(self.lmbd(epoch))) 69 | 70 | def get_momentum(self, epoch=None): 71 | if epoch is None: 72 | epoch = self.last_epoch + 1 73 | return self.lmbd(epoch) 74 | 75 | def main(): 76 | args = parser.parse_args() 77 | with open(args.config) as f: 78 | config = yaml.load(f) 79 | print("\n**************************") 80 | for k, v in config['common'].items(): 81 | setattr(args, k, v) 82 | print('\n[%s]:'%(k), v) 83 | print("\n**************************\n") 84 | 85 | try: 86 | os.makedirs(args.save_path) 87 | except OSError: 88 | pass 89 | 90 | train_transforms = transforms.Compose([ 91 | d_utils.PointcloudToTensor() 92 | ]) 93 | test_transforms = transforms.Compose([ 94 | d_utils.PointcloudToTensor() 95 | ]) 96 | 97 | train_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'trainval', normalize = True, transforms = train_transforms) 98 | train_dataloader = DataLoader( 99 | train_dataset, 100 | batch_size=args.batch_size, 101 | shuffle=True, 102 | num_workers=int(args.workers), 103 | pin_memory=True 104 | ) 105 | 106 | global test_dataset 107 | test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms) 108 | test_dataloader = DataLoader( 109 | test_dataset, 110 | batch_size=args.batch_size, 111 | shuffle=False, 112 | num_workers=int(args.workers), 113 | pin_memory=True 114 | ) 115 | 116 | device = torch.device("cuda") 117 | model = DRNET(num_classes = args.num_classes).to(device) 118 | model = nn.DataParallel(model) 119 | print("Let's use", torch.cuda.device_count(), "GPUs!") 120 | 121 | ''' 122 | optimizer = optim.SGD( 123 | model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 124 | scheduler = CosineAnnealingLR(optimizer, args.epochs, eta_min=0.001) 125 | ''' 126 | optimizer = optim.Adam( 127 | model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) 128 | 129 | lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr) 130 | bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip) 131 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) 132 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) 133 | 134 | 135 | if args.checkpoint is not '': 136 | model.load_state_dict(torch.load(args.checkpoint)) 137 | print('Load model successfully: %s' % (args.checkpoint)) 138 | 139 | criterion = nn.CrossEntropyLoss() 140 | num_batch = len(train_dataset)/args.batch_size 141 | 142 | # training 143 | train(train_dataloader, test_dataloader, model, device, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch) 144 | 145 | 146 | def train(train_dataloader, test_dataloader, model, device, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch): 147 | PointcloudAug = d_utils.PointcloudScaleAndTranslate() # initialize augmentation 148 | global Class_mIoU, Inst_mIoU 149 | Class_mIoU, Inst_mIoU = 0.82, 0.85 150 | batch_count = 0 151 | model.train() 152 | for epoch in range(args.epochs): 153 | # scheduler.step() 154 | for i, data in enumerate(train_dataloader, 0): 155 | 156 | if lr_scheduler is not None: 157 | lr_scheduler.step(epoch) 158 | if bnm_scheduler is not None: 159 | bnm_scheduler.step(epoch-1) 160 | 161 | points, target, cls = data 162 | points, target = Variable(points), Variable(target) 163 | points, target = points.to(device), target.to(device) 164 | # augmentation 165 | points.data = PointcloudAug(points.data) 166 | 167 | optimizer.zero_grad() 168 | 169 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes 170 | for b in range(len(cls)): 171 | batch_one_hot_cls[b, int(cls[b])] = 1 172 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 173 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 174 | 175 | pred, error1, error2, error3, error4 = model(points, batch_one_hot_cls) 176 | loss1 = torch.norm(error1, dim=-1) # B, N 177 | loss1 = torch.mean(torch.mean(loss1, dim=-1), dim=-1) 178 | loss2 = torch.norm(error2, dim=-1) # B, N 179 | loss2 = torch.mean(torch.mean(loss2, dim=-1), dim=-1) 180 | loss3 = torch.norm(error3, dim=-1) # B, N 181 | loss3 = torch.mean(torch.mean(loss3, dim=-1), dim=-1) 182 | loss4 = torch.norm(error4, dim=-1) # B, N 183 | loss4 = torch.mean(torch.mean(loss4, dim=-1), dim=-1) 184 | pred = pred.view(-1, args.num_classes) 185 | target = target.view(-1,1)[:,0] 186 | loss = criterion(pred, target) + 0.1*loss1 187 | # + 0.01*loss2 + 0.01*loss3 + 0.01*loss4 188 | loss.backward() 189 | optimizer.step() 190 | 191 | if i % args.print_freq_iter == 0: 192 | print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %(epoch+1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0])) 193 | batch_count += 1 194 | 195 | # validation in between an epoch 196 | if (epoch >60) and args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0: 197 | print('testing..') 198 | validate(test_dataloader, model, device, criterion, args, batch_count) 199 | 200 | def validate(test_dataloader, model, device, criterion, args, iter): 201 | global Class_mIoU, Inst_mIoU, test_dataset 202 | model.eval() 203 | 204 | seg_classes = test_dataset.seg_classes 205 | shape_ious = {cat:[] for cat in seg_classes.keys()} 206 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 207 | for cat in seg_classes.keys(): 208 | for label in seg_classes[cat]: 209 | seg_label_to_cat[label] = cat 210 | 211 | losses = [] 212 | temp_device = torch.device("cpu") 213 | for _, data in enumerate(test_dataloader, 0): 214 | points, target, cls = data 215 | with torch.no_grad(): 216 | points = Variable(points) 217 | with torch.no_grad(): 218 | target = Variable(target) 219 | 220 | points = points.to(device) 221 | target = target.to(device) 222 | 223 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes 224 | for b in range(len(cls)): 225 | batch_one_hot_cls[b, int(cls[b])] = 1 226 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 227 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 228 | 229 | with torch.no_grad(): 230 | pred, error1, error2, error3, error4 = model(points, batch_one_hot_cls) 231 | 232 | loss1 = torch.norm(error1, dim=-1) # B, N 233 | loss1 = torch.mean(torch.mean(loss1, dim=-1), dim=-1) 234 | loss2 = torch.norm(error2, dim=-1) # B, N 235 | loss2 = torch.mean(torch.mean(loss2, dim=-1), dim=-1) 236 | loss3 = torch.norm(error3, dim=-1) # B, N 237 | loss3 = torch.mean(torch.mean(loss3, dim=-1), dim=-1) 238 | loss4 = torch.norm(error4, dim=-1) # B, N 239 | loss4 = torch.mean(torch.mean(loss4, dim=-1), dim=-1) 240 | loss = criterion(pred.view(-1, args.num_classes), target.view(-1,1)[:,0]) + 0.1*loss1 241 | # + 0.01*loss2 + 0.01*loss3 + 0.01*loss4 242 | losses.append(loss.data.clone()) 243 | 244 | pred = pred.data.cpu() 245 | target = target.data.cpu() 246 | 247 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor) 248 | # pred to the groundtruth classes (selected by seg_classes[cat]) 249 | for b in range(len(cls)): 250 | cat = seg_label_to_cat[target[b, 0].item()] 251 | logits = pred[b, :, :] # (num_points, num_classes) 252 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0] 253 | 254 | for b in range(len(cls)): 255 | segp = pred_val[b, :].to(temp_device) 256 | segl = target[b, :].to(temp_device) 257 | cat = seg_label_to_cat[segl[0].item()] 258 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 259 | for l in seg_classes[cat]: 260 | if torch.sum((segl == l) | (segp == l)) == 0: 261 | # part is not present in this shape 262 | part_ious[l - seg_classes[cat][0]] = 1.0 263 | else: 264 | part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l))) 265 | shape_ious[cat].append(np.mean(part_ious)) # torch.mean(torch.stack(part_ious)) 266 | 267 | instance_ious = [] 268 | for cat in shape_ious.keys(): 269 | for iou in shape_ious[cat]: 270 | instance_ious.append(iou) 271 | shape_ious[cat] = np.mean(shape_ious[cat]) 272 | mean_class_ious = np.mean(list(shape_ious.values())) 273 | 274 | for cat in sorted(shape_ious.keys()): 275 | print('****** %s: %0.6f'%(cat, shape_ious[cat])) 276 | print('************ Test Loss: %0.6f' % (torch.mean(torch.stack(losses)).cpu().numpy())) #torch.mean(torch.stack(losses)).numpy() np.array(losses).mean()) 277 | print('************ Class_mIoU: %0.6f' % (mean_class_ious)) 278 | print('************ Instance_mIoU: %0.6f' % (np.mean(instance_ious))) 279 | 280 | if mean_class_ious > Class_mIoU or np.mean(instance_ious) > Inst_mIoU: 281 | if mean_class_ious > Class_mIoU: 282 | Class_mIoU = mean_class_ious 283 | if np.mean(instance_ious) > Inst_mIoU: 284 | Inst_mIoU = np.mean(instance_ious) 285 | torch.save(model.state_dict(), '%s/seg_drnet_iter_%d_ins_%0.6f_cls_%0.6f.pth' % (args.save_path, iter, np.mean(instance_ious), mean_class_ious)) 286 | model.train() 287 | 288 | if __name__ == "__main__": 289 | main() 290 | -------------------------------------------------------------------------------- /train_partseg_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | mkdir -p log 3 | now=$(date +"%Y%m%d_%H%M%S") 4 | log_name="PartSeg_LOG_"$now"" 5 | CUDA_VISIBLE_DEVICES=2,3 python -u train_partseg_gpus.py \ 6 | --config cfgs/config_partseg_gpus.yaml \ 7 | 2>&1|tee log/$log_name.log & 8 | -------------------------------------------------------------------------------- /voting_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lr_sched 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import os 10 | from torchvision import transforms 11 | from models import DRNET_Seg as DRNET 12 | from data import ShapeNetPart 13 | import data.data_utils as d_utils 14 | import argparse 15 | import random 16 | import yaml 17 | 18 | torch.backends.cudnn.enabled = True 19 | torch.backends.cudnn.benchmark = True 20 | torch.backends.cudnn.deterministic = True 21 | 22 | seed = 123 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | parser = argparse.ArgumentParser(description='DRNET Shape Part Segmentation Test') 30 | parser.add_argument('--config', default='cfgs/config_partseg_gpus.yaml', type=str) 31 | 32 | NUM_REPEAT = 300 33 | NUM_VOTE = 10 34 | 35 | def main(): 36 | args = parser.parse_args() 37 | with open(args.config) as f: 38 | config = yaml.load(f) 39 | for k, v in config['common'].items(): 40 | setattr(args, k, v) 41 | 42 | test_transforms = transforms.Compose([ 43 | d_utils.PointcloudToTensor() 44 | ]) 45 | 46 | test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms) 47 | test_dataloader = DataLoader( 48 | test_dataset, 49 | batch_size=args.batch_size, 50 | shuffle=False, 51 | num_workers=int(args.workers), 52 | pin_memory=True 53 | ) 54 | 55 | device = torch.device("cuda") 56 | model = DRNET(num_classes = args.num_classes).to(device) 57 | model = nn.DataParallel(model) 58 | print("Let's use", torch.cuda.device_count(), "GPUs!") 59 | 60 | if args.checkpoint is not '': 61 | model.load_state_dict(torch.load(args.checkpoint)) 62 | print('Load model successfully: %s' % (args.checkpoint)) 63 | 64 | # evaluate 65 | PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15) # initialize random scaling 66 | model.eval() 67 | global_Class_mIoU, global_Inst_mIoU = 0, 0 68 | seg_classes = test_dataset.seg_classes 69 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 70 | for cat in seg_classes.keys(): 71 | for label in seg_classes[cat]: 72 | seg_label_to_cat[label] = cat 73 | 74 | for i in range(NUM_REPEAT): 75 | shape_ious = {cat:[] for cat in seg_classes.keys()} 76 | for _, data in enumerate(test_dataloader, 0): 77 | points, target, cls = data 78 | with torch.no_grad(): 79 | points = Variable(points) 80 | with torch.no_grad(): 81 | target = Variable(target) 82 | 83 | points = points.to(device) 84 | target = target.to(device) 85 | 86 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes 87 | for b in range(len(cls)): 88 | batch_one_hot_cls[b, int(cls[b])] = 1 89 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls) 90 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda()) 91 | 92 | pred = 0 93 | with torch.no_grad(): 94 | new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda()) 95 | for v in range(NUM_VOTE): 96 | if v > 0: 97 | new_points.data = PointcloudScale(points.data) 98 | # temp_pred, error1, error2, error3, error4, _ = model(new_points, batch_one_hot_cls) 99 | temp_pred = model(new_points, batch_one_hot_cls) 100 | pred += F.softmax(temp_pred, dim = 2) 101 | pred /= NUM_VOTE 102 | 103 | pred = pred.data.cpu() 104 | target = target.data.cpu() 105 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor) 106 | # pred to the groundtruth classes (selected by seg_classes[cat]) 107 | for b in range(len(cls)): 108 | cat = seg_label_to_cat[target[b, 0].item()] 109 | logits = pred[b, :, :] # (num_points, num_classes) 110 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0] 111 | 112 | for b in range(len(cls)): 113 | segp = pred_val[b, :] 114 | segl = target[b, :] 115 | cat = seg_label_to_cat[segl[0].item()] 116 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 117 | for l in seg_classes[cat]: 118 | if torch.sum((segl == l) | (segp == l)) == 0: 119 | # part is not present in this shape 120 | part_ious[l - seg_classes[cat][0]] = 1.0 121 | else: 122 | part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l))) 123 | shape_ious[cat].append(np.mean(part_ious)) 124 | 125 | instance_ious = [] 126 | for cat in shape_ious.keys(): 127 | for iou in shape_ious[cat]: 128 | instance_ious.append(iou) 129 | shape_ious[cat] = np.mean(shape_ious[cat]) 130 | mean_class_ious = np.mean(list(shape_ious.values())) 131 | 132 | print('\n------ Repeat %3d ------' % (i + 1)) 133 | for cat in sorted(shape_ious.keys()): 134 | print('%s: %0.6f'%(cat, shape_ious[cat])) 135 | print('Class_mIoU: %0.6f' % (mean_class_ious)) 136 | print('Instance_mIoU: %0.6f' % (np.mean(instance_ious))) 137 | 138 | if np.mean(instance_ious) > global_Inst_mIoU: 139 | global_Class_mIoU = mean_class_ious 140 | global_Inst_mIoU = np.mean(instance_ious) 141 | 142 | print('\nBest voting Class_mIoU = %0.6f, Instance_mIoU = %0.6f' % (global_Class_mIoU, global_Inst_mIoU)) 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /voting_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | mkdir -p log 3 | now=$(date +"%Y%m%d_%H%M%S") 4 | log_name="PartSeg_VOTE_"$now"" 5 | CUDA_VISIBLE_DEVICES=2,3 python -u voting_test.py \ 6 | --config cfgs/config_partseg_test.yaml \ 7 | 2>&1|tee log/$log_name.log & 8 | --------------------------------------------------------------------------------