├── .gitignore ├── LICENSE ├── README.md ├── config ├── cls.yaml ├── model │ ├── Hengshuang.yaml │ ├── Menghao.yaml │ └── Nico.yaml └── partseg.yaml ├── dataset.py ├── models ├── Hengshuang │ ├── model.py │ └── transformer.py ├── Menghao │ └── model.py └── Nico │ ├── model.py │ └── transformer.py ├── pointnet_util.py ├── provider.py ├── requirements.txt ├── train_cls.py └── train_partseg.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__/ 3 | modelnet40_normal_resampled/ 4 | outputs/ 5 | log/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yang You 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Various Point Transformers 2 | 3 | Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, all three methods are implemented, while tuning their hyperparameters. 4 | 5 | 6 | ## Classification 7 | ### Data Preparation 8 | Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`. 9 | 10 | ### Run 11 | Change which method to use in `config/cls.yaml` and run 12 | ``` 13 | python train_cls.py 14 | ``` 15 | ### Results 16 | Using Adam with learning rate decay 0.3 for every 50 epochs, train for 200 epochs; data augmentation follows [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). For Hengshuang and Nico, initial LR is 1e-3 (I would appreciate if someone could fine-tune these hyper-paramters); for Menghao, initial LR is 1e-4, as suggested by the [author](https://github.com/MenghaoGuo). ModelNet40 classification results (instance average) are listed below: 17 | | Model | Accuracy | 18 | |--|--| 19 | | Hengshuang | 91.7 | 20 | | Menghao | 92.6 | 21 | | Nico | 85.5 | 22 | 23 | 24 | ## Part Segmentation 25 | ### Data Preparation 26 | Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and save in `data/shapenetcore_partanno_segmentation_benchmark_v0_normal`. 27 | 28 | ### Run 29 | Change which method to use in `config/partseg.yaml` and run 30 | ``` 31 | python train_partseg.py 32 | ``` 33 | ### Results 34 | Currently only Hengshuang's method is implemented. 35 | 36 | ### Miscellaneous 37 | Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch. 38 | Code for [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688) is adapted from the author's Jittor implementation https://github.com/MenghaoGuo/PCT. 39 | 40 | -------------------------------------------------------------------------------- /config/cls.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 1 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | 10 | defaults: 11 | - model: Menghao 12 | 13 | hydra: 14 | run: 15 | dir: log/cls/${model.name} 16 | 17 | sweep: 18 | dir: log/cls 19 | subdir: ${model.name} 20 | -------------------------------------------------------------------------------- /config/model/Hengshuang.yaml: -------------------------------------------------------------------------------- 1 | nneighbor: 16 2 | nblocks: 4 3 | transformer_dim: 512 4 | name: Hengshuang 5 | -------------------------------------------------------------------------------- /config/model/Menghao.yaml: -------------------------------------------------------------------------------- 1 | name: Menghao 2 | -------------------------------------------------------------------------------- /config/model/Nico.yaml: -------------------------------------------------------------------------------- 1 | n_head: 8 2 | m: 4 3 | k: 64 4 | global_k: 128 5 | global_dim: 512 6 | local_dim: 256 7 | reduce_dim: 64 8 | name: Nico 9 | -------------------------------------------------------------------------------- /config/partseg.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 1 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | lr_decay: 0.5 10 | step_size: 20 11 | 12 | defaults: 13 | - model: Hengshuang 14 | 15 | hydra: 16 | run: 17 | dir: log/partseg/${model.name} 18 | 19 | sweep: 20 | dir: log/partseg 21 | subdir: ${model.name} 22 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from pointnet_util import farthest_point_sample, pc_normalize 6 | import json 7 | 8 | 9 | class ModelNetDataLoader(Dataset): 10 | def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): 11 | self.root = root 12 | self.npoints = npoint 13 | self.uniform = uniform 14 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 15 | 16 | self.cat = [line.rstrip() for line in open(self.catfile)] 17 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 18 | self.normal_channel = normal_channel 19 | 20 | shape_ids = {} 21 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 22 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 23 | 24 | assert (split == 'train' or split == 'test') 25 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 26 | # list of (shape_name, shape_txt_file_path) tuple 27 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 28 | in range(len(shape_ids[split]))] 29 | print('The size of %s data is %d'%(split,len(self.datapath))) 30 | 31 | self.cache_size = cache_size # how many data points to cache in memory 32 | self.cache = {} # from index to (point_set, cls) tuple 33 | 34 | def __len__(self): 35 | return len(self.datapath) 36 | 37 | def _get_item(self, index): 38 | if index in self.cache: 39 | point_set, cls = self.cache[index] 40 | else: 41 | fn = self.datapath[index] 42 | cls = self.classes[self.datapath[index][0]] 43 | cls = np.array([cls]).astype(np.int32) 44 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 45 | if self.uniform: 46 | point_set = farthest_point_sample(point_set, self.npoints) 47 | else: 48 | point_set = point_set[0:self.npoints,:] 49 | 50 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 51 | 52 | if not self.normal_channel: 53 | point_set = point_set[:, 0:3] 54 | 55 | if len(self.cache) < self.cache_size: 56 | self.cache[index] = (point_set, cls) 57 | 58 | return point_set, cls 59 | 60 | def __getitem__(self, index): 61 | return self._get_item(index) 62 | 63 | 64 | class PartNormalDataset(Dataset): 65 | def __init__(self, root='./data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False): 66 | self.npoints = npoints 67 | self.root = root 68 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 69 | self.cat = {} 70 | self.normal_channel = normal_channel 71 | 72 | 73 | with open(self.catfile, 'r') as f: 74 | for line in f: 75 | ls = line.strip().split() 76 | self.cat[ls[0]] = ls[1] 77 | self.cat = {k: v for k, v in self.cat.items()} 78 | self.classes_original = dict(zip(self.cat, range(len(self.cat)))) 79 | 80 | if not class_choice is None: 81 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 82 | # print(self.cat) 83 | 84 | self.meta = {} 85 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 86 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 87 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 88 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 89 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 90 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 91 | for item in self.cat: 92 | # print('category', item) 93 | self.meta[item] = [] 94 | dir_point = os.path.join(self.root, self.cat[item]) 95 | fns = sorted(os.listdir(dir_point)) 96 | # print(fns[0][0:-4]) 97 | if split == 'trainval': 98 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 99 | elif split == 'train': 100 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 101 | elif split == 'val': 102 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 103 | elif split == 'test': 104 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 105 | else: 106 | print('Unknown split: %s. Exiting..' % (split)) 107 | exit(-1) 108 | 109 | # print(os.path.basename(fns)) 110 | for fn in fns: 111 | token = (os.path.splitext(os.path.basename(fn))[0]) 112 | self.meta[item].append(os.path.join(dir_point, token + '.txt')) 113 | 114 | self.datapath = [] 115 | for item in self.cat: 116 | for fn in self.meta[item]: 117 | self.datapath.append((item, fn)) 118 | 119 | self.classes = {} 120 | for i in self.cat.keys(): 121 | self.classes[i] = self.classes_original[i] 122 | 123 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels 124 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 125 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 126 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 127 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 128 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 129 | 130 | # for cat in sorted(self.seg_classes.keys()): 131 | # print(cat, self.seg_classes[cat]) 132 | 133 | self.cache = {} # from index to (point_set, cls, seg) tuple 134 | self.cache_size = 20000 135 | 136 | 137 | def __getitem__(self, index): 138 | if index in self.cache: 139 | point_set, cls, seg = self.cache[index] 140 | else: 141 | fn = self.datapath[index] 142 | cat = self.datapath[index][0] 143 | cls = self.classes[cat] 144 | cls = np.array([cls]).astype(np.int32) 145 | data = np.loadtxt(fn[1]).astype(np.float32) 146 | if not self.normal_channel: 147 | point_set = data[:, 0:3] 148 | else: 149 | point_set = data[:, 0:6] 150 | seg = data[:, -1].astype(np.int32) 151 | if len(self.cache) < self.cache_size: 152 | self.cache[index] = (point_set, cls, seg) 153 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 154 | 155 | choice = np.random.choice(len(seg), self.npoints, replace=True) 156 | # resample 157 | point_set = point_set[choice, :] 158 | seg = seg[choice] 159 | 160 | return point_set, cls, seg 161 | 162 | def __len__(self): 163 | return len(self.datapath) 164 | 165 | 166 | if __name__ == '__main__': 167 | data = ModelNetDataLoader('modelnet40_normal_resampled/', split='train', uniform=False, normal_channel=True) 168 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) 169 | for point,label in DataLoader: 170 | print(point.shape) 171 | print(label.shape) -------------------------------------------------------------------------------- /models/Hengshuang/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet_util import PointNetFeaturePropagation, PointNetSetAbstraction 4 | from .transformer import TransformerBlock 5 | 6 | 7 | class TransitionDown(nn.Module): 8 | def __init__(self, k, nneighbor, channels): 9 | super().__init__() 10 | self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) 11 | 12 | def forward(self, xyz, points): 13 | return self.sa(xyz, points) 14 | 15 | 16 | class TransitionUp(nn.Module): 17 | def __init__(self, dim1, dim2, dim_out): 18 | class SwapAxes(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def forward(self, x): 23 | return x.transpose(1, 2) 24 | 25 | super().__init__() 26 | self.fc1 = nn.Sequential( 27 | nn.Linear(dim1, dim_out), 28 | SwapAxes(), 29 | nn.BatchNorm1d(dim_out), # TODO 30 | SwapAxes(), 31 | nn.ReLU(), 32 | ) 33 | self.fc2 = nn.Sequential( 34 | nn.Linear(dim2, dim_out), 35 | SwapAxes(), 36 | nn.BatchNorm1d(dim_out), # TODO 37 | SwapAxes(), 38 | nn.ReLU(), 39 | ) 40 | self.fp = PointNetFeaturePropagation(-1, []) 41 | 42 | def forward(self, xyz1, points1, xyz2, points2): 43 | feats1 = self.fc1(points1) 44 | feats2 = self.fc2(points2) 45 | feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2) 46 | return feats1 + feats2 47 | 48 | 49 | class Backbone(nn.Module): 50 | def __init__(self, cfg): 51 | super().__init__() 52 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 53 | self.fc1 = nn.Sequential( 54 | nn.Linear(d_points, 32), 55 | nn.ReLU(), 56 | nn.Linear(32, 32) 57 | ) 58 | self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) 59 | self.transition_downs = nn.ModuleList() 60 | self.transformers = nn.ModuleList() 61 | for i in range(nblocks): 62 | channel = 32 * 2 ** (i + 1) 63 | self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) 64 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 65 | self.nblocks = nblocks 66 | 67 | def forward(self, x): 68 | xyz = x[..., :3] 69 | points = self.transformer1(xyz, self.fc1(x))[0] 70 | 71 | xyz_and_feats = [(xyz, points)] 72 | for i in range(self.nblocks): 73 | xyz, points = self.transition_downs[i](xyz, points) 74 | points = self.transformers[i](xyz, points)[0] 75 | xyz_and_feats.append((xyz, points)) 76 | return points, xyz_and_feats 77 | 78 | 79 | class PointTransformerCls(nn.Module): 80 | def __init__(self, cfg): 81 | super().__init__() 82 | self.backbone = Backbone(cfg) 83 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 84 | self.fc2 = nn.Sequential( 85 | nn.Linear(32 * 2 ** nblocks, 256), 86 | nn.ReLU(), 87 | nn.Linear(256, 64), 88 | nn.ReLU(), 89 | nn.Linear(64, n_c) 90 | ) 91 | self.nblocks = nblocks 92 | 93 | def forward(self, x): 94 | points, _ = self.backbone(x) 95 | res = self.fc2(points.mean(1)) 96 | return res 97 | 98 | 99 | class PointTransformerSeg(nn.Module): 100 | def __init__(self, cfg): 101 | super().__init__() 102 | self.backbone = Backbone(cfg) 103 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 104 | self.fc2 = nn.Sequential( 105 | nn.Linear(32 * 2 ** nblocks, 512), 106 | nn.ReLU(), 107 | nn.Linear(512, 512), 108 | nn.ReLU(), 109 | nn.Linear(512, 32 * 2 ** nblocks) 110 | ) 111 | self.transformer2 = TransformerBlock(32 * 2 ** nblocks, cfg.model.transformer_dim, nneighbor) 112 | self.nblocks = nblocks 113 | self.transition_ups = nn.ModuleList() 114 | self.transformers = nn.ModuleList() 115 | for i in reversed(range(nblocks)): 116 | channel = 32 * 2 ** i 117 | self.transition_ups.append(TransitionUp(channel * 2, channel, channel)) 118 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 119 | 120 | self.fc3 = nn.Sequential( 121 | nn.Linear(32, 64), 122 | nn.ReLU(), 123 | nn.Linear(64, 64), 124 | nn.ReLU(), 125 | nn.Linear(64, n_c) 126 | ) 127 | 128 | def forward(self, x): 129 | points, xyz_and_feats = self.backbone(x) 130 | xyz = xyz_and_feats[-1][0] 131 | points = self.transformer2(xyz, self.fc2(points))[0] 132 | 133 | for i in range(self.nblocks): 134 | points = self.transition_ups[i](xyz, points, xyz_and_feats[- i - 2][0], xyz_and_feats[- i - 2][1]) 135 | xyz = xyz_and_feats[- i - 2][0] 136 | points = self.transformers[i](xyz, points)[0] 137 | 138 | return self.fc3(points) 139 | 140 | 141 | -------------------------------------------------------------------------------- /models/Hengshuang/transformer.py: -------------------------------------------------------------------------------- 1 | from pointnet_util import index_points, square_distance 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class TransformerBlock(nn.Module): 8 | def __init__(self, d_points, d_model, k) -> None: 9 | super().__init__() 10 | self.fc1 = nn.Linear(d_points, d_model) 11 | self.fc2 = nn.Linear(d_model, d_points) 12 | self.fc_delta = nn.Sequential( 13 | nn.Linear(3, d_model), 14 | nn.ReLU(), 15 | nn.Linear(d_model, d_model) 16 | ) 17 | self.fc_gamma = nn.Sequential( 18 | nn.Linear(d_model, d_model), 19 | nn.ReLU(), 20 | nn.Linear(d_model, d_model) 21 | ) 22 | self.w_qs = nn.Linear(d_model, d_model, bias=False) 23 | self.w_ks = nn.Linear(d_model, d_model, bias=False) 24 | self.w_vs = nn.Linear(d_model, d_model, bias=False) 25 | self.k = k 26 | 27 | # xyz: b x n x 3, features: b x n x f 28 | def forward(self, xyz, features): 29 | dists = square_distance(xyz, xyz) 30 | knn_idx = dists.argsort()[:, :, :self.k] # b x n x k 31 | knn_xyz = index_points(xyz, knn_idx) 32 | 33 | pre = features 34 | x = self.fc1(features) 35 | q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) 36 | 37 | pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f 38 | 39 | attn = self.fc_gamma(q[:, :, None] - k + pos_enc) 40 | attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f 41 | 42 | res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) 43 | res = self.fc2(res) + pre 44 | return res, attn 45 | -------------------------------------------------------------------------------- /models/Menghao/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet_util import farthest_point_sample, index_points, square_distance 4 | 5 | 6 | def sample_and_group(npoint, nsample, xyz, points): 7 | B, N, C = xyz.shape 8 | S = npoint 9 | 10 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] 11 | 12 | new_xyz = index_points(xyz, fps_idx) 13 | new_points = index_points(points, fps_idx) 14 | 15 | dists = square_distance(new_xyz, xyz) # B x npoint x N 16 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K 17 | 18 | grouped_points = index_points(points, idx) 19 | grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) 20 | new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) 21 | return new_xyz, new_points 22 | 23 | 24 | class Local_op(nn.Module): 25 | def __init__(self, in_channels, out_channels): 26 | super().__init__() 27 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 28 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) 29 | self.bn1 = nn.BatchNorm1d(out_channels) 30 | self.bn2 = nn.BatchNorm1d(out_channels) 31 | self.relu = nn.ReLU() 32 | 33 | def forward(self, x): 34 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) 35 | x = x.permute(0, 1, 3, 2) 36 | x = x.reshape(-1, d, s) 37 | batch_size, _, N = x.size() 38 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 39 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N 40 | x = torch.max(x, 2)[0] 41 | x = x.view(batch_size, -1) 42 | x = x.reshape(b, n, -1).permute(0, 2, 1) 43 | return x 44 | 45 | 46 | class SA_Layer(nn.Module): 47 | def __init__(self, channels): 48 | super().__init__() 49 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 50 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 51 | self.q_conv.weight = self.k_conv.weight 52 | self.v_conv = nn.Conv1d(channels, channels, 1) 53 | self.trans_conv = nn.Conv1d(channels, channels, 1) 54 | self.after_norm = nn.BatchNorm1d(channels) 55 | self.act = nn.ReLU() 56 | self.softmax = nn.Softmax(dim=-1) 57 | 58 | def forward(self, x): 59 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 60 | x_k = self.k_conv(x)# b, c, n 61 | x_v = self.v_conv(x) 62 | energy = x_q @ x_k # b, n, n 63 | attention = self.softmax(energy) 64 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) 65 | x_r = x_v @ attention # b, c, n 66 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 67 | x = x + x_r 68 | return x 69 | 70 | 71 | class StackedAttention(nn.Module): 72 | def __init__(self, channels=256): 73 | super().__init__() 74 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) 75 | self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) 76 | 77 | self.bn1 = nn.BatchNorm1d(channels) 78 | self.bn2 = nn.BatchNorm1d(channels) 79 | 80 | self.sa1 = SA_Layer(channels) 81 | self.sa2 = SA_Layer(channels) 82 | self.sa3 = SA_Layer(channels) 83 | self.sa4 = SA_Layer(channels) 84 | 85 | self.relu = nn.ReLU() 86 | 87 | def forward(self, x): 88 | # 89 | # b, 3, npoint, nsample 90 | # conv2d 3 -> 128 channels 1, 1 91 | # b * npoint, c, nsample 92 | # permute reshape 93 | batch_size, _, N = x.size() 94 | 95 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 96 | x = self.relu(self.bn2(self.conv2(x))) 97 | 98 | x1 = self.sa1(x) 99 | x2 = self.sa2(x1) 100 | x3 = self.sa3(x2) 101 | x4 = self.sa4(x3) 102 | 103 | x = torch.cat((x1, x2, x3, x4), dim=1) 104 | 105 | return x 106 | 107 | 108 | class PointTransformerCls(nn.Module): 109 | def __init__(self, cfg): 110 | super().__init__() 111 | output_channels = cfg.num_class 112 | d_points = cfg.input_dim 113 | self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False) 114 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 115 | self.bn1 = nn.BatchNorm1d(64) 116 | self.bn2 = nn.BatchNorm1d(64) 117 | self.gather_local_0 = Local_op(in_channels=128, out_channels=128) 118 | self.gather_local_1 = Local_op(in_channels=256, out_channels=256) 119 | self.pt_last = StackedAttention() 120 | 121 | self.relu = nn.ReLU() 122 | self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), 123 | nn.BatchNorm1d(1024), 124 | nn.LeakyReLU(negative_slope=0.2)) 125 | 126 | self.linear1 = nn.Linear(1024, 512, bias=False) 127 | self.bn6 = nn.BatchNorm1d(512) 128 | self.dp1 = nn.Dropout(p=0.5) 129 | self.linear2 = nn.Linear(512, 256) 130 | self.bn7 = nn.BatchNorm1d(256) 131 | self.dp2 = nn.Dropout(p=0.5) 132 | self.linear3 = nn.Linear(256, output_channels) 133 | 134 | def forward(self, x): 135 | xyz = x[..., :3] 136 | x = x.permute(0, 2, 1) 137 | batch_size, _, _ = x.size() 138 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 139 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N 140 | x = x.permute(0, 2, 1) 141 | new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x) 142 | feature_0 = self.gather_local_0(new_feature) 143 | feature = feature_0.permute(0, 2, 1) 144 | new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 145 | feature_1 = self.gather_local_1(new_feature) 146 | 147 | x = self.pt_last(feature_1) 148 | x = torch.cat([x, feature_1], dim=1) 149 | x = self.conv_fuse(x) 150 | x = torch.max(x, 2)[0] 151 | x = x.view(batch_size, -1) 152 | 153 | x = self.relu(self.bn6(self.linear1(x))) 154 | x = self.dp1(x) 155 | x = self.relu(self.bn7(self.linear2(x))) 156 | x = self.dp2(x) 157 | x = self.linear3(x) 158 | 159 | return x -------------------------------------------------------------------------------- /models/Nico/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet_util import PointNetSetAbstractionMsg 4 | from .transformer import MultiHeadAttention 5 | 6 | 7 | class SortNet(nn.Module): 8 | def __init__(self, d_model, d_points=6, k=64): 9 | super().__init__() 10 | self.fc = nn.Sequential( 11 | nn.Linear(d_model, 256), 12 | nn.ReLU(), 13 | nn.Linear(256, 64), 14 | nn.ReLU(), 15 | nn.Linear(64, 1) 16 | ) 17 | self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 18 | self.fc_agg = nn.Sequential( 19 | nn.Linear(64 + 128 + 128, 256), 20 | nn.ReLU(), 21 | nn.Linear(256, 256), 22 | nn.ReLU(), 23 | nn.Linear(256, d_model - 1 - d_points), 24 | ) 25 | self.k = k 26 | self.d_points = d_points 27 | 28 | def forward(self, points, features): 29 | score = self.fc(features) 30 | topk_idx = torch.topk(score[..., 0], self.k, 1)[1] 31 | features_abs = self.sa(points[..., :3], features, topk_idx)[1] 32 | res = torch.cat((self.fc_agg(features_abs), 33 | torch.gather(score, 1, topk_idx[..., None].expand(-1, -1, score.size(-1))), 34 | torch.gather(points, 1, topk_idx[..., None].expand(-1, -1, points.size(-1)))), -1) 35 | return res 36 | 37 | 38 | class LocalFeatureGeneration(nn.Module): 39 | def __init__(self, d_model, m, k, d_points=6, n_head=4): 40 | super().__init__() 41 | self.fc = nn.Sequential( 42 | nn.Linear(d_points, 64), 43 | nn.ReLU(), 44 | nn.Linear(64, 256), 45 | nn.ReLU(), 46 | nn.Linear(256, d_model) 47 | ) 48 | self.sortnets = nn.ModuleList([SortNet(d_model, k=k) for _ in range(m)]) 49 | self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head) 50 | 51 | def forward(self, points): 52 | x = self.fc(points) 53 | x, _ = self.att(x, x, x) 54 | out = torch.cat([sortnet(points, x) for sortnet in self.sortnets], 1) 55 | return out, x 56 | 57 | 58 | class GlobalFeatureGeneration(nn.Module): 59 | def __init__(self, d_model, k, d_points=6, n_head=4): 60 | super().__init__() 61 | self.fc = nn.Sequential( 62 | nn.Linear(d_points, 64), 63 | nn.ReLU(), 64 | nn.Linear(64, 256), 65 | nn.ReLU(), 66 | nn.Linear(256, d_model) 67 | ) 68 | self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 69 | self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head) 70 | self.fc_agg = nn.Sequential( 71 | nn.Linear(64 + 128 + 128, 256), 72 | nn.ReLU(), 73 | nn.Linear(256, 256), 74 | nn.ReLU(), 75 | nn.Linear(256, d_model), 76 | ) 77 | 78 | def forward(self, points): 79 | x = self.fc(points) 80 | x, _ = self.att(x, x, x) 81 | out = self.fc_agg(self.sa(points[..., :3], x)[1]) 82 | return out, x 83 | 84 | 85 | class PointTransformerCls(nn.Module): 86 | def __init__(self, cfg): 87 | super().__init__() 88 | d_model_l, d_model_g, d_reduce, m, k, n_c, d_points, n_head \ 89 | = cfg.model.global_dim, cfg.model.local_dim, cfg.model.reduce_dim, cfg.model.m, cfg.model.k, cfg.num_class, cfg.input_dim, cfg.model.n_head 90 | self.lfg = LocalFeatureGeneration(d_model=d_model_l, m=m, k=k, d_points=d_points) 91 | self.gfg = GlobalFeatureGeneration(d_model=d_model_g, k=cfg.model.global_k, d_points=d_points) 92 | self.lg_att = MultiHeadAttention(n_head, d_model_l, d_model_g, d_model_l // n_head, d_model_l // n_head) 93 | self.fc = nn.Sequential( 94 | nn.Linear(d_model_l, 256), 95 | nn.ReLU(), 96 | nn.Linear(256, 256), 97 | nn.ReLU(), 98 | nn.Linear(256, d_reduce), 99 | ) 100 | self.fc_cls = nn.Sequential( 101 | nn.Linear(k * m * d_reduce, 1024), 102 | nn.ReLU(), 103 | nn.Linear(1024, 256), 104 | nn.ReLU(), 105 | nn.Linear(256, 64), 106 | nn.ReLU(), 107 | nn.Linear(64, n_c) 108 | ) 109 | 110 | def forward(self, points): 111 | local_features = self.lfg(points)[0] 112 | global_features = self.gfg(points)[0] 113 | lg_features = self.lg_att(local_features, global_features, global_features)[0] 114 | x = self.fc(lg_features).reshape(points.size(0), -1) 115 | out = self.fc_cls(x) 116 | return out 117 | 118 | -------------------------------------------------------------------------------- /models/Nico/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | # reference https://github.com/jadore801120/attention-is-all-you-need-pytorch 8 | 9 | class Attention(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, q, k, v): 14 | attn = q @ k.transpose(-1, -2) 15 | attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-1) 16 | output = attn @ v 17 | 18 | return output, attn 19 | 20 | 21 | class MultiHeadAttention(nn.Module): 22 | ''' Multi-Head Attention module ''' 23 | 24 | def __init__(self, n_head, d_model_q, d_model_kv, d_k, d_v): 25 | super().__init__() 26 | 27 | self.n_head = n_head 28 | self.d_k = d_k 29 | self.d_v = d_v 30 | 31 | self.w_qs = nn.Linear(d_model_q, n_head * d_k, bias=False) 32 | self.w_ks = nn.Linear(d_model_kv, n_head * d_k, bias=False) 33 | self.w_vs = nn.Linear(d_model_kv, n_head * d_v, bias=False) 34 | self.fc = nn.Linear(n_head * d_v, d_model_q, bias=False) 35 | 36 | self.attention = Attention() 37 | 38 | self.layer_norm1 = nn.LayerNorm(n_head * d_v, eps=1e-6) 39 | self.layer_norm2 = nn.LayerNorm(d_model_q, eps=1e-6) 40 | 41 | 42 | def forward(self, q, k, v): 43 | 44 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 45 | b_size, n_q, n_k = q.size(0), q.size(1), k.size(1) 46 | 47 | residual = q 48 | 49 | # Pass through the pre-attention projection: b x k x (n*dv) 50 | # Separate different heads: b x k x n x dv 51 | q = self.w_qs(q).view(-1, n_q, n_head, d_k) 52 | k = self.w_ks(k).view(-1, n_k, n_head, d_k) 53 | v = self.w_vs(v).view(-1, n_k, n_head, d_v) 54 | 55 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 56 | 57 | # get b x n x k x dv 58 | q, attn = self.attention(q, k, v) 59 | 60 | # b x k x ndv 61 | q = q.transpose(1, 2).contiguous().view(b_size, n_q, -1) 62 | s = self.layer_norm1(residual + q) 63 | res = self.layer_norm2(s + self.fc(s)) 64 | 65 | return res, attn -------------------------------------------------------------------------------- /pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | # reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You 9 | 10 | 11 | def timeit(tag, t): 12 | print("{}: {}s".format(tag, time() - t)) 13 | return time() 14 | 15 | def pc_normalize(pc): 16 | centroid = np.mean(pc, axis=0) 17 | pc = pc - centroid 18 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 19 | pc = pc / m 20 | return pc 21 | 22 | def square_distance(src, dst): 23 | """ 24 | Calculate Euclid distance between each two points. 25 | src^T * dst = xn * xm + yn * ym + zn * zm; 26 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 27 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 28 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 29 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 30 | Input: 31 | src: source points, [B, N, C] 32 | dst: target points, [B, M, C] 33 | Output: 34 | dist: per-point square distance, [B, N, M] 35 | """ 36 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) 37 | 38 | 39 | def index_points(points, idx): 40 | """ 41 | Input: 42 | points: input points data, [B, N, C] 43 | idx: sample index data, [B, S, [K]] 44 | Return: 45 | new_points:, indexed points data, [B, S, [K], C] 46 | """ 47 | raw_size = idx.size() 48 | idx = idx.reshape(raw_size[0], -1) 49 | res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) 50 | return res.reshape(*raw_size, -1) 51 | 52 | 53 | def farthest_point_sample(xyz, npoint): 54 | """ 55 | Input: 56 | xyz: pointcloud data, [B, N, 3] 57 | npoint: number of samples 58 | Return: 59 | centroids: sampled pointcloud index, [B, npoint] 60 | """ 61 | device = xyz.device 62 | B, N, C = xyz.shape 63 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 64 | distance = torch.ones(B, N).to(device) * 1e10 65 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 66 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 67 | for i in range(npoint): 68 | centroids[:, i] = farthest 69 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 70 | dist = torch.sum((xyz - centroid) ** 2, -1) 71 | distance = torch.min(distance, dist) 72 | farthest = torch.max(distance, -1)[1] 73 | return centroids 74 | 75 | 76 | def query_ball_point(radius, nsample, xyz, new_xyz): 77 | """ 78 | Input: 79 | radius: local region radius 80 | nsample: max sample number in local region 81 | xyz: all points, [B, N, 3] 82 | new_xyz: query points, [B, S, 3] 83 | Return: 84 | group_idx: grouped points index, [B, S, nsample] 85 | """ 86 | device = xyz.device 87 | B, N, C = xyz.shape 88 | _, S, _ = new_xyz.shape 89 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 90 | sqrdists = square_distance(new_xyz, xyz) 91 | group_idx[sqrdists > radius ** 2] = N 92 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 93 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 94 | mask = group_idx == N 95 | group_idx[mask] = group_first[mask] 96 | return group_idx 97 | 98 | 99 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False): 100 | """ 101 | Input: 102 | npoint: 103 | radius: 104 | nsample: 105 | xyz: input points position data, [B, N, 3] 106 | points: input points data, [B, N, D] 107 | Return: 108 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 109 | new_points: sampled points data, [B, npoint, nsample, 3+D] 110 | """ 111 | B, N, C = xyz.shape 112 | S = npoint 113 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] 114 | torch.cuda.empty_cache() 115 | new_xyz = index_points(xyz, fps_idx) 116 | torch.cuda.empty_cache() 117 | if knn: 118 | dists = square_distance(new_xyz, xyz) # B x npoint x N 119 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K 120 | else: 121 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 122 | torch.cuda.empty_cache() 123 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 124 | torch.cuda.empty_cache() 125 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 126 | torch.cuda.empty_cache() 127 | 128 | if points is not None: 129 | grouped_points = index_points(points, idx) 130 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 131 | else: 132 | new_points = grouped_xyz_norm 133 | if returnfps: 134 | return new_xyz, new_points, grouped_xyz, fps_idx 135 | else: 136 | return new_xyz, new_points 137 | 138 | 139 | def sample_and_group_all(xyz, points): 140 | """ 141 | Input: 142 | xyz: input points position data, [B, N, 3] 143 | points: input points data, [B, N, D] 144 | Return: 145 | new_xyz: sampled points position data, [B, 1, 3] 146 | new_points: sampled points data, [B, 1, N, 3+D] 147 | """ 148 | device = xyz.device 149 | B, N, C = xyz.shape 150 | new_xyz = torch.zeros(B, 1, C).to(device) 151 | grouped_xyz = xyz.view(B, 1, N, C) 152 | if points is not None: 153 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 154 | else: 155 | new_points = grouped_xyz 156 | return new_xyz, new_points 157 | 158 | 159 | class PointNetSetAbstraction(nn.Module): 160 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False): 161 | super(PointNetSetAbstraction, self).__init__() 162 | self.npoint = npoint 163 | self.radius = radius 164 | self.nsample = nsample 165 | self.knn = knn 166 | self.mlp_convs = nn.ModuleList() 167 | self.mlp_bns = nn.ModuleList() 168 | last_channel = in_channel 169 | for out_channel in mlp: 170 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 171 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 172 | last_channel = out_channel 173 | self.group_all = group_all 174 | 175 | def forward(self, xyz, points): 176 | """ 177 | Input: 178 | xyz: input points position data, [B, N, C] 179 | points: input points data, [B, N, C] 180 | Return: 181 | new_xyz: sampled points position data, [B, S, C] 182 | new_points_concat: sample points feature data, [B, S, D'] 183 | """ 184 | if self.group_all: 185 | new_xyz, new_points = sample_and_group_all(xyz, points) 186 | else: 187 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn) 188 | # new_xyz: sampled points position data, [B, npoint, C] 189 | # new_points: sampled points data, [B, npoint, nsample, C+D] 190 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 191 | for i, conv in enumerate(self.mlp_convs): 192 | bn = self.mlp_bns[i] 193 | new_points = F.relu(bn(conv(new_points))) 194 | 195 | new_points = torch.max(new_points, 2)[0].transpose(1, 2) 196 | return new_xyz, new_points 197 | 198 | 199 | class PointNetSetAbstractionMsg(nn.Module): 200 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False): 201 | super(PointNetSetAbstractionMsg, self).__init__() 202 | self.npoint = npoint 203 | self.radius_list = radius_list 204 | self.nsample_list = nsample_list 205 | self.knn = knn 206 | self.conv_blocks = nn.ModuleList() 207 | self.bn_blocks = nn.ModuleList() 208 | for i in range(len(mlp_list)): 209 | convs = nn.ModuleList() 210 | bns = nn.ModuleList() 211 | last_channel = in_channel + 3 212 | for out_channel in mlp_list[i]: 213 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 214 | bns.append(nn.BatchNorm2d(out_channel)) 215 | last_channel = out_channel 216 | self.conv_blocks.append(convs) 217 | self.bn_blocks.append(bns) 218 | 219 | def forward(self, xyz, points, seed_idx=None): 220 | """ 221 | Input: 222 | xyz: input points position data, [B, C, N] 223 | points: input points data, [B, D, N] 224 | Return: 225 | new_xyz: sampled points position data, [B, C, S] 226 | new_points_concat: sample points feature data, [B, D', S] 227 | """ 228 | 229 | B, N, C = xyz.shape 230 | S = self.npoint 231 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx) 232 | new_points_list = [] 233 | for i, radius in enumerate(self.radius_list): 234 | K = self.nsample_list[i] 235 | if self.knn: 236 | dists = square_distance(new_xyz, xyz) # B x npoint x N 237 | group_idx = dists.argsort()[:, :, :K] # B x npoint x K 238 | else: 239 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 240 | grouped_xyz = index_points(xyz, group_idx) 241 | grouped_xyz -= new_xyz.view(B, S, 1, C) 242 | if points is not None: 243 | grouped_points = index_points(points, group_idx) 244 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 245 | else: 246 | grouped_points = grouped_xyz 247 | 248 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 249 | for j in range(len(self.conv_blocks[i])): 250 | conv = self.conv_blocks[i][j] 251 | bn = self.bn_blocks[i][j] 252 | grouped_points = F.relu(bn(conv(grouped_points))) 253 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 254 | new_points_list.append(new_points) 255 | 256 | new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2) 257 | return new_xyz, new_points_concat 258 | 259 | 260 | # NoteL this function swaps N and C 261 | class PointNetFeaturePropagation(nn.Module): 262 | def __init__(self, in_channel, mlp): 263 | super(PointNetFeaturePropagation, self).__init__() 264 | self.mlp_convs = nn.ModuleList() 265 | self.mlp_bns = nn.ModuleList() 266 | last_channel = in_channel 267 | for out_channel in mlp: 268 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 269 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 270 | last_channel = out_channel 271 | 272 | def forward(self, xyz1, xyz2, points1, points2): 273 | """ 274 | Input: 275 | xyz1: input points position data, [B, C, N] 276 | xyz2: sampled input points position data, [B, C, S] 277 | points1: input points data, [B, D, N] 278 | points2: input points data, [B, D, S] 279 | Return: 280 | new_points: upsampled points data, [B, D', N] 281 | """ 282 | xyz1 = xyz1.permute(0, 2, 1) 283 | xyz2 = xyz2.permute(0, 2, 1) 284 | 285 | points2 = points2.permute(0, 2, 1) 286 | B, N, C = xyz1.shape 287 | _, S, _ = xyz2.shape 288 | 289 | if S == 1: 290 | interpolated_points = points2.repeat(1, N, 1) 291 | else: 292 | dists = square_distance(xyz1, xyz2) 293 | dists, idx = dists.sort(dim=-1) 294 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 295 | 296 | dist_recip = 1.0 / (dists + 1e-8) 297 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 298 | weight = dist_recip / norm 299 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 300 | 301 | if points1 is not None: 302 | points1 = points1.permute(0, 2, 1) 303 | new_points = torch.cat([points1, interpolated_points], dim=-1) 304 | else: 305 | new_points = interpolated_points 306 | 307 | new_points = new_points.permute(0, 2, 1) 308 | for i, conv in enumerate(self.mlp_convs): 309 | bn = self.mlp_bns[i] 310 | new_points = F.relu(bn(conv(new_points))) 311 | return new_points -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc 249 | 250 | 251 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tqdm 4 | hydra-core==1.2 5 | omegaconf 6 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | from dataset import ModelNetDataLoader 6 | import argparse 7 | import numpy as np 8 | import os 9 | import torch 10 | import datetime 11 | import logging 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | import sys 15 | import provider 16 | import importlib 17 | import shutil 18 | import hydra 19 | import omegaconf 20 | 21 | 22 | def test(model, loader, num_class=40): 23 | mean_correct = [] 24 | class_acc = np.zeros((num_class,3)) 25 | for j, data in tqdm(enumerate(loader), total=len(loader)): 26 | points, target = data 27 | target = target[:, 0] 28 | points, target = points.cuda(), target.cuda() 29 | classifier = model.eval() 30 | pred = classifier(points) 31 | pred_choice = pred.data.max(1)[1] 32 | for cat in np.unique(target.cpu()): 33 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 34 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 35 | class_acc[cat,1]+=1 36 | correct = pred_choice.eq(target.long().data).cpu().sum() 37 | mean_correct.append(correct.item()/float(points.size()[0])) 38 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 39 | class_acc = np.mean(class_acc[:,2]) 40 | instance_acc = np.mean(mean_correct) 41 | return instance_acc, class_acc 42 | 43 | 44 | @hydra.main(config_path='config', config_name='cls') 45 | def main(args): 46 | omegaconf.OmegaConf.set_struct(args, False) 47 | 48 | '''HYPER PARAMETER''' 49 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 50 | logger = logging.getLogger(__name__) 51 | 52 | print(args.pretty()) 53 | 54 | '''DATA LOADING''' 55 | logger.info('Load dataset ...') 56 | DATA_PATH = hydra.utils.to_absolute_path('modelnet40_normal_resampled/') 57 | 58 | TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', normal_channel=args.normal) 59 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) 60 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4) 61 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) 62 | 63 | '''MODEL LOADING''' 64 | args.num_class = 40 65 | args.input_dim = 6 if args.normal else 3 66 | shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') 67 | 68 | classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerCls')(args).cuda() 69 | criterion = torch.nn.CrossEntropyLoss() 70 | 71 | try: 72 | checkpoint = torch.load('best_model.pth') 73 | start_epoch = checkpoint['epoch'] 74 | classifier.load_state_dict(checkpoint['model_state_dict']) 75 | logger.info('Use pretrain model') 76 | except: 77 | logger.info('No existing model, starting training from scratch...') 78 | start_epoch = 0 79 | 80 | 81 | if args.optimizer == 'Adam': 82 | optimizer = torch.optim.Adam( 83 | classifier.parameters(), 84 | lr=args.learning_rate, 85 | betas=(0.9, 0.999), 86 | eps=1e-08, 87 | weight_decay=args.weight_decay 88 | ) 89 | else: 90 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 91 | 92 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3) 93 | global_epoch = 0 94 | global_step = 0 95 | best_instance_acc = 0.0 96 | best_class_acc = 0.0 97 | best_epoch = 0 98 | mean_correct = [] 99 | 100 | '''TRANING''' 101 | logger.info('Start training...') 102 | for epoch in range(start_epoch,args.epoch): 103 | logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 104 | 105 | classifier.train() 106 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 107 | points, target = data 108 | points = points.data.numpy() 109 | points = provider.random_point_dropout(points) 110 | points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) 111 | points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) 112 | points = torch.Tensor(points) 113 | target = target[:, 0] 114 | 115 | points, target = points.cuda(), target.cuda() 116 | optimizer.zero_grad() 117 | 118 | pred = classifier(points) 119 | loss = criterion(pred, target.long()) 120 | pred_choice = pred.data.max(1)[1] 121 | correct = pred_choice.eq(target.long().data).cpu().sum() 122 | mean_correct.append(correct.item() / float(points.size()[0])) 123 | loss.backward() 124 | optimizer.step() 125 | global_step += 1 126 | 127 | scheduler.step() 128 | 129 | train_instance_acc = np.mean(mean_correct) 130 | logger.info('Train Instance Accuracy: %f' % train_instance_acc) 131 | 132 | 133 | with torch.no_grad(): 134 | instance_acc, class_acc = test(classifier.eval(), testDataLoader) 135 | 136 | if (instance_acc >= best_instance_acc): 137 | best_instance_acc = instance_acc 138 | best_epoch = epoch + 1 139 | 140 | if (class_acc >= best_class_acc): 141 | best_class_acc = class_acc 142 | logger.info('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) 143 | logger.info('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) 144 | 145 | if (instance_acc >= best_instance_acc): 146 | logger.info('Save model...') 147 | savepath = 'best_model.pth' 148 | logger.info('Saving at %s'% savepath) 149 | state = { 150 | 'epoch': best_epoch, 151 | 'instance_acc': instance_acc, 152 | 'class_acc': class_acc, 153 | 'model_state_dict': classifier.state_dict(), 154 | 'optimizer_state_dict': optimizer.state_dict(), 155 | } 156 | torch.save(state, savepath) 157 | global_epoch += 1 158 | 159 | logger.info('End of training...') 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /train_partseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | import argparse 6 | import os 7 | import torch 8 | import datetime 9 | import logging 10 | import sys 11 | import importlib 12 | import shutil 13 | import provider 14 | import numpy as np 15 | 16 | from pathlib import Path 17 | from tqdm import tqdm 18 | from dataset import PartNormalDataset 19 | import hydra 20 | import omegaconf 21 | 22 | 23 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 24 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 25 | 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 26 | 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 27 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 28 | for cat in seg_classes.keys(): 29 | for label in seg_classes[cat]: 30 | seg_label_to_cat[label] = cat 31 | 32 | 33 | def inplace_relu(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('ReLU') != -1: 36 | m.inplace=True 37 | 38 | def to_categorical(y, num_classes): 39 | """ 1-hot encodes a tensor """ 40 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 41 | if (y.is_cuda): 42 | return new_y.cuda() 43 | return new_y 44 | 45 | @hydra.main(config_path='config', config_name='partseg') 46 | def main(args): 47 | omegaconf.OmegaConf.set_struct(args, False) 48 | 49 | '''HYPER PARAMETER''' 50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 51 | logger = logging.getLogger(__name__) 52 | 53 | root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') 54 | 55 | TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) 56 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) 57 | TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal) 58 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) 59 | 60 | '''MODEL LOADING''' 61 | args.input_dim = (6 if args.normal else 3) + 16 62 | args.num_class = 50 63 | num_category = 16 64 | num_part = args.num_class 65 | shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') 66 | 67 | classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda() 68 | criterion = torch.nn.CrossEntropyLoss() 69 | 70 | try: 71 | checkpoint = torch.load('best_model.pth') 72 | start_epoch = checkpoint['epoch'] 73 | classifier.load_state_dict(checkpoint['model_state_dict']) 74 | logger.info('Use pretrain model') 75 | except: 76 | logger.info('No existing model, starting training from scratch...') 77 | start_epoch = 0 78 | 79 | if args.optimizer == 'Adam': 80 | optimizer = torch.optim.Adam( 81 | classifier.parameters(), 82 | lr=args.learning_rate, 83 | betas=(0.9, 0.999), 84 | eps=1e-08, 85 | weight_decay=args.weight_decay 86 | ) 87 | else: 88 | optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9) 89 | 90 | def bn_momentum_adjust(m, momentum): 91 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 92 | m.momentum = momentum 93 | 94 | LEARNING_RATE_CLIP = 1e-5 95 | MOMENTUM_ORIGINAL = 0.1 96 | MOMENTUM_DECCAY = 0.5 97 | MOMENTUM_DECCAY_STEP = args.step_size 98 | 99 | best_acc = 0 100 | global_epoch = 0 101 | best_class_avg_iou = 0 102 | best_inctance_avg_iou = 0 103 | 104 | for epoch in range(start_epoch, args.epoch): 105 | mean_correct = [] 106 | 107 | logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 108 | '''Adjust learning rate and BN momentum''' 109 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) 110 | logger.info('Learning rate:%f' % lr) 111 | for param_group in optimizer.param_groups: 112 | param_group['lr'] = lr 113 | momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 114 | if momentum < 0.01: 115 | momentum = 0.01 116 | print('BN momentum updated to: %f' % momentum) 117 | classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum)) 118 | classifier = classifier.train() 119 | 120 | '''learning one epoch''' 121 | for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 122 | points = points.data.numpy() 123 | points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) 124 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) 125 | points = torch.Tensor(points) 126 | 127 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() 128 | optimizer.zero_grad() 129 | 130 | seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1)) 131 | seg_pred = seg_pred.contiguous().view(-1, num_part) 132 | target = target.view(-1, 1)[:, 0] 133 | pred_choice = seg_pred.data.max(1)[1] 134 | 135 | correct = pred_choice.eq(target.data).cpu().sum() 136 | mean_correct.append(correct.item() / (args.batch_size * args.num_point)) 137 | loss = criterion(seg_pred, target) 138 | loss.backward() 139 | optimizer.step() 140 | 141 | train_instance_acc = np.mean(mean_correct) 142 | logger.info('Train accuracy is: %.5f' % train_instance_acc) 143 | 144 | with torch.no_grad(): 145 | test_metrics = {} 146 | total_correct = 0 147 | total_seen = 0 148 | total_seen_class = [0 for _ in range(num_part)] 149 | total_correct_class = [0 for _ in range(num_part)] 150 | shape_ious = {cat: [] for cat in seg_classes.keys()} 151 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 152 | 153 | for cat in seg_classes.keys(): 154 | for label in seg_classes[cat]: 155 | seg_label_to_cat[label] = cat 156 | 157 | classifier = classifier.eval() 158 | 159 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 160 | cur_batch_size, NUM_POINT, _ = points.size() 161 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() 162 | seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1)) 163 | cur_pred_val = seg_pred.cpu().data.numpy() 164 | cur_pred_val_logits = cur_pred_val 165 | cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) 166 | target = target.cpu().data.numpy() 167 | 168 | for i in range(cur_batch_size): 169 | cat = seg_label_to_cat[target[i, 0]] 170 | logits = cur_pred_val_logits[i, :, :] 171 | cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 172 | 173 | correct = np.sum(cur_pred_val == target) 174 | total_correct += correct 175 | total_seen += (cur_batch_size * NUM_POINT) 176 | 177 | for l in range(num_part): 178 | total_seen_class[l] += np.sum(target == l) 179 | total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) 180 | 181 | for i in range(cur_batch_size): 182 | segp = cur_pred_val[i, :] 183 | segl = target[i, :] 184 | cat = seg_label_to_cat[segl[0]] 185 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 186 | for l in seg_classes[cat]: 187 | if (np.sum(segl == l) == 0) and ( 188 | np.sum(segp == l) == 0): # part is not present, no prediction as well 189 | part_ious[l - seg_classes[cat][0]] = 1.0 190 | else: 191 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 192 | np.sum((segl == l) | (segp == l))) 193 | shape_ious[cat].append(np.mean(part_ious)) 194 | 195 | all_shape_ious = [] 196 | for cat in shape_ious.keys(): 197 | for iou in shape_ious[cat]: 198 | all_shape_ious.append(iou) 199 | shape_ious[cat] = np.mean(shape_ious[cat]) 200 | mean_shape_ious = np.mean(list(shape_ious.values())) 201 | test_metrics['accuracy'] = total_correct / float(total_seen) 202 | test_metrics['class_avg_accuracy'] = np.mean( 203 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 204 | for cat in sorted(shape_ious.keys()): 205 | logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 206 | test_metrics['class_avg_iou'] = mean_shape_ious 207 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 208 | 209 | logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( 210 | epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) 211 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): 212 | logger.info('Save model...') 213 | savepath = 'best_model.pth' 214 | logger.info('Saving at %s' % savepath) 215 | state = { 216 | 'epoch': epoch, 217 | 'train_acc': train_instance_acc, 218 | 'test_acc': test_metrics['accuracy'], 219 | 'class_avg_iou': test_metrics['class_avg_iou'], 220 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 221 | 'model_state_dict': classifier.state_dict(), 222 | 'optimizer_state_dict': optimizer.state_dict(), 223 | } 224 | torch.save(state, savepath) 225 | logger.info('Saving model....') 226 | 227 | if test_metrics['accuracy'] > best_acc: 228 | best_acc = test_metrics['accuracy'] 229 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 230 | best_class_avg_iou = test_metrics['class_avg_iou'] 231 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 232 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 233 | logger.info('Best accuracy is: %.5f' % best_acc) 234 | logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) 235 | logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) 236 | global_epoch += 1 237 | 238 | 239 | if __name__ == '__main__': 240 | main() --------------------------------------------------------------------------------