├── LICENSE ├── README.md ├── data └── ntu │ ├── ntu120.list │ └── ntu60.list ├── datasets ├── msr.py └── ntu60.py ├── imgs ├── arch.png ├── equation.png ├── intro.png └── pstconv.png ├── models └── sequence_classification.py ├── modules ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu ├── pointnet2_test.py ├── pointnet2_utils.py ├── pst_convolutions.py └── setup.py ├── scripts ├── depth2point4ntu.py └── depth2point4ntu.sh ├── train-msr.py ├── train-ntu.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hehe Fan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences](https://openreview.net/pdf?id=O3bqkf_Puys) 2 | ![](https://github.com/hehefan/Point-Spatio-Temporal-Convolution/blob/main/imgs/intro.png) 3 | 4 | ![](https://github.com/hehefan/Point-Spatio-Temporal-Convolution/blob/main/imgs/equation.png) 5 | 6 | ## Introduction 7 | Point cloud sequences are irregular and unordered in the spatial dimension while exhibiting regularities and order in the temporal dimension. Therefore, existing grid based convolutions for conventional video processing cannot be directly applied to spatio-temporal modeling of raw point cloud sequences. In the paper, we propose a point spatio-temporal (PST) convolution to achieve informative representations of point cloud sequences. The proposed PST convolution first disentangles space and time in point cloud sequences. Then, a spatial convolution is employed to capture the local structure of points in the 3D space, and a temporal convolution is used to model the dynamics of the spatial regions along the time dimension. 8 | ![](https://github.com/hehefan/Point-Spatio-Temporal-Convolution/blob/main/imgs/pstconv.png) 9 | Furthermore, we incorporate the proposed PST convolution into a deep network, namely PSTNet, to extract features of 3D point cloud sequences in a spatio-temporally hierarchical manner. 10 | ![](https://github.com/hehefan/Point-Spatio-Temporal-Convolution/blob/main/imgs/arch.png) 11 | 12 | ## Installation 13 | 14 | The code is tested with Red Hat Enterprise Linux Workstation release 7.7 (Maipo), g++ (GCC) 8.3.1, PyTorch v1.2, CUDA 10.2 and cuDNN v7.6. 15 | 16 | Install PyTorch v1.2: 17 | ``` 18 | pip install torch==1.2.0 torchvision==0.4.0 19 | ``` 20 | 21 | Compile the CUDA layers for [PointNet++](http://arxiv.org/abs/1706.02413), which we used for furthest point sampling (FPS) and radius neighbouring search: 22 | ``` 23 | cd modules 24 | python setup.py install 25 | ``` 26 | To see if the compilation is successful, try to run `python modules/pst_convolutions.py` to see if a forward pass works. 27 | 28 | Install [Mayavi](https://docs.enthought.com/mayavi/mayavi/installation.html) for point cloud visualization (optional). Desktop is required. 29 | 30 | ## Citation 31 | If you find our work useful in your research, please consider citing: 32 | ``` 33 | @inproceedings{fan2021pstnet, 34 | title={PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences}, 35 | author={Hehe Fan and Xin Yu and Yuhang Ding and Yi Yang and Mohan Kankanhalli}, 36 | booktitle={International Conference on Learning Representations}, 37 | year={2021} 38 | } 39 | ``` 40 | 41 | ## Related Repos 42 | 1. PointNet++ PyTorch implementation: https://github.com/facebookresearch/votenet/tree/master/pointnet2 43 | 2. MeteorNet: https://github.com/xingyul/meteornet 44 | 3. 3DV: https://github.com/3huo/3DV-Action 45 | 4. P4Transformer: https://github.com/hehefan/P4Transformer 46 | 5. PointRNN (TensorFlow implementation): https://github.com/hehefan/PointRNN 47 | 6. PointRNN (PyTorch implementation): https://github.com/hehefan/PointRNN-PyTorch 48 | 7. Awesome Dynamic Point Cloud / Point Cloud Video / Point Cloud Sequence / 4D Point Cloud Analysis: https://github.com/hehefan/Awesome-Dynamic-Point-Cloud-Analysis 49 | -------------------------------------------------------------------------------- /datasets/msr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | class MSRAction3D(Dataset): 7 | def __init__(self, root, frames_per_clip=16, frame_inverval=1, num_points=2048, train=True): 8 | super(MSRAction3D, self).__init__() 9 | 10 | self.videos = [] 11 | self.labels = [] 12 | self.index_map = [] 13 | index = 0 14 | for video_name in os.listdir(root): 15 | if train and (int(video_name.split('_')[1].split('s')[1]) <= 5): 16 | video = np.load(os.path.join(root, video_name), allow_pickle=True)['point_clouds'] 17 | self.videos.append(video) 18 | label = int(video_name.split('_')[0][1:])-1 19 | self.labels.append(label) 20 | 21 | nframes = video.shape[0] 22 | for t in range(0, nframes-frame_inverval*(frames_per_clip-1)): 23 | self.index_map.append((index, t)) 24 | index += 1 25 | 26 | if not train and (int(video_name.split('_')[1].split('s')[1]) > 5): 27 | video = np.load(os.path.join(root, video_name), allow_pickle=True)['point_clouds'] 28 | self.videos.append(video) 29 | label = int(video_name.split('_')[0][1:])-1 30 | self.labels.append(label) 31 | 32 | nframes = video.shape[0] 33 | for t in range(0, nframes-frame_inverval*(frames_per_clip-1)): 34 | self.index_map.append((index, t)) 35 | index += 1 36 | 37 | self.frames_per_clip = frames_per_clip 38 | self.frame_inverval = frame_inverval 39 | self.num_points = num_points 40 | self.train = train 41 | self.num_classes = max(self.labels) + 1 42 | 43 | 44 | def __len__(self): 45 | return len(self.index_map) 46 | 47 | def __getitem__(self, idx): 48 | index, t = self.index_map[idx] 49 | 50 | video = self.videos[index] 51 | label = self.labels[index] 52 | 53 | clip = [video[t+i*self.frame_inverval] for i in range(self.frames_per_clip)] 54 | for i, p in enumerate(clip): 55 | if p.shape[0] > self.num_points: 56 | r = np.random.choice(p.shape[0], size=self.num_points, replace=False) 57 | else: 58 | repeat, residue = self.num_points // p.shape[0], self.num_points % p.shape[0] 59 | r = np.random.choice(p.shape[0], size=residue, replace=False) 60 | r = np.concatenate([np.arange(p.shape[0]) for _ in range(repeat)] + [r], axis=0) 61 | clip[i] = p[r, :] 62 | clip = np.array(clip) 63 | 64 | if self.train: 65 | # scale the points 66 | scales = np.random.uniform(0.9, 1.1, size=3) 67 | clip = clip * scales 68 | 69 | clip = clip / 300 70 | 71 | return clip.astype(np.float32), label, index 72 | 73 | if __name__ == '__main__': 74 | dataset = MSRAction(root='../data/msr_action', frames_per_clip=16) 75 | clip, label, video_idx = dataset[0] 76 | print(clip) 77 | print(label) 78 | print(video_idx) 79 | print(dataset.num_classes) 80 | -------------------------------------------------------------------------------- /datasets/ntu60.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | Cross_Subject = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] 7 | 8 | class NTU60Subject(Dataset): 9 | def __init__(self, root, meta, frames_per_clip=23, step_between_clips=2, num_points=2048, train=True): 10 | super(NTU60Subject, self).__init__() 11 | 12 | self.videos = [] 13 | self.labels = [] 14 | self.index_map = [] 15 | index = 0 16 | 17 | with open(meta, 'r') as f: 18 | for line in f: 19 | name, nframes = line.split() 20 | subject = int(name[9:12]) 21 | if train: 22 | if subject in Cross_Subject: 23 | label = int(name[-3:]) - 1 24 | nframes = int(nframes) 25 | for t in range(0, nframes-step_between_clips*(frames_per_clip-1), step_between_clips): 26 | self.index_map.append((index, t)) 27 | index += 1 28 | self.labels.append(label) 29 | self.videos.append(os.path.join(root, name+'.npz')) 30 | else: 31 | if subject not in Cross_Subject: 32 | label = int(name[-3:]) - 1 33 | nframes = int(nframes) 34 | for t in range(0, nframes-step_between_clips*(frames_per_clip-1), step_between_clips): 35 | self.index_map.append((index, t)) 36 | index += 1 37 | self.labels.append(label) 38 | self.videos.append(os.path.join(root, name+'.npz')) 39 | 40 | self.frames_per_clip = frames_per_clip 41 | self.step_between_clips = step_between_clips 42 | self.num_points = num_points 43 | self.train = train 44 | self.num_classes = max(self.labels) + 1 45 | 46 | 47 | def __len__(self): 48 | return len(self.index_map) 49 | 50 | def __getitem__(self, idx): 51 | index, t = self.index_map[idx] 52 | 53 | video = self.videos[index] 54 | video = np.load(video, allow_pickle=True)['data'] * 100 55 | label = self.labels[index] 56 | 57 | clip = [video[t+i*self.step_between_clips] for i in range(self.frames_per_clip)] 58 | for i, p in enumerate(clip): 59 | if p.shape[0] > self.num_points: 60 | r = np.random.choice(p.shape[0], size=self.num_points, replace=False) 61 | else: 62 | repeat, residue = self.num_points // p.shape[0], self.num_points % p.shape[0] 63 | r = np.random.choice(p.shape[0], size=residue, replace=False) 64 | r = np.concatenate([np.arange(p.shape[0]) for _ in range(repeat)] + [r], axis=0) 65 | clip[i] = p[r, :] 66 | clip = np.array(clip) 67 | 68 | if self.train: 69 | # scale the points 70 | scales = np.random.uniform(0.9, 1.1, size=3) 71 | clip = clip * scales 72 | 73 | return clip.astype(np.float32), label, index 74 | 75 | if __name__ == '__main__': 76 | dataset = NTU60Subject(root='/scratch/HeheFan-data/data/ntu/video', meta='/scratch/HeheFan-data/data/ntu/ntu60.list', frames_per_clip=16) 77 | clip, label, video_idx = dataset[0] 78 | data = clip[0] 79 | print(data[:,0].max()-data[:,0].min()) 80 | print(data[:,1].max()-data[:,1].min()) 81 | print(data[:,2].max()-data[:,2].min()) 82 | #print(clip) 83 | print(label) 84 | print(video_idx) 85 | print(dataset.num_classes) 86 | -------------------------------------------------------------------------------- /imgs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehefan/Point-Spatio-Temporal-Convolution/04f9b47bce1907e28f3f3862dda047178c5e81dd/imgs/arch.png -------------------------------------------------------------------------------- /imgs/equation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehefan/Point-Spatio-Temporal-Convolution/04f9b47bce1907e28f3f3862dda047178c5e81dd/imgs/equation.png -------------------------------------------------------------------------------- /imgs/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehefan/Point-Spatio-Temporal-Convolution/04f9b47bce1907e28f3f3862dda047178c5e81dd/imgs/intro.png -------------------------------------------------------------------------------- /imgs/pstconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hehefan/Point-Spatio-Temporal-Convolution/04f9b47bce1907e28f3f3862dda047178c5e81dd/imgs/pstconv.png -------------------------------------------------------------------------------- /models/sequence_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | import os 7 | 8 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | ROOT_DIR = os.path.dirname(BASE_DIR) 10 | sys.path.append(ROOT_DIR) 11 | sys.path.append(os.path.join(ROOT_DIR, 'modules')) 12 | 13 | from pst_convolutions import PSTConv 14 | 15 | class MSRAction(nn.Module): 16 | def __init__(self, radius=1.5, nsamples=3*3, num_classes=20): 17 | super(MSRAction, self).__init__() 18 | 19 | self.conv1 = PSTConv(in_planes=0, 20 | mid_planes=45, 21 | out_planes=64, 22 | spatial_kernel_size=[radius, nsamples], 23 | temporal_kernel_size=1, 24 | spatial_stride=2, 25 | temporal_stride=1, 26 | temporal_padding=[0,0], 27 | spatial_aggregation="multiplication", 28 | spatial_pooling="sum") 29 | 30 | self.conv2a = PSTConv(in_planes=64, 31 | mid_planes=96, 32 | out_planes=128, 33 | spatial_kernel_size=[2*radius, nsamples], 34 | temporal_kernel_size=3, 35 | spatial_stride=2, 36 | temporal_stride=2, 37 | temporal_padding=[1,0], 38 | spatial_aggregation="multiplication", 39 | spatial_pooling="sum") 40 | 41 | self.conv2b = PSTConv(in_planes=128, 42 | mid_planes=192, 43 | out_planes=256, 44 | spatial_kernel_size=[2*radius, nsamples], 45 | temporal_kernel_size=3, 46 | spatial_stride=1, 47 | temporal_stride=1, 48 | temporal_padding=[1,1], 49 | spatial_aggregation="multiplication", 50 | spatial_pooling="sum") 51 | 52 | self.conv3a = PSTConv(in_planes=256, 53 | mid_planes=284, 54 | out_planes=512, 55 | spatial_kernel_size=[2*2*radius, nsamples], 56 | temporal_kernel_size=3, 57 | spatial_stride=2, 58 | temporal_stride=2, 59 | temporal_padding=[1,0], 60 | spatial_aggregation="multiplication", 61 | spatial_pooling="sum") 62 | 63 | self.conv3b = PSTConv(in_planes=512, 64 | mid_planes=768, 65 | out_planes=1024, 66 | spatial_kernel_size=[2*2*radius, nsamples], 67 | temporal_kernel_size=3, 68 | spatial_stride=1, 69 | temporal_stride=1, 70 | temporal_padding=[1,1], 71 | spatial_aggregation="multiplication", 72 | spatial_pooling="sum") 73 | 74 | self.conv4 = PSTConv(in_planes=1024, 75 | mid_planes=1536, 76 | out_planes=2048, 77 | spatial_kernel_size=[2*2*radius, nsamples], 78 | temporal_kernel_size=1, 79 | spatial_stride=2, 80 | temporal_stride=1, 81 | temporal_padding=[0,0], 82 | spatial_aggregation="multiplication", 83 | spatial_pooling="sum") 84 | 85 | self.fc = nn.Linear(2048, num_classes) 86 | 87 | def forward(self, xyzs): 88 | 89 | new_xys, new_features = self.conv1(xyzs, None) 90 | new_features = F.relu(new_features) 91 | 92 | new_xys, new_features = self.conv2a(new_xys, new_features) 93 | new_features = F.relu(new_features) 94 | 95 | new_xys, new_features = self.conv2b(new_xys, new_features) 96 | new_features = F.relu(new_features) 97 | 98 | new_xys, new_features = self.conv3a(new_xys, new_features) 99 | new_features = F.relu(new_features) 100 | 101 | new_xys, new_features = self.conv3b(new_xys, new_features) 102 | new_features = F.relu(new_features) 103 | 104 | new_xys, new_features = self.conv4(new_xys, new_features) # (B, L, C, N) 105 | 106 | new_features = torch.mean(input=new_features, dim=-1, keepdim=False) # (B, L, C) 107 | 108 | new_feature = torch.max(input=new_features, dim=1, keepdim=False)[0] # (B, C) 109 | 110 | out = self.fc(new_feature) 111 | 112 | return out 113 | 114 | class NTU(nn.Module): 115 | def __init__(self, radius=0.1, nsamples=3*3, num_classes=20): 116 | super(NTU, self).__init__() 117 | 118 | self.conv1 = PSTConv(in_planes=0, 119 | mid_planes=45, 120 | out_planes=64, 121 | spatial_kernel_size=[radius, nsamples], 122 | temporal_kernel_size=1, 123 | spatial_stride=2, 124 | temporal_stride=1, 125 | temporal_padding=[0,0]) 126 | 127 | self.conv2a = PSTConv(in_planes=64, 128 | mid_planes=96, 129 | out_planes=128, 130 | spatial_kernel_size=[2*radius, nsamples], 131 | temporal_kernel_size=3, 132 | spatial_stride=2, 133 | temporal_stride=2, 134 | temporal_padding=[0,0]) 135 | 136 | self.conv2b = PSTConv(in_planes=128, 137 | mid_planes=192, 138 | out_planes=256, 139 | spatial_kernel_size=[2*radius, nsamples], 140 | temporal_kernel_size=3, 141 | spatial_stride=1, 142 | temporal_stride=1, 143 | temporal_padding=[0,0]) 144 | 145 | self.conv3a = PSTConv(in_planes=256, 146 | mid_planes=384, 147 | out_planes=512, 148 | spatial_kernel_size=[2*2*radius, nsamples], 149 | temporal_kernel_size=3, 150 | spatial_stride=2, 151 | temporal_stride=2, 152 | temporal_padding=[0,0]) 153 | 154 | self.conv3b = PSTConv(in_planes=512, 155 | mid_planes=768, 156 | out_planes=1024, 157 | spatial_kernel_size=[2*2*radius, nsamples], 158 | temporal_kernel_size=3, 159 | spatial_stride=1, 160 | temporal_stride=1, 161 | temporal_padding=[0,0]) 162 | 163 | self.conv4 = PSTConv(in_planes=1024, 164 | mid_planes=1536, 165 | out_planes=2048, 166 | spatial_kernel_size=[2*2*radius, nsamples], 167 | temporal_kernel_size=1, 168 | spatial_stride=2, 169 | temporal_stride=1, 170 | temporal_padding=[0,0]) 171 | 172 | self.fc = nn.Linear(2048, num_classes) 173 | 174 | def forward(self, xyzs): 175 | 176 | new_xys, new_features = self.conv1(xyzs, None) 177 | new_features = F.relu(new_features) 178 | 179 | new_xys, new_features = self.conv2a(new_xys, new_features) 180 | new_features = F.relu(new_features) 181 | 182 | new_xys, new_features = self.conv2b(new_xys, new_features) 183 | new_features = F.relu(new_features) 184 | 185 | new_xys, new_features = self.conv3a(new_xys, new_features) 186 | new_features = F.relu(new_features) 187 | 188 | new_xys, new_features = self.conv3b(new_xys, new_features) 189 | new_features = F.relu(new_features) 190 | 191 | new_xys, new_features = self.conv4(new_xys, new_features) # (B, L, C, N) 192 | 193 | new_features = torch.mean(input=new_features, dim=-1, keepdim=False) # (B, L, C) 194 | 195 | new_feature = torch.max(input=new_features, dim=1, keepdim=False)[0] # (B, C) 196 | 197 | out = self.fc(new_feature) 198 | 199 | return out 200 | -------------------------------------------------------------------------------- /modules/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /modules/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /modules/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /modules/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /modules/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /modules/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /modules/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | AT_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /modules/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /modules/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /modules/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | AT_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | AT_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /modules/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /modules/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | AT_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | AT_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | AT_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /modules/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /modules/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | AT_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | AT_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | AT_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /modules/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /modules/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /modules/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import sys 18 | 19 | try: 20 | import builtins 21 | except: 22 | import __builtin__ as builtins 23 | 24 | try: 25 | import pointnet2._ext as _ext 26 | except ImportError: 27 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 28 | raise ImportError( 29 | "Could not import _ext module.\n" 30 | "Please see the setup instructions in the README: " 31 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 32 | ) 33 | 34 | if False: 35 | # Workaround for type hints without depending on the `typing` module 36 | from typing import * 37 | 38 | class FurthestPointSampling(Function): 39 | @staticmethod 40 | def forward(ctx, xyz, npoint): 41 | # type: (Any, torch.Tensor, int) -> torch.Tensor 42 | r""" 43 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 44 | minimum distance 45 | 46 | Parameters 47 | ---------- 48 | xyz : torch.Tensor 49 | (B, N, 3) tensor where N > npoint 50 | npoint : int32 51 | number of features in the sampled set 52 | 53 | Returns 54 | ------- 55 | torch.Tensor 56 | (B, npoint) tensor containing the set 57 | """ 58 | fps_inds = _ext.furthest_point_sampling(xyz, npoint) 59 | ctx.mark_non_differentiable(fps_inds) 60 | return fps_inds 61 | 62 | @staticmethod 63 | def backward(xyz, a=None): 64 | return None, None 65 | 66 | 67 | furthest_point_sample = FurthestPointSampling.apply 68 | 69 | 70 | class GatherOperation(Function): 71 | @staticmethod 72 | def forward(ctx, features, idx): 73 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 74 | r""" 75 | 76 | Parameters 77 | ---------- 78 | features : torch.Tensor 79 | (B, C, N) tensor 80 | 81 | idx : torch.Tensor 82 | (B, npoint) tensor of the features to gather 83 | 84 | Returns 85 | ------- 86 | torch.Tensor 87 | (B, C, npoint) tensor 88 | """ 89 | 90 | _, C, N = features.size() 91 | 92 | ctx.for_backwards = (idx, C, N) 93 | 94 | return _ext.gather_points(features, idx) 95 | 96 | @staticmethod 97 | def backward(ctx, grad_out): 98 | idx, C, N = ctx.for_backwards 99 | 100 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 101 | return grad_features, None 102 | 103 | 104 | gather_operation = GatherOperation.apply 105 | 106 | 107 | class ThreeNN(Function): 108 | @staticmethod 109 | def forward(ctx, unknown, known): 110 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 111 | r""" 112 | Find the three nearest neighbors of unknown in known 113 | Parameters 114 | ---------- 115 | unknown : torch.Tensor 116 | (B, n, 3) tensor of known features 117 | known : torch.Tensor 118 | (B, m, 3) tensor of unknown features 119 | 120 | Returns 121 | ------- 122 | dist : torch.Tensor 123 | (B, n, 3) l2 distance to the three nearest neighbors 124 | idx : torch.Tensor 125 | (B, n, 3) index of 3 nearest neighbors 126 | """ 127 | dist2, idx = _ext.three_nn(unknown, known) 128 | 129 | return torch.sqrt(dist2), idx 130 | 131 | @staticmethod 132 | def backward(ctx, a=None, b=None): 133 | return None, None 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | class ThreeInterpolate(Function): 140 | @staticmethod 141 | def forward(ctx, features, idx, weight): 142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 143 | r""" 144 | Performs weight linear interpolation on 3 features 145 | Parameters 146 | ---------- 147 | features : torch.Tensor 148 | (B, c, m) Features descriptors to be interpolated from 149 | idx : torch.Tensor 150 | (B, n, 3) three nearest neighbors of the target features in features 151 | weight : torch.Tensor 152 | (B, n, 3) weights 153 | 154 | Returns 155 | ------- 156 | torch.Tensor 157 | (B, c, n) tensor of the interpolated features 158 | """ 159 | B, c, m = features.size() 160 | n = idx.size(1) 161 | 162 | ctx.three_interpolate_for_backward = (idx, weight, m) 163 | 164 | return _ext.three_interpolate(features, idx, weight) 165 | 166 | @staticmethod 167 | def backward(ctx, grad_out): 168 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 169 | r""" 170 | Parameters 171 | ---------- 172 | grad_out : torch.Tensor 173 | (B, c, n) tensor with gradients of ouputs 174 | 175 | Returns 176 | ------- 177 | grad_features : torch.Tensor 178 | (B, c, m) tensor with gradients of features 179 | 180 | None 181 | 182 | None 183 | """ 184 | idx, weight, m = ctx.three_interpolate_for_backward 185 | 186 | grad_features = _ext.three_interpolate_grad( 187 | grad_out.contiguous(), idx, weight, m 188 | ) 189 | 190 | return grad_features, None, None 191 | 192 | 193 | three_interpolate = ThreeInterpolate.apply 194 | 195 | 196 | class GroupingOperation(Function): 197 | @staticmethod 198 | def forward(ctx, features, idx): 199 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 200 | r""" 201 | 202 | Parameters 203 | ---------- 204 | features : torch.Tensor 205 | (B, C, N) tensor of features to group 206 | idx : torch.Tensor 207 | (B, npoint, nsample) tensor containing the indicies of features to group with 208 | 209 | Returns 210 | ------- 211 | torch.Tensor 212 | (B, C, npoint, nsample) tensor 213 | """ 214 | B, nfeatures, nsample = idx.size() 215 | _, C, N = features.size() 216 | 217 | ctx.for_backwards = (idx, N) 218 | 219 | return _ext.group_points(features, idx) 220 | 221 | @staticmethod 222 | def backward(ctx, grad_out): 223 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 224 | r""" 225 | 226 | Parameters 227 | ---------- 228 | grad_out : torch.Tensor 229 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 230 | 231 | Returns 232 | ------- 233 | torch.Tensor 234 | (B, C, N) gradient of the features 235 | None 236 | """ 237 | idx, N = ctx.for_backwards 238 | 239 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 240 | 241 | return grad_features, None 242 | 243 | 244 | grouping_operation = GroupingOperation.apply 245 | 246 | 247 | class BallQuery(Function): 248 | @staticmethod 249 | def forward(ctx, radius, nsample, xyz, new_xyz): 250 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 251 | r""" 252 | 253 | Parameters 254 | ---------- 255 | radius : float 256 | radius of the balls 257 | nsample : int 258 | maximum number of features in the balls 259 | xyz : torch.Tensor 260 | (B, N, 3) xyz coordinates of the features 261 | new_xyz : torch.Tensor 262 | (B, npoint, 3) centers of the ball query 263 | 264 | Returns 265 | ------- 266 | torch.Tensor 267 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 268 | """ 269 | inds = _ext.ball_query(new_xyz, xyz, radius, nsample) 270 | ctx.mark_non_differentiable(inds) 271 | return inds 272 | 273 | @staticmethod 274 | def backward(ctx, a=None): 275 | return None, None, None, None 276 | 277 | 278 | ball_query = BallQuery.apply 279 | 280 | 281 | class QueryAndGroup(nn.Module): 282 | r""" 283 | Groups with a ball query of radius 284 | 285 | Parameters 286 | --------- 287 | radius : float32 288 | Radius of ball 289 | nsample : int32 290 | Maximum number of features to gather in the ball 291 | """ 292 | 293 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 294 | # type: (QueryAndGroup, float, int, bool) -> None 295 | super(QueryAndGroup, self).__init__() 296 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 297 | self.ret_grouped_xyz = ret_grouped_xyz 298 | self.normalize_xyz = normalize_xyz 299 | self.sample_uniformly = sample_uniformly 300 | self.ret_unique_cnt = ret_unique_cnt 301 | if self.ret_unique_cnt: 302 | assert(self.sample_uniformly) 303 | 304 | def forward(self, xyz, new_xyz, features=None): 305 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 306 | r""" 307 | Parameters 308 | ---------- 309 | xyz : torch.Tensor 310 | xyz coordinates of the features (B, N, 3) 311 | new_xyz : torch.Tensor 312 | centriods (B, npoint, 3) 313 | features : torch.Tensor 314 | Descriptors of the features (B, C, N) 315 | 316 | Returns 317 | ------- 318 | new_features : torch.Tensor 319 | (B, 3 + C, npoint, nsample) tensor 320 | """ 321 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 322 | 323 | if self.sample_uniformly: 324 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 325 | for i_batch in range(idx.shape[0]): 326 | for i_region in range(idx.shape[1]): 327 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 328 | num_unique = unique_ind.shape[0] 329 | unique_cnt[i_batch, i_region] = num_unique 330 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 331 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 332 | idx[i_batch, i_region, :] = all_ind 333 | 334 | 335 | xyz_trans = xyz.transpose(1, 2).contiguous() 336 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 337 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 338 | if self.normalize_xyz: 339 | grouped_xyz /= self.radius 340 | 341 | if features is not None: 342 | grouped_features = grouping_operation(features, idx) 343 | if self.use_xyz: 344 | new_features = torch.cat( 345 | [grouped_xyz, grouped_features], dim=1 346 | ) # (B, C + 3, npoint, nsample) 347 | else: 348 | new_features = grouped_features 349 | else: 350 | assert ( 351 | self.use_xyz 352 | ), "Cannot have not features and not use xyz as a feature!" 353 | new_features = grouped_xyz 354 | 355 | ret = [new_features] 356 | if self.ret_grouped_xyz: 357 | ret.append(grouped_xyz) 358 | if self.ret_unique_cnt: 359 | ret.append(unique_cnt) 360 | if len(ret) == 1: 361 | return ret[0] 362 | else: 363 | return tuple(ret) 364 | 365 | 366 | class GroupAll(nn.Module): 367 | r""" 368 | Groups all features 369 | 370 | Parameters 371 | --------- 372 | """ 373 | 374 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 375 | # type: (GroupAll, bool) -> None 376 | super(GroupAll, self).__init__() 377 | self.use_xyz = use_xyz 378 | 379 | def forward(self, xyz, new_xyz, features=None): 380 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 381 | r""" 382 | Parameters 383 | ---------- 384 | xyz : torch.Tensor 385 | xyz coordinates of the features (B, N, 3) 386 | new_xyz : torch.Tensor 387 | Ignored 388 | features : torch.Tensor 389 | Descriptors of the features (B, C, N) 390 | 391 | Returns 392 | ------- 393 | new_features : torch.Tensor 394 | (B, C + 3, 1, N) tensor 395 | """ 396 | 397 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 398 | if features is not None: 399 | grouped_features = features.unsqueeze(2) 400 | if self.use_xyz: 401 | new_features = torch.cat( 402 | [grouped_xyz, grouped_features], dim=1 403 | ) # (B, 3 + C, 1, N) 404 | else: 405 | new_features = grouped_features 406 | else: 407 | new_features = grouped_xyz 408 | 409 | if self.ret_grouped_xyz: 410 | return new_features, grouped_xyz 411 | else: 412 | return new_features 413 | -------------------------------------------------------------------------------- /modules/pst_convolutions.py: -------------------------------------------------------------------------------- 1 | """ Point Point Spatio-Temporal (PST) Convolutions and Transposed Convolutions 2 | 3 | From: "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" 4 | 5 | Author: Hehe Fan 6 | Date: July 2020 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch import Tensor 13 | 14 | import math 15 | import os 16 | import sys 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(BASE_DIR) 19 | 20 | import pointnet2_utils 21 | from typing import List 22 | 23 | def kaiming_uniform(tensor, size): 24 | fan = size[1] * size[2] * size[3] 25 | gain = math.sqrt(2.0 / (1 + math.sqrt(5) ** 2)) 26 | std = gain / math.sqrt(fan) 27 | bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 28 | with torch.no_grad(): 29 | return tensor.uniform_(-bound, bound) 30 | 31 | def uniform(tensor, a, b): 32 | with torch.no_grad(): 33 | return tensor.uniform_(a, b) 34 | 35 | class PSTConv(nn.Module): 36 | def __init__(self, 37 | in_planes: int, 38 | mid_planes: int, 39 | out_planes: int, 40 | spatial_kernel_size: [float, int], 41 | temporal_kernel_size: int, 42 | spatial_stride: int = 1, 43 | temporal_stride: int = 1, 44 | temporal_padding: [int, int] = [0, 0], 45 | padding_mode: str = "zeros", 46 | spatial_aggregation: str = "addition", 47 | spatial_pooling: str = "max", 48 | bias: bool = False, 49 | batch_norm: bool = True): 50 | """ 51 | Args: 52 | in_planes: C, number of point feature channels in the input. it is 0 if point features are not available. 53 | mid_planes: C_m, number of channels produced by the spatial convolution 54 | out_planes: C', number of channels produced by the temporal convolution 55 | spatial_kernel_size: (r, k), radius and nsamples 56 | temporal_kernel_size: odd 57 | spatial_stride: spatial sub-sampling rate, >= 1 58 | temporal_stride: controls the stride for the temporal cross correlation, >= 1 59 | temporal_padding: 60 | padding_mode: "zeros" or "replicate" 61 | spatial_aggregation: controls the way to aggregate point displacements and point features, "addition" or "multiplication" 62 | spatial_pooling: "max", "sum" or "avg" 63 | bias: 64 | batch_norm: 65 | """ 66 | super().__init__() 67 | 68 | assert (padding_mode in ["zeros", "replicate"]), "PSTConv: 'padding_mode' should be 'zeros' or 'replicate'!" 69 | assert (spatial_aggregation in ["addition", "multiplication"]), "PSTConv: 'spatial_aggregation' should be 'addition' or 'multiplication'!" 70 | assert (spatial_pooling in ["max", "sum", "avg"]), "PSTConv: 'spatial_pooling' should be 'max', 'sum' or 'avg'!" 71 | 72 | self.in_planes = in_planes 73 | self.mid_planes = mid_planes 74 | self.out_planes = out_planes 75 | 76 | self.r, self.k = spatial_kernel_size 77 | self.spatial_stride = spatial_stride 78 | 79 | self.temporal_kernel_size = temporal_kernel_size 80 | self.temporal_radius = math.floor(temporal_kernel_size/2) 81 | self.temporal_stride = temporal_stride 82 | self.temporal_padding = temporal_padding 83 | self.padding_mode = padding_mode 84 | 85 | self.spatial_aggregation = spatial_aggregation 86 | self.spatial_pooling = spatial_pooling 87 | 88 | if in_planes != 0: 89 | self.spatial_conv_f = nn.Conv2d(in_channels=in_planes, out_channels=mid_planes, kernel_size=1, stride=1, padding=0, bias=bias) 90 | kaiming_uniform(self.spatial_conv_f.weight, size=[mid_planes, in_planes+3, 1, 1]) 91 | if bias: 92 | bound = 1 / math.sqrt(in_planes+3) 93 | uniform(self.spatial_conv_f.bias, -bound, bound) 94 | 95 | self.spatial_conv_d = nn.Conv2d(in_channels=3, out_channels=mid_planes, kernel_size=1, stride=1, padding=0, bias=bias) 96 | kaiming_uniform(self.spatial_conv_d.weight, size=[mid_planes, in_planes+3, 1, 1]) 97 | if bias: 98 | bound = 1 / math.sqrt(in_planes+3) 99 | uniform(self.spatial_conv_d.bias, -bound, bound) 100 | 101 | self.batch_norm = nn.BatchNorm1d(num_features=temporal_kernel_size*mid_planes) if batch_norm else False 102 | self.relu = nn.ReLU(inplace=True) 103 | 104 | self.temporal = nn.Conv1d(in_channels=temporal_kernel_size*mid_planes, out_channels=out_planes, kernel_size=1, stride=1, padding=0, bias=bias) 105 | 106 | def forward(self, xyzs: torch.Tensor, features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 107 | """ 108 | Args: 109 | xyzs: torch.Tensor 110 | (B, L, N, 3) tensor of sequence of the xyz coordinates 111 | features: torch.Tensor 112 | (B, L, C, N) tensor of sequence of the features 113 | """ 114 | device = xyzs.get_device() 115 | 116 | nframes = xyzs.size(1) # L 117 | npoints = xyzs.size(2) # N 118 | 119 | if self.temporal_kernel_size > 1 and self.temporal_stride > 1: 120 | assert ((nframes + sum(self.temporal_padding) - self.temporal_kernel_size) % self.temporal_stride == 0), "PSTConv: Temporal parameter error!" 121 | 122 | xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1) 123 | xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs] 124 | 125 | if self.in_planes != 0: 126 | features = torch.split(tensor=features, split_size_or_sections=1, dim=1) 127 | features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features] 128 | 129 | if self.padding_mode == "zeros": 130 | xyz_padding = torch.zeros(xyzs[0].size(), dtype=torch.float32, device=device) 131 | for i in range(self.temporal_padding[0]): 132 | xyzs = [xyz_padding] + xyzs 133 | for i in range(self.temporal_padding[1]): 134 | xyzs = xyzs + [xyz_padding] 135 | 136 | if self.in_planes != 0: 137 | feature_padding = torch.zeros(features[0].size(), dtype=torch.float32, device=device) 138 | for i in range(self.temporal_padding[0]): 139 | features = [feature_padding] + features 140 | for i in range(self.temporal_padding[1]): 141 | features = features + [feature_padding] 142 | else: # "replicate" 143 | for i in range(self.temporal_padding[0]): 144 | xyzs = [xyzs[0]] + xyzs 145 | for i in range(self.temporal_padding[1]): 146 | xyzs = xyzs + [xyzs[-1]] 147 | 148 | if self.in_planes != 0: 149 | for i in range(self.temporal_padding[0]): 150 | features = [features[0]] + features 151 | for i in range(self.temporal_padding[1]): 152 | features = features + [features[-1]] 153 | 154 | new_xyzs = [] 155 | new_features = [] 156 | for t in range(self.temporal_radius, len(xyzs)-self.temporal_radius, self.temporal_stride): # temporal anchor frames 157 | # spatial anchor point subsampling by FPS 158 | anchor_idx = pointnet2_utils.furthest_point_sample(xyzs[t], npoints//self.spatial_stride) # (B, N//self.spatial_stride) 159 | anchor_xyz_flipped = pointnet2_utils.gather_operation(xyzs[t].transpose(1, 2).contiguous(), anchor_idx) # (B, 3, N//self.spatial_stride) 160 | anchor_xyz_expanded = torch.unsqueeze(anchor_xyz_flipped, 3) # (B, 3, N//spatial_stride, 1) 161 | anchor_xyz = anchor_xyz_flipped.transpose(1, 2).contiguous() # (B, N//spatial_stride, 3) 162 | 163 | # spatial convolution 164 | spatial_features = [] 165 | for i in range(t-self.temporal_radius, t+self.temporal_radius+1): 166 | neighbor_xyz = xyzs[i] 167 | 168 | idx = pointnet2_utils.ball_query(self.r, self.k, neighbor_xyz, anchor_xyz) 169 | 170 | neighbor_xyz_flipped = neighbor_xyz.transpose(1, 2).contiguous() # (B, 3, N) 171 | neighbor_xyz_grouped = pointnet2_utils.grouping_operation(neighbor_xyz_flipped, idx) # (B, 3, N//spatial_stride, k) 172 | 173 | displacement = neighbor_xyz_grouped - anchor_xyz_expanded # (B, 3, N//spatial_stride, k) 174 | displacement = self.spatial_conv_d(displacement) # (B, mid_planes, N//spatial_stride, k) 175 | 176 | if self.in_planes != 0: 177 | neighbor_feature_grouped = pointnet2_utils.grouping_operation(features[i], idx) # (B, in_planes, N//spatial_stride, k) 178 | feature = self.spatial_conv_f(neighbor_feature_grouped) # (B, mid_planes, N//spatial_stride, k) 179 | 180 | if self.spatial_aggregation == "addition": 181 | spatial_feature = feature + displacement 182 | else: 183 | spatial_feature = feature * displacement 184 | 185 | else: 186 | spatial_feature = displacement 187 | 188 | if self.spatial_pooling == 'max': 189 | spatial_feature, _ = torch.max(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) 190 | elif self.spatial_pooling == 'sum': 191 | spatial_feature = torch.sum(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) 192 | else: 193 | spatial_feature = torch.mean(input=spatial_feature, dim=-1, keepdim=False) # (B, mid_planes, N//spatial_stride) 194 | 195 | spatial_features.append(spatial_feature) 196 | 197 | spatial_features = torch.cat(tensors=spatial_features, dim=1, out=None) # (B, temporal_kernel_size*mid_planes, N//spatial_stride) 198 | 199 | # batch norm and relu 200 | if self.batch_norm: 201 | spatial_features = self.batch_norm(spatial_features) 202 | 203 | spatial_features = self.relu(spatial_features) 204 | 205 | # temporal convolution 206 | spatio_temporal_feature = self.temporal(spatial_features) 207 | 208 | new_xyzs.append(anchor_xyz) 209 | new_features.append(spatio_temporal_feature) 210 | 211 | new_xyzs = torch.stack(tensors=new_xyzs, dim=1) 212 | new_features = torch.stack(tensors=new_features, dim=1) 213 | 214 | return new_xyzs, new_features 215 | 216 | class PSTConvTranspose(nn.Module): 217 | def __init__(self, 218 | in_planes: int, 219 | mid_planes: int, 220 | out_planes: int, 221 | temporal_kernel_size: int, 222 | temporal_stride: int = 1, 223 | temporal_padding: [int, int] = [0, 0], 224 | original_in_planes: int = 0, 225 | bias: bool = False, 226 | batch_norm: bool = True, 227 | activation: bool = True): 228 | """ 229 | Args: 230 | in_planes: C'. when point features are not available, in_planes is 0. 231 | mid_planes: C'_m 232 | out_planes: C" 233 | temporal_kernel_size: odd 234 | temporal_stride: controls the stride for the temporal cross correlation, >= 1 235 | temporal_padding: <=0, removes unnecessary temporal transposed features 236 | original_in_planes: C, used for skip connection from original points. when original point features are not available, original_in_planes is 0. 237 | bias: whether to use bias 238 | batch_norm: whether to use batch norm 239 | activation: 240 | """ 241 | super().__init__() 242 | 243 | self.in_planes = in_planes 244 | self.mid_planes = mid_planes 245 | self.out_planes = out_planes 246 | 247 | # temporal parameters 248 | self.temporal_kernel_size = temporal_kernel_size 249 | self.temporal_radius = math.floor(self.temporal_kernel_size/2) 250 | self.temporal_stride = temporal_stride 251 | self.temporal_padding = temporal_padding 252 | 253 | # temporal transposed convolution 254 | self.temporal_conv = nn.Conv1d(in_channels=in_planes, out_channels=temporal_kernel_size*mid_planes, kernel_size=1, stride=1, padding=0, bias=bias) 255 | 256 | self.batch_norm = nn.BatchNorm1d(num_features=mid_planes) if batch_norm else False 257 | self.activation = nn.ReLU(inplace=True) if activation else False 258 | 259 | # spatial interpolation convolution 260 | self.spatial_conv = nn.Conv1d(in_channels=mid_planes+original_in_planes, out_channels=out_planes, kernel_size=1, stride=1, padding=0, bias=bias) 261 | 262 | 263 | def forward(self, xyzs: torch.Tensor, original_xyzs: torch.Tensor, features: torch.Tensor, original_features: torch.Tensor = None) -> torch.Tensor: 264 | r""" 265 | Parameters 266 | ---------- 267 | xyzs : torch.Tensor 268 | (B, L', N', 3) tensor of the xyz positions of the convolved features 269 | original_xyzs : torch.Tensor 270 | (B, L, N, 3) tensor of the xyz positions of the original points 271 | 272 | features : torch.Tensor 273 | (B, L', C', N') tensor of the features to be propigated to 274 | original_features : torch.Tensor 275 | (B, L, C, N) tensor of original point features for skip connection 276 | 277 | Returns 278 | ------- 279 | new_features : torch.Tensor 280 | (B, L, C", N) tensor of the features of the unknown features 281 | """ 282 | 283 | L1 = original_xyzs.size(1) 284 | N1 = original_xyzs.size(2) 285 | 286 | L2 = xyzs.size(1) 287 | N2 = xyzs.size(2) 288 | 289 | if self.temporal_kernel_size > 1 and self.temporal_stride > 1: 290 | assert ((L2 - 1) * self.temporal_stride + sum(self.temporal_padding) + self.temporal_kernel_size == L1), "PSTConvTranspose: Temporal parameter error!" 291 | 292 | xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1) 293 | xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs] 294 | 295 | features = torch.split(tensor=features, split_size_or_sections=1, dim=1) 296 | features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in features] 297 | 298 | new_xyzs = original_xyzs 299 | 300 | original_xyzs = torch.split(tensor=original_xyzs, split_size_or_sections=1, dim=1) 301 | original_xyzs = [torch.squeeze(input=original_xyz, dim=1).contiguous() for original_xyz in original_xyzs] 302 | 303 | if original_features is not None: 304 | original_features = torch.split(tensor=original_features, split_size_or_sections=1, dim=1) 305 | original_features = [torch.squeeze(input=feature, dim=1).contiguous() for feature in original_features] 306 | 307 | # temporal transposed convolution 308 | temporal_trans_features = [] 309 | for feature in features: 310 | feature = self.temporal_conv(feature) 311 | feature = torch.split(tensor=feature, split_size_or_sections=self.mid_planes, dim=1) 312 | temporal_trans_features.append(feature) 313 | 314 | # temporal interpolation 315 | temporal_interpolated_xyzs = [] 316 | temporal_interpolated_features = [] 317 | 318 | middles = [] 319 | deltas = [] 320 | for t2 in range(1, L2+1): 321 | middle = t2 + (t2-1)*(self.temporal_stride-1) + self.temporal_radius + self.temporal_padding[0] 322 | middles.append(middle) 323 | delta = range(middle - self.temporal_radius, middle + self.temporal_radius + self.temporal_padding[1] + 1) 324 | deltas.append(delta) 325 | 326 | for t1 in range(1, L1+1): 327 | seed_xyzs = [] 328 | seed_features = [] 329 | for t2 in range(L2): 330 | delta = deltas[t2] 331 | if t1 in delta: 332 | seed_xyzs.append(xyzs[t2]) 333 | seed_feature = temporal_trans_features[t2][t1-middles[t2]+self.temporal_radius] 334 | if self.batch_norm: 335 | seed_feature = self.batch_norm(seed_feature) 336 | if self.activation: 337 | seed_feature = self.activation(seed_feature) 338 | seed_features.append(seed_feature) 339 | seed_xyzs = torch.cat(seed_xyzs, dim=1) 340 | seed_features = torch.cat(seed_features, dim=2) 341 | temporal_interpolated_xyzs.append(seed_xyzs) 342 | temporal_interpolated_features.append(seed_features) 343 | 344 | # spatial interpolation 345 | new_features = [] 346 | for t1 in range(L1): 347 | neighbor_xyz = temporal_interpolated_xyzs[t1] # [B, N', 3] 348 | anchor_xyz = original_xyzs[t1] # [B, N, 3] 349 | 350 | dist, idx = pointnet2_utils.three_nn(anchor_xyz, neighbor_xyz) 351 | 352 | dist_recip = 1.0 / (dist + 1e-8) 353 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 354 | weight = dist_recip / norm 355 | 356 | interpolated_feats = pointnet2_utils.three_interpolate(temporal_interpolated_features[t1], idx, weight) 357 | 358 | if original_features is not None: 359 | new_feature = torch.cat([interpolated_feats, original_features[t1]], dim=1) 360 | else: 361 | new_feature = interpolated_feats 362 | 363 | new_feature = self.spatial_conv(new_feature) 364 | 365 | new_features.append(new_feature) 366 | 367 | new_features = torch.stack(tensors=new_features, dim=1) 368 | 369 | return new_xyzs, new_features 370 | 371 | if __name__ == '__main__': 372 | xyzs = torch.zeros(4, 8, 512, 3).cuda() 373 | features = torch.zeros(4, 8, 16, 512).cuda() 374 | 375 | conv = PSTConv(in_planes=16, 376 | mid_planes=32, 377 | out_planes=64, 378 | spatial_kernel_size=[1.0, 3], 379 | temporal_kernel_size=3, 380 | spatial_stride=2, 381 | temporal_stride=3, 382 | temporal_padding=[1, 0], 383 | padding_mode="replicate").cuda() 384 | 385 | new_xyzs, new_features = conv(xyzs, features) 386 | 387 | deconv = PSTConvTranspose(in_planes=64, 388 | mid_planes=128, 389 | out_planes=256, 390 | temporal_kernel_size=3, 391 | temporal_stride=3, 392 | temporal_padding=[-1, 0], 393 | original_in_planes=16).cuda() 394 | 395 | out_xyzs, out_features = deconv(new_xyzs, xyzs, new_features, features) 396 | 397 | print("-----------------------------") 398 | print("Input:") 399 | print(xyzs.shape) 400 | print(features.shape) 401 | print("-----------------------------") 402 | print("PST convolution:") 403 | print(new_xyzs.shape) 404 | print(new_features.shape) 405 | print("-----------------------------") 406 | print("PST transposed convolution:") 407 | print(out_xyzs.shape) 408 | print(out_features.shape) 409 | print("-----------------------------") 410 | -------------------------------------------------------------------------------- /modules/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | 10 | _ext_src_root = "_ext_src" 11 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 12 | "{}/src/*.cu".format(_ext_src_root) 13 | ) 14 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 15 | 16 | setup( 17 | name='pointnet2', 18 | ext_modules=[ 19 | CUDAExtension( 20 | name='pointnet2._ext', 21 | sources=_ext_sources, 22 | extra_compile_args={ 23 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 24 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 25 | }, 26 | ) 27 | ], 28 | cmdclass={ 29 | 'build_ext': BuildExtension 30 | } 31 | ) 32 | -------------------------------------------------------------------------------- /scripts/depth2point4ntu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from matplotlib.image import imread 5 | from glob import glob 6 | import argparse 7 | parser = argparse.ArgumentParser(description='Depth to Point Cloud') 8 | 9 | parser.add_argument('--input', default='/home/hehefan/Data/ntu/nturgb+d_depth_masked', type=str) 10 | parser.add_argument('--output', default='/scratch/ntu/video', type=str) 11 | parser.add_argument('-n', '--action', type=int) 12 | 13 | args = parser.parse_args() 14 | 15 | 16 | W = 512 17 | H = 424 18 | 19 | ''' 20 | def mkdir(directory): 21 | if not os.path.exists(directory): 22 | os.makedirs(directory) 23 | 24 | mkdir(args.output) 25 | ''' 26 | 27 | xx, yy = np.meshgrid(np.arange(W), np.arange(H)) 28 | focal = 280 29 | 30 | for video_path in sorted(glob('%s/*A0%02d'%(args.input, args.action))): 31 | video_name = video_path.split('/')[-1] 32 | 33 | point_clouds = [] 34 | for img_name in sorted(os.listdir(video_path)): 35 | img_path = os.path.join(video_path, img_name) 36 | img = imread(img_path) # (H, W) 37 | 38 | depth_min = img[img > 0].min() 39 | depth_map = img 40 | 41 | x = xx[depth_map > 0] 42 | y = yy[depth_map > 0] 43 | z = depth_map[depth_map > 0] 44 | x = (x - W / 2) / focal * z 45 | y = (y - H / 2) / focal * z 46 | 47 | points = np.stack([x, y, z], axis=-1) 48 | point_clouds.append(points) 49 | 50 | np.savez_compressed(os.path.join(args.output, video_name + '.npz'), data=point_clouds) 51 | print('Action %02d finished!'%args.action) 52 | -------------------------------------------------------------------------------- /scripts/depth2point4ntu.sh: -------------------------------------------------------------------------------- 1 | nohup python depth2point4ntu.py -n 1 & 2 | nohup python depth2point4ntu.py -n 2 & 3 | nohup python depth2point4ntu.py -n 3 & 4 | nohup python depth2point4ntu.py -n 4 & 5 | nohup python depth2point4ntu.py -n 5 & 6 | nohup python depth2point4ntu.py -n 6 & 7 | nohup python depth2point4ntu.py -n 7 & 8 | nohup python depth2point4ntu.py -n 8 & 9 | nohup python depth2point4ntu.py -n 9 & 10 | nohup python depth2point4ntu.py -n 10 & 11 | nohup python depth2point4ntu.py -n 11 & 12 | nohup python depth2point4ntu.py -n 12 & 13 | nohup python depth2point4ntu.py -n 13 & 14 | nohup python depth2point4ntu.py -n 14 & 15 | nohup python depth2point4ntu.py -n 15 & 16 | nohup python depth2point4ntu.py -n 16 & 17 | nohup python depth2point4ntu.py -n 17 & 18 | nohup python depth2point4ntu.py -n 18 & 19 | nohup python depth2point4ntu.py -n 19 & 20 | nohup python depth2point4ntu.py -n 20 & 21 | nohup python depth2point4ntu.py -n 21 & 22 | nohup python depth2point4ntu.py -n 22 & 23 | nohup python depth2point4ntu.py -n 23 & 24 | nohup python depth2point4ntu.py -n 24 & 25 | nohup python depth2point4ntu.py -n 25 & 26 | nohup python depth2point4ntu.py -n 26 & 27 | nohup python depth2point4ntu.py -n 27 & 28 | nohup python depth2point4ntu.py -n 28 & 29 | nohup python depth2point4ntu.py -n 29 & 30 | nohup python depth2point4ntu.py -n 30 & 31 | nohup python depth2point4ntu.py -n 31 & 32 | nohup python depth2point4ntu.py -n 32 & 33 | nohup python depth2point4ntu.py -n 33 & 34 | nohup python depth2point4ntu.py -n 34 & 35 | nohup python depth2point4ntu.py -n 35 & 36 | nohup python depth2point4ntu.py -n 36 & 37 | nohup python depth2point4ntu.py -n 37 & 38 | nohup python depth2point4ntu.py -n 38 & 39 | nohup python depth2point4ntu.py -n 39 & 40 | nohup python depth2point4ntu.py -n 40 & 41 | nohup python depth2point4ntu.py -n 41 & 42 | nohup python depth2point4ntu.py -n 42 & 43 | nohup python depth2point4ntu.py -n 43 & 44 | nohup python depth2point4ntu.py -n 44 & 45 | nohup python depth2point4ntu.py -n 45 & 46 | nohup python depth2point4ntu.py -n 46 & 47 | nohup python depth2point4ntu.py -n 47 & 48 | nohup python depth2point4ntu.py -n 48 & 49 | nohup python depth2point4ntu.py -n 49 & 50 | nohup python depth2point4ntu.py -n 50 & 51 | nohup python depth2point4ntu.py -n 51 & 52 | nohup python depth2point4ntu.py -n 52 & 53 | nohup python depth2point4ntu.py -n 53 & 54 | nohup python depth2point4ntu.py -n 54 & 55 | nohup python depth2point4ntu.py -n 55 & 56 | nohup python depth2point4ntu.py -n 56 & 57 | nohup python depth2point4ntu.py -n 57 & 58 | nohup python depth2point4ntu.py -n 58 & 59 | nohup python depth2point4ntu.py -n 59 & 60 | nohup python depth2point4ntu.py -n 60 & 61 | nohup python depth2point4ntu.py -n 61 & 62 | nohup python depth2point4ntu.py -n 62 & 63 | nohup python depth2point4ntu.py -n 63 & 64 | nohup python depth2point4ntu.py -n 64 & 65 | nohup python depth2point4ntu.py -n 65 & 66 | nohup python depth2point4ntu.py -n 66 & 67 | nohup python depth2point4ntu.py -n 67 & 68 | nohup python depth2point4ntu.py -n 68 & 69 | nohup python depth2point4ntu.py -n 69 & 70 | nohup python depth2point4ntu.py -n 70 & 71 | nohup python depth2point4ntu.py -n 71 & 72 | nohup python depth2point4ntu.py -n 72 & 73 | nohup python depth2point4ntu.py -n 73 & 74 | nohup python depth2point4ntu.py -n 74 & 75 | nohup python depth2point4ntu.py -n 75 & 76 | nohup python depth2point4ntu.py -n 76 & 77 | nohup python depth2point4ntu.py -n 77 & 78 | nohup python depth2point4ntu.py -n 78 & 79 | nohup python depth2point4ntu.py -n 79 & 80 | nohup python depth2point4ntu.py -n 80 & 81 | nohup python depth2point4ntu.py -n 81 & 82 | nohup python depth2point4ntu.py -n 82 & 83 | nohup python depth2point4ntu.py -n 83 & 84 | nohup python depth2point4ntu.py -n 84 & 85 | nohup python depth2point4ntu.py -n 85 & 86 | nohup python depth2point4ntu.py -n 86 & 87 | nohup python depth2point4ntu.py -n 87 & 88 | nohup python depth2point4ntu.py -n 88 & 89 | nohup python depth2point4ntu.py -n 89 & 90 | nohup python depth2point4ntu.py -n 90 & 91 | nohup python depth2point4ntu.py -n 91 & 92 | nohup python depth2point4ntu.py -n 92 & 93 | nohup python depth2point4ntu.py -n 93 & 94 | nohup python depth2point4ntu.py -n 94 & 95 | nohup python depth2point4ntu.py -n 95 & 96 | nohup python depth2point4ntu.py -n 96 & 97 | nohup python depth2point4ntu.py -n 97 & 98 | nohup python depth2point4ntu.py -n 98 & 99 | nohup python depth2point4ntu.py -n 99 & 100 | nohup python depth2point4ntu.py -n 100 & 101 | nohup python depth2point4ntu.py -n 101 & 102 | nohup python depth2point4ntu.py -n 102 & 103 | nohup python depth2point4ntu.py -n 103 & 104 | nohup python depth2point4ntu.py -n 104 & 105 | nohup python depth2point4ntu.py -n 105 & 106 | nohup python depth2point4ntu.py -n 106 & 107 | nohup python depth2point4ntu.py -n 107 & 108 | nohup python depth2point4ntu.py -n 108 & 109 | nohup python depth2point4ntu.py -n 109 & 110 | nohup python depth2point4ntu.py -n 110 & 111 | nohup python depth2point4ntu.py -n 111 & 112 | nohup python depth2point4ntu.py -n 112 & 113 | nohup python depth2point4ntu.py -n 113 & 114 | nohup python depth2point4ntu.py -n 114 & 115 | nohup python depth2point4ntu.py -n 115 & 116 | nohup python depth2point4ntu.py -n 116 & 117 | nohup python depth2point4ntu.py -n 117 & 118 | nohup python depth2point4ntu.py -n 118 & 119 | nohup python depth2point4ntu.py -n 119 & 120 | nohup python depth2point4ntu.py -n 120 & 121 | -------------------------------------------------------------------------------- /train-msr.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import os 4 | import time 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | from torch.utils.data.dataloader import default_collate 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import torchvision 13 | from torchvision import transforms 14 | 15 | import utils 16 | 17 | from datasets.msr import MSRAction3D 18 | import models.sequence_classification as Models 19 | 20 | def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq): 21 | model.train() 22 | metric_logger = utils.MetricLogger(delimiter=" ") 23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 24 | metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) 25 | 26 | header = 'Epoch: [{}]'.format(epoch) 27 | for clip, target, _ in metric_logger.log_every(data_loader, print_freq, header): 28 | start_time = time.time() 29 | clip, target = clip.to(device), target.to(device) 30 | output = model(clip) 31 | loss = criterion(output, target) 32 | 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | 37 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 38 | batch_size = clip.shape[0] 39 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 40 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 41 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 42 | metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time)) 43 | lr_scheduler.step() 44 | sys.stdout.flush() 45 | 46 | def evaluate(model, criterion, data_loader, device): 47 | model.eval() 48 | metric_logger = utils.MetricLogger(delimiter=" ") 49 | header = 'Test:' 50 | video_prob = {} 51 | video_label = {} 52 | with torch.no_grad(): 53 | for clip, target, video_idx in metric_logger.log_every(data_loader, 100, header): 54 | clip = clip.to(device, non_blocking=True) 55 | target = target.to(device, non_blocking=True) 56 | output = model(clip) 57 | loss = criterion(output, target) 58 | 59 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 60 | prob = F.softmax(input=output, dim=1) 61 | 62 | # FIXME need to take into account that the datasets 63 | # could have been padded in distributed setup 64 | batch_size = clip.shape[0] 65 | target = target.cpu().numpy() 66 | video_idx = video_idx.cpu().numpy() 67 | prob = prob.cpu().numpy() 68 | for i in range(0, batch_size): 69 | idx = video_idx[i] 70 | if idx in video_prob: 71 | video_prob[idx] += prob[i] 72 | else: 73 | video_prob[idx] = prob[i] 74 | video_label[idx] = target[i] 75 | metric_logger.update(loss=loss.item()) 76 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 77 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 78 | # gather the stats from all processes 79 | metric_logger.synchronize_between_processes() 80 | 81 | print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'.format(top1=metric_logger.acc1, top5=metric_logger.acc5)) 82 | 83 | # video level prediction 84 | video_pred = {k: np.argmax(v) for k, v in video_prob.items()} 85 | pred_correct = [video_pred[k]==video_label[k] for k in video_pred] 86 | total_acc = np.mean(pred_correct) 87 | 88 | class_count = [0] * data_loader.dataset.num_classes 89 | class_correct = [0] * data_loader.dataset.num_classes 90 | 91 | for k, v in video_pred.items(): 92 | label = video_label[k] 93 | class_count[label] += 1 94 | class_correct[label] += (v==label) 95 | class_acc = [c/float(s) for c, s in zip(class_correct, class_count)] 96 | 97 | print(' * Video Acc@1 %f'%total_acc) 98 | print(' * Class Acc@1 %s'%str(class_acc)) 99 | 100 | return total_acc 101 | 102 | 103 | def main(args): 104 | 105 | if args.output_dir: 106 | utils.mkdir(args.output_dir) 107 | 108 | print(args) 109 | print("torch version: ", torch.__version__) 110 | print("torchvision version: ", torchvision.__version__) 111 | 112 | np.random.seed(args.seed) 113 | torch.manual_seed(args.seed) 114 | torch.cuda.manual_seed(args.seed) 115 | torch.backends.cudnn.deterministic = True 116 | torch.backends.cudnn.benchmark = False 117 | 118 | device = torch.device('cuda') 119 | 120 | # Data loading code 121 | print("Loading data") 122 | 123 | st = time.time() 124 | 125 | dataset = MSRAction3D( 126 | root=args.data_path, 127 | frames_per_clip=args.clip_len, 128 | frame_inverval=args.frame_inverval, 129 | num_points=args.num_points, 130 | train=True 131 | ) 132 | 133 | dataset_test = MSRAction3D( 134 | root=args.data_path, 135 | frames_per_clip=args.clip_len, 136 | frame_inverval=args.frame_inverval, 137 | num_points=args.num_points, 138 | train=False 139 | ) 140 | 141 | print("Creating data loaders") 142 | 143 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 144 | 145 | data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) 146 | 147 | print("Creating model") 148 | Model = getattr(Models, args.model) 149 | model = Model(radius=args.radius, nsamples=args.nsamples, num_classes=dataset.num_classes) 150 | if torch.cuda.device_count() > 1: 151 | model = nn.DataParallel(model) 152 | model.to(device) 153 | 154 | criterion = nn.CrossEntropyLoss() 155 | 156 | lr = args.lr 157 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) 158 | 159 | # convert scheduler to be per iteration, not per epoch, for warmup that lasts 160 | # between different epochs 161 | warmup_iters = args.lr_warmup_epochs * len(data_loader) 162 | lr_milestones = [len(data_loader) * m for m in args.lr_milestones] 163 | lr_scheduler = utils.WarmupMultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5) 164 | 165 | model_without_ddp = model 166 | 167 | if args.resume: 168 | checkpoint = torch.load(args.resume, map_location='cpu') 169 | model_without_ddp.load_state_dict(checkpoint['model']) 170 | optimizer.load_state_dict(checkpoint['optimizer']) 171 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 172 | args.start_epoch = checkpoint['epoch'] + 1 173 | 174 | 175 | print("Start training") 176 | start_time = time.time() 177 | acc = 0 178 | for epoch in range(args.start_epoch, args.epochs): 179 | train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq) 180 | 181 | acc = max(acc, evaluate(model, criterion, data_loader_test, device=device)) 182 | 183 | if args.output_dir: 184 | checkpoint = { 185 | 'model': model_without_ddp.state_dict(), 186 | 'optimizer': optimizer.state_dict(), 187 | 'lr_scheduler': lr_scheduler.state_dict(), 188 | 'epoch': epoch, 189 | 'args': args} 190 | utils.save_on_master( 191 | checkpoint, 192 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 193 | utils.save_on_master( 194 | checkpoint, 195 | os.path.join(args.output_dir, 'checkpoint.pth')) 196 | 197 | total_time = time.time() - start_time 198 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 199 | print('Training time {}'.format(total_time_str)) 200 | print('Accuracy {}'.format(acc)) 201 | 202 | 203 | def parse_args(): 204 | import argparse 205 | parser = argparse.ArgumentParser(description='PSTNet Training') 206 | 207 | parser.add_argument('--data-path', default='data/MSR-Action3D', type=str, help='dataset') 208 | parser.add_argument('--seed', default=0, type=int, help='random seed') 209 | parser.add_argument('--model', default='MSRAction', type=str, help='model') 210 | parser.add_argument('--radius', default=0.5, type=float, help='radius for the ball query') 211 | parser.add_argument('--nsamples', default=9, type=int, help='number of neighbors for the ball query') 212 | parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') 213 | parser.add_argument('--frame-interval', default=1, type=int, metavar='N', help='interval between sampled frames') 214 | parser.add_argument('--num-points', default=2048, type=int, metavar='N', help='number of points per frame') 215 | parser.add_argument('-b', '--batch-size', default=16, type=int) 216 | parser.add_argument('--epochs', default=35, type=int, metavar='N', help='number of total epochs to run') 217 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', help='number of data loading workers (default: 16)') 218 | parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') 219 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 220 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 221 | parser.add_argument('--lr-milestones', nargs='+', default=[20, 30], type=int, help='decrease lr on milestones') 222 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 223 | parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') 224 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 225 | parser.add_argument('--output-dir', default='output/MAR-Action3D', type=str, help='path where to save') 226 | parser.add_argument('--resume', default='', help='resume from checkpoint') 227 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') 228 | 229 | args = parser.parse_args() 230 | 231 | return args 232 | 233 | 234 | if __name__ == "__main__": 235 | args = parse_args() 236 | main(args) 237 | -------------------------------------------------------------------------------- /train-ntu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import os 4 | import time 5 | import sys 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | from torch.utils.data.dataloader import default_collate 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import torchvision 13 | from torchvision import transforms 14 | 15 | import utils 16 | from datasets.ntu60 import NTU60Subject 17 | import models.sequence_classification as Models 18 | 19 | def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq): 20 | model.train() 21 | metric_logger = utils.MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 23 | metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) 24 | 25 | header = 'Epoch: [{}]'.format(epoch) 26 | for clip, target, _ in metric_logger.log_every(data_loader, print_freq, header): 27 | start_time = time.time() 28 | clip, target = clip.to(device), target.to(device) 29 | output = model(clip) 30 | loss = criterion(output, target) 31 | 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | 36 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 37 | batch_size = clip.shape[0] 38 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 39 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 40 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 41 | metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time)) 42 | lr_scheduler.step() 43 | sys.stdout.flush() 44 | 45 | def evaluate(model, criterion, data_loader, device): 46 | model.eval() 47 | metric_logger = utils.MetricLogger(delimiter=" ") 48 | header = 'Test:' 49 | video_prob = {} 50 | video_label = {} 51 | with torch.no_grad(): 52 | for clip, target, video_idx in metric_logger.log_every(data_loader, 100, header): 53 | clip = clip.to(device, non_blocking=True) 54 | target = target.to(device, non_blocking=True) 55 | output = model(clip) 56 | loss = criterion(output, target) 57 | 58 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 59 | prob = F.softmax(input=output, dim=1) 60 | 61 | # FIXME need to take into account that the datasets 62 | # could have been padded in distributed setup 63 | batch_size = clip.shape[0] 64 | target = target.cpu().numpy() 65 | video_idx = video_idx.cpu().numpy() 66 | prob = prob.cpu().numpy() 67 | for i in range(0, batch_size): 68 | idx = video_idx[i] 69 | if idx in video_prob: 70 | video_prob[idx] += prob[i] 71 | else: 72 | video_prob[idx] = prob[i] 73 | video_label[idx] = target[i] 74 | metric_logger.update(loss=loss.item()) 75 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 76 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 77 | # gather the stats from all processes 78 | metric_logger.synchronize_between_processes() 79 | 80 | print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}'.format(top1=metric_logger.acc1, top5=metric_logger.acc5)) 81 | 82 | # video level prediction 83 | video_pred = {k: np.argmax(v) for k, v in video_prob.items()} 84 | pred_correct = [video_pred[k]==video_label[k] for k in video_pred] 85 | total_acc = np.mean(pred_correct) 86 | 87 | class_count = [0] * data_loader.dataset.num_classes 88 | class_correct = [0] * data_loader.dataset.num_classes 89 | 90 | for k, v in video_pred.items(): 91 | label = video_label[k] 92 | class_count[label] += 1 93 | class_correct[label] += (v==label) 94 | class_acc = [c/float(s) for c, s in zip(class_correct, class_count)] 95 | 96 | print(' * Video Acc@1 %f'%total_acc) 97 | print(' * Class Acc@1 %s'%str(class_acc)) 98 | 99 | return total_acc 100 | 101 | 102 | def main(args): 103 | 104 | if args.output_dir: 105 | utils.mkdir(args.output_dir) 106 | 107 | print(args) 108 | print("torch version: ", torch.__version__) 109 | print("torchvision version: ", torchvision.__version__) 110 | 111 | np.random.seed(args.seed) 112 | torch.manual_seed(args.seed) 113 | torch.cuda.manual_seed(args.seed) 114 | torch.backends.cudnn.deterministic = True 115 | torch.backends.cudnn.benchmark = False 116 | 117 | device = torch.device('cuda') 118 | 119 | # Data loading code 120 | print("Loading data") 121 | 122 | st = time.time() 123 | 124 | dataset = NTU60Subject( 125 | root=args.data_path, 126 | meta=args.data_meta, 127 | frames_per_clip=args.clip_len, 128 | step_between_clips=args.frame_step, 129 | num_points=args.num_points, 130 | train=True 131 | ) 132 | 133 | dataset_test = NTU60Subject( 134 | root=args.data_path, 135 | meta=args.data_meta, 136 | frames_per_clip=args.clip_len, 137 | step_between_clips=args.frame_step, 138 | num_points=args.num_points, 139 | train=False 140 | ) 141 | 142 | print("Creating data loaders") 143 | 144 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 145 | 146 | data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) 147 | 148 | print("Creating model") 149 | Model = getattr(Models, args.model) 150 | model = Model(radius=args.radius, nsamples=args.nsamples, num_classes=dataset.num_classes) 151 | if torch.cuda.device_count() > 1: 152 | model = nn.DataParallel(model) 153 | model.to(device) 154 | 155 | criterion = nn.CrossEntropyLoss() 156 | 157 | lr = args.lr 158 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) 159 | 160 | # convert scheduler to be per iteration, not per epoch, for warmup that lasts 161 | # between different epochs 162 | warmup_iters = args.lr_warmup_epochs * len(data_loader) 163 | lr_milestones = [len(data_loader) * m for m in args.lr_milestones] 164 | lr_scheduler = utils.WarmupMultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5) 165 | 166 | model_without_ddp = model 167 | 168 | if args.resume: 169 | checkpoint = torch.load(args.resume, map_location='cpu') 170 | model_without_ddp.load_state_dict(checkpoint['model']) 171 | optimizer.load_state_dict(checkpoint['optimizer']) 172 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 173 | args.start_epoch = checkpoint['epoch'] + 1 174 | 175 | 176 | print("Start training") 177 | start_time = time.time() 178 | acc = 0 179 | for epoch in range(args.start_epoch, args.epochs): 180 | train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq) 181 | 182 | acc = max(acc, evaluate(model, criterion, data_loader_test, device=device)) 183 | 184 | if args.output_dir: 185 | checkpoint = { 186 | 'model': model_without_ddp.state_dict(), 187 | 'optimizer': optimizer.state_dict(), 188 | 'lr_scheduler': lr_scheduler.state_dict(), 189 | 'epoch': epoch, 190 | 'args': args} 191 | utils.save_on_master( 192 | checkpoint, 193 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 194 | utils.save_on_master( 195 | checkpoint, 196 | os.path.join(args.output_dir, 'checkpoint.pth')) 197 | 198 | total_time = time.time() - start_time 199 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 200 | print('Training time {}'.format(total_time_str)) 201 | print('Accuracy {}'.format(acc)) 202 | 203 | 204 | def parse_args(): 205 | import argparse 206 | parser = argparse.ArgumentParser(description='PSTNet Training') 207 | 208 | parser.add_argument('--data-path', default='data/video', type=str, help='dataset') 209 | parser.add_argument('--data-meta', default='data/ntu60.list', help='dataset') 210 | parser.add_argument('--seed', default=0, type=int, help='random seed') 211 | parser.add_argument('--model', default='NTU', type=str, help='model') 212 | parser.add_argument('--radius', default=0.1, type=float, help='radius for the ball query') 213 | parser.add_argument('--nsamples', default=9, type=int, help='number of neighbors for the ball query') 214 | parser.add_argument('--clip-len', default=23, type=int, metavar='N', help='number of frames per clip') 215 | parser.add_argument('--frame-step', default=2, type=int, metavar='N', help='steps between frame sampling') 216 | parser.add_argument('--num-points', default=2048, type=int, metavar='N', help='number of points per frame') 217 | parser.add_argument('-b', '--batch-size', default=16, type=int) 218 | parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') 219 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') 220 | parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') 221 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 222 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 223 | parser.add_argument('--lr-milestones', nargs='+', default=[5, 10], type=int, help='decrease lr on milestones') 224 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 225 | parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') 226 | parser.add_argument('--print-freq', default=100, type=int, help='print frequency') 227 | parser.add_argument('--output-dir', default='', type=str, help='path where to save') 228 | parser.add_argument('--resume', default='', help='resume from checkpoint') 229 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') 230 | 231 | args = parser.parse_args() 232 | 233 | return args 234 | 235 | 236 | if __name__ == "__main__": 237 | args = parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import time 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import errno 9 | from bisect import bisect_right 10 | import os 11 | 12 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 13 | def __init__( 14 | self, 15 | optimizer, 16 | milestones, 17 | gamma=0.1, 18 | warmup_factor=1.0 / 3, 19 | warmup_iters=5, 20 | warmup_method="linear", 21 | last_epoch=-1, 22 | ): 23 | if not milestones == sorted(milestones): 24 | raise ValueError( 25 | "Milestones should be a list of" " increasing integers. Got {}", 26 | milestones, 27 | ) 28 | 29 | if warmup_method not in ("constant", "linear"): 30 | raise ValueError( 31 | "Only 'constant' or 'linear' warmup_method accepted" 32 | "got {}".format(warmup_method) 33 | ) 34 | self.milestones = milestones 35 | self.gamma = gamma 36 | self.warmup_factor = warmup_factor 37 | self.warmup_iters = warmup_iters 38 | self.warmup_method = warmup_method 39 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 40 | 41 | def get_lr(self): 42 | warmup_factor = 1 43 | if self.last_epoch < self.warmup_iters: 44 | if self.warmup_method == "constant": 45 | warmup_factor = self.warmup_factor 46 | elif self.warmup_method == "linear": 47 | alpha = float(self.last_epoch) / self.warmup_iters 48 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 49 | return [ 50 | base_lr * 51 | warmup_factor * 52 | self.gamma ** bisect_right(self.milestones, self.last_epoch) 53 | for base_lr in self.base_lrs 54 | ] 55 | 56 | class SmoothedValue(object): 57 | """Track a series of values and provide access to smoothed values over a 58 | window or the global series average. 59 | """ 60 | 61 | def __init__(self, window_size=20, fmt=None): 62 | if fmt is None: 63 | fmt = "{median:.4f} ({global_avg:.4f})" 64 | self.deque = deque(maxlen=window_size) 65 | self.total = 0.0 66 | self.count = 0 67 | self.fmt = fmt 68 | 69 | def update(self, value, n=1): 70 | self.deque.append(value) 71 | self.count += n 72 | self.total += value * n 73 | 74 | def synchronize_between_processes(self): 75 | """ 76 | Warning: does not synchronize the deque! 77 | """ 78 | if not is_dist_avail_and_initialized(): 79 | return 80 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 81 | dist.barrier() 82 | dist.all_reduce(t) 83 | t = t.tolist() 84 | self.count = int(t[0]) 85 | self.total = t[1] 86 | 87 | @property 88 | def median(self): 89 | d = torch.tensor(list(self.deque)) 90 | return d.median().item() 91 | 92 | @property 93 | def avg(self): 94 | d = torch.tensor(list(self.deque), dtype=torch.float32) 95 | return d.mean().item() 96 | 97 | @property 98 | def global_avg(self): 99 | return self.total / self.count 100 | 101 | @property 102 | def max(self): 103 | return max(self.deque) 104 | 105 | @property 106 | def value(self): 107 | return self.deque[-1] 108 | 109 | def __str__(self): 110 | return self.fmt.format( 111 | median=self.median, 112 | avg=self.avg, 113 | global_avg=self.global_avg, 114 | max=self.max, 115 | value=self.value) 116 | 117 | 118 | class MetricLogger(object): 119 | def __init__(self, delimiter="\t"): 120 | self.meters = defaultdict(SmoothedValue) 121 | self.delimiter = delimiter 122 | 123 | def update(self, **kwargs): 124 | for k, v in kwargs.items(): 125 | if isinstance(v, torch.Tensor): 126 | v = v.item() 127 | assert isinstance(v, (float, int)) 128 | self.meters[k].update(v) 129 | 130 | def __getattr__(self, attr): 131 | if attr in self.meters: 132 | return self.meters[attr] 133 | if attr in self.__dict__: 134 | return self.__dict__[attr] 135 | raise AttributeError("'{}' object has no attribute '{}'".format( 136 | type(self).__name__, attr)) 137 | 138 | def __str__(self): 139 | loss_str = [] 140 | for name, meter in self.meters.items(): 141 | loss_str.append( 142 | "{}: {}".format(name, str(meter)) 143 | ) 144 | return self.delimiter.join(loss_str) 145 | 146 | def synchronize_between_processes(self): 147 | for meter in self.meters.values(): 148 | meter.synchronize_between_processes() 149 | 150 | def add_meter(self, name, meter): 151 | self.meters[name] = meter 152 | 153 | def log_every(self, iterable, print_freq, header=None): 154 | i = 0 155 | if not header: 156 | header = '' 157 | start_time = time.time() 158 | end = time.time() 159 | iter_time = SmoothedValue(fmt='{avg:.4f}') 160 | data_time = SmoothedValue(fmt='{avg:.4f}') 161 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 162 | if torch.cuda.is_available(): 163 | log_msg = self.delimiter.join([ 164 | header, 165 | '[{0' + space_fmt + '}/{1}]', 166 | 'eta: {eta}', 167 | '{meters}', 168 | 'time: {time}', 169 | 'data: {data}', 170 | 'max mem: {memory:.0f}' 171 | ]) 172 | else: 173 | log_msg = self.delimiter.join([ 174 | header, 175 | '[{0' + space_fmt + '}/{1}]', 176 | 'eta: {eta}', 177 | '{meters}', 178 | 'time: {time}', 179 | 'data: {data}' 180 | ]) 181 | MB = 1024.0 * 1024.0 182 | for obj in iterable: 183 | data_time.update(time.time() - end) 184 | yield obj 185 | iter_time.update(time.time() - end) 186 | if i % print_freq == 0: 187 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 188 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 189 | if torch.cuda.is_available(): 190 | print(log_msg.format( 191 | i, len(iterable), eta=eta_string, 192 | meters=str(self), 193 | time=str(iter_time), data=str(data_time), 194 | memory=torch.cuda.max_memory_allocated() / MB)) 195 | else: 196 | print(log_msg.format( 197 | i, len(iterable), eta=eta_string, 198 | meters=str(self), 199 | time=str(iter_time), data=str(data_time))) 200 | i += 1 201 | end = time.time() 202 | total_time = time.time() - start_time 203 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 204 | print('{} Total time: {}'.format(header, total_time_str)) 205 | 206 | 207 | def accuracy(output, target, topk=(1,)): 208 | """Computes the accuracy over the k top predictions for the specified values of k""" 209 | with torch.no_grad(): 210 | maxk = max(topk) 211 | batch_size = target.size(0) 212 | 213 | _, pred = output.topk(maxk, 1, True, True) 214 | pred = pred.t() 215 | correct = pred.eq(target[None]) 216 | 217 | res = [] 218 | for k in topk: 219 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 220 | res.append(correct_k * (100.0 / batch_size)) 221 | return res 222 | 223 | 224 | def mkdir(path): 225 | try: 226 | os.makedirs(path) 227 | except OSError as e: 228 | if e.errno != errno.EEXIST: 229 | raise 230 | 231 | 232 | def setup_for_distributed(is_master): 233 | """ 234 | This function disables printing when not in master process 235 | """ 236 | import builtins as __builtin__ 237 | builtin_print = __builtin__.print 238 | 239 | def print(*args, **kwargs): 240 | force = kwargs.pop('force', False) 241 | if is_master or force: 242 | builtin_print(*args, **kwargs) 243 | 244 | __builtin__.print = print 245 | 246 | 247 | def is_dist_avail_and_initialized(): 248 | if not dist.is_available(): 249 | return False 250 | if not dist.is_initialized(): 251 | return False 252 | return True 253 | 254 | 255 | def get_world_size(): 256 | if not is_dist_avail_and_initialized(): 257 | return 1 258 | return dist.get_world_size() 259 | 260 | 261 | def get_rank(): 262 | if not is_dist_avail_and_initialized(): 263 | return 0 264 | return dist.get_rank() 265 | 266 | 267 | def is_main_process(): 268 | return get_rank() == 0 269 | 270 | 271 | def save_on_master(*args, **kwargs): 272 | if is_main_process(): 273 | torch.save(*args, **kwargs) 274 | 275 | 276 | def init_distributed_mode(args): 277 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 278 | args.rank = int(os.environ["RANK"]) 279 | args.world_size = int(os.environ['WORLD_SIZE']) 280 | args.gpu = int(os.environ['LOCAL_RANK']) 281 | elif 'SLURM_PROCID' in os.environ: 282 | args.rank = int(os.environ['SLURM_PROCID']) 283 | args.gpu = args.rank % torch.cuda.device_count() 284 | elif hasattr(args, "rank"): 285 | pass 286 | else: 287 | print('Not using distributed mode') 288 | args.distributed = False 289 | return 290 | 291 | args.distributed = True 292 | 293 | torch.cuda.set_device(args.gpu) 294 | args.dist_backend = 'nccl' 295 | print('| distributed init (rank {}): {}'.format( 296 | args.rank, args.dist_url), flush=True) 297 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 298 | world_size=args.world_size, rank=args.rank) 299 | setup_for_distributed(args.rank == 0) 300 | 301 | --------------------------------------------------------------------------------