├── .gitignore ├── README.md ├── data ├── CustomDataset.py ├── ModelNet40.py ├── ShapeNet.py ├── __init__.py └── provider.py ├── evaluate.py ├── evaluate_custom.py ├── models ├── __init__.py ├── pointnet2_cls.py └── pointnet2_seg.py ├── train_clss.py ├── train_custom_cls.py ├── train_part_seg.py └── utils ├── IoU.py ├── __init__.py ├── common.py ├── feature_propagation.py ├── grouping.py ├── sampling.py └── set_abstraction.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .idea/ 5 | test*.py 6 | checkpoints/ 7 | data/ILSVRC2012_img_val 8 | model/.DS_Store 9 | .DS_Store 10 | work_dirs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | An unofficial PyTorch Implementation of [PointNet++: Deep Hierarchical Feature Learning on 4 | Point Sets in a Metric Space]()[NIPS 2017]. 5 | 6 | ### Requirements 7 | - PyTorch, Python3, TensorboardX, tqdm, fire 8 | 9 | ## Classification 10 | - **Start** 11 | - Dataset: [ModelNet40](https://modelnet.cs.princeton.edu/), download it from [Official Site](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) or [Baidu Disk](https://pan.baidu.com/s/1E0DqMLebg89IzrXlB-YVDA) with **hi1i**. 12 | - Train 13 | ``` 14 | python train_clss.py --data_root your_data_root --log_dir your_log_dir 15 | 16 | eg. 17 | python train_clss.py --data_root /root/modelnet40_normal_resampled --log_dir cls_ssg_1024 18 | ``` 19 | - Evaluate 20 | 21 | ``` 22 | python evaluate.py evaluate_cls model data_root checkpoint npoints 23 | 24 | eg. 25 | python evaluate.py evaluate_cls pointnet2_cls_ssg /root/modelnet40_normal_resampled \ 26 | checkpoints/pointnet2_cls_250.pth 1024 27 | 28 | python evaluate.py evaluate_cls pointnet2_cls_msg root/modelnet40_normal_resampled \ 29 | checkpoints/pointnet2_cls_250.pth 1024 30 | ``` 31 | - **Performance**(the **first row** is the results reported in Paper, the **following rows** are results reported from this repo.) 32 | 33 | | Model | NPoints | Aug | Accuracy(%) | 34 | | :---: | :---: | :---: | :---: | 35 | | PointNet2(**official**) | 5000 | ✓ | 91.7 | 36 | | PointNet2_SSG | 1024 | ✗ | **91.8** | 37 | | PointNet2_SSG | 4096 | ✗ | 91.7 | 38 | | PointNet2_SSG | 4096 | ✓ | 90.5 | 39 | | PointNet2_MSG | 4096 | ✓ | 91.0 | 40 | 41 | | Model | Train_NPoints | DP | Test_NPoints | Accuracy(%) | 42 | | :---: | :---: | :---: | :---: | :---: | 43 | | PointNet2_SSG | 1024 | ✗ | 256 | 67.9 | 44 | | PointNet2_SSG | 1024 | ✓ | **256** | **90.8** | 45 | | PointNet2_SSG | 1024 | ✗ | 1024 | 91.8 | 46 | | PointNet2_SSG | 1024 | ✓ | 1204 | **91.9** | 47 | 48 | - **Train Your own Dataset** 49 | - Prepare the dataset(n classes) in the `ModelNet40` structure 50 | ``` 51 | CustomData(dir) 52 | |- class1(dir) 53 | | - class1_name11.txt 54 | | - class1_name12.txt 55 | ... 56 | |- class2(dir) 57 | | - class2_name21.txt 58 | | - class2_name22.txt 59 | ... 60 | |- classn(dir) 61 | |- shape_names.txt 62 | | - class1(line1) 63 | | - class2(line2) 64 | | - ... 65 | | - classn(linen) 66 | |- train.txt 67 | | - class1_name11 68 | | - class2_name21 69 | | - class2_name22 70 | | - ... 71 | | - classn_namen1 72 | |- test.txt 73 | | - class1_name12 74 | | - class2_name22 75 | | - ... 76 | | - classn_namen2 77 | ``` 78 | - **Start to train** 79 | ``` 80 | python train_custom_cls.py --data_root your_datapath/CustomData --nclasses 2 --npoints 2048 81 | ``` 82 | - **Start to evaluate** 83 | ``` 84 | python evaluate_custom.py evaluate_cls pointnet2_cls_ssg your_datapath/CustomData work_dirs/checkpoints/pointnet2_cls_250.pth 2 85 | ``` 86 | 87 | ## Part Segmentation 88 | - **Start** 89 | - Dataset: [ShapeNet part](https://shapenet.cs.stanford.edu/iccv17/#dataset), download it from [Official Site](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) or [Baidu Disk](https://pan.baidu.com/s/18YoYMam3vVVqE5i6BXU5kw) with **3e5z**. 90 | - Train 91 | ``` 92 | python train_part_seg.py --data_root your_data_root --log_dir your_log_dir 93 | 94 | eg. 95 | python train_part_seg.py --data_root /root/shapenetcore_partanno_segmentation_benchmark_v0_normal \ 96 | --log_dir seg_ssg --batch_size 64 97 | ``` 98 | - Evaluate 99 | 100 | ``` 101 | python evaluate.py evaluate_seg data_root checkpoint 102 | 103 | eg. 104 | python evaluate.py evaluate_seg /root/shapenetcore_partanno_segmentation_benchmark_v0_normal \ 105 | seg_ssg/checkpoints/pointnet2_cls_250.pth 106 | ``` 107 | - **Metrics**: [Average IoU](https://shapenet.cs.stanford.edu/iccv17/#evaluation) 108 | 109 | | Model | Metrics | mean | aero | bag | cap | car | chair | ear phone | guitar | knife | lamp | laptop | motor | mug | pistol | rocket | skate board | table | 110 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 111 | | PointNet2(**official**) | IoU | 85.1 | 82.4 | 79.0 | 87.7 | 77.3 | 90.8 | 71.8 | 91.0 | 85.9 | 83.7 | 95.3 | 71.6 | 94.1 | 81.3 | 58.7 | 76.4 | 82.6 | 112 | | PointNet2_SSG | IoU | 84.1 | 82.3 | 75.0 | 80.1 | 77.8 | 90.2 | 73.7 | 90.7 | 84.1 | 82.9 | 95.0 | 69.3 | 93.3 | 80.3 | 55.6 | 76.3 | 80.7 | 113 | | PointNet2_SSG | Accuracy | 93.2 | 89.9 | 89.0 | 85.5 | 91.8 | 94.4 | 93.5 | 96.1 | 91.1 | 89.2 | 96.9 | 87.4 | 96.4 | 93.7 | 77.2 | 95.9 | 94.8 | 114 | 115 | 116 | ## Reference 117 | 118 | - [https://github.com/charlesq34/pointnet2](https://github.com/charlesq34/pointnet2) 119 | - [https://github.com/yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch) 120 | - [https://github.com/sshaoshuai/Pointnet2.PyTorch](https://github.com/sshaoshuai/Pointnet2.PyTorch) -------------------------------------------------------------------------------- /data/CustomDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | from data.provider import pc_normalize, rotate_point_cloud_with_normal, rotate_perturbation_point_cloud_with_normal, \ 6 | random_scale_point_cloud, shift_point_cloud, jitter_point_cloud, shuffle_points, random_point_dropout 7 | 8 | 9 | class CustomDataset(Dataset): 10 | def __init__(self, data_root, split, npoints, augment=False, dp=False, normalize=True): 11 | assert(split == 'train' or split == 'test') 12 | self.npoints = npoints 13 | self.augment = augment 14 | self.dp = dp 15 | self.normalize = normalize 16 | 17 | cls2name, name2cls = self.decode_classes(os.path.join(data_root, 'shape_names.txt')) 18 | train_list_path = os.path.join(data_root, 'train.txt') 19 | train_files_list = self.read_list_file(train_list_path, name2cls) 20 | test_list_path = os.path.join(data_root, 'test.txt') 21 | test_files_list = self.read_list_file(test_list_path, name2cls) 22 | self.files_list = train_files_list if split == 'train' else test_files_list 23 | self.caches = {} 24 | 25 | def read_list_file(self, file_path, name2cls): 26 | base = os.path.dirname(file_path) 27 | files_list = [] 28 | with open(file_path, 'r') as f: 29 | for line in f.readlines(): 30 | name = line.strip().split('_')[0] 31 | cur = os.path.join(base, name, '{}.txt'.format(line.strip())) 32 | files_list.append([cur, name2cls[name]]) 33 | return files_list 34 | 35 | def decode_classes(self, file_path): 36 | cls2name, name2cls = {}, {} 37 | with open(file_path, 'r') as f: 38 | for i, name in enumerate(f.readlines()): 39 | cls2name[i] = name.strip() 40 | name2cls[name.strip()] = i 41 | return cls2name, name2cls 42 | 43 | def augment_pc(self, pc_normal): 44 | rotated_pc_normal = rotate_point_cloud_with_normal(pc_normal) 45 | rotated_pc_normal = rotate_perturbation_point_cloud_with_normal(rotated_pc_normal) 46 | jittered_pc = random_scale_point_cloud(rotated_pc_normal[:, :3]) 47 | jittered_pc = shift_point_cloud(jittered_pc) 48 | jittered_pc = jitter_point_cloud(jittered_pc) 49 | rotated_pc_normal[:, :3] = jittered_pc 50 | return rotated_pc_normal 51 | 52 | def __getitem__(self, index): 53 | if index in self.caches: 54 | return self.caches[index] 55 | file, label = self.files_list[index] 56 | xyz_points = np.loadtxt(file, delimiter=',') 57 | if self.npoints > 0: 58 | inds = np.random.randint(0, len(xyz_points), size=(self.npoints, )) 59 | else: 60 | inds = np.arange(len(xyz_points)) 61 | np.random.shuffle(inds) 62 | xyz_points = xyz_points[inds, :] 63 | if self.normalize: 64 | xyz_points[:, :3] = pc_normalize(xyz_points[:, :3]) 65 | if self.augment: 66 | xyz_points = self.augment_pc(xyz_points) 67 | if self.dp: 68 | xyz_points = random_point_dropout(xyz_points) 69 | self.caches[index] = [xyz_points, label] 70 | return xyz_points, label 71 | 72 | def __len__(self): 73 | return len(self.files_list) -------------------------------------------------------------------------------- /data/ModelNet40.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | from data.provider import pc_normalize, rotate_point_cloud_with_normal, rotate_perturbation_point_cloud_with_normal, \ 6 | random_scale_point_cloud, shift_point_cloud, jitter_point_cloud, shuffle_points, random_point_dropout 7 | 8 | 9 | class ModelNet40(Dataset): 10 | 11 | def __init__(self, data_root, split, npoints, augment=False, dp=False, normalize=True): 12 | assert(split == 'train' or split == 'test') 13 | self.npoints = npoints 14 | self.augment = augment 15 | self.dp = dp 16 | self.normalize = normalize 17 | 18 | cls2name, name2cls = self.decode_classes(os.path.join(data_root, 'modelnet40_shape_names.txt')) 19 | train_list_path = os.path.join(data_root, 'modelnet40_train.txt') 20 | train_files_list = self.read_list_file(train_list_path, name2cls) 21 | test_list_path = os.path.join(data_root, 'modelnet40_test.txt') 22 | test_files_list = self.read_list_file(test_list_path, name2cls) 23 | self.files_list = train_files_list if split == 'train' else test_files_list 24 | self.caches = {} 25 | 26 | def read_list_file(self, file_path, name2cls): 27 | base = os.path.dirname(file_path) 28 | files_list = [] 29 | with open(file_path, 'r') as f: 30 | for line in f.readlines(): 31 | name = '_'.join(line.strip().split('_')[:-1]) 32 | cur = os.path.join(base, name, '{}.txt'.format(line.strip())) 33 | files_list.append([cur, name2cls[name]]) 34 | return files_list 35 | 36 | def decode_classes(self, file_path): 37 | cls2name, name2cls = {}, {} 38 | with open(file_path, 'r') as f: 39 | for i, name in enumerate(f.readlines()): 40 | cls2name[i] = name.strip() 41 | name2cls[name.strip()] = i 42 | return cls2name, name2cls 43 | 44 | def augment_pc(self, pc_normal): 45 | rotated_pc_normal = rotate_point_cloud_with_normal(pc_normal) 46 | rotated_pc_normal = rotate_perturbation_point_cloud_with_normal(rotated_pc_normal) 47 | jittered_pc = random_scale_point_cloud(rotated_pc_normal[:, :3]) 48 | jittered_pc = shift_point_cloud(jittered_pc) 49 | jittered_pc = jitter_point_cloud(jittered_pc) 50 | rotated_pc_normal[:, :3] = jittered_pc 51 | return rotated_pc_normal 52 | 53 | def __getitem__(self, index): 54 | if index in self.caches: 55 | return self.caches[index] 56 | file, label = self.files_list[index] 57 | xyz_points = np.loadtxt(file, delimiter=',') 58 | #if self.npoints > 0: 59 | # inds = np.random.randint(0, len(xyz_points), size=(self.npoints, )) 60 | # xyz_points = xyz_points[inds, :] 61 | xyz_points = xyz_points[:self.npoints, :] 62 | if self.normalize: 63 | xyz_points[:, :3] = pc_normalize(xyz_points[:, :3]) 64 | if self.augment: 65 | xyz_points = self.augment_pc(xyz_points) 66 | if self.dp: 67 | xyz_points = random_point_dropout(xyz_points) 68 | self.caches[index] = [xyz_points, label] 69 | return xyz_points, label 70 | 71 | def __len__(self): 72 | return len(self.files_list) 73 | 74 | 75 | if __name__ == '__main__': 76 | modelnet40 = ModelNet40(data_root='/root/Pointnet_Pointnet2_pytorch/data/modelnet40_normal_resampled', split='test') 77 | test_loader = DataLoader(dataset=modelnet40, 78 | batch_size=16, 79 | shuffle=True) 80 | for point, label in test_loader: 81 | print(point.shape) 82 | print(label.shape) -------------------------------------------------------------------------------- /data/ShapeNet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | import sys 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.provider import pc_normalize, rotate_point_cloud_with_normal, rotate_perturbation_point_cloud_with_normal, \ 9 | random_scale_point_cloud, shift_point_cloud, jitter_point_cloud, shuffle_points, random_point_dropout 10 | 11 | 12 | class ShapeNet(Dataset): 13 | def __init__(self, data_root, split, npoints, augment=False, dp=False, normalize=True): 14 | assert(split == 'train' or split == 'test' or split == 'val' or split == 'trainval') 15 | self.npoints = npoints 16 | self.augment = augment 17 | self.dp = dp 18 | self.normalize = normalize 19 | self.cat = {} 20 | with open(os.path.join(data_root, 'synsetoffset2category.txt'), 'r') as f: 21 | for line in f.readlines(): 22 | self.cat[line.strip().split()[0]] = line.strip().split()[1] 23 | train_json_path = os.path.join(data_root, 'train_test_split', 'shuffled_train_file_list.json') 24 | val_json_path = os.path.join(data_root, 'train_test_split', 'shuffled_val_file_list.json') 25 | test_json_path = os.path.join(data_root, 'train_test_split', 'shuffled_test_file_list.json') 26 | train_lists = self.decode_json(data_root, train_json_path) 27 | val_lists = self.decode_json(data_root, val_json_path) 28 | test_lists = self.decode_json(data_root, test_json_path) 29 | 30 | self.file_lists = [] 31 | if split == 'train': 32 | self.file_lists.extend(train_lists) 33 | elif split == 'val': 34 | self.file_lists.extend(val_lists) 35 | elif split == 'trainval': 36 | self.file_lists.extend(train_lists) 37 | self.file_lists.extend(val_lists) 38 | elif split == 'test': 39 | self.file_lists.extend(test_lists) 40 | 41 | self.seg_classes = {'Earphone': [16, 17, 18], 42 | 'Motorbike': [30, 31, 32, 33, 34, 35], 43 | 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 44 | 'Laptop': [28, 29], 'Cap': [6, 7], 45 | 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 46 | 'Guitar': [19, 20, 21], 'Bag': [4, 5], 47 | 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 48 | 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 49 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 50 | 51 | self.caches = {} 52 | 53 | def decode_json(self, data_root, path): 54 | with open(path, 'r') as f: 55 | l = json.load(f) 56 | l = [os.path.join(data_root, item.split('/')[1], item.split('/')[2] + '.txt') for item in l] 57 | return l 58 | 59 | def augment_pc(self, pc_normal): 60 | rotated_pc_normal = rotate_point_cloud_with_normal(pc_normal) 61 | rotated_pc_normal = rotate_perturbation_point_cloud_with_normal(rotated_pc_normal) 62 | jittered_pc = random_scale_point_cloud(rotated_pc_normal[:, :3]) 63 | jittered_pc = shift_point_cloud(jittered_pc) 64 | jittered_pc = jitter_point_cloud(jittered_pc) 65 | rotated_pc_normal[:, :3] = jittered_pc 66 | return rotated_pc_normal 67 | 68 | def __getitem__(self, index): 69 | if index in self.caches: 70 | xyz_points, labels = self.caches[index] 71 | else: 72 | pc = np.loadtxt(self.file_lists[index]).astype(np.float32) 73 | xyz_points = pc[:, :6] 74 | labels = pc[:, -1].astype(np.int32) 75 | 76 | if self.normalize: 77 | xyz_points[:, :3] = pc_normalize(xyz_points[:, :3]) 78 | if self.augment: 79 | xyz_points = self.augment_pc(xyz_points) 80 | if self.dp: 81 | xyz_points = random_point_dropout(xyz_points) 82 | self.caches[index] = xyz_points, labels 83 | 84 | # resample 85 | choice = np.random.choice(len(xyz_points), self.npoints, replace=True) 86 | xyz_points = xyz_points[choice, :] 87 | labels = labels[choice] 88 | return xyz_points, labels 89 | 90 | def __len__(self): 91 | return len(self.file_lists) 92 | 93 | 94 | if __name__ == '__main__': 95 | shapenet = ShapeNet(data_root='/root/shapenetcore_partanno_segmentation_benchmark_v0_normal', split='test', npoints=2500) 96 | print(shapenet.__len__()) 97 | print(shapenet.__getitem__(0)) -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhulf0804/Pointnet2.PyTorch/1b98042fa286ce13db5cbfeb498f0f64dc1487b4/data/__init__.py -------------------------------------------------------------------------------- /data/provider.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: charlesq34 3 | addr: https://github.com/charlesq34/pointnet2/blob/master/utils/provider.py 4 | 5 | update: zhulf 6 | ''' 7 | 8 | import numpy as np 9 | 10 | 11 | def pc_normalize(pc): 12 | mean = np.mean(pc, axis=0) 13 | pc -= mean 14 | m = np.max(np.sqrt(np.sum(np.power(pc, 2), axis=1))) 15 | pc /= m 16 | return pc 17 | 18 | 19 | def shuffle_points(pc): 20 | idx = np.arange(pc.shape[0]) 21 | np.random.shuffle(idx) 22 | return pc[idx,:] 23 | 24 | 25 | def rotate_point_cloud(pc): 26 | rotation_angle = np.random.uniform() * 2 * np.pi 27 | cosval = np.cos(rotation_angle) 28 | sinval = np.sin(rotation_angle) 29 | rotation_matrix = np.array([[cosval, 0, sinval], 30 | [0, 1, 0], 31 | [-sinval, 0, cosval]]) 32 | rotated_pc = np.dot(pc, rotation_matrix) 33 | return rotated_pc 34 | 35 | 36 | def rotate_point_cloud_with_normal(pc_normal): 37 | rotation_angle = np.random.uniform() * 2 * np.pi 38 | cosval = np.cos(rotation_angle) 39 | sinval = np.sin(rotation_angle) 40 | rotation_matrix = np.array([[cosval, 0, sinval], 41 | [0, 1, 0], 42 | [-sinval, 0, cosval]]) 43 | 44 | pc_normal[:,0:3] = np.dot(pc_normal[:, 0:3], rotation_matrix) 45 | pc_normal[:,3:6] = np.dot(pc_normal[:, 3:6], rotation_matrix) 46 | return pc_normal 47 | 48 | 49 | def rotate_perturbation_point_cloud_with_normal(pc_normal, angle_sigma=0.06, angle_clip=0.18): 50 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 51 | Rx = np.array([[1,0,0], 52 | [0,np.cos(angles[0]),-np.sin(angles[0])], 53 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 54 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 55 | [0,1,0], 56 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 57 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 58 | [np.sin(angles[2]),np.cos(angles[2]),0], 59 | [0,0,1]]) 60 | R = np.dot(Rz, np.dot(Ry,Rx)) 61 | pc_normal[:,0:3] = np.dot(pc_normal[:, :3], R) 62 | pc_normal[:,3:6] = np.dot(pc_normal[:, 3:], R) 63 | return pc_normal 64 | 65 | 66 | def rotate_point_cloud_by_angle(pc, rotation_angle): 67 | cosval = np.cos(rotation_angle) 68 | sinval = np.sin(rotation_angle) 69 | rotation_matrix = np.array([[cosval, 0, sinval], 70 | [0, 1, 0], 71 | [-sinval, 0, cosval]]) 72 | pc = np.dot(pc, rotation_matrix) 73 | return pc 74 | 75 | 76 | def rotate_point_cloud_by_angle_with_normal(pc_normal, rotation_angle): 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, 0, sinval], 80 | [0, 1, 0], 81 | [-sinval, 0, cosval]]) 82 | pc_normal[:, :3] = np.dot(pc_normal[:, :3], rotation_matrix) 83 | pc_normal[:, 3:6] = np.dot(pc_normal[:, 3:6], rotation_matrix) 84 | return pc_normal 85 | 86 | 87 | 88 | def rotate_perturbation_point_cloud(pc, angle_sigma=0.06, angle_clip=0.18): 89 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 90 | Rx = np.array([[1,0,0], 91 | [0,np.cos(angles[0]),-np.sin(angles[0])], 92 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 93 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 94 | [0,1,0], 95 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 96 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 97 | [np.sin(angles[2]),np.cos(angles[2]),0], 98 | [0,0,1]]) 99 | R = np.dot(Rz, np.dot(Ry,Rx)) 100 | pc = np.dot(pc, R) 101 | return pc 102 | 103 | 104 | def jitter_point_cloud(pc, sigma=0.01, clip=0.05): 105 | N, C = pc.shape 106 | assert(clip > 0) 107 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 108 | jittered_data += pc 109 | return jittered_data 110 | 111 | 112 | def shift_point_cloud(pc, shift_range=0.1): 113 | N, C = pc.shape 114 | shifts = np.random.uniform(-shift_range, shift_range, (1, C)) 115 | pc += shifts 116 | return pc 117 | 118 | 119 | def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25): 120 | scale = np.random.uniform(scale_low, scale_high, 1) 121 | pc *= scale 122 | return pc 123 | 124 | 125 | def random_point_dropout(pc, max_dropout_ratio=0.875): 126 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 127 | drop_idx = np.where(np.random.random((pc.shape[0]))<=dropout_ratio)[0] 128 | if len(drop_idx)>0: 129 | pc[drop_idx,:] = pc[0,:] # set to the first point 130 | return pc 131 | 132 | 133 | def augment_pc(pc_normal): 134 | rotated_pc_normal = rotate_point_cloud_with_normal(pc_normal) 135 | rotated_pc_normal = rotate_perturbation_point_cloud_with_normal(rotated_pc_normal) 136 | jittered_pc = random_scale_point_cloud(rotated_pc_normal[:, :3]) 137 | jittered_pc = shift_point_cloud(jittered_pc) 138 | jittered_pc = jitter_point_cloud(jittered_pc) 139 | rotated_pc_normal[:, :3] = jittered_pc 140 | return rotated_pc_normal 141 | 142 | 143 | if __name__ == '__main__': 144 | pc = np.random.randn(4, 6) 145 | print(pc) 146 | pc = augment_pc(pc) 147 | print(pc) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | from models.pointnet2_cls import pointnet2_cls_ssg, pointnet2_cls_msg 8 | from models.pointnet2_seg import pointnet2_seg_ssg 9 | from data.ModelNet40 import ModelNet40 10 | from data.ShapeNet import ShapeNet 11 | from utils.IoU import cal_accuracy_iou 12 | 13 | 14 | def evaluate_cls(model_id, data_root, checkpoint, npoints, dims=6, nclasses=40): 15 | print('Loading..') 16 | Models = { 17 | 'pointnet2_cls_ssg': pointnet2_cls_ssg, 18 | 'pointnet2_cls_msg': pointnet2_cls_msg 19 | } 20 | Model = Models[model_id] 21 | modelnet40_test = ModelNet40(data_root=data_root, split='test', npoints=npoints) 22 | test_loader = DataLoader(dataset=modelnet40_test, 23 | batch_size=64, shuffle=False, 24 | num_workers=1) 25 | device = torch.device('cuda') 26 | model = Model(dims, nclasses) 27 | model = model.to(device) 28 | model.load_state_dict(torch.load(checkpoint)) 29 | model.eval() 30 | print('Loading {} completed'.format(checkpoint)) 31 | print("Dataset: {}, Evaluating..".format(len(modelnet40_test))) 32 | total_correct, total_seen = 0, 0 33 | for data, labels in tqdm(test_loader): 34 | labels = labels.to(device) 35 | xyz, points = data[:, :, :3], data[:, :, 3:] 36 | with torch.no_grad(): 37 | pred = model(xyz.to(device), points.to(device)) 38 | pred = torch.max(pred, dim=-1)[1] 39 | total_correct += torch.sum(pred == labels) 40 | total_seen += xyz.shape[0] 41 | print("Evaluating completed!") 42 | print('Corr: {}, Seen: {}, Acc: {:.4f}'.format(total_correct, total_seen, total_correct / float(total_seen))) 43 | 44 | 45 | def evaluate_seg(data_root, checkpoint, npoints=2048, dims=6, nclasses=50): 46 | print('Loading..') 47 | shapenet_test = ShapeNet(data_root=data_root, split='test', npoints=npoints) 48 | test_loader = DataLoader(dataset=shapenet_test, batch_size=64, shuffle=False, num_workers=4) 49 | device = torch.device('cuda') 50 | model = pointnet2_seg_ssg(dims, nclasses) 51 | model = model.to(device) 52 | model.load_state_dict(torch.load(checkpoint)) 53 | model.eval() 54 | print('Loading {} completed'.format(checkpoint)) 55 | print("Dataset: {}, Evaluating..".format(len(shapenet_test))) 56 | preds, labels = [], [] 57 | for data, label in tqdm(test_loader): 58 | labels.append(label) 59 | xyz, points = data[:, :, :3], data[:, :, 3:] 60 | with torch.no_grad(): 61 | pred = model(xyz.to(device), points.to(device)) 62 | pred = torch.max(pred, dim=1)[1].cpu().detach().numpy() 63 | preds.append(pred) 64 | iou, acc = cal_accuracy_iou(np.concatenate(preds, axis=0), np.concatenate(labels, axis=0), shapenet_test.seg_classes) 65 | print("Weighed Acc: {:.4f}".format(acc)) 66 | print("Weighed Average IoU: {:.4f}".format(iou)) 67 | print('='*40) 68 | print("Evaluating completed !") 69 | 70 | 71 | if __name__ == '__main__': 72 | fire.Fire() -------------------------------------------------------------------------------- /evaluate_custom.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | from models.pointnet2_cls import pointnet2_cls_ssg, pointnet2_cls_msg 8 | from data.CustomDataset import CustomDataset 9 | from utils.common import setup_seed 10 | 11 | 12 | def evaluate_cls(model_id, data_root, checkpoint, nclasses, npoints=-1, dims=6): 13 | setup_seed(222) 14 | print('Loading..') 15 | Models = { 16 | 'pointnet2_cls_ssg': pointnet2_cls_ssg, 17 | 'pointnet2_cls_msg': pointnet2_cls_msg 18 | } 19 | Model = Models[model_id] 20 | custom_test = CustomDataset(data_root=data_root, split='test', npoints=npoints) 21 | test_loader = DataLoader(dataset=custom_test, 22 | batch_size=1, shuffle=False, 23 | num_workers=1) 24 | device = torch.device('cuda') 25 | model = Model(dims, nclasses) 26 | model = model.to(device) 27 | model.load_state_dict(torch.load(checkpoint)) 28 | model.eval() 29 | print('Loading {} completed'.format(checkpoint)) 30 | print("Dataset: {}, Evaluating..".format(len(custom_test))) 31 | total_correct, total_seen = 0, 0 32 | for data, labels in tqdm(test_loader): 33 | labels = labels.to(device) 34 | xyz, points = data[:, :, :3], data[:, :, 3:] 35 | with torch.no_grad(): 36 | pred = model(xyz.to(device), points.to(device)) 37 | pred = torch.max(pred, dim=-1)[1] 38 | total_correct += torch.sum(pred == labels) 39 | total_seen += xyz.shape[0] 40 | print("Evaluating completed!") 41 | print('Corr: {}, Seen: {}, Acc: {:.4f}'.format(total_correct, total_seen, total_correct / float(total_seen))) 42 | 43 | 44 | if __name__ == '__main__': 45 | fire.Fire() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhulf0804/Pointnet2.PyTorch/1b98042fa286ce13db5cbfeb498f0f64dc1487b4/models/__init__.py -------------------------------------------------------------------------------- /models/pointnet2_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.set_abstraction import PointNet_SA_Module, PointNet_SA_Module_MSG 5 | 6 | 7 | class pointnet2_cls_ssg(nn.Module): 8 | def __init__(self, in_channels, nclasses): 9 | super(pointnet2_cls_ssg, self).__init__() 10 | self.pt_sa1 = PointNet_SA_Module(M=512, radius=0.2, K=32, in_channels=in_channels, mlp=[64, 64, 128], group_all=False) 11 | self.pt_sa2 = PointNet_SA_Module(M=128, radius=0.4, K=64, in_channels=131, mlp=[128, 128, 256], group_all=False) 12 | self.pt_sa3 = PointNet_SA_Module(M=None, radius=None, K=None, in_channels=259, mlp=[256, 512, 1024], group_all=True) 13 | self.fc1 = nn.Linear(1024, 512, bias=False) 14 | self.bn1 = nn.BatchNorm1d(512) 15 | self.dropout1 = nn.Dropout(0.5) 16 | self.fc2 = nn.Linear(512, 256, bias=False) 17 | self.bn2 = nn.BatchNorm1d(256) 18 | self.dropout2 = nn.Dropout(0.5) 19 | self.cls = nn.Linear(256, nclasses) 20 | 21 | def forward(self, xyz, points): 22 | batchsize = xyz.shape[0] 23 | new_xyz, new_points = self.pt_sa1(xyz, points) 24 | new_xyz, new_points = self.pt_sa2(new_xyz, new_points) 25 | new_xyz, new_points = self.pt_sa3(new_xyz, new_points) 26 | net = new_points.view(batchsize, -1) 27 | net = self.dropout1(F.relu(self.bn1(self.fc1(net)))) 28 | net = self.dropout2(F.relu(self.bn2(self.fc2(net)))) 29 | net = self.cls(net) 30 | return net 31 | 32 | 33 | class pointnet2_cls_msg(nn.Module): 34 | def __init__(self, in_channels, nclasses): 35 | super(pointnet2_cls_msg, self).__init__() 36 | self.pt_sa1 = PointNet_SA_Module_MSG(M=512, 37 | radiuses=[0.1, 0.2, 0.4], 38 | Ks=[16, 32, 128], 39 | in_channels=in_channels, 40 | mlps=[[32, 32, 64], 41 | [64, 64, 128], 42 | [64, 96, 128]]) 43 | self.pt_sa2 = PointNet_SA_Module_MSG(M=128, 44 | radiuses=[0.2, 0.4, 0.8], 45 | Ks=[32, 64, 128], 46 | in_channels=323, 47 | mlps=[[64, 64, 128], 48 | [128, 128, 256], 49 | [128, 128, 256]]) 50 | self.pt_sa3 = PointNet_SA_Module(M=None, radius=None, K=None, in_channels=643, mlp=[256, 512, 1024], group_all=True) 51 | self.fc1 = nn.Linear(1024, 512, bias=False) 52 | self.bn1 = nn.BatchNorm1d(512) 53 | self.dropout1 = nn.Dropout(0.5) 54 | self.fc2 = nn.Linear(512, 256, bias=False) 55 | self.bn2 = nn.BatchNorm1d(256) 56 | self.dropout2 = nn.Dropout(0.5) 57 | self.cls = nn.Linear(256, nclasses) 58 | 59 | def forward(self, xyz, points): 60 | batchsize = xyz.shape[0] 61 | new_xyz, new_points = self.pt_sa1(xyz, points) 62 | new_xyz, new_points = self.pt_sa2(new_xyz, new_points) 63 | new_xyz, new_points = self.pt_sa3(new_xyz, new_points) 64 | net = new_points.view(batchsize, -1) 65 | net = self.dropout1(F.relu(self.bn1(self.fc1(net)))) 66 | net = self.dropout2(F.relu(self.bn2(self.fc2(net)))) 67 | net = self.cls(net) 68 | return net 69 | 70 | 71 | class cls_loss(nn.Module): 72 | def __init__(self): 73 | super(cls_loss, self).__init__() 74 | self.loss = nn.CrossEntropyLoss() 75 | def forward(self, pred, lable): 76 | ''' 77 | 78 | :param pred: shape=(B, nclass) 79 | :param lable: shape=(B, ) 80 | :return: loss 81 | ''' 82 | loss = self.loss(pred, lable) 83 | return loss 84 | 85 | 86 | if __name__ == '__main__': 87 | xyz = torch.randn(16, 2048, 3) 88 | points = torch.randn(16, 2048, 3) 89 | label = torch.randint(0, 40, size=(16, )) 90 | ssg_model = pointnet2_cls_ssg(6, 40) 91 | 92 | print(ssg_model) 93 | #net = ssg_model(xyz, points) 94 | #print(net.shape) 95 | #print(label.shape) 96 | #loss = cls_loss() 97 | #loss = loss(net, label) 98 | #print(loss) -------------------------------------------------------------------------------- /models/pointnet2_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.set_abstraction import PointNet_SA_Module, PointNet_SA_Module_MSG 5 | from utils.feature_propagation import PointNet_FP_Module 6 | 7 | 8 | class pointnet2_seg_ssg(nn.Module): 9 | def __init__(self, in_channels, nclasses): 10 | super(pointnet2_seg_ssg, self).__init__() 11 | self.pt_sa1 = PointNet_SA_Module(M=512, radius=0.2, K=32, in_channels=in_channels, mlp=[64, 64, 128], group_all=False) 12 | self.pt_sa2 = PointNet_SA_Module(M=128, radius=0.4, K=64, in_channels=131, mlp=[128, 128, 256], group_all=False) 13 | self.pt_sa3 = PointNet_SA_Module(M=None, radius=None, K=None, in_channels=259, mlp=[256, 512, 1024], group_all=True) 14 | 15 | self.pt_fp1 = PointNet_FP_Module(in_channels=1024+256, mlp=[256, 256], bn=True) 16 | self.pt_fp2 = PointNet_FP_Module(in_channels=256 + 128, mlp=[256, 128], bn=True) 17 | self.pt_fp3 = PointNet_FP_Module(in_channels=128 + 6, mlp=[128, 128, 128], bn=True) 18 | 19 | self.conv1 = nn.Conv1d(128, 128, 1, stride=1, bias=False) 20 | self.bn1 = nn.BatchNorm1d(128) 21 | self.dropout1 = nn.Dropout(0.5) 22 | self.cls = nn.Conv1d(128, nclasses, 1, stride=1) 23 | 24 | def forward(self, l0_xyz, l0_points): 25 | l1_xyz, l1_points = self.pt_sa1(l0_xyz, l0_points) 26 | l2_xyz, l2_points = self.pt_sa2(l1_xyz, l1_points) 27 | l3_xyz, l3_points = self.pt_sa3(l2_xyz, l2_points) 28 | 29 | l2_points = self.pt_fp1(l2_xyz, l3_xyz, l2_points, l3_points) 30 | l1_points = self.pt_fp2(l1_xyz, l2_xyz, l1_points, l2_points) 31 | l0_points = self.pt_fp3(l0_xyz, l1_xyz, torch.cat([l0_points, l0_xyz], dim=-1), l1_points) 32 | 33 | net = l0_points.permute(0, 2, 1).contiguous() 34 | net = self.dropout1(F.relu(self.bn1(self.conv1(net)))) 35 | net = self.cls(net) 36 | 37 | return net 38 | 39 | 40 | class seg_loss(nn.Module): 41 | def __init__(self): 42 | super(seg_loss, self).__init__() 43 | self.loss = nn.CrossEntropyLoss() 44 | def forward(self, pred, label): 45 | ''' 46 | 47 | :param pred: shape=(B, N, C) 48 | :param label: shape=(B, N) 49 | :return: 50 | ''' 51 | loss = self.loss(pred, label) 52 | return loss 53 | 54 | 55 | if __name__ == '__main__': 56 | in_channels = 6 57 | n_classes = 50 58 | l0_xyz = torch.randn(4, 1024, 3) 59 | l0_points = torch.randn(4, 1024, 3) 60 | model = pointnet2_seg_ssg(in_channels, n_classes) 61 | net = model(l0_xyz, l0_points) 62 | print(net.shape) -------------------------------------------------------------------------------- /train_clss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from models.pointnet2_cls import pointnet2_cls_ssg, pointnet2_cls_msg, cls_loss 10 | from data.ModelNet40 import ModelNet40 11 | 12 | 13 | def train_one_epoch(train_loader, model, loss_func, optimizer, device): 14 | losses, total_seen, total_correct = [], 0, 0 15 | for data, labels in train_loader: 16 | optimizer.zero_grad() # Important 17 | labels = labels.to(device) 18 | xyz, points = data[:, :, :3], data[:, :, 3:] 19 | pred = model(xyz.to(device), points.to(device)) 20 | loss = loss_func(pred, labels) 21 | 22 | loss.backward() 23 | optimizer.step() 24 | pred = torch.max(pred, dim=-1)[1] 25 | total_correct += torch.sum(pred == labels) 26 | total_seen += xyz.shape[0] 27 | losses.append(loss.item()) 28 | return np.mean(losses), total_correct, total_seen, total_correct / float(total_seen) 29 | 30 | 31 | def test_one_epoch(test_loader, model, loss_func, device): 32 | losses, total_seen, total_correct = [], 0, 0 33 | for data, labels in test_loader: 34 | labels = labels.to(device) 35 | xyz, points = data[:, :, :3], data[:, :, 3:] 36 | with torch.no_grad(): 37 | pred = model(xyz.to(device), points.to(device)) 38 | loss = loss_func(pred, labels) 39 | 40 | pred = torch.max(pred, dim=-1)[1] 41 | total_correct += torch.sum(pred == labels) 42 | total_seen += xyz.shape[0] 43 | losses.append(loss.item()) 44 | return np.mean(losses), total_correct, total_seen, total_correct / float(total_seen) 45 | 46 | 47 | def train(train_loader, test_loader, model, loss_func, optimizer, scheduler, device, ngpus, nepoches, log_interval, log_dir, checkpoint_interval): 48 | if not os.path.exists(log_dir): 49 | os.makedirs(log_dir) 50 | checkpoint_dir = os.path.join(log_dir, 'checkpoints') 51 | if not os.path.exists(checkpoint_dir): 52 | os.makedirs(checkpoint_dir) 53 | tensorboard_dir = os.path.join(log_dir, 'tensorboard') 54 | if not os.path.exists(tensorboard_dir): 55 | os.makedirs(tensorboard_dir) 56 | writer = SummaryWriter(tensorboard_dir) 57 | 58 | for epoch in range(nepoches): 59 | if epoch % checkpoint_interval == 0: 60 | print('='*40) 61 | if ngpus > 1: 62 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, "pointnet2_cls_%d.pth" % epoch)) 63 | else: 64 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "pointnet2_cls_%d.pth" % epoch)) 65 | model.eval() 66 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 67 | loss, total_correct, total_seen, acc = test_one_epoch(test_loader, model, loss_func, device) 68 | print('Test Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, Corr: {}, Total: {}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, total_correct, total_seen, acc)) 69 | writer.add_scalar('test loss', loss, epoch) 70 | writer.add_scalar('test acc', acc, epoch) 71 | model.train() 72 | loss, total_correct, total_seen, acc = train_one_epoch(train_loader, model, loss_func, optimizer, device) 73 | writer.add_scalar('train loss', loss, epoch) 74 | writer.add_scalar('train acc', acc, epoch) 75 | if epoch % log_interval == 0: 76 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 77 | print('Train Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, Corr: {}, Total: {}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, total_correct, total_seen, acc)) 78 | scheduler.step() 79 | 80 | 81 | if __name__ == '__main__': 82 | Models = { 83 | 'pointnet2_cls_ssg': pointnet2_cls_ssg, 84 | 'pointnet2_cls_msg': pointnet2_cls_msg 85 | } 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--data_root', type=str, required=True, help='Root to the dataset') 88 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 89 | parser.add_argument('--npoints', type=int, default=1024, help='Number of the training points') 90 | parser.add_argument('--nclasses', type=int, default=40, help='Number of classes') 91 | parser.add_argument('--augment', type=bool, default=False, help='Augment the train data') 92 | parser.add_argument('--dp', type=bool, default=False, help='Random input dropout during training') 93 | parser.add_argument('--model', type=str, default='pointnet2_cls_ssg', help='Model name') 94 | parser.add_argument('--gpus', type=str, default='0', help='Cuda ids') 95 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learing rate') 96 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='Initial learing rate') 97 | parser.add_argument('--nepoches', type=int, default=251, help='Number of traing epoches') 98 | parser.add_argument('--step_size', type=int, default=20, help='StepLR step size') 99 | parser.add_argument('--gamma', type=float, default=0.7, help='StepLR gamma') 100 | parser.add_argument('--log_interval', type=int, default=10, help='Print iterval') 101 | parser.add_argument('--log_dir', type=str, required=True, help='Train/val loss and accuracy logs') 102 | parser.add_argument('--checkpoint_interval', type=int, default=10, help='Checkpoint saved interval') 103 | args = parser.parse_args() 104 | print(args) 105 | 106 | device_ids = list(map(int, args.gpus.strip().split(','))) if ',' in args.gpus else [int(args.gpus)] 107 | ngpus = len(device_ids) 108 | 109 | modelnet40_train = ModelNet40(data_root=args.data_root, split='train', npoints=args.npoints, augment=args.augment, dp=args.dp) 110 | modelnet40_test = ModelNet40(data_root=args.data_root, split='test', npoints=args.npoints) 111 | train_loader = DataLoader(dataset=modelnet40_train, batch_size=args.batch_size // ngpus, shuffle=True, num_workers=4) 112 | test_loader = DataLoader(dataset=modelnet40_test, batch_size=args.batch_size // ngpus, shuffle=False, num_workers=4) 113 | print('Train set: {}'.format(len(modelnet40_train))) 114 | print('Test set: {}'.format(len(modelnet40_test))) 115 | 116 | Model = Models[args.model] 117 | model = Model(6, args.nclasses) 118 | # Mutli-gpus 119 | device = torch.device("cuda:{}".format(device_ids[0]) if torch.cuda.is_available() else "cpu") 120 | if ngpus > 1 and torch.cuda.device_count() > 1: 121 | model = nn.DataParallel(model, device_ids=device_ids) 122 | model = model.to(device) 123 | 124 | loss = cls_loss().to(device) 125 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum) 126 | optimizer = torch.optim.Adam( 127 | model.parameters(), 128 | lr=args.lr, 129 | betas=(0.9, 0.999), 130 | eps=1e-08, 131 | weight_decay=args.decay_rate 132 | ) 133 | 134 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.7) 135 | 136 | tic = time.time() 137 | train(train_loader=train_loader, 138 | test_loader=test_loader, 139 | model=model, 140 | loss_func=loss, 141 | optimizer=optimizer, 142 | scheduler=scheduler, 143 | device=device, 144 | ngpus=ngpus, 145 | nepoches=args.nepoches, 146 | log_interval=args.log_interval, 147 | log_dir=args.log_dir, 148 | checkpoint_interval=args.checkpoint_interval, 149 | ) 150 | toc = time.time() 151 | print('Training completed, {:.2f} minutes'.format((toc - tic) / 60)) -------------------------------------------------------------------------------- /train_custom_cls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from models.pointnet2_cls import pointnet2_cls_ssg, pointnet2_cls_msg, cls_loss 10 | from data.CustomDataset import CustomDataset 11 | from utils.common import setup_seed 12 | 13 | 14 | def train_one_epoch(train_loader, model, loss_func, optimizer, device): 15 | losses, total_seen, total_correct = [], 0, 0 16 | for data, labels in train_loader: 17 | optimizer.zero_grad() # Important 18 | labels = labels.to(device) 19 | xyz, points = data[:, :, :3], data[:, :, 3:] 20 | pred = model(xyz.to(device), points.to(device)) 21 | loss = loss_func(pred, labels) 22 | 23 | loss.backward() 24 | optimizer.step() 25 | pred = torch.max(pred, dim=-1)[1] 26 | total_correct += torch.sum(pred == labels) 27 | total_seen += xyz.shape[0] 28 | losses.append(loss.item()) 29 | return np.mean(losses), total_correct, total_seen, total_correct / float(total_seen) 30 | 31 | 32 | def test_one_epoch(test_loader, model, loss_func, device): 33 | losses, total_seen, total_correct = [], 0, 0 34 | for data, labels in test_loader: 35 | labels = labels.to(device) 36 | xyz, points = data[:, :, :3], data[:, :, 3:] 37 | with torch.no_grad(): 38 | pred = model(xyz.to(device), points.to(device)) 39 | loss = loss_func(pred, labels) 40 | 41 | pred = torch.max(pred, dim=-1)[1] 42 | total_correct += torch.sum(pred == labels) 43 | total_seen += xyz.shape[0] 44 | losses.append(loss.item()) 45 | return np.mean(losses), total_correct, total_seen, total_correct / float(total_seen) 46 | 47 | 48 | def train(train_loader, test_loader, model, loss_func, optimizer, scheduler, device, ngpus, nepoches, log_interval, log_dir, checkpoint_interval): 49 | if not os.path.exists(log_dir): 50 | os.makedirs(log_dir) 51 | checkpoint_dir = os.path.join(log_dir, 'checkpoints') 52 | if not os.path.exists(checkpoint_dir): 53 | os.makedirs(checkpoint_dir) 54 | tensorboard_dir = os.path.join(log_dir, 'tensorboard') 55 | if not os.path.exists(tensorboard_dir): 56 | os.makedirs(tensorboard_dir) 57 | writer = SummaryWriter(tensorboard_dir) 58 | 59 | for epoch in range(nepoches): 60 | if epoch % checkpoint_interval == 0: 61 | print('='*40) 62 | if ngpus > 1: 63 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, "pointnet2_cls_%d.pth" % epoch)) 64 | else: 65 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "pointnet2_cls_%d.pth" % epoch)) 66 | model.eval() 67 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 68 | loss, total_correct, total_seen, acc = test_one_epoch(test_loader, model, loss_func, device) 69 | print('Test Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, Corr: {}, Total: {}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, total_correct, total_seen, acc)) 70 | writer.add_scalar('test loss', loss, epoch) 71 | writer.add_scalar('test acc', acc, epoch) 72 | model.train() 73 | loss, total_correct, total_seen, acc = train_one_epoch(train_loader, model, loss_func, optimizer, device) 74 | writer.add_scalar('train loss', loss, epoch) 75 | writer.add_scalar('train acc', acc, epoch) 76 | if epoch % log_interval == 0: 77 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 78 | print('Train Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, Corr: {}, Total: {}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, total_correct, total_seen, acc)) 79 | scheduler.step() 80 | 81 | 82 | if __name__ == '__main__': 83 | Models = { 84 | 'pointnet2_cls_ssg': pointnet2_cls_ssg, 85 | 'pointnet2_cls_msg': pointnet2_cls_msg 86 | } 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--data_root', type=str, required=True, help='Root to the dataset') 89 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 90 | parser.add_argument('--npoints', type=int, default=1024, help='Number of the training points') 91 | parser.add_argument('--nclasses', type=int, required=True, help='Number of classes') 92 | parser.add_argument('--augment', type=bool, default=False, help='Augment the train data') 93 | parser.add_argument('--dp', type=bool, default=False, help='Random input dropout during training') 94 | parser.add_argument('--model', type=str, default='pointnet2_cls_ssg', help='Model name') 95 | parser.add_argument('--gpus', type=str, default='0', help='Cuda ids') 96 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learing rate') 97 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='Initial learing rate') 98 | parser.add_argument('--nepoches', type=int, default=251, help='Number of traing epoches') 99 | parser.add_argument('--step_size', type=int, default=20, help='StepLR step size') 100 | parser.add_argument('--gamma', type=float, default=0.7, help='StepLR gamma') 101 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 102 | parser.add_argument('--log_interval', type=int, default=10, help='Print iterval') 103 | parser.add_argument('--log_dir', type=str, default='work_dirs', help='Train/val loss and accuracy logs') 104 | parser.add_argument('--checkpoint_interval', type=int, default=10, help='Checkpoint saved interval') 105 | args = parser.parse_args() 106 | print(args) 107 | setup_seed(args.seed) 108 | 109 | device_ids = list(map(int, args.gpus.strip().split(','))) if ',' in args.gpus else [int(args.gpus)] 110 | ngpus = len(device_ids) 111 | 112 | custom_train = CustomDataset(data_root=args.data_root, split='train', npoints=args.npoints, augment=args.augment, dp=args.dp) 113 | custom_test = CustomDataset(data_root=args.data_root, split='test', npoints=-1) 114 | train_loader = DataLoader(dataset=custom_train, batch_size=args.batch_size // ngpus, shuffle=True, num_workers=4) 115 | test_loader = DataLoader(dataset=custom_test, batch_size=1, shuffle=False, num_workers=1) 116 | print('Train set: {}'.format(len(custom_train))) 117 | print('Test set: {}'.format(len(custom_test))) 118 | 119 | Model = Models[args.model] 120 | model = Model(6, args.nclasses) 121 | # Mutli-gpus 122 | device = torch.device("cuda:{}".format(device_ids[0]) if torch.cuda.is_available() else "cpu") 123 | if ngpus > 1 and torch.cuda.device_count() > 1: 124 | model = nn.DataParallel(model, device_ids=device_ids) 125 | model = model.to(device) 126 | 127 | loss = cls_loss().to(device) 128 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum) 129 | optimizer = torch.optim.Adam( 130 | model.parameters(), 131 | lr=args.lr, 132 | betas=(0.9, 0.999), 133 | eps=1e-08, 134 | weight_decay=args.decay_rate 135 | ) 136 | 137 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.7) 138 | 139 | tic = time.time() 140 | train(train_loader=train_loader, 141 | test_loader=test_loader, 142 | model=model, 143 | loss_func=loss, 144 | optimizer=optimizer, 145 | scheduler=scheduler, 146 | device=device, 147 | ngpus=ngpus, 148 | nepoches=args.nepoches, 149 | log_interval=args.log_interval, 150 | log_dir=args.log_dir, 151 | checkpoint_interval=args.checkpoint_interval, 152 | ) 153 | toc = time.time() 154 | print('Training completed, {:.2f} minutes'.format((toc - tic) / 60)) -------------------------------------------------------------------------------- /train_part_seg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from models.pointnet2_seg import pointnet2_seg_ssg, seg_loss 10 | from data.ShapeNet import ShapeNet 11 | from utils.IoU import cal_accuracy_iou 12 | 13 | 14 | def train_one_epoch(train_loader, seg_classes, model, loss_func, optimizer, device, pt): 15 | losses, preds, labels = [], [], [] 16 | for data, label in train_loader: 17 | labels.append(label) 18 | optimizer.zero_grad() # Important 19 | label = label.long().to(device) 20 | xyz, points = data[:, :, :3], data[:, :, 3:] 21 | pred = model(xyz.to(device), points.to(device)) 22 | loss = loss_func(pred, label) 23 | 24 | loss.backward() 25 | optimizer.step() 26 | pred = torch.max(pred, dim=1)[1] 27 | preds.append(pred.cpu().detach().numpy()) 28 | losses.append(loss.item()) 29 | iou, acc = cal_accuracy_iou(np.concatenate(preds, axis=0), np.concatenate(labels, axis=0), seg_classes, pt) 30 | return np.mean(losses), iou, acc 31 | 32 | 33 | def test_one_epoch(test_loader, seg_classes, model, loss_func, device): 34 | losses, preds, labels = [], [], [] 35 | for data, label in test_loader: 36 | labels.append(label) 37 | label = label.long().to(device) 38 | xyz, points = data[:, :, :3], data[:, :, 3:] 39 | with torch.no_grad(): 40 | pred = model(xyz.to(device), points.to(device)) 41 | loss = loss_func(pred, label) 42 | pred = torch.max(pred, dim=1)[1] 43 | preds.append(pred.cpu().detach().numpy()) 44 | losses.append(loss.item()) 45 | iou, acc = cal_accuracy_iou(np.concatenate(preds, axis=0), np.concatenate(labels, axis=0), seg_classes) 46 | return np.mean(losses), iou, acc 47 | 48 | 49 | def train(train_loader, test_loader, seg_classes, model, loss_func, optimizer, scheduler, device, ngpus, nepoches, log_interval, log_dir, checkpoint_interval): 50 | if not os.path.exists(log_dir): 51 | os.makedirs(log_dir) 52 | checkpoint_dir = os.path.join(log_dir, 'checkpoints') 53 | if not os.path.exists(checkpoint_dir): 54 | os.makedirs(checkpoint_dir) 55 | tensorboard_dir = os.path.join(log_dir, 'tensorboard') 56 | if not os.path.exists(tensorboard_dir): 57 | os.makedirs(tensorboard_dir) 58 | writer = SummaryWriter(tensorboard_dir) 59 | 60 | for epoch in range(nepoches): 61 | if epoch % checkpoint_interval == 0: 62 | if ngpus > 1: 63 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, "pointnet2_seg_%d.pth" % epoch)) 64 | else: 65 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "pointnet2_seg_%d.pth" % epoch)) 66 | model.eval() 67 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 68 | loss, iou, acc = test_one_epoch(test_loader, seg_classes, model, loss_func, device) 69 | print('Test Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, IoU: {:.4f}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, iou, acc)) 70 | writer.add_scalar('test loss', loss, epoch) 71 | writer.add_scalar('test iou', iou, epoch) 72 | writer.add_scalar('test acc', acc, epoch) 73 | model.train() 74 | pt = False 75 | if epoch % log_interval == 0: 76 | pt = True 77 | loss, iou, acc = train_one_epoch(train_loader, seg_classes, model, loss_func, optimizer, device, pt) 78 | writer.add_scalar('train loss', loss, epoch) 79 | writer.add_scalar('train iou', iou, epoch) 80 | writer.add_scalar('train acc', acc, epoch) 81 | if epoch % log_interval == 0: 82 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 83 | print('Train Epoch: {} / {}, lr: {:.6f}, Loss: {:.2f}, IoU: {:.4f}, Acc: {:.4f}'.format(epoch, nepoches, lr, loss, iou, acc)) 84 | scheduler.step() 85 | 86 | 87 | if __name__ == '__main__': 88 | Models = { 89 | 'pointnet2_seg_ssg': pointnet2_seg_ssg, 90 | } 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--data_root', type=str, required=True, help='Root to the dataset') 93 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 94 | parser.add_argument('--npoints', type=int, default=2500, help='Number of the training points') 95 | parser.add_argument('--nclasses', type=int, default=50, help='Number of classes') 96 | parser.add_argument('--augment', type=bool, default=False, help='Augment the train data') 97 | parser.add_argument('--dp', type=bool, default=False, help='Random input dropout during training') 98 | parser.add_argument('--model', type=str, default='pointnet2_seg_ssg', help='Model name') 99 | parser.add_argument('--gpus', type=str, default='0', help='Cuda ids') 100 | parser.add_argument('--lr', type=float, default=0.001, help='Initial learing rate') 101 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='Initial learing rate') 102 | parser.add_argument('--nepoches', type=int, default=251, help='Number of traing epoches') 103 | parser.add_argument('--step_size', type=int, default=20, help='StepLR step size') 104 | parser.add_argument('--gamma', type=float, default=0.7, help='StepLR gamma') 105 | parser.add_argument('--log_interval', type=int, default=10, help='Print iterval') 106 | parser.add_argument('--log_dir', type=str, required=True, help='Train/val loss and accuracy logs') 107 | parser.add_argument('--checkpoint_interval', type=int, default=10, help='Checkpoint saved interval') 108 | args = parser.parse_args() 109 | print(args) 110 | 111 | device_ids = list(map(int, args.gpus.strip().split(','))) if ',' in args.gpus else [int(args.gpus)] 112 | ngpus = len(device_ids) 113 | 114 | shapenet_train = ShapeNet(data_root=args.data_root, split='trainval', npoints=args.npoints, augment=args.augment, dp=args.dp) 115 | shapenet_test = ShapeNet(data_root=args.data_root, split='test', npoints=args.npoints) 116 | train_loader = DataLoader(dataset=shapenet_train, batch_size=args.batch_size // ngpus, shuffle=True, num_workers=4) 117 | test_loader = DataLoader(dataset=shapenet_test, batch_size=args.batch_size // ngpus, shuffle=False, num_workers=4) 118 | print('Train set: {}'.format(len(shapenet_train))) 119 | print('Test set: {}'.format(len(shapenet_test))) 120 | 121 | Model = Models[args.model] 122 | model = Model(6, args.nclasses) 123 | # Mutli-gpus 124 | device = torch.device("cuda:{}".format(device_ids[0]) if torch.cuda.is_available() else "cpu") 125 | if ngpus > 1 and torch.cuda.device_count() > 1: 126 | model = nn.DataParallel(model, device_ids=device_ids) 127 | model = model.to(device) 128 | 129 | loss = seg_loss().to(device) 130 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.init_lr, momentum=args.momentum) 131 | optimizer = torch.optim.Adam( 132 | model.parameters(), 133 | lr=args.lr, 134 | betas=(0.9, 0.999), 135 | eps=1e-08, 136 | weight_decay=args.decay_rate 137 | ) 138 | 139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.7) 140 | 141 | tic = time.time() 142 | train(train_loader=train_loader, 143 | test_loader=test_loader, 144 | seg_classes=shapenet_train.seg_classes, 145 | model=model, 146 | loss_func=loss, 147 | optimizer=optimizer, 148 | scheduler=scheduler, 149 | device=device, 150 | ngpus=ngpus, 151 | nepoches=args.nepoches, 152 | log_interval=args.log_interval, 153 | log_dir=args.log_dir, 154 | checkpoint_interval=args.checkpoint_interval, 155 | ) 156 | toc = time.time() 157 | print('Training completed, {:.2f} minutes'.format((toc - tic) / 60)) -------------------------------------------------------------------------------- /utils/IoU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cal_accuracy_iou(preds, labels, seg_classes, pt=True): 5 | ''' 6 | 7 | :param pred: shape=(B, N) 8 | :param labels: shape=(B, N) 9 | :param seg_classes: dict: cat->labels 10 | :return: 11 | ''' 12 | nclasses, n = len(seg_classes), len(preds) 13 | shape_ious = {cat: [] for cat in seg_classes} 14 | shape_count = {cat: 0.0 for cat in seg_classes} 15 | shape_points_seen = {cat: 0.0 for cat in seg_classes} 16 | shape_points_correct = {cat: 0.0 for cat in seg_classes} 17 | seg2cat = {} 18 | for k, vs in seg_classes.items(): 19 | for v in vs: 20 | seg2cat[v] = k 21 | for i in range(n): 22 | pred, label = preds[i], labels[i] 23 | npoints = len(pred) 24 | cat = seg2cat[label[0]] 25 | shape_count[cat] += 1 26 | shape_points_seen[cat] += npoints 27 | shape_points_correct[cat] += np.sum(pred == label) 28 | part_ious = [] 29 | for l in seg_classes[cat]: 30 | intersection = np.sum(np.all([pred == l, label == l], axis=0)) 31 | union = np.sum(np.any([pred == l, label == l], axis=0)) 32 | if union < 1: 33 | part_ious.append(1.0) 34 | continue 35 | part_ious.append(intersection / union) 36 | shape_ious[cat].append(np.mean(part_ious)) 37 | 38 | if pt: 39 | print('='*40) 40 | weighted_acc = 0.0 41 | weighted_average_iou = 0.0 42 | accs, ious = [], [] 43 | for cat in sorted(seg_classes.keys()): 44 | acc = shape_points_correct[cat] / float(shape_points_seen[cat]) 45 | iou = np.mean(shape_ious[cat]) 46 | if pt: 47 | print('{} | acc: {:.4f}, iou: {:.4f}'.format(cat, acc, iou)) 48 | accs.append(round(acc * 100, 1)) 49 | ious.append(round(iou * 100, 1)) 50 | weighted_acc += shape_count[cat] * acc 51 | weighted_average_iou += shape_count[cat] * iou 52 | #print('accs: ', accs) 53 | #print('ious: ', ious) 54 | weighted_acc = weighted_acc / np.sum(list(shape_count.values())).astype(np.float32) 55 | weighted_average_iou = weighted_average_iou / np.sum(list(shape_count.values())).astype(np.float32) 56 | return weighted_average_iou, weighted_acc -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhulf0804/Pointnet2.PyTorch/1b98042fa286ce13db5cbfeb498f0f64dc1487b4/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_dists(points1, points2): 6 | ''' 7 | Calculate dists between two group points 8 | :param cur_point: shape=(B, M, C) 9 | :param points: shape=(B, N, C) 10 | :return: 11 | ''' 12 | B, M, C = points1.shape 13 | _, N, _ = points2.shape 14 | dists = torch.sum(torch.pow(points1, 2), dim=-1).view(B, M, 1) + \ 15 | torch.sum(torch.pow(points2, 2), dim=-1).view(B, 1, N) 16 | dists -= 2 * torch.matmul(points1, points2.permute(0, 2, 1)) 17 | dists = torch.where(dists < 0, torch.ones_like(dists) * 1e-7, dists) # Very Important for dist = 0. 18 | return torch.sqrt(dists).float() 19 | 20 | 21 | def gather_points(points, inds): 22 | ''' 23 | 24 | :param points: shape=(B, N, C) 25 | :param inds: shape=(B, M) or shape=(B, M, K) 26 | :return: sampling points: shape=(B, M, C) or shape=(B, M, K, C) 27 | ''' 28 | device = points.device 29 | B, N, C = points.shape 30 | inds_shape = list(inds.shape) 31 | inds_shape[1:] = [1] * len(inds_shape[1:]) 32 | repeat_shape = list(inds.shape) 33 | repeat_shape[0] = 1 34 | batchlists = torch.arange(0, B, dtype=torch.long).to(device).reshape(inds_shape).repeat(repeat_shape) 35 | return points[batchlists, inds, :] 36 | 37 | 38 | def setup_seed(seed): 39 | torch.backends.cudnn.deterministic = True 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | np.random.seed(seed) -------------------------------------------------------------------------------- /utils/feature_propagation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.common import gather_points, get_dists 4 | 5 | 6 | def three_nn(xyz1, xyz2): 7 | ''' 8 | 9 | :param xyz1: shape=(B, N1, 3) 10 | :param xyz2: shape=(B, N2, 3) 11 | :return: dists: shape=(B, N1, 3), inds: shape=(B, N1, 3) 12 | ''' 13 | dists = get_dists(xyz1, xyz2) 14 | dists, inds = torch.sort(dists, dim=-1) 15 | dists, inds = dists[:, :, :3], inds[:, :, :3] 16 | return dists, inds 17 | 18 | 19 | def three_interpolate(xyz1, xyz2, points2): 20 | ''' 21 | 22 | :param xyz1: shape=(B, N1, 3) 23 | :param xyz2: shape=(B, N2, 3) 24 | :param points2: shape=(B, N2, C2) 25 | :return: interpolated_points: shape=(B, N1, C2) 26 | ''' 27 | _, _, C2 = points2.shape 28 | dists, inds = three_nn(xyz1, xyz2) 29 | inversed_dists = 1.0 / (dists + 1e-8) 30 | weight = inversed_dists / torch.sum(inversed_dists, dim=-1, keepdim=True) # shape=(B, N1, 3) 31 | weight = torch.unsqueeze(weight, -1).repeat(1, 1, 1, C2) 32 | interpolated_points = gather_points(points2, inds) # shape=(B, N1, 3, C2) 33 | interpolated_points = torch.sum(weight * interpolated_points, dim=2) 34 | return interpolated_points 35 | 36 | 37 | class PointNet_FP_Module(nn.Module): 38 | def __init__(self, in_channels, mlp, bn=True): 39 | super(PointNet_FP_Module, self).__init__() 40 | self.backbone = nn.Sequential() 41 | bias = False if bn else True 42 | for i, out_channels in enumerate(mlp): 43 | self.backbone.add_module('Conv_{}'.format(i), nn.Conv2d(in_channels, 44 | out_channels, 45 | 1, 46 | stride=1, 47 | padding=0, 48 | bias=bias)) 49 | if bn: 50 | self.backbone.add_module('Bn_{}'.format(i), nn.BatchNorm2d(out_channels)) 51 | self.backbone.add_module('Relu_{}'.format(i), nn.ReLU()) 52 | in_channels = out_channels 53 | def forward(self, xyz1, xyz2, points1, points2): 54 | ''' 55 | 56 | :param xyz1: shape=(B, N1, 3) 57 | :param xyz2: shape=(B, N2, 3) (N1 >= N2) 58 | :param points1: shape=(B, N1, C1) 59 | :param points2: shape=(B, N2, C2) 60 | :return: new_points2: shape = (B, N1, mlp[-1]) 61 | ''' 62 | B, N1, C1 = points1.shape 63 | _, N2, C2 = points2.shape 64 | if N2 == 1: 65 | interpolated_points = points2.repeat(1, N1, 1) 66 | else: 67 | interpolated_points = three_interpolate(xyz1, xyz2, points2) 68 | cat_interpolated_points = torch.cat([interpolated_points, points1], dim=-1).permute(0, 2, 1).contiguous() 69 | new_points = torch.squeeze(self.backbone(torch.unsqueeze(cat_interpolated_points, -1)), dim=-1) 70 | return new_points.permute(0, 2, 1).contiguous() -------------------------------------------------------------------------------- /utils/grouping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.common import gather_points, get_dists 3 | 4 | 5 | def ball_query(xyz, new_xyz, radius, K): 6 | ''' 7 | 8 | :param xyz: shape=(B, N, 3) 9 | :param new_xyz: shape=(B, M, 3) 10 | :param radius: int 11 | :param K: int, an upper limit samples 12 | :return: shape=(B, M, K) 13 | ''' 14 | device = xyz.device 15 | B, N, C = xyz.shape 16 | M = new_xyz.shape[1] 17 | grouped_inds = torch.arange(0, N, dtype=torch.long).to(device).view(1, 1, N).repeat(B, M, 1) 18 | dists = get_dists(new_xyz, xyz) 19 | grouped_inds[dists > radius] = N 20 | grouped_inds = torch.sort(grouped_inds, dim=-1)[0][:, :, :K] 21 | grouped_min_inds = grouped_inds[:, :, 0:1].repeat(1, 1, K) 22 | grouped_inds[grouped_inds == N] = grouped_min_inds[grouped_inds == N] 23 | return grouped_inds -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.common import get_dists 3 | 4 | 5 | def fps(xyz, M): 6 | ''' 7 | Sample M points from points according to farthest point sampling (FPS) algorithm. 8 | :param xyz: shape=(B, N, 3) 9 | :return: inds: shape=(B, M) 10 | ''' 11 | device = xyz.device 12 | B, N, C = xyz.shape 13 | centroids = torch.zeros(size=(B, M), dtype=torch.long).to(device) 14 | dists = torch.ones(B, N).to(device) * 1e5 15 | inds = torch.randint(0, N, size=(B, ), dtype=torch.long).to(device) 16 | batchlists = torch.arange(0, B, dtype=torch.long).to(device) 17 | for i in range(M): 18 | centroids[:, i] = inds 19 | cur_point = xyz[batchlists, inds, :] # (B, 3) 20 | cur_dist = torch.squeeze(get_dists(torch.unsqueeze(cur_point, 1), xyz), dim=1) 21 | dists[cur_dist < dists] = cur_dist[cur_dist < dists] 22 | inds = torch.max(dists, dim=1)[1] 23 | return centroids -------------------------------------------------------------------------------- /utils/set_abstraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.sampling import fps 4 | from utils.grouping import ball_query 5 | from utils.common import gather_points 6 | 7 | 8 | def sample_and_group(xyz, points, M, radius, K, use_xyz=True): 9 | ''' 10 | :param xyz: shape=(B, N, 3) 11 | :param points: shape=(B, N, C) 12 | :param M: int 13 | :param radius:float 14 | :param K: int 15 | :param use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features 16 | :return: new_xyz, shape=(B, M, 3); new_points, shape=(B, M, K, C+3); 17 | group_inds, shape=(B, M, K); grouped_xyz, shape=(B, M, K, 3) 18 | ''' 19 | new_xyz = gather_points(xyz, fps(xyz, M)) 20 | grouped_inds = ball_query(xyz, new_xyz, radius, K) 21 | grouped_xyz = gather_points(xyz, grouped_inds) 22 | grouped_xyz -= torch.unsqueeze(new_xyz, 2).repeat(1, 1, K, 1) 23 | if points is not None: 24 | grouped_points = gather_points(points, grouped_inds) 25 | if use_xyz: 26 | new_points = torch.cat((grouped_xyz.float(), grouped_points.float()), dim=-1) 27 | else: 28 | new_points = grouped_points 29 | else: 30 | new_points = grouped_xyz 31 | return new_xyz, new_points, grouped_inds, grouped_xyz 32 | 33 | 34 | def sample_and_group_all(xyz, points, use_xyz=True): 35 | ''' 36 | 37 | :param xyz: shape=(B, M, 3) 38 | :param points: shape=(B, M, C) 39 | :param use_xyz: 40 | :return: new_xyz, shape=(B, 1, 3); new_points, shape=(B, 1, M, C+3); 41 | group_inds, shape=(B, 1, M); grouped_xyz, shape=(B, 1, M, 3) 42 | ''' 43 | B, M, C = xyz.shape 44 | new_xyz = torch.zeros(B, 1, C) 45 | grouped_inds = torch.arange(0, M).long().view(1, 1, M).repeat(B, 1, 1) 46 | grouped_xyz = xyz.view(B, 1, M, C) 47 | if points is not None: 48 | if use_xyz: 49 | new_points = torch.cat([xyz.float(), points.float()], dim=2) 50 | else: 51 | new_points = points 52 | new_points = torch.unsqueeze(new_points, dim=1) 53 | else: 54 | new_points = grouped_xyz 55 | return new_xyz, new_points, grouped_inds, grouped_xyz 56 | 57 | 58 | class PointNet_SA_Module(nn.Module): 59 | def __init__(self, M, radius, K, in_channels, mlp, group_all, bn=True, pooling='max', use_xyz=True): 60 | super(PointNet_SA_Module, self).__init__() 61 | self.M = M 62 | self.radius = radius 63 | self.K = K 64 | self.in_channels = in_channels 65 | self.mlp = mlp 66 | self.group_all = group_all 67 | self.bn = bn 68 | self.pooling = pooling 69 | self.use_xyz = use_xyz 70 | self.backbone = nn.Sequential() 71 | for i, out_channels in enumerate(mlp): 72 | self.backbone.add_module('Conv{}'.format(i), 73 | nn.Conv2d(in_channels, out_channels, 1, 74 | stride=1, padding=0, bias=False)) 75 | if bn: 76 | self.backbone.add_module('Bn{}'.format(i), 77 | nn.BatchNorm2d(out_channels)) 78 | self.backbone.add_module('Relu{}'.format(i), nn.ReLU()) 79 | in_channels = out_channels 80 | def forward(self, xyz, points): 81 | if self.group_all: 82 | new_xyz, new_points, grouped_inds, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz) 83 | else: 84 | new_xyz, new_points, grouped_inds, grouped_xyz = sample_and_group(xyz=xyz, 85 | points=points, 86 | M=self.M, 87 | radius=self.radius, 88 | K=self.K, 89 | use_xyz=self.use_xyz) 90 | new_points = self.backbone(new_points.permute(0, 3, 2, 1).contiguous()) 91 | if self.pooling == 'avg': 92 | new_points = torch.mean(new_points, dim=2) 93 | else: 94 | new_points = torch.max(new_points, dim=2)[0] 95 | new_points = new_points.permute(0, 2, 1).contiguous() 96 | return new_xyz, new_points 97 | 98 | 99 | class PointNet_SA_Module_MSG(nn.Module): 100 | def __init__(self, M, radiuses, Ks, in_channels, mlps, bn=True, pooling='max', use_xyz=True): 101 | super(PointNet_SA_Module_MSG, self).__init__() 102 | self.M = M 103 | self.radiuses = radiuses 104 | self.Ks = Ks 105 | self.in_channels = in_channels 106 | self.mlps = mlps 107 | self.bn = bn 108 | self.pooling = pooling 109 | self.use_xyz = use_xyz 110 | self.backbones = nn.ModuleList() 111 | for j in range(len(mlps)): 112 | mlp = mlps[j] 113 | backbone = nn.Sequential() 114 | in_channels = self.in_channels 115 | for i, out_channels in enumerate(mlp): 116 | backbone.add_module('Conv{}_{}'.format(j, i), 117 | nn.Conv2d(in_channels, out_channels, 1, 118 | stride=1, padding=0, bias=False)) 119 | if bn: 120 | backbone.add_module('Bn{}_{}'.format(j, i), 121 | nn.BatchNorm2d(out_channels)) 122 | backbone.add_module('Relu{}_{}'.format(j, i), nn.ReLU()) 123 | in_channels = out_channels 124 | self.backbones.append(backbone) 125 | 126 | def forward(self, xyz, points): 127 | new_xyz = gather_points(xyz, fps(xyz, self.M)) 128 | new_points_all = [] 129 | for i in range(len(self.mlps)): 130 | radius = self.radiuses[i] 131 | K = self.Ks[i] 132 | grouped_inds = ball_query(xyz, new_xyz, radius, K) 133 | grouped_xyz = gather_points(xyz, grouped_inds) 134 | grouped_xyz -= torch.unsqueeze(new_xyz, 2).repeat(1, 1, K, 1) 135 | if points is not None: 136 | grouped_points = gather_points(points, grouped_inds) 137 | if self.use_xyz: 138 | new_points = torch.cat( 139 | (grouped_xyz.float(), grouped_points.float()), 140 | dim=-1) 141 | else: 142 | new_points = grouped_points 143 | else: 144 | new_points = grouped_xyz 145 | new_points = self.backbones[i](new_points.permute(0, 3, 2, 1).contiguous()) 146 | if self.pooling == 'avg': 147 | new_points = torch.mean(new_points, dim=2) 148 | else: 149 | new_points = torch.max(new_points, dim=2)[0] 150 | new_points = new_points.permute(0, 2, 1).contiguous() 151 | new_points_all.append(new_points) 152 | return new_xyz, torch.cat(new_points_all, dim=-1) 153 | 154 | 155 | if __name__ == '__main__': 156 | def setup_seed(seed): 157 | torch.manual_seed(seed) 158 | torch.cuda.manual_seed_all(seed) 159 | setup_seed(2) 160 | xyz = torch.randn(4, 1024, 3) 161 | points = torch.randn(4, 1024, 3) 162 | 163 | M, radius, K = 5, 5, 6 164 | new_xyz, new_points, grouped_inds, grouped_xyz = sample_and_group(xyz, points, M, radius, K) 165 | print(new_xyz[0]) 166 | print(new_points[0]) 167 | ''' 168 | print('='*20, 'backbone', '='*20) 169 | M, radius, K, in_channels, mlp = 2, 0.2, 3, 6, [32, 64, 128] 170 | new_xyz, new_points, grouped_inds = pointnet_sa_module(xyz, points, M, radius, K, in_channels, mlp) 171 | print('new_xyz: ', new_xyz.shape) 172 | print('new_points: ', new_points.shape) 173 | print('grouped_inds: ', grouped_inds.shape) 174 | 175 | print('='*20, 'backbone msg', '='*20) 176 | M, radius_list, K_list, in_channels, mlp_list = 2, [0.2, 0.4], [3, 4], 6, [[32, 64, 128], [64, 64]] 177 | new_xyz, new_points_cat = pointnet_sa_module_msg(xyz, points, M, radius_list, K_list, in_channels, mlp_list) 178 | print('new_xyz: ', new_xyz.shape) 179 | print(new_points_cat.shape) 180 | ''' --------------------------------------------------------------------------------