├── LICENSE
├── README.md
├── cfgs
├── config_partseg_gpus.yaml
└── config_partseg_test.yaml
├── data
├── ShapeNetPartLoader.py
├── __init__.py
├── __pycache__
│ ├── Indoor3DSemSegLoader.cpython-36.pyc
│ ├── ModelNet40Loader.cpython-36.pyc
│ ├── ShapeNetPartLoader.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ └── data_utils.cpython-36.pyc
└── data_utils.py
├── figures
├── cls.png
├── compare.png
├── intro.png
├── modelnet.png
├── new_network.png
├── partseg.png
├── scanobjectnn.png
├── shapenet.png
└── visual.png
├── models
├── __init__.py
└── drnet.py
├── pointnet2
├── _ext-src
│ ├── include
│ │ ├── ball_query.h
│ │ ├── cuda_utils.h
│ │ ├── group_points.h
│ │ ├── interpolate.h
│ │ ├── sampling.h
│ │ └── utils.h
│ └── src
│ │ ├── ball_query.cpp
│ │ ├── ball_query_gpu.cu
│ │ ├── bindings.cpp
│ │ ├── group_points.cpp
│ │ ├── group_points_gpu.cu
│ │ ├── interpolate.cpp
│ │ ├── interpolate_gpu.cu
│ │ ├── sampling.cpp
│ │ └── sampling_gpu.cu
├── _ext.cpython-36m-x86_64-linux-gnu.so
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── pointnet2_modules.cpython-36.pyc
│ └── pointnet2_utils.cpython-36.pyc
│ ├── linalg_utils.py
│ ├── pointnet2_modules.py
│ └── pointnet2_utils.py
├── train_partseg_gpus.py
├── train_partseg_gpus.sh
├── voting_test.py
└── voting_test.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 ShiQiu0419
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dense-Resolution Network for Point Cloud Classification and Segmentation
2 | [](https://paperswithcode.com/sota/3d-point-cloud-classification-on-scanobjectnn?p=dense-resolution-network-for-point-cloud)
3 | [](https://paperswithcode.com/sota/3d-part-segmentation-on-shapenet-part?p=dense-resolution-network-for-point-cloud)
4 | [](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=dense-resolution-network-for-point-cloud)
5 |
6 | This repository is for Dense-Resolution Networ (DRNet) introduced in the following paper
7 |
8 | [Shi Qiu](https://shiqiu0419.github.io/) [Saeed Anwar](https://saeed-anwar.github.io/), [Nick Barnes](http://users.cecs.anu.edu.au/~nmb/)
9 | "Dense-Resolution Network for Point Cloud Classification and Segmentation"
10 | IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2021)
11 |
12 | ## Paper
13 | The paper can be downloaded from [here (arXiv)](https://arxiv.org/abs/2005.06734) or [here (CVF)](https://openaccess.thecvf.com/content/WACV2021/papers/Qiu_Dense-Resolution_Network_for_Point_Cloud_Classification_and_Segmentation_WACV_2021_paper.pdf), together with the [supplementary material](https://openaccess.thecvf.com/content/WACV2021/supplemental/Qiu_Dense-Resolution_Network_for_WACV_2021_supplemental.pdf).
14 |
15 | ## Motivation
16 |
17 |
18 |
19 |
20 | ## Implementation
21 | * Python 3.6
22 | * Pytorch 1.3.0
23 | * Cuda 10.0
24 |
25 | ## Dataset
26 | Download the [ShapeNet Part Dataset](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and upzip it to your rootpath. Alternatively, you can modify the path of your dataset in `cfgs/config_partseg_gpus.yaml` and `cfgs/config_partseg_test.yaml`.
27 |
28 | ## CUDA Kernel Building
29 | For PyTorch version <= 0.4.0, please refer to [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN).
30 | For PyTorch version >= 1.0.0, please refer to [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch).
31 |
32 | **Note:**
33 | In our DRNet, we use Farthest Point Sampling (e.g., `pointnet2_utils.furthest_point_sample`) to down-sample the point cloud. Also, we adpot Feature Propagation (e.g., `pointnet2_utils.three_nn` and `pointnet2_utils.three_interpolate`) to up-sample the feature maps.
34 |
35 | ## Training
36 |
37 | sh train_partseg_gpus.sh
38 |
39 | Due to the complexity of DRNet, we support Multi-GPU via `nn.DataParallel`. You can also adjust other parameters such as batch size or the number of input points in `cfgs/config_partseg_gpus.yaml`, in order to fit the memory limit of your device.
40 |
41 | ## Voting Evaluation
42 | You can set the path of your pre-trained model in `cfgs/config_partseg_test.yaml`, then run:
43 |
44 | sh voting_test.sh
45 |
46 | ## Citation
47 |
48 | If you find our paper is useful, please cite:
49 |
50 | @inproceedings{qiu2021dense,
51 | title={Dense-Resolution Network for Point Cloud Classification and Segmentation},
52 | author={Qiu, Shi and Anwar, Saeed and Barnes, Nick},
53 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
54 | month={January},
55 | year={2021},
56 | pages={3813-3822}
57 | }
58 |
59 | ## Acknowledgement
60 | The code is built on [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch), [Relation-Shape-CNN](https://github.com/Yochengliu/Relation-Shape-CNN), [DGCNN](https://github.com/WangYueFt/dgcnn/tree/master/pytorch). We thank the authors for sharing their codes.
61 |
--------------------------------------------------------------------------------
/cfgs/config_partseg_gpus.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | workers: 8
3 |
4 | num_points: 2048
5 | num_classes: 50
6 | batch_size: 32
7 |
8 | base_lr: 0.001
9 | lr_clip: 0.00001
10 | lr_decay: 0.5
11 | decay_step: 21
12 | epochs: 200
13 |
14 | weight_decay: 0
15 | bn_momentum: 0.9
16 | bnm_clip: 0.01
17 | bn_decay: 0.5
18 |
19 | evaluate: 1 # validation in training process
20 | val_freq_epoch: 0.8 # frequency in epoch for validation, can be decimal
21 | print_freq_iter: 20 # frequency in iteration for printing infomation
22 |
23 | input_channels: 0 # feature channels except (x, y, z)
24 |
25 | checkpoint: '' # the model to start from
26 | save_path: seg_own
27 | data_root: shapenetcore_partanno_segmentation_benchmark_v0_normal
28 |
--------------------------------------------------------------------------------
/cfgs/config_partseg_test.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | workers: 8
3 |
4 | num_points: 2048
5 | num_classes: 50
6 | batch_size: 4
7 |
8 | base_lr: 0.001
9 | lr_clip: 0.00001
10 | lr_decay: 0.5
11 | decay_step: 21
12 | epochs: 200
13 |
14 | weight_decay: 0
15 | bn_momentum: 0.9
16 | bnm_clip: 0.01
17 | bn_decay: 0.5
18 |
19 | evaluate: 1 # validation in training process
20 | val_freq_epoch: 0.8 # frequency in epoch for validation, can be decimal
21 | print_freq_iter: 20 # frequency in iteration for printing infomation
22 |
23 | input_channels: 0 # feature channels except (x, y, z)
24 |
25 | checkpoint: 'seg/your_trained_model.pth' # the model to start from
26 | save_path: seg_own
27 | data_root: shapenetcore_partanno_segmentation_benchmark_v0_normal
28 |
--------------------------------------------------------------------------------
/data/ShapeNetPartLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torch
4 | import json
5 | import numpy as np
6 | import sys
7 | import torchvision.transforms as transforms
8 |
9 | def pc_normalize(pc):
10 | l = pc.shape[0]
11 | centroid = np.mean(pc, axis=0)
12 | pc = pc - centroid
13 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
14 | pc = pc / m
15 | return pc
16 |
17 | class ShapeNetPart():
18 | def __init__(self, root, num_points = 2048, split='train', normalize=True, transforms = None):
19 | self.transforms = transforms
20 | self.num_points = num_points
21 | self.root = root
22 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
23 | self.normalize = normalize
24 |
25 | self.cat = {}
26 | with open(self.catfile, 'r') as f:
27 | for line in f:
28 | ls = line.strip().split()
29 | self.cat[ls[0]] = ls[1]
30 | self.cat = {k:v for k,v in self.cat.items()}
31 |
32 | self.meta = {}
33 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
34 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
35 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
36 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
38 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
39 | for item in self.cat:
40 | self.meta[item] = []
41 | dir_point = os.path.join(self.root, self.cat[item])
42 | fns = sorted(os.listdir(dir_point))
43 | if split=='trainval':
44 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
45 | elif split=='train':
46 | fns = [fn for fn in fns if fn[0:-4] in train_ids]
47 | elif split=='val':
48 | fns = [fn for fn in fns if fn[0:-4] in val_ids]
49 | elif split=='test':
50 | fns = [fn for fn in fns if fn[0:-4] in test_ids]
51 | else:
52 | print('Unknown split: %s. Exiting..'%(split))
53 | exit(-1)
54 |
55 | for fn in fns:
56 | token = (os.path.splitext(os.path.basename(fn))[0])
57 | self.meta[item].append(os.path.join(dir_point, token + '.txt'))
58 |
59 | self.datapath = []
60 | for item in self.cat:
61 | for fn in self.meta[item]:
62 | self.datapath.append((item, fn))
63 |
64 | self.classes = dict(zip(self.cat, range(len(self.cat))))
65 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
66 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
67 |
68 | self.cache = {}
69 | self.cache_size = 20000
70 |
71 | def __getitem__(self, index):
72 | if index in self.cache:
73 | point_set, seg, cls = self.cache[index]
74 | else:
75 | fn = self.datapath[index]
76 | cat = self.datapath[index][0]
77 | cls = self.classes[cat]
78 | cls = np.array([cls]).astype(np.int64)
79 | data = np.loadtxt(fn[1]).astype(np.float32)
80 | point_set = data[:,0:3]
81 | if self.normalize:
82 | point_set = pc_normalize(point_set)
83 | seg = data[:,-1].astype(np.int64)
84 | if len(self.cache) < self.cache_size:
85 | self.cache[index] = (point_set, seg, cls)
86 |
87 | choice = np.random.choice(len(seg), self.num_points, replace=True)
88 | #resample
89 | point_set = point_set[choice, :]
90 | seg = seg[choice]
91 | if self.transforms is not None:
92 | point_set = self.transforms(point_set)
93 |
94 | return point_set, torch.from_numpy(seg), torch.from_numpy(cls)
95 |
96 | def __len__(self):
97 | return len(self.datapath)
98 |
99 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # from .ModelNet40Loader import ModelNet40Cls
2 | from .ShapeNetPartLoader import ShapeNetPart
3 | # from .Indoor3DSemSegLoader import Indoor3DSemSeg
4 |
--------------------------------------------------------------------------------
/data/__pycache__/Indoor3DSemSegLoader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/Indoor3DSemSegLoader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/ModelNet40Loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/ModelNet40Loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/ShapeNetPartLoader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/ShapeNetPartLoader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/data_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/data/__pycache__/data_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class PointcloudToTensor(object):
5 | def __call__(self, points):
6 | return torch.from_numpy(points).float()
7 |
8 | def angle_axis(angle: float, axis: np.ndarray):
9 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
10 |
11 | Parameters
12 | ----------
13 | angle : float
14 | Angle to rotate by
15 | axis: np.ndarray
16 | Axis to rotate about
17 |
18 | Returns
19 | -------
20 | torch.Tensor
21 | 3x3 rotation matrix
22 | """
23 | u = axis / np.linalg.norm(axis)
24 | cosval, sinval = np.cos(angle), np.sin(angle)
25 |
26 | # yapf: disable
27 | cross_prod_mat = np.array([[0.0, -u[2], u[1]],
28 | [u[2], 0.0, -u[0]],
29 | [-u[1], u[0], 0.0]])
30 |
31 | R = torch.from_numpy(
32 | cosval * np.eye(3)
33 | + sinval * cross_prod_mat
34 | + (1.0 - cosval) * np.outer(u, u)
35 | )
36 | # yapf: enable
37 | return R.float()
38 |
39 | class PointcloudRotatebyAngle(object):
40 | def __init__(self, rotation_angle = 0.0):
41 | self.rotation_angle = rotation_angle
42 |
43 | def __call__(self, pc):
44 | normals = pc.size(2) > 3
45 | bsize = pc.size()[0]
46 | for i in range(bsize):
47 | cosval = np.cos(self.rotation_angle)
48 | sinval = np.sin(self.rotation_angle)
49 | rotation_matrix = np.array([[cosval, 0, sinval],
50 | [0, 1, 0],
51 | [-sinval, 0, cosval]])
52 | rotation_matrix = torch.from_numpy(rotation_matrix).float().cuda()
53 |
54 | cur_pc = pc[i, :, :]
55 | if not normals:
56 | cur_pc = cur_pc @ rotation_matrix
57 | else:
58 | pc_xyz = cur_pc[:, 0:3]
59 | pc_normals = cur_pc[:, 3:]
60 | cur_pc[:, 0:3] = pc_xyz @ rotation_matrix
61 | cur_pc[:, 3:] = pc_normals @ rotation_matrix
62 |
63 | pc[i, :, :] = cur_pc
64 |
65 | return pc
66 |
67 | class PointcloudJitter(object):
68 | def __init__(self, std=0.01, clip=0.05):
69 | self.std, self.clip = std, clip
70 |
71 | def __call__(self, pc):
72 | bsize = pc.size()[0]
73 | for i in range(bsize):
74 | jittered_data = pc.new(pc.size(1), 3).normal_(
75 | mean=0.0, std=self.std
76 | ).clamp_(-self.clip, self.clip)
77 | pc[i, :, 0:3] += jittered_data
78 |
79 | return pc
80 |
81 | class PointcloudScaleAndTranslate(object):
82 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2):
83 | self.scale_low = scale_low
84 | self.scale_high = scale_high
85 | self.translate_range = translate_range
86 |
87 | def __call__(self, pc):
88 | bsize = pc.size()[0]
89 | for i in range(bsize):
90 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
91 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
92 |
93 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda()
94 |
95 | return pc
96 |
97 | class PointcloudScale(object):
98 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
99 | self.scale_low = scale_low
100 | self.scale_high = scale_high
101 |
102 | def __call__(self, pc):
103 | bsize = pc.size()[0]
104 | for i in range(bsize):
105 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
106 |
107 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda())
108 |
109 | return pc
110 |
111 | class PointcloudTranslate(object):
112 | def __init__(self, translate_range=0.2):
113 | self.translate_range = translate_range
114 |
115 | def __call__(self, pc):
116 | bsize = pc.size()[0]
117 | for i in range(bsize):
118 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
119 |
120 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda()
121 |
122 | return pc
123 |
124 | class PointcloudRandomInputDropout(object):
125 | def __init__(self, max_dropout_ratio=0.875):
126 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
127 | self.max_dropout_ratio = max_dropout_ratio
128 |
129 | def __call__(self, pc):
130 | bsize = pc.size()[0]
131 | for i in range(bsize):
132 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875
133 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0]
134 | if len(drop_idx) > 0:
135 | cur_pc = pc[i, :, :]
136 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point
137 | pc[i, :, :] = cur_pc
138 |
139 | return pc
140 |
--------------------------------------------------------------------------------
/figures/cls.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/cls.png
--------------------------------------------------------------------------------
/figures/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/compare.png
--------------------------------------------------------------------------------
/figures/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/intro.png
--------------------------------------------------------------------------------
/figures/modelnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/modelnet.png
--------------------------------------------------------------------------------
/figures/new_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/new_network.png
--------------------------------------------------------------------------------
/figures/partseg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/partseg.png
--------------------------------------------------------------------------------
/figures/scanobjectnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/scanobjectnn.png
--------------------------------------------------------------------------------
/figures/shapenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/shapenet.png
--------------------------------------------------------------------------------
/figures/visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/figures/visual.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .drnet import DRNET as DRNET_Seg
--------------------------------------------------------------------------------
/models/drnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import sys
6 | import copy
7 | import math
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from pointnet2.utils import pointnet2_utils
13 |
14 |
15 | def knn(x, k):
16 | inner = -2*torch.matmul(x.transpose(2, 1), x)
17 | xx = torch.sum(x**2, dim=1, keepdim=True)
18 | pairwise_distance = -xx - inner - xx.transpose(2, 1)
19 |
20 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
21 | return idx
22 |
23 |
24 | def pw_dist(x):
25 | inner = -2*torch.matmul(x.transpose(2, 1), x)
26 | xx = torch.sum(x**2, dim=1, keepdim=True)
27 | pairwise_distance = -xx - inner - xx.transpose(2, 1) # (batch_size, num_points, n)
28 |
29 | return -pairwise_distance
30 |
31 |
32 | def knn_metric(x, d, conv_op1, conv_op2, conv_op11, k):
33 | batch_size = x.size(0)
34 | num_points = x.size(2)
35 | inner = -2*torch.matmul(x.transpose(2, 1), x)
36 | xx = torch.sum(x**2, dim=1, keepdim=True)
37 | pairwise_distance = -xx - inner - xx.transpose(2, 1)
38 |
39 | metric = (-pairwise_distance).topk(k=d*k, dim=-1, largest=False)[0] # B,N,100
40 | metric_idx = (-pairwise_distance).topk(k=d*k, dim=-1, largest=False)[1] # B,N,100
41 | metric_trans = metric.permute(0, 2, 1) # B,100,N
42 | metric = conv_op1(metric_trans) # B,50,N
43 | metric = torch.squeeze(conv_op11(metric).permute(0, 2, 1), -1) # B,N
44 | # normalize function
45 | metric = torch.sigmoid(-metric)
46 | # projection function
47 | metric = 5 * metric + 0.5
48 | # scaling function
49 |
50 | value1 = torch.where((metric>=0.5)&(metric<1.5), torch.full_like(metric, 1), torch.full_like(metric, 0))
51 | value2 = torch.where((metric>=1.5)&(metric<2.5), torch.full_like(metric, 2), torch.full_like(metric, 0))
52 | value3 = torch.where((metric>=2.5)&(metric<3.5), torch.full_like(metric, 3), torch.full_like(metric, 0))
53 | value4 = torch.where((metric>=3.5)&(metric<4.5), torch.full_like(metric, 4), torch.full_like(metric, 0))
54 | value5 = torch.where((metric>=4.5)&(metric<=5.5), torch.full_like(metric, 5), torch.full_like(metric, 0))
55 |
56 | value = value1 + value2 + value3 + value4 + value5 # B,N
57 |
58 | select_idx = torch.cuda.LongTensor(np.arange(k)) # k
59 | select_idx = torch.unsqueeze(select_idx, 0).repeat(num_points, 1) # N,k
60 | select_idx = torch.unsqueeze(select_idx, 0).repeat(batch_size, 1, 1) # B,N,k
61 | value = torch.unsqueeze(value, -1).repeat(1, 1, k) # B,N,k
62 | select_idx = select_idx * value
63 | select_idx = select_idx.long()
64 | idx = pairwise_distance.topk(k=k*d, dim=-1)[1] # (batch_size, num_points, k*d)
65 | # dilatedly selecting k from k*d idx
66 | idx = torch.gather(idx, dim=-1, index=select_idx) # B,N,k
67 | return idx
68 |
69 |
70 | def get_adptive_dilated_graph_feature(x, conv_op1, conv_op2, conv_op11, d=5, k=20, idx=None):
71 | batch_size = x.size(0)
72 | num_points = x.size(2)
73 | x = x.view(batch_size, -1, num_points)
74 | if idx is None:
75 | idx = knn_metric(x, d, conv_op1, conv_op2, conv_op11, k=k) # (batch_size, num_points, k)
76 | device = torch.device('cuda')
77 | idx_base = torch.arange(0, batch_size, device=device)
78 | idx_base = idx_base.view(-1, 1, 1)*num_points
79 | idx_base = idx_base.type(torch.cuda.LongTensor)
80 | idx = idx.type(torch.cuda.LongTensor)
81 | idx = idx + idx_base
82 | idx = idx.view(-1)
83 | _, num_dims, _ = x.size()
84 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
85 | feature = x.view(batch_size*num_points, -1)[idx, :]
86 | feature = feature.view(batch_size, num_points, k, num_dims)
87 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
88 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2)
89 |
90 | return feature
91 |
92 |
93 | def get_graph_feature(x, k=20, idx=None):
94 | batch_size = x.size(0)
95 | num_points = x.size(2)
96 | x = x.view(batch_size, -1, num_points)
97 | if idx is None:
98 | idx = knn(x, k=k) # (batch_size, num_points, k)
99 | device = torch.device('cuda')
100 |
101 | idx_base = torch.arange(0, batch_size, device=device)
102 | idx_base = idx_base.view(-1, 1, 1)*num_points
103 | idx_base = idx_base.type(torch.cuda.LongTensor)
104 | idx = idx.type(torch.cuda.LongTensor)
105 | idx = idx + idx_base
106 | idx = idx.view(-1)
107 | _, num_dims, _ = x.size()
108 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
109 | feature = x.view(batch_size*num_points, -1)[idx, :]
110 | feature = feature.view(batch_size, num_points, k, num_dims)
111 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
112 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2)
113 |
114 | return feature
115 |
116 |
117 | class Channel_fusion(nn.Module):
118 | def __init__(self, in_dim):
119 | super(Channel_fusion, self).__init__()
120 | self.channel_in = in_dim
121 |
122 | self.bn1 = nn.BatchNorm1d(in_dim//8)
123 | self.bn2 = nn.BatchNorm1d(in_dim)
124 |
125 | self.squeeze_conv = nn.Sequential(nn.Conv1d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1, bias=False),
126 | self.bn1)
127 | self.excite_conv = nn.Sequential(nn.Conv1d(in_channels=in_dim//8, out_channels=in_dim, kernel_size=1, bias=False),
128 | self.bn2)
129 |
130 | def forward(self, x):
131 | """
132 | inputs :
133 | x : B, C, N
134 | returns :
135 | out : B, C, N
136 | """
137 | batch_size, channel_num, num_points= x.size()
138 | fusion_score = F.adaptive_avg_pool1d(x, 1) # B, C, 1
139 | fusion_score = self.squeeze_conv(fusion_score) # B, C', 1
140 | fusion_score = F.relu(fusion_score)
141 | fusion_score = self.excite_conv(fusion_score) # B, C, 1
142 | fusion_score = torch.sigmoid(fusion_score).expand_as(x) # B, C, N
143 |
144 | return fusion_score
145 |
146 |
147 | class DRNET(nn.Module):
148 | def __init__(self, num_classes):
149 | super().__init__()
150 | self.k = 20
151 |
152 | self.bn1 = nn.BatchNorm2d(64)
153 | self.bn11 = nn.BatchNorm2d(3)
154 | self.bn12 = nn.BatchNorm1d(512)
155 | self.bn13 = nn.BatchNorm1d(1024)
156 | self.bn14 = nn.BatchNorm2d(64)
157 | self.bn2 = nn.BatchNorm2d(64)
158 | self.bn21 = nn.BatchNorm1d(256)
159 | self.bn22 = nn.BatchNorm1d(128)
160 | self.bn23 = nn.BatchNorm1d(128)
161 | self.bn24 = nn.BatchNorm2d(128)
162 | self.bn3 = nn.BatchNorm2d(256)
163 | self.bn32 = nn.BatchNorm1d(256)
164 | self.bn33 = nn.BatchNorm1d(256)
165 | self.bn34 = nn.BatchNorm1d(128)
166 | self.bn35 = nn.BatchNorm1d(128)
167 | self.bn5 = nn.BatchNorm1d(64)
168 | self.bn53 = nn.BatchNorm1d(512)
169 | self.bn54 = nn.BatchNorm1d(1024)
170 | self.bn63 = nn.BatchNorm1d(128)
171 | self.bn64 = nn.BatchNorm1d(50)
172 | self.bn7 = nn.BatchNorm1d(1024)
173 |
174 | self.bn8 = nn.BatchNorm1d(1024)
175 | self.bn9 = nn.BatchNorm1d(1024)
176 |
177 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
178 | self.bn1,
179 | nn.LeakyReLU(negative_slope=0.2))
180 | self.conv11 = nn.Sequential(nn.Conv2d(64, 3, kernel_size=[1,20], bias=False),
181 | self.bn11,
182 | nn.LeakyReLU(negative_slope=0.2))
183 | self.conv12 = nn.Sequential(nn.Conv1d(448, 512, kernel_size=1, bias=False),
184 | self.bn12,
185 | nn.LeakyReLU(negative_slope=0.2))
186 | self.conv13 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
187 | self.bn13,
188 | nn.LeakyReLU(negative_slope=0.2))
189 | self.conv14 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
190 | self.bn14,
191 | nn.LeakyReLU(negative_slope=0.2))
192 | self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
193 | self.bn2,
194 | nn.LeakyReLU(negative_slope=0.2))
195 | self.conv21 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
196 | self.bn21,
197 | nn.LeakyReLU(negative_slope=0.2))
198 | self.conv22 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
199 | self.bn22,
200 | nn.LeakyReLU(negative_slope=0.2))
201 | self.conv23 = nn.Sequential(nn.Conv1d(192, 128, kernel_size=1, bias=False),
202 | self.bn23,
203 | nn.LeakyReLU(negative_slope=0.2))
204 | self.conv24 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
205 | self.bn24,
206 | nn.LeakyReLU(negative_slope=0.2))
207 | self.conv3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
208 | self.bn3,
209 | nn.LeakyReLU(negative_slope=0.2))
210 | self.conv32 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
211 | self.bn32,
212 | nn.LeakyReLU(negative_slope=0.2))
213 | self.conv33 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
214 | self.bn33,
215 | nn.LeakyReLU(negative_slope=0.2))
216 | self.conv34 = nn.Sequential(nn.Conv1d(384, 128, kernel_size=1, bias=False),
217 | self.bn34,
218 | nn.LeakyReLU(negative_slope=0.2))
219 | self.conv35 = nn.Sequential(nn.Conv1d(320, 128, kernel_size=1, bias=False),
220 | self.bn35,
221 | nn.LeakyReLU(negative_slope=0.2))
222 | self.conv5 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
223 | self.bn5,
224 | nn.LeakyReLU(negative_slope=0.2))
225 | self.conv53 = nn.Sequential(nn.Conv1d(256, 512, kernel_size=1, bias=False),
226 | self.bn53,
227 | nn.LeakyReLU(negative_slope=0.2))
228 | self.conv54 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
229 | self.bn54,
230 | nn.LeakyReLU(negative_slope=0.2))
231 | self.conv63 = nn.Sequential(nn.Conv1d(1024, 128, kernel_size=1, bias=False),
232 | self.bn63,
233 | nn.LeakyReLU(negative_slope=0.2))
234 | self.conv64 = nn.Sequential(nn.Conv1d(128, 50, kernel_size=1))
235 |
236 | self.conv7 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
237 | self.bn7,
238 | nn.LeakyReLU(negative_slope=0.2))
239 | self.conv8 = nn.Sequential(nn.Conv1d(2112, 1024, kernel_size=1, bias=False),
240 | self.bn8,
241 | nn.LeakyReLU(negative_slope=0.2))
242 | self.conv9 = nn.Sequential(nn.Conv1d(1024, 1024, kernel_size=1, bias=False),
243 | self.bn9,
244 | nn.LeakyReLU(negative_slope=0.2))
245 |
246 | self.bnfc2 = nn.BatchNorm2d(64)
247 | self.bnfc21 = nn.BatchNorm2d(64)
248 | self.bnfc24 = nn.BatchNorm2d(64)
249 | self.bnfc3 = nn.BatchNorm2d(128)
250 | self.bnfc31 = nn.BatchNorm2d(64)
251 | self.bnfc4 = nn.BatchNorm2d(256)
252 | self.bnfc41 = nn.BatchNorm2d(128)
253 | self.convfc2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
254 | self.bnfc2,
255 | nn.LeakyReLU(negative_slope=0.2))
256 | self.convfc21 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=[1,20], bias=False),
257 | self.bnfc21,
258 | nn.LeakyReLU(negative_slope=0.2))
259 | self.convfc24 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
260 | self.bnfc24,
261 | nn.LeakyReLU(negative_slope=0.2))
262 | self.convfc3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False),
263 | self.bnfc3,
264 | nn.LeakyReLU(negative_slope=0.2))
265 | self.convfc31 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=[1,20], bias=False),
266 | self.bnfc31,
267 | nn.LeakyReLU(negative_slope=0.2))
268 | self.convfc4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
269 | self.bnfc4,
270 | nn.LeakyReLU(negative_slope=0.2))
271 | self.convfc41 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=[1,20], bias=False),
272 | self.bnfc41,
273 | nn.LeakyReLU(negative_slope=0.2))
274 |
275 | self.bnop1 = nn.BatchNorm1d(50)
276 | self.bnop2 = nn.BatchNorm1d(50)
277 | self.bnop3 = nn.BatchNorm1d(50)
278 | self.bnop4 = nn.BatchNorm1d(50)
279 | self.bnop11 = nn.BatchNorm1d(1)
280 | self.bnop21 = nn.BatchNorm1d(1)
281 | self.bnop31 = nn.BatchNorm1d(1)
282 | self.bnop41 = nn.BatchNorm1d(1)
283 | self.bnop12 = nn.BatchNorm1d(1)
284 | self.bnop22 = nn.BatchNorm1d(1)
285 | self.bnop32 = nn.BatchNorm1d(1)
286 | self.bnop42 = nn.BatchNorm1d(1)
287 | self.conv_op1 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop1, nn.LeakyReLU(negative_slope=0.2))
288 | self.conv_op2 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop2, nn.LeakyReLU(negative_slope=0.2))
289 | self.conv_op3 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop3, nn.LeakyReLU(negative_slope=0.2))
290 | self.conv_op4 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=50, kernel_size=1), self.bnop4, nn.LeakyReLU(negative_slope=0.2))
291 | self.conv_op11 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop11)
292 | self.conv_op21 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop21)
293 | self.conv_op31 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop31)
294 | self.conv_op41 = nn.Sequential(nn.Conv1d(in_channels=100, out_channels=1, kernel_size=1), self.bnop41)
295 | self.conv_op12 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop12)
296 | self.conv_op22 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop22)
297 | self.conv_op32 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop32)
298 | self.conv_op42 = nn.Sequential(nn.Conv1d(in_channels=50, out_channels=1, kernel_size=1), self.bnop42)
299 |
300 | self.fuse = Channel_fusion(1024)
301 |
302 | self.dp = nn.Dropout(p=0.5)
303 |
304 | def _break_up_pc(self, pc):
305 | xyz = pc[..., 0:3].contiguous()
306 | features = (
307 | pc[..., 3:].transpose(1, 2).contiguous()
308 | if pc.size(-1) > 3 else None
309 | )
310 |
311 | return xyz, features
312 |
313 | def forward(self, pointcloud: torch.cuda.FloatTensor, cls):
314 | # x: B,3,N
315 |
316 | xyz, features = self._break_up_pc(pointcloud)
317 | num_pts = xyz.size(1)
318 | batch_size = xyz.size(0)
319 | # FPS to find different point subsets and their relations
320 | subset1_idx = pointnet2_utils.furthest_point_sample(xyz, num_pts//4).long() # B,N/2
321 | subset1_xyz = torch.unsqueeze(subset1_idx, -1).repeat(1, 1, 3) # B,N/2,3
322 | subset1_xyz = torch.take(xyz, subset1_xyz) # B,N/2,3
323 |
324 | dist, idx1 = pointnet2_utils.three_nn(xyz, subset1_xyz)
325 | dist_recip = 1.0 / (dist + 1e-8)
326 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
327 | weight1 = dist_recip / norm
328 |
329 | subset12_idx = pointnet2_utils.furthest_point_sample(subset1_xyz, num_pts//16).long() # B,N/4
330 | subset12_xyz = torch.unsqueeze(subset12_idx, -1).repeat(1, 1, 3) # B,N/4,3
331 | subset12_xyz = torch.take(subset1_xyz, subset12_xyz) # B,N/4,3
332 |
333 | dist, idx12 = pointnet2_utils.three_nn(subset1_xyz, subset12_xyz)
334 | dist_recip = 1.0 / (dist + 1e-8)
335 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
336 | weight12 = dist_recip / norm
337 |
338 | device = torch.device('cuda')
339 | centroid = torch.zeros([batch_size, 1, 3], device=device)
340 | dist, idx0 = pointnet2_utils.three_nn(subset12_xyz, centroid)
341 | dist_recip = 1.0 / (dist + 1e-8)
342 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
343 | weight0 = dist_recip / norm
344 | #######################################
345 | # Error-minimizing module 1:
346 | # Encoding
347 | x = xyz.transpose(2, 1) # x: B,3,N
348 | x1_1 = x
349 | x = get_adptive_dilated_graph_feature(x, self.conv_op1, self.conv_op11, self.conv_op12, d=5, k=20)
350 | x = self.conv1(x) # B,64,N,k
351 | x = self.conv14(x) # B,64,N,k
352 | x1_2 = x
353 | # Back-projection
354 | x = self.conv11(x) # B,3,N,1
355 | x = torch.squeeze(x, -1) # B,3,N
356 | x1_3 = x
357 | # Calculating Error
358 | delta_1 = x1_3 - x1_1 # B,3,N
359 | # Output
360 | x = x1_2 # B,64,N,k
361 | x1 = x.max(dim=-1, keepdim=False)[0] # B,64,N
362 | #######################################
363 |
364 | #######################################
365 | # Multi-resolution (MR) Branch
366 | # Down-scaling 1
367 | subset1_feat = torch.unsqueeze(subset1_idx, -1).repeat(1, 1, 64) # B,N/2,64
368 | x1_subset1 = torch.take(x1.transpose(1, 2).contiguous(), subset1_feat).transpose(1, 2).contiguous() # B,64,N/2
369 |
370 | x2_1 = x1_subset1 # B,64,N/2
371 | x = get_graph_feature(x1_subset1, k=self.k//2)
372 | x = self.conv2(x) # B,64,N/2,k
373 | x = self.conv24(x) # B,128,N/2,k
374 | x2 = x.max(dim=-1, keepdim=False)[0] # B,128,N/2
375 |
376 | # Dense-connection
377 | x12 = pointnet2_utils.three_interpolate(x2, idx1, weight1) # B,128,N
378 | x12 = torch.cat((x12, x1), dim=1) # B,192,N
379 | x12 = self.conv23(x12) # B,128,N
380 |
381 | # Down-scaling 2
382 | subset12_feat = torch.unsqueeze(subset12_idx, -1).repeat(1, 1, 128) # B,N/4,128
383 | x2_subset12 = torch.take(x2.transpose(1, 2).contiguous(), subset12_feat).transpose(1, 2).contiguous() # B,128,N/4
384 |
385 | x3_1 = x2_subset12 # B,128,N/4
386 | x = get_graph_feature(x2_subset12, k=self.k//4)
387 | x = self.conv3(x) # B,256,N/4,k
388 | x3 = x.max(dim=-1, keepdim=False)[0] # B,256,N/4
389 |
390 | # Dense-connection
391 | x23 = pointnet2_utils.three_interpolate(x3, idx12, weight12) # B,256,N/2
392 | x23 = torch.cat((x23, x2), dim=1) # B,384,N/2
393 | x23 = self.conv34(x23) # B,128,N/2
394 | x123 = pointnet2_utils.three_interpolate(x23, idx1, weight1) # B,128,N
395 | x123 = torch.cat((x123, x12, x1), dim=1) # B,320,N
396 | x123 = self.conv35(x123) # B,128,N
397 |
398 | # Down-scaling 3
399 | x_bot = self.conv53(x3)
400 | x_bot = self.conv54(x_bot) # B,1024,N/128
401 | x_bot = F.adaptive_max_pool1d(x_bot, 1) # B,1024,1
402 |
403 | # Upsampling 3:
404 | interpolated_feats1 = pointnet2_utils.three_interpolate(x_bot, idx0, weight0) # B,1024,N/4
405 | interpolated_feats2 = x3 # B,256,N/4
406 | x3_up = torch.cat((interpolated_feats1, interpolated_feats2), dim=1) # B,1280,N/4
407 | x3_up = self.conv32(x3_up) # B,256,N/4
408 | x3_up = self.conv33(x3_up) # B,256,N/4
409 |
410 | # Upsampling 2:
411 | interpolated_feats1 = pointnet2_utils.three_interpolate(x3_up, idx12, weight12) # B,256,N/2
412 | interpolated_feats2 = x2 # B,128,N/2
413 | interpolated_feats3 = x23 # B,128,N/2
414 | x2_up = torch.cat((interpolated_feats1, interpolated_feats3, interpolated_feats2), dim=1) # B,512,N/2
415 | x2_up = self.conv21(x2_up) # B,256,N/2
416 | x2_up = self.conv22(x2_up) # B,128,N/2
417 |
418 | # Upsampling 1:
419 | interpolated_feats1 = pointnet2_utils.three_interpolate(x2_up, idx1, weight1) # B,128,N
420 | interpolated_feats2 = x1 # B,64,N
421 | interpolated_feats3 = x12 # B,128,N
422 | interpolated_feats4 = x123 # B,128,N
423 | x1_up = torch.cat((interpolated_feats1, interpolated_feats4, interpolated_feats3, interpolated_feats2), dim=1) # B,448,N
424 | x1_up = self.conv12(x1_up) # B,512,N
425 | x1_up = self.conv13(x1_up) # B,1024,N
426 |
427 | x_mr = x1_up
428 | #############################################################################
429 |
430 | #############################################################################
431 | # Full-resolution Branch
432 | # Error-minimizing module 2:
433 | # Encoding
434 | x2_1 = x1 # B,64,N
435 | x = get_adptive_dilated_graph_feature(x1, self.conv_op2, self.conv_op21, self.conv_op22, d=5, k=20)
436 | x = self.convfc2(x) # B,64,N,k
437 | x = self.convfc24(x) # B,64,N,k
438 | x2_2 = x
439 | # Back-projection
440 | x = self.convfc21(x) # B,64,N,1
441 | x = torch.squeeze(x, -1) # B,64,N
442 | x2_3 = x
443 | # Calculating Error
444 | delta_2 = x2_3 - x2_1 # B,64,N
445 | # Output
446 | x = x2_2 # B,64,N,k
447 | x2 = x.max(dim=-1, keepdim=False)[0] # B,64,N
448 | #######################################
449 | # Error-minimizing module 3:
450 | # Encoding
451 | x3_1 = x2 # B,64,N
452 | x = get_adptive_dilated_graph_feature(x2, self.conv_op3, self.conv_op31, self.conv_op32, d=5, k=20)
453 | x = self.convfc3(x) # B,128,N,k
454 | x3_2 = x
455 | # Back-projection
456 | x = self.convfc31(x) # B,64,N,1
457 | x = torch.squeeze(x, -1) # B,64,N
458 | x3_3 = x
459 | # Calculating Error
460 | delta_3 = x3_3 - x3_1 # B,64,N
461 | # Output
462 | x = x3_2 # B,128,N,k
463 | x3 = x.max(dim=-1, keepdim=False)[0] # B,128,N
464 | #######################################
465 | # Error-minimizing module 4:
466 | # Encoding
467 | x4_1 = x3 # B,128,N
468 | x = get_adptive_dilated_graph_feature(x3, self.conv_op4, self.conv_op41, self.conv_op42, d=5, k=20)
469 | x = self.convfc4(x) # B,256,N,k
470 | x4_2 = x
471 | # Back-projection
472 | x = self.convfc41(x) # B,128,N,1
473 | x = torch.squeeze(x, -1) # B,128,N
474 | x4_3 = x
475 | # Calculating Error
476 | delta_4 = x4_3 - x4_1 # B,128,N
477 | # Output
478 | x = x4_2 # B,256,N,k
479 | x4 = x.max(dim=-1, keepdim=False)[0] # B,256,N
480 |
481 | x = torch.cat((x1, x2, x3, x4), dim=1) # B,512,N
482 | x_fr = self.conv7(x) # B,1024,N
483 |
484 | # Fusing FR and MR outputs
485 | fusion_score = self.fuse(x_mr)
486 | x = x_fr + x_fr * fusion_score
487 | x_all = self.conv9(x) # B,1024,N
488 |
489 | # Collecting global feature
490 | one_hot_label = cls.view(-1, 16, 1) # B,16,1
491 | one_hot_label = self.conv5(one_hot_label) # B,64,1
492 | x_max = F.adaptive_max_pool1d(x_all, 1) # B,1024,1
493 | x_global = torch.cat((x_max, one_hot_label), dim=1) # B,1088,1
494 |
495 | x_global = x_global.repeat(1, 1, num_pts) # B,1088,N
496 | x = torch.cat((x_all, x_global), dim=1) # B,2112,N
497 |
498 | x = self.conv8(x) # B,1024,N
499 |
500 | x = self.conv63(x) # B,128,N
501 | x = self.dp(x)
502 | x = self.conv64(x) # B,50,N
503 |
504 | return (x.transpose(2, 1).contiguous(),
505 | delta_1.transpose(2, 1).contiguous(),
506 | delta_2.transpose(2, 1).contiguous(),
507 | delta_3.transpose(2, 1).contiguous(),
508 | delta_4.transpose(2, 1).contiguous())
509 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5 | const int nsample);
6 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #define TOTAL_THREADS 512
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 |
18 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
19 | }
20 |
21 | inline dim3 opt_block_config(int x, int y) {
22 | const int x_threads = opt_n_threads(x);
23 | const int y_threads =
24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25 | dim3 block_config(x_threads, y_threads, 1);
26 |
27 | return block_config;
28 | }
29 |
30 | #define CUDA_CHECK_ERRORS() \
31 | do { \
32 | cudaError_t err = cudaGetLastError(); \
33 | if (cudaSuccess != err) { \
34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36 | __FILE__); \
37 | exit(-1); \
38 | } \
39 | } while (0)
40 |
41 | #endif
42 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/group_points.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
8 | at::Tensor weight);
9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
10 | at::Tensor weight, const int m);
11 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
7 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define CHECK_CUDA(x) \
6 | do { \
7 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
8 | } while (0)
9 |
10 | #define CHECK_CONTIGUOUS(x) \
11 | do { \
12 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_IS_INT(x) \
16 | do { \
17 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \
18 | #x " must be an int tensor"); \
19 | } while (0)
20 |
21 | #define CHECK_IS_FLOAT(x) \
22 | do { \
23 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \
24 | #x " must be a float tensor"); \
25 | } while (0)
26 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "utils.h"
3 |
4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5 | int nsample, const float *new_xyz,
6 | const float *xyz, int *idx);
7 |
8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9 | const int nsample) {
10 | CHECK_CONTIGUOUS(new_xyz);
11 | CHECK_CONTIGUOUS(xyz);
12 | CHECK_IS_FLOAT(new_xyz);
13 | CHECK_IS_FLOAT(xyz);
14 |
15 | if (new_xyz.type().is_cuda()) {
16 | CHECK_CUDA(xyz);
17 | }
18 |
19 | at::Tensor idx =
20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
22 |
23 | if (new_xyz.type().is_cuda()) {
24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25 | radius, nsample, new_xyz.data(),
26 | xyz.data(), idx.data());
27 | } else {
28 | AT_CHECK(false, "CPU not supported");
29 | }
30 |
31 | return idx;
32 | }
33 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
8 | // output: idx(b, m, nsample)
9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10 | int nsample,
11 | const float *__restrict__ new_xyz,
12 | const float *__restrict__ xyz,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | xyz += batch_index * n * 3;
16 | new_xyz += batch_index * m * 3;
17 | idx += m * nsample * batch_index;
18 |
19 | int index = threadIdx.x;
20 | int stride = blockDim.x;
21 |
22 | float radius2 = radius * radius;
23 | for (int j = index; j < m; j += stride) {
24 | float new_x = new_xyz[j * 3 + 0];
25 | float new_y = new_xyz[j * 3 + 1];
26 | float new_z = new_xyz[j * 3 + 2];
27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28 | float x = xyz[k * 3 + 0];
29 | float y = xyz[k * 3 + 1];
30 | float z = xyz[k * 3 + 2];
31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32 | (new_z - z) * (new_z - z);
33 | if (d2 < radius2) {
34 | if (cnt == 0) {
35 | for (int l = 0; l < nsample; ++l) {
36 | idx[j * nsample + l] = k;
37 | }
38 | }
39 | idx[j * nsample + cnt] = k;
40 | ++cnt;
41 | }
42 | }
43 | }
44 | }
45 |
46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47 | int nsample, const float *new_xyz,
48 | const float *xyz, int *idx) {
49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50 | query_ball_point_kernel<<>>(
51 | b, n, m, radius, nsample, new_xyz, xyz, idx);
52 |
53 | CUDA_CHECK_ERRORS();
54 | }
55 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "group_points.h"
3 | #include "interpolate.h"
4 | #include "sampling.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("gather_points", &gather_points);
8 | m.def("gather_points_grad", &gather_points_grad);
9 | m.def("furthest_point_sampling", &furthest_point_sampling);
10 |
11 | m.def("three_nn", &three_nn);
12 | m.def("three_interpolate", &three_interpolate);
13 | m.def("three_interpolate_grad", &three_interpolate_grad);
14 |
15 | m.def("ball_query", &ball_query);
16 |
17 | m.def("group_points", &group_points);
18 | m.def("group_points_grad", &group_points_grad);
19 | }
20 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include "group_points.h"
2 | #include "utils.h"
3 |
4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5 | const float *points, const int *idx,
6 | float *out);
7 |
8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9 | int nsample, const float *grad_out,
10 | const int *idx, float *grad_points);
11 |
12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13 | CHECK_CONTIGUOUS(points);
14 | CHECK_CONTIGUOUS(idx);
15 | CHECK_IS_FLOAT(points);
16 | CHECK_IS_INT(idx);
17 |
18 | if (points.type().is_cuda()) {
19 | CHECK_CUDA(idx);
20 | }
21 |
22 | at::Tensor output =
23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24 | at::device(points.device()).dtype(at::ScalarType::Float));
25 |
26 | if (points.type().is_cuda()) {
27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28 | idx.size(1), idx.size(2), points.data(),
29 | idx.data(), output.data());
30 | } else {
31 | AT_CHECK(false, "CPU not supported");
32 | }
33 |
34 | return output;
35 | }
36 |
37 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
38 | CHECK_CONTIGUOUS(grad_out);
39 | CHECK_CONTIGUOUS(idx);
40 | CHECK_IS_FLOAT(grad_out);
41 | CHECK_IS_INT(idx);
42 |
43 | if (grad_out.type().is_cuda()) {
44 | CHECK_CUDA(idx);
45 | }
46 |
47 | at::Tensor output =
48 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
49 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
50 |
51 | if (grad_out.type().is_cuda()) {
52 | group_points_grad_kernel_wrapper(
53 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
54 | grad_out.data(), idx.data(), output.data());
55 | } else {
56 | AT_CHECK(false, "CPU not supported");
57 | }
58 |
59 | return output;
60 | }
61 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, npoints, nsample)
7 | // output: out(b, c, npoints, nsample)
8 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
9 | int nsample,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | int batch_index = blockIdx.x;
14 | points += batch_index * n * c;
15 | idx += batch_index * npoints * nsample;
16 | out += batch_index * npoints * nsample * c;
17 |
18 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
19 | const int stride = blockDim.y * blockDim.x;
20 | for (int i = index; i < c * npoints; i += stride) {
21 | const int l = i / npoints;
22 | const int j = i % npoints;
23 | for (int k = 0; k < nsample; ++k) {
24 | int ii = idx[j * nsample + k];
25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26 | }
27 | }
28 | }
29 |
30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31 | const float *points, const int *idx,
32 | float *out) {
33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 |
35 | group_points_kernel<<>>(
36 | b, c, n, npoints, nsample, points, idx, out);
37 |
38 | CUDA_CHECK_ERRORS();
39 | }
40 |
41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42 | // output: grad_points(b, c, n)
43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44 | int nsample,
45 | const float *__restrict__ grad_out,
46 | const int *__restrict__ idx,
47 | float *__restrict__ grad_points) {
48 | int batch_index = blockIdx.x;
49 | grad_out += batch_index * npoints * nsample * c;
50 | idx += batch_index * npoints * nsample;
51 | grad_points += batch_index * n * c;
52 |
53 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
54 | const int stride = blockDim.y * blockDim.x;
55 | for (int i = index; i < c * npoints; i += stride) {
56 | const int l = i / npoints;
57 | const int j = i % npoints;
58 | for (int k = 0; k < nsample; ++k) {
59 | int ii = idx[j * nsample + k];
60 | atomicAdd(grad_points + l * n + ii,
61 | grad_out[(l * npoints + j) * nsample + k]);
62 | }
63 | }
64 | }
65 |
66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67 | int nsample, const float *grad_out,
68 | const int *idx, float *grad_points) {
69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70 |
71 | group_points_grad_kernel<<>>(
72 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
73 |
74 | CUDA_CHECK_ERRORS();
75 | }
76 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include "interpolate.h"
2 | #include "utils.h"
3 |
4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5 | const float *known, float *dist2, int *idx);
6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
7 | const float *points, const int *idx,
8 | const float *weight, float *out);
9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
10 | const float *grad_out,
11 | const int *idx, const float *weight,
12 | float *grad_points);
13 |
14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
15 | CHECK_CONTIGUOUS(unknowns);
16 | CHECK_CONTIGUOUS(knows);
17 | CHECK_IS_FLOAT(unknowns);
18 | CHECK_IS_FLOAT(knows);
19 |
20 | if (unknowns.type().is_cuda()) {
21 | CHECK_CUDA(knows);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
26 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
27 | at::Tensor dist2 =
28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
29 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
30 |
31 | if (unknowns.type().is_cuda()) {
32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
33 | unknowns.data(), knows.data(),
34 | dist2.data(), idx.data());
35 | } else {
36 | AT_CHECK(false, "CPU not supported");
37 | }
38 |
39 | return {dist2, idx};
40 | }
41 |
42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
43 | at::Tensor weight) {
44 | CHECK_CONTIGUOUS(points);
45 | CHECK_CONTIGUOUS(idx);
46 | CHECK_CONTIGUOUS(weight);
47 | CHECK_IS_FLOAT(points);
48 | CHECK_IS_INT(idx);
49 | CHECK_IS_FLOAT(weight);
50 |
51 | if (points.type().is_cuda()) {
52 | CHECK_CUDA(idx);
53 | CHECK_CUDA(weight);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
58 | at::device(points.device()).dtype(at::ScalarType::Float));
59 |
60 | if (points.type().is_cuda()) {
61 | three_interpolate_kernel_wrapper(
62 | points.size(0), points.size(1), points.size(2), idx.size(1),
63 | points.data(), idx.data(), weight.data(),
64 | output.data());
65 | } else {
66 | AT_CHECK(false, "CPU not supported");
67 | }
68 |
69 | return output;
70 | }
71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
72 | at::Tensor weight, const int m) {
73 | CHECK_CONTIGUOUS(grad_out);
74 | CHECK_CONTIGUOUS(idx);
75 | CHECK_CONTIGUOUS(weight);
76 | CHECK_IS_FLOAT(grad_out);
77 | CHECK_IS_INT(idx);
78 | CHECK_IS_FLOAT(weight);
79 |
80 | if (grad_out.type().is_cuda()) {
81 | CHECK_CUDA(idx);
82 | CHECK_CUDA(weight);
83 | }
84 |
85 | at::Tensor output =
86 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
87 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
88 |
89 | if (grad_out.type().is_cuda()) {
90 | three_interpolate_grad_kernel_wrapper(
91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
92 | grad_out.data(), idx.data(), weight.data(),
93 | output.data());
94 | } else {
95 | AT_CHECK(false, "CPU not supported");
96 | }
97 |
98 | return output;
99 | }
100 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: unknown(b, n, 3) known(b, m, 3)
8 | // output: dist2(b, n, 3), idx(b, n, 3)
9 | __global__ void three_nn_kernel(int b, int n, int m,
10 | const float *__restrict__ unknown,
11 | const float *__restrict__ known,
12 | float *__restrict__ dist2,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | unknown += batch_index * n * 3;
16 | known += batch_index * m * 3;
17 | dist2 += batch_index * n * 3;
18 | idx += batch_index * n * 3;
19 |
20 | int index = threadIdx.x;
21 | int stride = blockDim.x;
22 | for (int j = index; j < n; j += stride) {
23 | float ux = unknown[j * 3 + 0];
24 | float uy = unknown[j * 3 + 1];
25 | float uz = unknown[j * 3 + 2];
26 |
27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28 | int besti1 = 0, besti2 = 0, besti3 = 0;
29 | for (int k = 0; k < m; ++k) {
30 | float x = known[k * 3 + 0];
31 | float y = known[k * 3 + 1];
32 | float z = known[k * 3 + 2];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34 | if (d < best1) {
35 | best3 = best2;
36 | besti3 = besti2;
37 | best2 = best1;
38 | besti2 = besti1;
39 | best1 = d;
40 | besti1 = k;
41 | } else if (d < best2) {
42 | best3 = best2;
43 | besti3 = besti2;
44 | best2 = d;
45 | besti2 = k;
46 | } else if (d < best3) {
47 | best3 = d;
48 | besti3 = k;
49 | }
50 | }
51 | dist2[j * 3 + 0] = best1;
52 | dist2[j * 3 + 1] = best2;
53 | dist2[j * 3 + 2] = best3;
54 |
55 | idx[j * 3 + 0] = besti1;
56 | idx[j * 3 + 1] = besti2;
57 | idx[j * 3 + 2] = besti3;
58 | }
59 | }
60 |
61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62 | const float *known, float *dist2, int *idx) {
63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64 | three_nn_kernel<<>>(b, n, m, unknown, known,
65 | dist2, idx);
66 |
67 | CUDA_CHECK_ERRORS();
68 | }
69 |
70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
71 | // output: out(b, c, n)
72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
73 | const float *__restrict__ points,
74 | const int *__restrict__ idx,
75 | const float *__restrict__ weight,
76 | float *__restrict__ out) {
77 | int batch_index = blockIdx.x;
78 | points += batch_index * m * c;
79 |
80 | idx += batch_index * n * 3;
81 | weight += batch_index * n * 3;
82 |
83 | out += batch_index * n * c;
84 |
85 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
86 | const int stride = blockDim.y * blockDim.x;
87 | for (int i = index; i < c * n; i += stride) {
88 | const int l = i / n;
89 | const int j = i % n;
90 | float w1 = weight[j * 3 + 0];
91 | float w2 = weight[j * 3 + 1];
92 | float w3 = weight[j * 3 + 2];
93 |
94 | int i1 = idx[j * 3 + 0];
95 | int i2 = idx[j * 3 + 1];
96 | int i3 = idx[j * 3 + 2];
97 |
98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
99 | points[l * m + i3] * w3;
100 | }
101 | }
102 |
103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
104 | const float *points, const int *idx,
105 | const float *weight, float *out) {
106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107 | three_interpolate_kernel<<>>(
108 | b, c, m, n, points, idx, weight, out);
109 |
110 | CUDA_CHECK_ERRORS();
111 | }
112 |
113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
114 | // output: grad_points(b, c, m)
115 |
116 | __global__ void three_interpolate_grad_kernel(
117 | int b, int c, int n, int m, const float *__restrict__ grad_out,
118 | const int *__restrict__ idx, const float *__restrict__ weight,
119 | float *__restrict__ grad_points) {
120 | int batch_index = blockIdx.x;
121 | grad_out += batch_index * n * c;
122 | idx += batch_index * n * 3;
123 | weight += batch_index * n * 3;
124 | grad_points += batch_index * m * c;
125 |
126 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
127 | const int stride = blockDim.y * blockDim.x;
128 | for (int i = index; i < c * n; i += stride) {
129 | const int l = i / n;
130 | const int j = i % n;
131 | float w1 = weight[j * 3 + 0];
132 | float w2 = weight[j * 3 + 1];
133 | float w3 = weight[j * 3 + 2];
134 |
135 | int i1 = idx[j * 3 + 0];
136 | int i2 = idx[j * 3 + 1];
137 | int i3 = idx[j * 3 + 2];
138 |
139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
142 | }
143 | }
144 |
145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
146 | const float *grad_out,
147 | const int *idx, const float *weight,
148 | float *grad_points) {
149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150 | three_interpolate_grad_kernel<<>>(
151 | b, c, n, m, grad_out, idx, weight, grad_points);
152 |
153 | CUDA_CHECK_ERRORS();
154 | }
155 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "sampling.h"
2 | #include "utils.h"
3 |
4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5 | const float *points, const int *idx,
6 | float *out);
7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8 | const float *grad_out, const int *idx,
9 | float *grad_points);
10 |
11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12 | const float *dataset, float *temp,
13 | int *idxs);
14 |
15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16 | CHECK_CONTIGUOUS(points);
17 | CHECK_CONTIGUOUS(idx);
18 | CHECK_IS_FLOAT(points);
19 | CHECK_IS_INT(idx);
20 |
21 | if (points.type().is_cuda()) {
22 | CHECK_CUDA(idx);
23 | }
24 |
25 | at::Tensor output =
26 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
27 | at::device(points.device()).dtype(at::ScalarType::Float));
28 |
29 | if (points.type().is_cuda()) {
30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31 | idx.size(1), points.data(),
32 | idx.data(), output.data());
33 | } else {
34 | AT_CHECK(false, "CPU not supported");
35 | }
36 |
37 | return output;
38 | }
39 |
40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41 | const int n) {
42 | CHECK_CONTIGUOUS(grad_out);
43 | CHECK_CONTIGUOUS(idx);
44 | CHECK_IS_FLOAT(grad_out);
45 | CHECK_IS_INT(idx);
46 |
47 | if (grad_out.type().is_cuda()) {
48 | CHECK_CUDA(idx);
49 | }
50 |
51 | at::Tensor output =
52 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
53 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
54 |
55 | if (grad_out.type().is_cuda()) {
56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57 | idx.size(1), grad_out.data(),
58 | idx.data(), output.data());
59 | } else {
60 | AT_CHECK(false, "CPU not supported");
61 | }
62 |
63 | return output;
64 | }
65 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
66 | CHECK_CONTIGUOUS(points);
67 | CHECK_IS_FLOAT(points);
68 |
69 | at::Tensor output =
70 | torch::zeros({points.size(0), nsamples},
71 | at::device(points.device()).dtype(at::ScalarType::Int));
72 |
73 | at::Tensor tmp =
74 | torch::full({points.size(0), points.size(1)}, 1e10,
75 | at::device(points.device()).dtype(at::ScalarType::Float));
76 |
77 | if (points.type().is_cuda()) {
78 | furthest_point_sampling_kernel_wrapper(
79 | points.size(0), points.size(1), nsamples, points.data(),
80 | tmp.data(), output.data());
81 | } else {
82 | AT_CHECK(false, "CPU not supported");
83 | }
84 |
85 | return output;
86 | }
87 |
--------------------------------------------------------------------------------
/pointnet2/_ext-src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, m)
7 | // output: out(b, c, m)
8 | __global__ void gather_points_kernel(int b, int c, int n, int m,
9 | const float *__restrict__ points,
10 | const int *__restrict__ idx,
11 | float *__restrict__ out) {
12 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
13 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
14 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
15 | int a = idx[i * m + j];
16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
17 | }
18 | }
19 | }
20 | }
21 |
22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
23 | const float *points, const int *idx,
24 | float *out) {
25 | gather_points_kernel<<>>(b, c, n, npoints,
27 | points, idx, out);
28 |
29 | CUDA_CHECK_ERRORS();
30 | }
31 |
32 | // input: grad_out(b, c, m) idx(b, m)
33 | // output: grad_points(b, c, n)
34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
35 | const float *__restrict__ grad_out,
36 | const int *__restrict__ idx,
37 | float *__restrict__ grad_points) {
38 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
39 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
40 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
41 | int a = idx[i * m + j];
42 | atomicAdd(grad_points + (i * c + l) * n + a,
43 | grad_out[(i * c + l) * m + j]);
44 | }
45 | }
46 | }
47 | }
48 |
49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
50 | const float *grad_out, const int *idx,
51 | float *grad_points) {
52 | gather_points_grad_kernel<<>>(
54 | b, c, n, npoints, grad_out, idx, grad_points);
55 |
56 | CUDA_CHECK_ERRORS();
57 | }
58 |
59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
60 | int idx1, int idx2) {
61 | const float v1 = dists[idx1], v2 = dists[idx2];
62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
63 | dists[idx1] = max(v1, v2);
64 | dists_i[idx1] = v2 > v1 ? i2 : i1;
65 | }
66 |
67 | // Input dataset: (b, n, 3), tmp: (b, n)
68 | // Ouput idxs (b, m)
69 | template
70 | __global__ void furthest_point_sampling_kernel(
71 | int b, int n, int m, const float *__restrict__ dataset,
72 | float *__restrict__ temp, int *__restrict__ idxs) {
73 | if (m <= 0) return;
74 | __shared__ float dists[block_size];
75 | __shared__ int dists_i[block_size];
76 |
77 | int batch_index = blockIdx.x;
78 | dataset += batch_index * n * 3;
79 | temp += batch_index * n;
80 | idxs += batch_index * m;
81 |
82 | int tid = threadIdx.x;
83 | const int stride = block_size;
84 |
85 | int old = 0;
86 | if (threadIdx.x == 0) idxs[0] = old;
87 |
88 | __syncthreads();
89 | for (int j = 1; j < m; j++) {
90 | int besti = 0;
91 | float best = -1;
92 | float x1 = dataset[old * 3 + 0];
93 | float y1 = dataset[old * 3 + 1];
94 | float z1 = dataset[old * 3 + 2];
95 | for (int k = tid; k < n; k += stride) {
96 | float x2, y2, z2;
97 | x2 = dataset[k * 3 + 0];
98 | y2 = dataset[k * 3 + 1];
99 | z2 = dataset[k * 3 + 2];
100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
101 | if (mag <= 1e-3) continue;
102 |
103 | float d =
104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
105 |
106 | float d2 = min(d, temp[k]);
107 | temp[k] = d2;
108 | besti = d2 > best ? k : besti;
109 | best = d2 > best ? d2 : best;
110 | }
111 | dists[tid] = best;
112 | dists_i[tid] = besti;
113 | __syncthreads();
114 |
115 | if (block_size >= 512) {
116 | if (tid < 256) {
117 | __update(dists, dists_i, tid, tid + 256);
118 | }
119 | __syncthreads();
120 | }
121 | if (block_size >= 256) {
122 | if (tid < 128) {
123 | __update(dists, dists_i, tid, tid + 128);
124 | }
125 | __syncthreads();
126 | }
127 | if (block_size >= 128) {
128 | if (tid < 64) {
129 | __update(dists, dists_i, tid, tid + 64);
130 | }
131 | __syncthreads();
132 | }
133 | if (block_size >= 64) {
134 | if (tid < 32) {
135 | __update(dists, dists_i, tid, tid + 32);
136 | }
137 | __syncthreads();
138 | }
139 | if (block_size >= 32) {
140 | if (tid < 16) {
141 | __update(dists, dists_i, tid, tid + 16);
142 | }
143 | __syncthreads();
144 | }
145 | if (block_size >= 16) {
146 | if (tid < 8) {
147 | __update(dists, dists_i, tid, tid + 8);
148 | }
149 | __syncthreads();
150 | }
151 | if (block_size >= 8) {
152 | if (tid < 4) {
153 | __update(dists, dists_i, tid, tid + 4);
154 | }
155 | __syncthreads();
156 | }
157 | if (block_size >= 4) {
158 | if (tid < 2) {
159 | __update(dists, dists_i, tid, tid + 2);
160 | }
161 | __syncthreads();
162 | }
163 | if (block_size >= 2) {
164 | if (tid < 1) {
165 | __update(dists, dists_i, tid, tid + 1);
166 | }
167 | __syncthreads();
168 | }
169 |
170 | old = dists_i[0];
171 | if (tid == 0) idxs[j] = old;
172 | }
173 | }
174 |
175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
176 | const float *dataset, float *temp,
177 | int *idxs) {
178 | unsigned int n_threads = opt_n_threads(n);
179 |
180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181 |
182 | switch (n_threads) {
183 | case 512:
184 | furthest_point_sampling_kernel<512>
185 | <<>>(b, n, m, dataset, temp, idxs);
186 | break;
187 | case 256:
188 | furthest_point_sampling_kernel<256>
189 | <<>>(b, n, m, dataset, temp, idxs);
190 | break;
191 | case 128:
192 | furthest_point_sampling_kernel<128>
193 | <<>>(b, n, m, dataset, temp, idxs);
194 | break;
195 | case 64:
196 | furthest_point_sampling_kernel<64>
197 | <<>>(b, n, m, dataset, temp, idxs);
198 | break;
199 | case 32:
200 | furthest_point_sampling_kernel<32>
201 | <<>>(b, n, m, dataset, temp, idxs);
202 | break;
203 | case 16:
204 | furthest_point_sampling_kernel<16>
205 | <<>>(b, n, m, dataset, temp, idxs);
206 | break;
207 | case 8:
208 | furthest_point_sampling_kernel<8>
209 | <<>>(b, n, m, dataset, temp, idxs);
210 | break;
211 | case 4:
212 | furthest_point_sampling_kernel<4>
213 | <<>>(b, n, m, dataset, temp, idxs);
214 | break;
215 | case 2:
216 | furthest_point_sampling_kernel<2>
217 | <<>>(b, n, m, dataset, temp, idxs);
218 | break;
219 | case 1:
220 | furthest_point_sampling_kernel<1>
221 | <<>>(b, n, m, dataset, temp, idxs);
222 | break;
223 | default:
224 | furthest_point_sampling_kernel<512>
225 | <<>>(b, n, m, dataset, temp, idxs);
226 | }
227 |
228 | CUDA_CHECK_ERRORS();
229 | }
230 |
--------------------------------------------------------------------------------
/pointnet2/_ext.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/_ext.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/pointnet2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | from . import pointnet2_utils
9 | from . import pointnet2_modules
10 |
--------------------------------------------------------------------------------
/pointnet2/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pointnet2/utils/__pycache__/pointnet2_modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/pointnet2_modules.cpython-36.pyc
--------------------------------------------------------------------------------
/pointnet2/utils/__pycache__/pointnet2_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShiQiu0419/DRNet/edd9adceefbf8f6871abc565626d5f5cfb9571e0/pointnet2/utils/__pycache__/pointnet2_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/pointnet2/utils/linalg_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import torch
9 | from enum import Enum
10 | import numpy as np
11 |
12 | PDist2Order = Enum("PDist2Order", "d_first d_second")
13 |
14 |
15 | def pdist2(X, Z=None, order=PDist2Order.d_second):
16 | # type: (torch.Tensor, torch.Tensor, PDist2Order) -> torch.Tensor
17 | r""" Calculates the pairwise distance between X and Z
18 |
19 | D[b, i, j] = l2 distance X[b, i] and Z[b, j]
20 |
21 | Parameters
22 | ---------
23 | X : torch.Tensor
24 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d
25 | Z: torch.Tensor
26 | Z is a (B, M, d) tensor. If Z is None, then Z = X
27 |
28 | Returns
29 | -------
30 | torch.Tensor
31 | Distance matrix is size (B, N, M)
32 | """
33 |
34 | if order == PDist2Order.d_second:
35 | if X.dim() == 2:
36 | X = X.unsqueeze(0)
37 | if Z is None:
38 | Z = X
39 | G = np.matmul(X, Z.transpose(-2, -1))
40 | S = (X * X).sum(-1, keepdim=True)
41 | R = S.transpose(-2, -1)
42 | else:
43 | if Z.dim() == 2:
44 | Z = Z.unsqueeze(0)
45 | G = np.matmul(X, Z.transpose(-2, -1))
46 | S = (X * X).sum(-1, keepdim=True)
47 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
48 | else:
49 | if X.dim() == 2:
50 | X = X.unsqueeze(0)
51 | if Z is None:
52 | Z = X
53 | G = np.matmul(X.transpose(-2, -1), Z)
54 | R = (X * X).sum(-2, keepdim=True)
55 | S = R.transpose(-2, -1)
56 | else:
57 | if Z.dim() == 2:
58 | Z = Z.unsqueeze(0)
59 | G = np.matmul(X.transpose(-2, -1), Z)
60 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
61 | R = (Z * Z).sum(-2, keepdim=True)
62 |
63 | return torch.abs(R + S - 2 * G).squeeze(0)
64 |
65 |
66 | def pdist2_slow(X, Z=None):
67 | if Z is None:
68 | Z = X
69 | D = torch.zeros(X.size(0), X.size(2), Z.size(2))
70 |
71 | for b in range(D.size(0)):
72 | for i in range(D.size(1)):
73 | for j in range(D.size(2)):
74 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j])
75 | return D
76 |
77 |
78 | if __name__ == "__main__":
79 | X = torch.randn(2, 3, 5)
80 | Z = torch.randn(2, 3, 3)
81 |
82 | print(pdist2(X, order=PDist2Order.d_first))
83 | print(pdist2_slow(X))
84 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X)))
85 |
--------------------------------------------------------------------------------
/pointnet2/utils/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import etw_pytorch_utils as pt_utils
12 |
13 | from pointnet2.utils import pointnet2_utils
14 |
15 | if False:
16 | # Workaround for type hints without depending on the `typing` module
17 | from typing import *
18 |
19 |
20 | class _PointnetSAModuleBase(nn.Module):
21 | def __init__(self):
22 | super(_PointnetSAModuleBase, self).__init__()
23 | self.npoint = None
24 | self.groupers = None
25 | self.mlps = None
26 |
27 | def forward(self, xyz, features=None):
28 | # type: (_PointnetSAModuleBase, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
29 | r"""
30 | Parameters
31 | ----------
32 | xyz : torch.Tensor
33 | (B, N, 3) tensor of the xyz coordinates of the features
34 | features : torch.Tensor
35 | (B, N, C) tensor of the descriptors of the the features
36 |
37 | Returns
38 | -------
39 | new_xyz : torch.Tensor
40 | (B, npoint, 3) tensor of the new features' xyz
41 | new_features : torch.Tensor
42 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
43 | """
44 |
45 | new_features_list = []
46 |
47 | xyz_flipped = xyz.transpose(1, 2).contiguous()
48 | new_xyz = (
49 | pointnet2_utils.gather_operation(
50 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
51 | )
52 | .transpose(1, 2)
53 | .contiguous()
54 | if self.npoint is not None
55 | else None
56 | )
57 |
58 | for i in range(len(self.groupers)):
59 | new_features = self.groupers[i](
60 | xyz, new_xyz, features
61 | ) # (B, C, npoint, nsample)
62 |
63 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
64 | new_features = F.max_pool2d(
65 | new_features, kernel_size=[1, new_features.size(3)]
66 | ) # (B, mlp[-1], npoint, 1)
67 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
68 |
69 | new_features_list.append(new_features)
70 |
71 | return new_xyz, torch.cat(new_features_list, dim=1)
72 |
73 |
74 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
75 | r"""Pointnet set abstrction layer with multiscale grouping
76 |
77 | Parameters
78 | ----------
79 | npoint : int
80 | Number of features
81 | radii : list of float32
82 | list of radii to group with
83 | nsamples : list of int32
84 | Number of samples in each ball query
85 | mlps : list of list of int32
86 | Spec of the pointnet before the global max_pool for each scale
87 | bn : bool
88 | Use batchnorm
89 | """
90 |
91 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
92 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
93 | super(PointnetSAModuleMSG, self).__init__()
94 |
95 | assert len(radii) == len(nsamples) == len(mlps)
96 |
97 | self.npoint = npoint
98 | self.groupers = nn.ModuleList()
99 | self.mlps = nn.ModuleList()
100 | for i in range(len(radii)):
101 | radius = radii[i]
102 | nsample = nsamples[i]
103 | self.groupers.append(
104 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
105 | if npoint is not None
106 | else pointnet2_utils.GroupAll(use_xyz)
107 | )
108 | mlp_spec = mlps[i]
109 | if use_xyz:
110 | mlp_spec[0] += 3
111 |
112 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn))
113 |
114 |
115 | class PointnetSAModule(PointnetSAModuleMSG):
116 | r"""Pointnet set abstrction layer
117 |
118 | Parameters
119 | ----------
120 | npoint : int
121 | Number of features
122 | radius : float
123 | Radius of ball
124 | nsample : int
125 | Number of samples in the ball query
126 | mlp : list
127 | Spec of the pointnet before the global max_pool
128 | bn : bool
129 | Use batchnorm
130 | """
131 |
132 | def __init__(
133 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
134 | ):
135 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
136 | super(PointnetSAModule, self).__init__(
137 | mlps=[mlp],
138 | npoint=npoint,
139 | radii=[radius],
140 | nsamples=[nsample],
141 | bn=bn,
142 | use_xyz=use_xyz,
143 | )
144 |
145 |
146 | class PointnetFPModule(nn.Module):
147 | r"""Propigates the features of one set to another
148 |
149 | Parameters
150 | ----------
151 | mlp : list
152 | Pointnet module parameters
153 | bn : bool
154 | Use batchnorm
155 | """
156 |
157 | def __init__(self, mlp, bn=True):
158 | # type: (PointnetFPModule, List[int], bool) -> None
159 | super(PointnetFPModule, self).__init__()
160 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
161 |
162 | def forward(self, unknown, known, unknow_feats, known_feats):
163 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
164 | r"""
165 | Parameters
166 | ----------
167 | unknown : torch.Tensor
168 | (B, n, 3) tensor of the xyz positions of the unknown features
169 | known : torch.Tensor
170 | (B, m, 3) tensor of the xyz positions of the known features
171 | unknow_feats : torch.Tensor
172 | (B, C1, n) tensor of the features to be propigated to
173 | known_feats : torch.Tensor
174 | (B, C2, m) tensor of features to be propigated
175 |
176 | Returns
177 | -------
178 | new_features : torch.Tensor
179 | (B, mlp[-1], n) tensor of the features of the unknown features
180 | """
181 |
182 | if known is not None:
183 | dist, idx = pointnet2_utils.three_nn(unknown, known)
184 | dist_recip = 1.0 / (dist + 1e-8)
185 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
186 | weight = dist_recip / norm
187 |
188 | interpolated_feats = pointnet2_utils.three_interpolate(
189 | known_feats, idx, weight
190 | )
191 | else:
192 | interpolated_feats = known_feats.expand(
193 | *(known_feats.size()[0:2] + [unknown.size(1)])
194 | )
195 |
196 | if unknow_feats is not None:
197 | new_features = torch.cat(
198 | [interpolated_feats, unknow_feats], dim=1
199 | ) # (B, C2 + C1, n)
200 | else:
201 | new_features = interpolated_feats
202 |
203 | new_features = new_features.unsqueeze(-1)
204 | new_features = self.mlp(new_features)
205 |
206 | return new_features.squeeze(-1)
207 |
208 |
209 | if __name__ == "__main__":
210 | from torch.autograd import Variable
211 |
212 | torch.manual_seed(1)
213 | torch.cuda.manual_seed_all(1)
214 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True)
215 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True)
216 |
217 | test_module = PointnetSAModuleMSG(
218 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]]
219 | )
220 | test_module.cuda()
221 | print(test_module(xyz, xyz_feats))
222 |
223 | # test_module = PointnetFPModule(mlp=[6, 6])
224 | # test_module.cuda()
225 | # from torch.autograd import gradcheck
226 | # inputs = (xyz, xyz, None, xyz_feats)
227 | # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4)
228 | # print(test)
229 |
230 | for _ in range(1):
231 | _, new_features = test_module(xyz, xyz_feats)
232 | new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1))
233 | print(new_features)
234 | print(xyz.grad)
235 |
--------------------------------------------------------------------------------
/pointnet2/utils/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import torch
9 | from torch.autograd import Function
10 | import torch.nn as nn
11 | import etw_pytorch_utils as pt_utils
12 | import sys
13 |
14 | try:
15 | import builtins
16 | except:
17 | import __builtin__ as builtins
18 |
19 | try:
20 | import pointnet2._ext as _ext
21 | except ImportError:
22 | if not getattr(builtins, "__POINTNET2_SETUP__", False):
23 | raise ImportError(
24 | "Could not import _ext module.\n"
25 | "Please see the setup instructions in the README: "
26 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst"
27 | )
28 |
29 | if False:
30 | # Workaround for type hints without depending on the `typing` module
31 | from typing import *
32 |
33 |
34 | class RandomDropout(nn.Module):
35 | def __init__(self, p=0.5, inplace=False):
36 | super(RandomDropout, self).__init__()
37 | self.p = p
38 | self.inplace = inplace
39 |
40 | def forward(self, X):
41 | theta = torch.Tensor(1).uniform_(0, self.p)[0]
42 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace)
43 |
44 |
45 | class FurthestPointSampling(Function):
46 | @staticmethod
47 | def forward(ctx, xyz, npoint):
48 | # type: (Any, torch.Tensor, int) -> torch.Tensor
49 | r"""
50 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
51 | minimum distance
52 |
53 | Parameters
54 | ----------
55 | xyz : torch.Tensor
56 | (B, N, 3) tensor where N > npoint
57 | npoint : int32
58 | number of features in the sampled set
59 |
60 | Returns
61 | -------
62 | torch.Tensor
63 | (B, npoint) tensor containing the set
64 | """
65 | return _ext.furthest_point_sampling(xyz, npoint)
66 |
67 | @staticmethod
68 | def backward(xyz, a=None):
69 | return None, None
70 |
71 |
72 | furthest_point_sample = FurthestPointSampling.apply
73 |
74 |
75 | class GatherOperation(Function):
76 | @staticmethod
77 | def forward(ctx, features, idx):
78 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
79 | r"""
80 |
81 | Parameters
82 | ----------
83 | features : torch.Tensor
84 | (B, C, N) tensor
85 |
86 | idx : torch.Tensor
87 | (B, npoint) tensor of the features to gather
88 |
89 | Returns
90 | -------
91 | torch.Tensor
92 | (B, C, npoint) tensor
93 | """
94 |
95 | _, C, N = features.size()
96 |
97 | ctx.for_backwards = (idx, C, N)
98 |
99 | return _ext.gather_points(features, idx)
100 |
101 | @staticmethod
102 | def backward(ctx, grad_out):
103 | idx, C, N = ctx.for_backwards
104 |
105 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
106 | return grad_features, None
107 |
108 |
109 | gather_operation = GatherOperation.apply
110 |
111 |
112 | class ThreeNN(Function):
113 | @staticmethod
114 | def forward(ctx, unknown, known):
115 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
116 | r"""
117 | Find the three nearest neighbors of unknown in known
118 | Parameters
119 | ----------
120 | unknown : torch.Tensor
121 | (B, n, 3) tensor of known features
122 | known : torch.Tensor
123 | (B, m, 3) tensor of unknown features
124 |
125 | Returns
126 | -------
127 | dist : torch.Tensor
128 | (B, n, 3) l2 distance to the three nearest neighbors
129 | idx : torch.Tensor
130 | (B, n, 3) index of 3 nearest neighbors
131 | """
132 | dist2, idx = _ext.three_nn(unknown, known)
133 |
134 | return torch.sqrt(dist2), idx
135 |
136 | @staticmethod
137 | def backward(ctx, a=None, b=None):
138 | return None, None
139 |
140 |
141 | three_nn = ThreeNN.apply
142 |
143 |
144 | class ThreeInterpolate(Function):
145 | @staticmethod
146 | def forward(ctx, features, idx, weight):
147 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
148 | r"""
149 | Performs weight linear interpolation on 3 features
150 | Parameters
151 | ----------
152 | features : torch.Tensor
153 | (B, c, m) Features descriptors to be interpolated from
154 | idx : torch.Tensor
155 | (B, n, 3) three nearest neighbors of the target features in features
156 | weight : torch.Tensor
157 | (B, n, 3) weights
158 |
159 | Returns
160 | -------
161 | torch.Tensor
162 | (B, c, n) tensor of the interpolated features
163 | """
164 | B, c, m = features.size()
165 | n = idx.size(1)
166 |
167 | ctx.three_interpolate_for_backward = (idx, weight, m)
168 |
169 | return _ext.three_interpolate(features, idx, weight)
170 |
171 | @staticmethod
172 | def backward(ctx, grad_out):
173 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
174 | r"""
175 | Parameters
176 | ----------
177 | grad_out : torch.Tensor
178 | (B, c, n) tensor with gradients of ouputs
179 |
180 | Returns
181 | -------
182 | grad_features : torch.Tensor
183 | (B, c, m) tensor with gradients of features
184 |
185 | None
186 |
187 | None
188 | """
189 | idx, weight, m = ctx.three_interpolate_for_backward
190 |
191 | grad_features = _ext.three_interpolate_grad(
192 | grad_out.contiguous(), idx, weight, m
193 | )
194 |
195 | return grad_features, None, None
196 |
197 |
198 | three_interpolate = ThreeInterpolate.apply
199 |
200 |
201 | class GroupingOperation(Function):
202 | @staticmethod
203 | def forward(ctx, features, idx):
204 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
205 | r"""
206 |
207 | Parameters
208 | ----------
209 | features : torch.Tensor
210 | (B, C, N) tensor of features to group
211 | idx : torch.Tensor
212 | (B, npoint, nsample) tensor containing the indicies of features to group with
213 |
214 | Returns
215 | -------
216 | torch.Tensor
217 | (B, C, npoint, nsample) tensor
218 | """
219 | B, nfeatures, nsample = idx.size()
220 | _, C, N = features.size()
221 |
222 | ctx.for_backwards = (idx, N)
223 |
224 | return _ext.group_points(features, idx)
225 |
226 | @staticmethod
227 | def backward(ctx, grad_out):
228 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
229 | r"""
230 |
231 | Parameters
232 | ----------
233 | grad_out : torch.Tensor
234 | (B, C, npoint, nsample) tensor of the gradients of the output from forward
235 |
236 | Returns
237 | -------
238 | torch.Tensor
239 | (B, C, N) gradient of the features
240 | None
241 | """
242 | idx, N = ctx.for_backwards
243 |
244 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
245 |
246 | return grad_features, None
247 |
248 |
249 | grouping_operation = GroupingOperation.apply
250 |
251 |
252 | class BallQuery(Function):
253 | @staticmethod
254 | def forward(ctx, radius, nsample, xyz, new_xyz):
255 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
256 | r"""
257 |
258 | Parameters
259 | ----------
260 | radius : float
261 | radius of the balls
262 | nsample : int
263 | maximum number of features in the balls
264 | xyz : torch.Tensor
265 | (B, N, 3) xyz coordinates of the features
266 | new_xyz : torch.Tensor
267 | (B, npoint, 3) centers of the ball query
268 |
269 | Returns
270 | -------
271 | torch.Tensor
272 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
273 | """
274 | return _ext.ball_query(new_xyz, xyz, radius, nsample)
275 |
276 | @staticmethod
277 | def backward(ctx, a=None):
278 | return None, None, None, None
279 |
280 |
281 | ball_query = BallQuery.apply
282 |
283 |
284 | class QueryAndGroup(nn.Module):
285 | r"""
286 | Groups with a ball query of radius
287 |
288 | Parameters
289 | ---------
290 | radius : float32
291 | Radius of ball
292 | nsample : int32
293 | Maximum number of features to gather in the ball
294 | """
295 |
296 | def __init__(self, radius, nsample, use_xyz=True):
297 | # type: (QueryAndGroup, float, int, bool) -> None
298 | super(QueryAndGroup, self).__init__()
299 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
300 |
301 | def forward(self, xyz, new_xyz, features=None):
302 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
303 | r"""
304 | Parameters
305 | ----------
306 | xyz : torch.Tensor
307 | xyz coordinates of the features (B, N, 3)
308 | new_xyz : torch.Tensor
309 | centriods (B, npoint, 3)
310 | features : torch.Tensor
311 | Descriptors of the features (B, C, N)
312 |
313 | Returns
314 | -------
315 | new_features : torch.Tensor
316 | (B, 3 + C, npoint, nsample) tensor
317 | """
318 |
319 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
320 | xyz_trans = xyz.transpose(1, 2).contiguous()
321 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
322 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
323 |
324 | if features is not None:
325 | grouped_features = grouping_operation(features, idx)
326 | if self.use_xyz:
327 | new_features = torch.cat(
328 | [grouped_xyz, grouped_features], dim=1
329 | ) # (B, C + 3, npoint, nsample)
330 | else:
331 | new_features = grouped_features
332 | else:
333 | assert (
334 | self.use_xyz
335 | ), "Cannot have not features and not use xyz as a feature!"
336 | new_features = grouped_xyz
337 |
338 | return new_features
339 |
340 |
341 | class GroupAll(nn.Module):
342 | r"""
343 | Groups all features
344 |
345 | Parameters
346 | ---------
347 | """
348 |
349 | def __init__(self, use_xyz=True):
350 | # type: (GroupAll, bool) -> None
351 | super(GroupAll, self).__init__()
352 | self.use_xyz = use_xyz
353 |
354 | def forward(self, xyz, new_xyz, features=None):
355 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
356 | r"""
357 | Parameters
358 | ----------
359 | xyz : torch.Tensor
360 | xyz coordinates of the features (B, N, 3)
361 | new_xyz : torch.Tensor
362 | Ignored
363 | features : torch.Tensor
364 | Descriptors of the features (B, C, N)
365 |
366 | Returns
367 | -------
368 | new_features : torch.Tensor
369 | (B, C + 3, 1, N) tensor
370 | """
371 |
372 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
373 | if features is not None:
374 | grouped_features = features.unsqueeze(2)
375 | if self.use_xyz:
376 | new_features = torch.cat(
377 | [grouped_xyz, grouped_features], dim=1
378 | ) # (B, 3 + C, 1, N)
379 | else:
380 | new_features = grouped_features
381 | else:
382 | new_features = grouped_xyz
383 |
384 | return new_features
385 |
--------------------------------------------------------------------------------
/train_partseg_gpus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.optim.lr_scheduler as lr_sched
4 | from torch.optim.lr_scheduler import CosineAnnealingLR
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 | from torch.autograd import Variable
9 | import numpy as np
10 | import os
11 | from torchvision import transforms
12 | from models import DRNET_Seg as DRNET
13 | from data import ShapeNetPart
14 | import data.data_utils as d_utils
15 | import argparse
16 | import random
17 | import yaml
18 |
19 | torch.backends.cudnn.enabled = True
20 | torch.backends.cudnn.benchmark = True
21 | torch.backends.cudnn.deterministic = True
22 |
23 | seed = 123
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.cuda.manual_seed_all(seed)
29 |
30 | parser = argparse.ArgumentParser(description='DRNET Shape Part Segmentation Training')
31 | parser.add_argument('--config', default='cfgs/config_partseg_gpus.yaml', type=str)
32 |
33 |
34 | def set_bn_momentum_default(bn_momentum):
35 |
36 | def fn(m):
37 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
38 | m.momentum = bn_momentum
39 |
40 | return fn
41 |
42 |
43 | class BNMomentumScheduler(object):
44 |
45 | def __init__(
46 | self, model, bn_lambda, last_epoch=-1,
47 | setter=set_bn_momentum_default
48 | ):
49 | if not isinstance(model, nn.Module):
50 | raise RuntimeError(
51 | "Class '{}' is not a PyTorch nn Module".format(
52 | type(model).__name__
53 | )
54 | )
55 |
56 | self.model = model
57 | self.setter = setter
58 | self.lmbd = bn_lambda
59 |
60 | self.step(last_epoch + 1)
61 | self.last_epoch = last_epoch
62 |
63 | def step(self, epoch=None):
64 | if epoch is None:
65 | epoch = self.last_epoch + 1
66 |
67 | self.last_epoch = epoch
68 | self.model.apply(self.setter(self.lmbd(epoch)))
69 |
70 | def get_momentum(self, epoch=None):
71 | if epoch is None:
72 | epoch = self.last_epoch + 1
73 | return self.lmbd(epoch)
74 |
75 | def main():
76 | args = parser.parse_args()
77 | with open(args.config) as f:
78 | config = yaml.load(f)
79 | print("\n**************************")
80 | for k, v in config['common'].items():
81 | setattr(args, k, v)
82 | print('\n[%s]:'%(k), v)
83 | print("\n**************************\n")
84 |
85 | try:
86 | os.makedirs(args.save_path)
87 | except OSError:
88 | pass
89 |
90 | train_transforms = transforms.Compose([
91 | d_utils.PointcloudToTensor()
92 | ])
93 | test_transforms = transforms.Compose([
94 | d_utils.PointcloudToTensor()
95 | ])
96 |
97 | train_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'trainval', normalize = True, transforms = train_transforms)
98 | train_dataloader = DataLoader(
99 | train_dataset,
100 | batch_size=args.batch_size,
101 | shuffle=True,
102 | num_workers=int(args.workers),
103 | pin_memory=True
104 | )
105 |
106 | global test_dataset
107 | test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms)
108 | test_dataloader = DataLoader(
109 | test_dataset,
110 | batch_size=args.batch_size,
111 | shuffle=False,
112 | num_workers=int(args.workers),
113 | pin_memory=True
114 | )
115 |
116 | device = torch.device("cuda")
117 | model = DRNET(num_classes = args.num_classes).to(device)
118 | model = nn.DataParallel(model)
119 | print("Let's use", torch.cuda.device_count(), "GPUs!")
120 |
121 | '''
122 | optimizer = optim.SGD(
123 | model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
124 | scheduler = CosineAnnealingLR(optimizer, args.epochs, eta_min=0.001)
125 | '''
126 | optimizer = optim.Adam(
127 | model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
128 |
129 | lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr)
130 | bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip)
131 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
132 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
133 |
134 |
135 | if args.checkpoint is not '':
136 | model.load_state_dict(torch.load(args.checkpoint))
137 | print('Load model successfully: %s' % (args.checkpoint))
138 |
139 | criterion = nn.CrossEntropyLoss()
140 | num_batch = len(train_dataset)/args.batch_size
141 |
142 | # training
143 | train(train_dataloader, test_dataloader, model, device, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
144 |
145 |
146 | def train(train_dataloader, test_dataloader, model, device, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch):
147 | PointcloudAug = d_utils.PointcloudScaleAndTranslate() # initialize augmentation
148 | global Class_mIoU, Inst_mIoU
149 | Class_mIoU, Inst_mIoU = 0.82, 0.85
150 | batch_count = 0
151 | model.train()
152 | for epoch in range(args.epochs):
153 | # scheduler.step()
154 | for i, data in enumerate(train_dataloader, 0):
155 |
156 | if lr_scheduler is not None:
157 | lr_scheduler.step(epoch)
158 | if bnm_scheduler is not None:
159 | bnm_scheduler.step(epoch-1)
160 |
161 | points, target, cls = data
162 | points, target = Variable(points), Variable(target)
163 | points, target = points.to(device), target.to(device)
164 | # augmentation
165 | points.data = PointcloudAug(points.data)
166 |
167 | optimizer.zero_grad()
168 |
169 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes
170 | for b in range(len(cls)):
171 | batch_one_hot_cls[b, int(cls[b])] = 1
172 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
173 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
174 |
175 | pred, error1, error2, error3, error4 = model(points, batch_one_hot_cls)
176 | loss1 = torch.norm(error1, dim=-1) # B, N
177 | loss1 = torch.mean(torch.mean(loss1, dim=-1), dim=-1)
178 | loss2 = torch.norm(error2, dim=-1) # B, N
179 | loss2 = torch.mean(torch.mean(loss2, dim=-1), dim=-1)
180 | loss3 = torch.norm(error3, dim=-1) # B, N
181 | loss3 = torch.mean(torch.mean(loss3, dim=-1), dim=-1)
182 | loss4 = torch.norm(error4, dim=-1) # B, N
183 | loss4 = torch.mean(torch.mean(loss4, dim=-1), dim=-1)
184 | pred = pred.view(-1, args.num_classes)
185 | target = target.view(-1,1)[:,0]
186 | loss = criterion(pred, target) + 0.1*loss1
187 | # + 0.01*loss2 + 0.01*loss3 + 0.01*loss4
188 | loss.backward()
189 | optimizer.step()
190 |
191 | if i % args.print_freq_iter == 0:
192 | print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %(epoch+1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0]))
193 | batch_count += 1
194 |
195 | # validation in between an epoch
196 | if (epoch >60) and args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0:
197 | print('testing..')
198 | validate(test_dataloader, model, device, criterion, args, batch_count)
199 |
200 | def validate(test_dataloader, model, device, criterion, args, iter):
201 | global Class_mIoU, Inst_mIoU, test_dataset
202 | model.eval()
203 |
204 | seg_classes = test_dataset.seg_classes
205 | shape_ious = {cat:[] for cat in seg_classes.keys()}
206 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
207 | for cat in seg_classes.keys():
208 | for label in seg_classes[cat]:
209 | seg_label_to_cat[label] = cat
210 |
211 | losses = []
212 | temp_device = torch.device("cpu")
213 | for _, data in enumerate(test_dataloader, 0):
214 | points, target, cls = data
215 | with torch.no_grad():
216 | points = Variable(points)
217 | with torch.no_grad():
218 | target = Variable(target)
219 |
220 | points = points.to(device)
221 | target = target.to(device)
222 |
223 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes
224 | for b in range(len(cls)):
225 | batch_one_hot_cls[b, int(cls[b])] = 1
226 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
227 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
228 |
229 | with torch.no_grad():
230 | pred, error1, error2, error3, error4 = model(points, batch_one_hot_cls)
231 |
232 | loss1 = torch.norm(error1, dim=-1) # B, N
233 | loss1 = torch.mean(torch.mean(loss1, dim=-1), dim=-1)
234 | loss2 = torch.norm(error2, dim=-1) # B, N
235 | loss2 = torch.mean(torch.mean(loss2, dim=-1), dim=-1)
236 | loss3 = torch.norm(error3, dim=-1) # B, N
237 | loss3 = torch.mean(torch.mean(loss3, dim=-1), dim=-1)
238 | loss4 = torch.norm(error4, dim=-1) # B, N
239 | loss4 = torch.mean(torch.mean(loss4, dim=-1), dim=-1)
240 | loss = criterion(pred.view(-1, args.num_classes), target.view(-1,1)[:,0]) + 0.1*loss1
241 | # + 0.01*loss2 + 0.01*loss3 + 0.01*loss4
242 | losses.append(loss.data.clone())
243 |
244 | pred = pred.data.cpu()
245 | target = target.data.cpu()
246 |
247 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor)
248 | # pred to the groundtruth classes (selected by seg_classes[cat])
249 | for b in range(len(cls)):
250 | cat = seg_label_to_cat[target[b, 0].item()]
251 | logits = pred[b, :, :] # (num_points, num_classes)
252 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0]
253 |
254 | for b in range(len(cls)):
255 | segp = pred_val[b, :].to(temp_device)
256 | segl = target[b, :].to(temp_device)
257 | cat = seg_label_to_cat[segl[0].item()]
258 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
259 | for l in seg_classes[cat]:
260 | if torch.sum((segl == l) | (segp == l)) == 0:
261 | # part is not present in this shape
262 | part_ious[l - seg_classes[cat][0]] = 1.0
263 | else:
264 | part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l)))
265 | shape_ious[cat].append(np.mean(part_ious)) # torch.mean(torch.stack(part_ious))
266 |
267 | instance_ious = []
268 | for cat in shape_ious.keys():
269 | for iou in shape_ious[cat]:
270 | instance_ious.append(iou)
271 | shape_ious[cat] = np.mean(shape_ious[cat])
272 | mean_class_ious = np.mean(list(shape_ious.values()))
273 |
274 | for cat in sorted(shape_ious.keys()):
275 | print('****** %s: %0.6f'%(cat, shape_ious[cat]))
276 | print('************ Test Loss: %0.6f' % (torch.mean(torch.stack(losses)).cpu().numpy())) #torch.mean(torch.stack(losses)).numpy() np.array(losses).mean())
277 | print('************ Class_mIoU: %0.6f' % (mean_class_ious))
278 | print('************ Instance_mIoU: %0.6f' % (np.mean(instance_ious)))
279 |
280 | if mean_class_ious > Class_mIoU or np.mean(instance_ious) > Inst_mIoU:
281 | if mean_class_ious > Class_mIoU:
282 | Class_mIoU = mean_class_ious
283 | if np.mean(instance_ious) > Inst_mIoU:
284 | Inst_mIoU = np.mean(instance_ious)
285 | torch.save(model.state_dict(), '%s/seg_drnet_iter_%d_ins_%0.6f_cls_%0.6f.pth' % (args.save_path, iter, np.mean(instance_ious), mean_class_ious))
286 | model.train()
287 |
288 | if __name__ == "__main__":
289 | main()
290 |
--------------------------------------------------------------------------------
/train_partseg_gpus.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 | mkdir -p log
3 | now=$(date +"%Y%m%d_%H%M%S")
4 | log_name="PartSeg_LOG_"$now""
5 | CUDA_VISIBLE_DEVICES=2,3 python -u train_partseg_gpus.py \
6 | --config cfgs/config_partseg_gpus.yaml \
7 | 2>&1|tee log/$log_name.log &
8 |
--------------------------------------------------------------------------------
/voting_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.optim.lr_scheduler as lr_sched
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | import numpy as np
9 | import os
10 | from torchvision import transforms
11 | from models import DRNET_Seg as DRNET
12 | from data import ShapeNetPart
13 | import data.data_utils as d_utils
14 | import argparse
15 | import random
16 | import yaml
17 |
18 | torch.backends.cudnn.enabled = True
19 | torch.backends.cudnn.benchmark = True
20 | torch.backends.cudnn.deterministic = True
21 |
22 | seed = 123
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed(seed)
27 | torch.cuda.manual_seed_all(seed)
28 |
29 | parser = argparse.ArgumentParser(description='DRNET Shape Part Segmentation Test')
30 | parser.add_argument('--config', default='cfgs/config_partseg_gpus.yaml', type=str)
31 |
32 | NUM_REPEAT = 300
33 | NUM_VOTE = 10
34 |
35 | def main():
36 | args = parser.parse_args()
37 | with open(args.config) as f:
38 | config = yaml.load(f)
39 | for k, v in config['common'].items():
40 | setattr(args, k, v)
41 |
42 | test_transforms = transforms.Compose([
43 | d_utils.PointcloudToTensor()
44 | ])
45 |
46 | test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms)
47 | test_dataloader = DataLoader(
48 | test_dataset,
49 | batch_size=args.batch_size,
50 | shuffle=False,
51 | num_workers=int(args.workers),
52 | pin_memory=True
53 | )
54 |
55 | device = torch.device("cuda")
56 | model = DRNET(num_classes = args.num_classes).to(device)
57 | model = nn.DataParallel(model)
58 | print("Let's use", torch.cuda.device_count(), "GPUs!")
59 |
60 | if args.checkpoint is not '':
61 | model.load_state_dict(torch.load(args.checkpoint))
62 | print('Load model successfully: %s' % (args.checkpoint))
63 |
64 | # evaluate
65 | PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15) # initialize random scaling
66 | model.eval()
67 | global_Class_mIoU, global_Inst_mIoU = 0, 0
68 | seg_classes = test_dataset.seg_classes
69 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
70 | for cat in seg_classes.keys():
71 | for label in seg_classes[cat]:
72 | seg_label_to_cat[label] = cat
73 |
74 | for i in range(NUM_REPEAT):
75 | shape_ious = {cat:[] for cat in seg_classes.keys()}
76 | for _, data in enumerate(test_dataloader, 0):
77 | points, target, cls = data
78 | with torch.no_grad():
79 | points = Variable(points)
80 | with torch.no_grad():
81 | target = Variable(target)
82 |
83 | points = points.to(device)
84 | target = target.to(device)
85 |
86 | batch_one_hot_cls = np.zeros((len(cls), 16)) # 16 object classes
87 | for b in range(len(cls)):
88 | batch_one_hot_cls[b, int(cls[b])] = 1
89 | batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
90 | batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
91 |
92 | pred = 0
93 | with torch.no_grad():
94 | new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda())
95 | for v in range(NUM_VOTE):
96 | if v > 0:
97 | new_points.data = PointcloudScale(points.data)
98 | # temp_pred, error1, error2, error3, error4, _ = model(new_points, batch_one_hot_cls)
99 | temp_pred = model(new_points, batch_one_hot_cls)
100 | pred += F.softmax(temp_pred, dim = 2)
101 | pred /= NUM_VOTE
102 |
103 | pred = pred.data.cpu()
104 | target = target.data.cpu()
105 | pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor)
106 | # pred to the groundtruth classes (selected by seg_classes[cat])
107 | for b in range(len(cls)):
108 | cat = seg_label_to_cat[target[b, 0].item()]
109 | logits = pred[b, :, :] # (num_points, num_classes)
110 | pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0]
111 |
112 | for b in range(len(cls)):
113 | segp = pred_val[b, :]
114 | segl = target[b, :]
115 | cat = seg_label_to_cat[segl[0].item()]
116 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
117 | for l in seg_classes[cat]:
118 | if torch.sum((segl == l) | (segp == l)) == 0:
119 | # part is not present in this shape
120 | part_ious[l - seg_classes[cat][0]] = 1.0
121 | else:
122 | part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l)))
123 | shape_ious[cat].append(np.mean(part_ious))
124 |
125 | instance_ious = []
126 | for cat in shape_ious.keys():
127 | for iou in shape_ious[cat]:
128 | instance_ious.append(iou)
129 | shape_ious[cat] = np.mean(shape_ious[cat])
130 | mean_class_ious = np.mean(list(shape_ious.values()))
131 |
132 | print('\n------ Repeat %3d ------' % (i + 1))
133 | for cat in sorted(shape_ious.keys()):
134 | print('%s: %0.6f'%(cat, shape_ious[cat]))
135 | print('Class_mIoU: %0.6f' % (mean_class_ious))
136 | print('Instance_mIoU: %0.6f' % (np.mean(instance_ious)))
137 |
138 | if np.mean(instance_ious) > global_Inst_mIoU:
139 | global_Class_mIoU = mean_class_ious
140 | global_Inst_mIoU = np.mean(instance_ious)
141 |
142 | print('\nBest voting Class_mIoU = %0.6f, Instance_mIoU = %0.6f' % (global_Class_mIoU, global_Inst_mIoU))
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/voting_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 | mkdir -p log
3 | now=$(date +"%Y%m%d_%H%M%S")
4 | log_name="PartSeg_VOTE_"$now""
5 | CUDA_VISIBLE_DEVICES=2,3 python -u voting_test.py \
6 | --config cfgs/config_partseg_test.yaml \
7 | 2>&1|tee log/$log_name.log &
8 |
--------------------------------------------------------------------------------