├── chamfer_distance ├── __init__.py ├── chamfer_distance.py ├── chamfer_distance.cu └── chamfer_distance.cpp ├── exceptions └── exceptions.py ├── modules ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── group_points.h │ │ ├── sampling.h │ │ ├── interpolate.h │ │ ├── utils.h │ │ └── cuda_utils.h │ └── src │ │ ├── bindings.cpp │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── sampling.cpp │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ └── sampling_gpu.cu ├── setup.py ├── pointnet2_test.py ├── transformer.py ├── pytorch_utils.py ├── pointnet2_utils.py └── pointnet2_modules.py ├── README.md ├── LICENSE ├── models ├── Folding.py ├── PSTNet.py ├── sim_model_test.py └── CLR_Model.py ├── data_aug ├── CLR_MSR.py └── sim_data_test.py ├── logger.py ├── datasets ├── msr.py └── MSRAction_all.list ├── .gitignore ├── 0-pretrain-test.py ├── 0-pretrain-msr.py ├── utils.py └── 1-linear-msr.py /chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class BaseSimCLRException(Exception): 2 | """Base exception""" 3 | 4 | 5 | class InvalidBackboneError(BaseSimCLRException): 6 | """Raised when the choice of backbone Convnet is invalid.""" 7 | 8 | 9 | class InvalidDatasetSelection(BaseSimCLRException): 10 | """Raised when the choice of dataset is invalid.""" 11 | -------------------------------------------------------------------------------- /modules/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /modules/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /modules/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /modules/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /modules/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointCMP: Contrastive Mask Prediction for Self-supervised Learning on Point Cloud Videos (CVPR2023) 2 | 3 | ## Introduction 4 | In this paper, we propose a contrastive mask prediction (PointCMP) framework for self-supervised learning on point cloud videos. Specifically, our PointCMP employs a two-branch structure to achieve simultaneous learning of both local and global spatio-temporal information. On top of this two-branch structure, a mutual similarity based augmentation module is developed to synthesize hard samples at the feature level. 5 | 6 | ## Installation 7 | The code is tested with Python 3.7.12, PyTorch 1.7.1, GCC 9.4.0, and CUDA 10.2. 8 | Compile the CUDA layers for [PointNet++](http://arxiv.org/abs/1706.02413): 9 | ``` 10 | cd modules 11 | python setup.py install 12 | ``` 13 | 14 | ## Related Repositories 15 | We thank the authors of related repositories: 16 | 1. PSTNet: https://github.com/hehefan/Point-Spatio-Temporal-Convolution 17 | 2. P4Transformer: https://github.com/hehefan/P4Transformer 18 | 19 | -------------------------------------------------------------------------------- /modules/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os 10 | 11 | _ext_src_root = "_ext_src" 12 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 13 | "{}/src/*.cu".format(_ext_src_root) 14 | ) 15 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 16 | 17 | headers = "-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), '_ext_src', 'include') 18 | 19 | setup( 20 | name='pointnet2', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='pointnet2._ext', 24 | sources=_ext_sources, 25 | extra_compile_args={ 26 | "cxx": ["-O2", headers], 27 | "nvcc": ["-O2", headers], 28 | }, 29 | ) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhiqiang Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /modules/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /modules/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /modules/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /modules/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.cpp_extension import load 5 | cd = load(name="cd", 6 | sources=["chamfer_distance/chamfer_distance.cpp", 7 | "chamfer_distance/chamfer_distance.cu"]) 8 | 9 | class ChamferDistanceFunction(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, xyz1, xyz2): 12 | batchsize, n, _ = xyz1.size() 13 | _, m, _ = xyz2.size() 14 | xyz1 = xyz1.contiguous() 15 | xyz2 = xyz2.contiguous() 16 | dist1 = torch.zeros(batchsize, n) 17 | dist2 = torch.zeros(batchsize, m) 18 | 19 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 20 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 21 | 22 | if not xyz1.is_cuda: 23 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 24 | else: 25 | dist1 = dist1.cuda() 26 | dist2 = dist2.cuda() 27 | idx1 = idx1.cuda() 28 | idx2 = idx2.cuda() 29 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 30 | 31 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 32 | 33 | return dist1, dist2 34 | 35 | @staticmethod 36 | def backward(ctx, graddist1, graddist2): 37 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 38 | 39 | graddist1 = graddist1.contiguous() 40 | graddist2 = graddist2.contiguous() 41 | 42 | gradxyz1 = torch.zeros(xyz1.size()) 43 | gradxyz2 = torch.zeros(xyz2.size()) 44 | 45 | if not graddist1.is_cuda: 46 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 47 | else: 48 | gradxyz1 = gradxyz1.cuda() 49 | gradxyz2 = gradxyz2.cuda() 50 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 51 | 52 | return gradxyz1, gradxyz2 53 | 54 | 55 | class ChamferDistance(torch.nn.Module): 56 | def forward(self, xyz1, xyz2): 57 | return ChamferDistanceFunction.apply(xyz1, xyz2) 58 | -------------------------------------------------------------------------------- /models/Folding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def get_and_init_FC_layer(din, dout): 11 | li = nn.Linear(din, dout) 12 | # init weights/bias 13 | nn.init.xavier_uniform_(li.weight.data, gain=nn.init.calculate_gain('relu')) 14 | li.bias.data.fill_(0.) 15 | return li 16 | 17 | 18 | def get_MLP_layers(dims, doLastRelu): 19 | # dims: (C_in,512,512,3) 20 | layers = [] 21 | for i in range(1, len(dims)): 22 | layers.append(get_and_init_FC_layer(dims[i-1], dims[i])) 23 | if i==len(dims)-1 and not doLastRelu: 24 | continue 25 | layers.append(nn.ReLU()) 26 | return layers 27 | 28 | 29 | class FoldingNetSingle(nn.Module): 30 | def __init__(self, dims): 31 | super(FoldingNetSingle, self).__init__() 32 | self.mlp = PointwiseMLP(dims, doLastRelu=False) 33 | 34 | def forward(self, X): 35 | return self.mlp.forward(X) 36 | 37 | 38 | class PointwiseMLP(nn.Sequential): 39 | '''Nxdin ->Nxd1->Nxd2->...-> Nxdout''' 40 | def __init__(self, dims, doLastRelu=False): 41 | layers = get_MLP_layers(dims, doLastRelu) 42 | super(PointwiseMLP, self).__init__(*layers) 43 | 44 | 45 | # following foldingnet 46 | 47 | class FoldingDecoder(nn.Module): 48 | def __init__(self, token_dim): 49 | super(FoldingDecoder, self).__init__() 50 | 51 | self.Fold1 = FoldingNetSingle((token_dim, 512, 512, 3)) # 3MLP 52 | self.Fold2 = FoldingNetSingle((token_dim+3, 512, 512, 3)) # 3MLP 53 | 54 | 55 | def forward(self, features): 56 | 57 | global_features = torch.mean(features, dim=1, keepdim=True) # [B, 1, C] 58 | local_features_2fold = features + global_features # [B, N, C] 59 | 60 | fold_xyz = self.Fold1(local_features_2fold) 61 | fold_xyz = torch.cat((local_features_2fold, fold_xyz), dim=-1) 62 | fold_xyz = self.Fold2(fold_xyz) # [B, N, 3] 63 | 64 | return fold_xyz -------------------------------------------------------------------------------- /modules/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /modules/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.GELU(), 50 | nn.Dropout(dropout) 51 | ) if project_out else nn.Identity() 52 | 53 | def forward(self, x): 54 | b, n, _, h = *x.shape, self.heads 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 57 | 58 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 59 | 60 | attn = dots.softmax(dim=-1) 61 | 62 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | out = self.to_out(out) 65 | return out 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 74 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) 79 | x = ff(x) 80 | return x 81 | -------------------------------------------------------------------------------- /data_aug/CLR_MSR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import copy 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class CLRMSRSubject(Dataset): 10 | def __init__(self, root, meta, frames_per_clip=16, step_between_clips=1, num_points=2048, sub_clips=4, step_between_frames=1, train=True): 11 | super(CLRMSRSubject, self).__init__() 12 | 13 | self.sub_clips = sub_clips 14 | self.root = root 15 | self.videos = [] 16 | self.index_map = [] 17 | index = 0 18 | 19 | with open(meta, 'r') as f: 20 | for line in f: 21 | name, nframes = line.split() 22 | if train: 23 | if int(name.split('_')[1].split('s')[1]) <= 5: 24 | nframes = int(nframes) 25 | for t in range(0, nframes-step_between_frames*(frames_per_clip-1), step_between_clips): 26 | self.index_map.append((index, t)) 27 | index += 1 28 | self.videos.append(os.path.join(root, name+'.npz')) 29 | 30 | self.frames_per_clip = frames_per_clip 31 | self.step_between_clips = step_between_clips 32 | self.step_between_frames = step_between_frames 33 | self.num_points = num_points 34 | self.train = train 35 | 36 | 37 | def __len__(self): 38 | return len(self.index_map) 39 | 40 | 41 | def __getitem__(self, idx): 42 | index, t = self.index_map[idx] 43 | 44 | video_name = self.videos[index] 45 | video = np.load(video_name, allow_pickle=True)['point_clouds'] 46 | 47 | clip = [video[t+i*self.step_between_frames] for i in range(self.frames_per_clip)] 48 | for i, p in enumerate(clip): 49 | if p.shape[0] > self.num_points: 50 | r = np.random.choice(p.shape[0], size=self.num_points, replace=False) 51 | else: 52 | repeat, residue = self.num_points // p.shape[0], self.num_points % p.shape[0] 53 | r = np.random.choice(p.shape[0], size=residue, replace=False) 54 | r = np.concatenate([np.arange(p.shape[0]) for _ in range(repeat)] + [r], axis=0) 55 | clip[i] = p[r, :] 56 | clip = np.array(clip) 57 | 58 | if self.train: 59 | # scale the points 60 | scales = np.random.uniform(0.9, 1.1, size=3) 61 | clip = clip * scales 62 | 63 | clip = clip.astype(np.float32) / 300 # [L, N, 3] 64 | 65 | clips = np.split(clip, indices_or_sections=self.sub_clips, axis=0) # S*[L', N, 3] 66 | clips = np.array(clips) # [S, L', N, 3] 67 | 68 | return clips, index 69 | 70 | 71 | if __name__ == '__main__': 72 | np.random.seed(0) 73 | dataset = ContrastiveLearningDataset(root='./MSRAction') 74 | clips, video_index = dataset[0] 75 | print(len(dataset)) 76 | -------------------------------------------------------------------------------- /modules/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /modules/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /data_aug/sim_data_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import copy 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class CLRMSRSubject(Dataset): 10 | def __init__(self, root, meta, frames_per_clip=16, step_between_clips=1, num_points=2048, sub_clips=4, step_between_frames=1, train=True): 11 | super(CLRMSRSubject, self).__init__() 12 | 13 | self.sub_clips = sub_clips 14 | self.root = root 15 | self.videos = [] 16 | self.index_map = [] 17 | index = 0 18 | 19 | with open(meta, 'r') as f: 20 | for line in f: 21 | name, nframes = line.split() 22 | if train: 23 | if int(name.split('_')[1].split('s')[1]) <= 5: 24 | nframes = int(nframes) 25 | for t in range(0, nframes-step_between_frames*(frames_per_clip-1), step_between_clips): 26 | self.index_map.append((index, t)) 27 | index += 1 28 | self.videos.append(os.path.join(root, name+'.npz')) 29 | 30 | self.frames_per_clip = frames_per_clip 31 | self.step_between_clips = step_between_clips 32 | self.step_between_frames = step_between_frames 33 | self.num_points = num_points 34 | self.train = train 35 | 36 | 37 | def __len__(self): 38 | return len(self.index_map) 39 | 40 | def __getitem__(self, idx): 41 | index, t = self.index_map[idx] 42 | 43 | video_name = self.videos[index] 44 | video = np.load(video_name, allow_pickle=True)['point_clouds'] 45 | 46 | clip = [video[t+i*self.step_between_frames] for i in range(self.frames_per_clip)] 47 | for i, p in enumerate(clip): 48 | if p.shape[0] > self.num_points: 49 | r = np.random.choice(p.shape[0], size=self.num_points, replace=False) 50 | else: 51 | repeat, residue = self.num_points // p.shape[0], self.num_points % p.shape[0] 52 | r = np.random.choice(p.shape[0], size=residue, replace=False) 53 | r = np.concatenate([np.arange(p.shape[0]) for _ in range(repeat)] + [r], axis=0) 54 | clip[i] = p[r, :] 55 | clip = np.array(clip) 56 | 57 | clipv2 = copy.deepcopy(clip) 58 | 59 | # V1 60 | scales = np.random.uniform(0.9, 1.1, size=3) 61 | clip = clip * scales 62 | clip = clip / 300 63 | clips = np.split(clip, indices_or_sections=self.sub_clips, axis=0) 64 | clips = np.array(clips) # [S, L', N, 3] 65 | 66 | # V2 67 | scalesv2 = np.random.uniform(0.9, 1.1, size=3) 68 | clipv2 = clipv2 * scalesv2 69 | clipv2 = clipv2 / 300 70 | 71 | jittered_data = np.random.normal(0, 0.01, size=(clipv2.shape[0],clipv2.shape[1],3)).clip(-0.02, 0.02) 72 | translation = np.random.normal(0, 0.01, size=(3)).clip(-0.05, 0.05) 73 | clipv2 = clipv2 + jittered_data + translation 74 | 75 | clipsv2 = np.split(clipv2, indices_or_sections=self.sub_clips, axis=0) 76 | clipsv2 = np.array(clipsv2) # [S, L', N, 3] 77 | 78 | return clips.astype(np.float32), clipsv2.astype(np.float32), index 79 | 80 | 81 | if __name__ == '__main__': 82 | np.random.seed(0) 83 | dataset = CLRMSRSubject(root='./MSRAction') 84 | clips, clipsv2, index = dataset[0] 85 | print('clips.shape:', clips.shape) 86 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import sys 5 | from termcolor import colored 6 | 7 | 8 | class _ColorfulFormatter(logging.Formatter): 9 | def __init__(self, *args, **kwargs): 10 | self._root_name = kwargs.pop("root_name") + "." 11 | self._abbrev_name = kwargs.pop("abbrev_name", "") 12 | if len(self._abbrev_name): 13 | self._abbrev_name = self._abbrev_name + "." 14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 15 | 16 | def formatMessage(self, record): 17 | record.name = record.name.replace(self._root_name, self._abbrev_name) 18 | log = super(_ColorfulFormatter, self).formatMessage(record) 19 | if record.levelno == logging.WARNING: 20 | prefix = colored("WARNING", "red", attrs=["blink"]) 21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 23 | else: 24 | return log 25 | return prefix + " " + log 26 | 27 | 28 | # so that calling setup_logger multiple times won't add many handlers 29 | @functools.lru_cache() 30 | def setup_logger( 31 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None 32 | ): 33 | """ 34 | Initialize the detectron2 logger and set its verbosity level to "INFO". 35 | Args: 36 | output (str): a file name or a directory to save log. If None, will not save log file. 37 | If ends with ".txt" or ".log", assumed to be a file name. 38 | Otherwise, logs will be saved to `output/log.txt`. 39 | name (str): the root module name of this logger 40 | Returns: 41 | logging.Logger: a logger 42 | """ 43 | logger = logging.getLogger(name) 44 | logger.setLevel(logging.DEBUG) 45 | logger.propagate = False 46 | 47 | if abbrev_name is None: 48 | abbrev_name = name 49 | 50 | plain_formatter = logging.Formatter( 51 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 52 | ) 53 | # stdout logging: master only 54 | if distributed_rank == 0: 55 | ch = logging.StreamHandler(stream=sys.stdout) 56 | ch.setLevel(logging.DEBUG) 57 | if color: 58 | formatter = _ColorfulFormatter( 59 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 60 | datefmt="%m/%d %H:%M:%S", 61 | root_name=name, 62 | abbrev_name=str(abbrev_name), 63 | ) 64 | else: 65 | formatter = plain_formatter 66 | ch.setFormatter(formatter) 67 | logger.addHandler(ch) 68 | 69 | # file logging: all workers 70 | if output is not None: 71 | if output.endswith(".txt") or output.endswith(".log"): 72 | filename = output 73 | else: 74 | filename = os.path.join(output, "log.txt") 75 | if distributed_rank > 0: 76 | filename = filename + f".rank{distributed_rank}" 77 | os.makedirs(os.path.dirname(filename), exist_ok=True) 78 | 79 | fh = logging.StreamHandler(_cached_log_stream(filename)) 80 | fh.setLevel(logging.DEBUG) 81 | fh.setFormatter(plain_formatter) 82 | logger.addHandler(fh) 83 | 84 | return logger 85 | 86 | 87 | # cache the opened file object, so that different calls to `setup_logger` 88 | # with the same file name can safely write to the same file. 89 | @functools.lru_cache(maxsize=None) 90 | def _cached_log_stream(filename): 91 | return open(filename, "a") -------------------------------------------------------------------------------- /datasets/msr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | class MSRAction3D(Dataset): 7 | def __init__(self, root, meta, frames_per_clip=16, step_between_clips=1, num_points=2048, step_between_frames=1, sub_clips=4, train=True): 8 | super(MSRAction3D, self).__init__() 9 | 10 | self.videos = [] 11 | self.labels = [] 12 | self.index_map = [] 13 | index = 0 14 | 15 | with open(meta, 'r') as f: 16 | for line in f: 17 | video_name, nframes = line.split() 18 | 19 | if train and (int(video_name.split('_')[1].split('s')[1]) <= 5): 20 | video = np.load(os.path.join(root, video_name+'.npz'), allow_pickle=True)['point_clouds'] 21 | self.videos.append(video) 22 | label = int(video_name.split('_')[0][1:])-1 23 | self.labels.append(label) 24 | 25 | nframes = video.shape[0] 26 | for t in range(0, nframes-step_between_frames*(frames_per_clip-1), step_between_clips): 27 | self.index_map.append((index, t)) 28 | index += 1 29 | 30 | if not train and (int(video_name.split('_')[1].split('s')[1]) > 5): 31 | video = np.load(os.path.join(root, video_name+'.npz'), allow_pickle=True)['point_clouds'] 32 | self.videos.append(video) 33 | label = int(video_name.split('_')[0][1:])-1 34 | self.labels.append(label) 35 | 36 | nframes = video.shape[0] 37 | for t in range(0, nframes-step_between_frames*(frames_per_clip-1), step_between_clips): 38 | self.index_map.append((index, t)) 39 | index += 1 40 | 41 | self.sub_clips = sub_clips 42 | self.frames_per_clip = frames_per_clip 43 | self.step_between_clips = step_between_clips 44 | self.step_between_frames = step_between_frames 45 | self.num_points = num_points 46 | self.train = train 47 | self.num_classes = max(self.labels) + 1 48 | 49 | 50 | def __len__(self): 51 | return len(self.index_map) 52 | 53 | def __getitem__(self, idx): 54 | index, t = self.index_map[idx] 55 | 56 | video = self.videos[index] 57 | label = self.labels[index] 58 | 59 | clip = [video[t+i*self.step_between_frames] for i in range(self.frames_per_clip)] 60 | for i, p in enumerate(clip): 61 | if p.shape[0] > self.num_points: 62 | r = np.random.choice(p.shape[0], size=self.num_points, replace=False) 63 | else: 64 | repeat, residue = self.num_points // p.shape[0], self.num_points % p.shape[0] 65 | r = np.random.choice(p.shape[0], size=residue, replace=False) 66 | r = np.concatenate([np.arange(p.shape[0]) for _ in range(repeat)] + [r], axis=0) 67 | clip[i] = p[r, :] 68 | clip = np.array(clip) 69 | 70 | if self.train: 71 | # scale the points 72 | scales = np.random.uniform(0.9, 1.1, size=3) 73 | clip = clip * scales 74 | 75 | clip = clip / 300 76 | 77 | # # If subclip is used like pre-training. 78 | # clip = np.split(clip, indices_or_sections=self.sub_clips, axis=0) # S*[L', N, 3] 79 | # clip = np.array(clip) # [S, L', N, 3] 80 | 81 | return clip.astype(np.float32), label, index 82 | 83 | 84 | if __name__ == '__main__': 85 | np.random.seed(0) 86 | dataset = MSRAction3D(root='/data/MSRAction', frames_per_clip=16) 87 | clip, label, video_idx = dataset[0] 88 | 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /modules/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /models/PSTNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import torch.nn.functional as F 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | ROOT_DIR = os.path.dirname(BASE_DIR) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'modules')) 15 | 16 | from pst_convolutions import PSTConv 17 | import pointnet2_utils 18 | import utils 19 | 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self, 23 | radius=1.5, 24 | nsamples=3*3): 25 | super(Encoder, self).__init__() 26 | 27 | self.conv1 = PSTConv(in_planes=0, 28 | mid_planes=45, 29 | out_planes=64, 30 | spatial_kernel_size=[radius, nsamples], 31 | temporal_kernel_size=1, 32 | spatial_stride=2, 33 | temporal_stride=1, 34 | temporal_padding=[0,0], 35 | spatial_aggregation="multiplication", 36 | spatial_pooling="sum") 37 | 38 | self.conv2a = PSTConv(in_planes=64, 39 | mid_planes=96, 40 | out_planes=128, 41 | spatial_kernel_size=[2*radius, nsamples], 42 | temporal_kernel_size=3, 43 | spatial_stride=2, 44 | temporal_stride=2, 45 | temporal_padding=[1,0], 46 | spatial_aggregation="multiplication", 47 | spatial_pooling="sum") 48 | 49 | self.conv2b = PSTConv(in_planes=128, 50 | mid_planes=192, 51 | out_planes=256, 52 | spatial_kernel_size=[2*radius, nsamples], 53 | temporal_kernel_size=3, 54 | spatial_stride=1, 55 | temporal_stride=1, 56 | temporal_padding=[1,1], 57 | spatial_aggregation="multiplication", 58 | spatial_pooling="sum") 59 | 60 | self.conv3a = PSTConv(in_planes=256, 61 | mid_planes=384, 62 | out_planes=512, 63 | spatial_kernel_size=[2*2*radius, nsamples], 64 | temporal_kernel_size=3, 65 | spatial_stride=2, 66 | temporal_stride=2, 67 | temporal_padding=[1,0], 68 | spatial_aggregation="multiplication", 69 | spatial_pooling="sum") 70 | 71 | self.conv3b = PSTConv(in_planes=512, 72 | mid_planes=768, 73 | out_planes=1024, 74 | spatial_kernel_size=[2*2*radius, nsamples], 75 | temporal_kernel_size=3, 76 | spatial_stride=1, 77 | temporal_stride=1, 78 | temporal_padding=[1,1], 79 | spatial_aggregation="multiplication", 80 | spatial_pooling="sum") 81 | 82 | self.conv4 = PSTConv(in_planes=1024, 83 | mid_planes=1536, 84 | out_planes=2048, 85 | spatial_kernel_size=[2*2*radius, nsamples], 86 | temporal_kernel_size=1, 87 | spatial_stride=2, 88 | temporal_stride=1, 89 | temporal_padding=[0,0], 90 | spatial_aggregation="multiplication", 91 | spatial_pooling="sum") 92 | 93 | 94 | def forward(self, clips_input): 95 | 96 | new_xys, new_features = self.conv1(clips_input, None) 97 | new_features = F.relu(new_features) 98 | 99 | new_xys, new_features = self.conv2a(new_xys, new_features) 100 | new_features = F.relu(new_features) 101 | 102 | new_xys, new_features = self.conv2b(new_xys, new_features) 103 | new_features = F.relu(new_features) 104 | 105 | new_xys, new_features = self.conv3a(new_xys, new_features) 106 | new_features = F.relu(new_features) 107 | 108 | new_xys, new_features = self.conv3b(new_xys, new_features) 109 | new_features = F.relu(new_features) 110 | 111 | new_xys, new_features = self.conv4(new_xys, new_features) # [B*S, L', N, 3] [B*S, L', C, N] 112 | 113 | return new_xys, new_features 114 | -------------------------------------------------------------------------------- /models/sim_model_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import torch.nn.functional as F 9 | import timm 10 | from timm.models.layers import trunc_normal_ 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(BASE_DIR) 15 | sys.path.append(ROOT_DIR) 16 | sys.path.append(os.path.join(ROOT_DIR, 'modules')) 17 | 18 | from Folding import * 19 | from PSTNet import * 20 | from transformer import * 21 | from pst_convolutions import PSTConv 22 | from chamfer_distance import ChamferDistance 23 | import pointnet2_utils 24 | import utils 25 | 26 | 27 | class ContrastiveLearningModel(nn.Module): 28 | def __init__(self, 29 | radius=1.5, 30 | nsamples=3*3, 31 | representation_dim=1024, 32 | num_classes=20, 33 | temperature=0.1, 34 | pretraining=True): 35 | super(ContrastiveLearningModel, self).__init__() 36 | 37 | self.encoder = Encoder(radius=radius, nsamples=nsamples) 38 | 39 | self.pretraining = pretraining 40 | 41 | self.mask_channel = True 42 | 43 | self.token_dim = representation_dim 44 | self.emb_relu = False 45 | self.depth = 3 46 | self.heads = 8 47 | self.dim_head =128 48 | self.mlp_dim = 2048 49 | 50 | self.mlp_head = nn.Sequential( 51 | nn.Linear(self.mlp_dim, self.mlp_dim, bias=False), 52 | nn.ReLU(inplace=True), 53 | nn.Linear(self.mlp_dim, self.token_dim) 54 | ) 55 | 56 | self.v2_fc = nn.Linear(representation_dim, representation_dim) 57 | # self.v2_fc = nn.Linear(representation_dim, 512) 58 | 59 | 60 | def forward(self, clips, clipsv2): 61 | 62 | device = clips.get_device() 63 | 64 | # pretraining: 65 | Batchsize, Sub_clips, L_sub_clip, N_point, C_xyz = clips.shape 66 | clips = clips.reshape((-1, L_sub_clip, N_point, C_xyz)) 67 | clipsv2 = clipsv2.reshape((-1, L_sub_clip, N_point, C_xyz)) 68 | clips_input = torch.cat(tensors=(clips, clipsv2), dim=0) # [2*B*S, L', N, 3] 69 | 70 | new_xys, new_features = self.encoder(clips_input) 71 | 72 | new_features = new_features.permute(0, 1, 3, 2) # [B*S, L, N, C] 73 | BS2, L_out, N_out, C_out = new_features.shape 74 | new_features = self.mlp_head(new_features) 75 | new_features = new_features.reshape((2, Batchsize, Sub_clips, L_out, N_out, self.token_dim)) # [2, B, S, L, N, C] 76 | assert(L_out==1) 77 | 78 | new_features = torch.squeeze(new_features, dim=-3).contiguous() # [2, B, S, N, C] 79 | 80 | # for global view 81 | view1_global = torch.mean(input=new_features, dim=-2, keepdim=False) # [2, B, S, C] 82 | view1_global = torch.max(input=view1_global, dim=-2, keepdim=False)[0] # [2, B, C] 83 | view1_global = self.v2_fc(view1_global) 84 | view1_global = F.normalize(view1_global, dim=-1) # [2, B, C] 85 | 86 | view1_ = view1_global[0] 87 | view2_ = view1_global[1] 88 | 89 | if self.mask_channel: 90 | # for masking 91 | with torch.no_grad(): 92 | new_features_detach = new_features.clone().detach() 93 | new_features_detach = new_features_detach[0] 94 | view1_global_detach = torch.mean(input=new_features_detach, dim=-2, keepdim=False) # [B, S, C] 95 | view1_global_detach = torch.max(input=view1_global_detach, dim=-2, keepdim=False)[0] # [B, C] 96 | view1_global_detach = F.normalize(view1_global_detach, dim=-1) 97 | 98 | new_features_detach_norm = F.normalize(new_features_detach, dim=-1) 99 | new_features_detach_norm = new_features_detach_norm.reshape((Batchsize, Sub_clips*N_out, self.token_dim)) # [B, S*N, C] 100 | 101 | mask_list = [] 102 | mask_indx = [] 103 | mask_high_sim = torch.ones((Batchsize, 10, Sub_clips, N_out, self.token_dim), dtype=torch.float32).to(device) # [B, S, N, C] 104 | for bi in range(Batchsize): 105 | channel_score = new_features_detach_norm[bi] * view1_global_detach[bi] # [S*N, C] 106 | sort_idx = channel_score.argsort(dim=-1) 107 | sort_idx_sort = sort_idx.argsort(dim=-1) 108 | sort_idx_sort = sort_idx_sort.sum(dim=0) 109 | sort_idx_sort = sort_idx_sort.argsort(dim=-1) 110 | high_similarity_idx = sort_idx_sort[int(self.token_dim * 0.8):] 111 | c_idx_num = high_similarity_idx.shape[0] 112 | mask_high_sim[bi,0,:,:,high_similarity_idx] = 0 # mask high sim to 0 113 | 114 | for mci in range(1, mask_high_sim.shape[1]): 115 | c_index = torch.LongTensor(random.sample(range(c_idx_num), int(c_idx_num*0.95))).to(device) 116 | high_sim_idx_random = torch.index_select(high_similarity_idx, 0, c_index) 117 | mask_high_sim[bi,mci,:,:,high_sim_idx_random] = 0 118 | 119 | erase_global = new_features_detach.unsqueeze(1) * mask_high_sim 120 | erase_global = torch.mean(input=erase_global, dim=-2, keepdim=False) 121 | erase_global = torch.max(input=erase_global, dim=-2, keepdim=False)[0] 122 | erase_global = self.v2_fc(erase_global) 123 | erase_global = F.normalize(erase_global, dim=-1) 124 | 125 | return view1_, view2_.detach() erase_global.detach() 126 | 127 | else: 128 | return view1_, view2_.detach() 129 | 130 | 131 | -------------------------------------------------------------------------------- /modules/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /modules/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /modules/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /0-pretrain-test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import os 4 | import time 5 | import sys 6 | import random 7 | import numpy as np 8 | from tensorboardX import SummaryWriter 9 | 10 | import torch 11 | import torch.utils.data 12 | import torch.distributed as dist 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | import utils 17 | from logger import setup_logger 18 | from data_aug.sim_data_test import CLRMSRSubject 19 | from models.sim_model_test import ContrastiveLearningModel 20 | from timm.scheduler import CosineLRScheduler 21 | 22 | 23 | def train(model, optimizer, lr_scheduler, data_loader, 24 | device, epoch, print_freq, logger, clip_loss): 25 | 26 | batch_time = utils.AverageMeter() 27 | losses = utils.AverageMeter() 28 | 29 | model.train() 30 | for i, (clips, clipsv2, index) in enumerate(data_loader): 31 | start_time = time.time() 32 | # [B, S, L, N, 3] 33 | clips = clips.to(device) 34 | clipsv2 = clipsv2.to(device) 35 | 36 | mask_channel = True 37 | if mask_channel: 38 | view1_, view2_, erase_global = model(clips, clipsv2) 39 | 40 | view12 = torch.matmul(view1_, view2_.transpose(0,1)) # [B, B] 41 | view1e = view1_.unsqueeze(1) * erase_global # [B, H, C] 42 | view1e = torch.sum(view1e, dim=-1, keepdim=False) # [B, H] 43 | 44 | view_score = torch.cat(tensors=(view12, view1e), dim=1) # [B, B+H] 45 | view_score = view_score / 0.08 46 | else: 47 | view1_, view2_ = model(clips, clipsv2) 48 | view12 = torch.matmul(view1_, view2_.transpose(0,1)) # [B, B] 49 | view_score = view12 / 0.08 50 | 51 | target_view_score = torch.arange(view_score.shape[0]).to(device) 52 | loss = clip_loss(view_score, target_view_score) 53 | 54 | batch_size = clips.shape[0] 55 | lr_ = optimizer.param_groups[-1]["lr"] 56 | 57 | losses.update(loss.item(), batch_size) 58 | 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | lr_scheduler.step() 63 | 64 | batch_time.update(time.time() - start_time) 65 | 66 | if i % print_freq == 0: 67 | logger.info(('Epoch:[{0}][{1}/{2}]\t' 68 | 'lr:{lr:.5f}\t' 69 | 'Loss:{loss.val:.3f} ({loss.avg:.3f})'.format( 70 | epoch, i, len(data_loader), lr=lr_, loss=losses))) 71 | 72 | return losses.avg 73 | 74 | 75 | def main(args): 76 | 77 | # Fix the seed 78 | random.seed(args.seed) 79 | np.random.seed(args.seed) 80 | torch.manual_seed(args.seed) 81 | torch.cuda.manual_seed(args.seed) 82 | torch.cuda.manual_seed_all(args.seed) 83 | torch.backends.cudnn.deterministic = True 84 | torch.backends.cudnn.benchmark = False 85 | 86 | device = torch.device("cuda") 87 | 88 | # Check folders and setup logger 89 | log_dir = os.path.join(args.log_dir, args.model) 90 | utils.mkdir(log_dir) 91 | 92 | with open(os.path.join(log_dir, 'args.txt'), 'w') as f: 93 | f.write(str(args)) 94 | 95 | logger = setup_logger(output=log_dir, distributed_rank=0, name=args.model) 96 | tf_writer = SummaryWriter(log_dir=log_dir) 97 | 98 | # Data loading code 99 | train_dataset = CLRMSRSubject( 100 | root=args.data_path, 101 | meta=args.data_meta, 102 | frames_per_clip=args.clip_len, 103 | step_between_clips=args.clip_stride, 104 | step_between_frames=args.frame_stride, 105 | num_points=args.num_points, 106 | sub_clips=args.sub_clips, 107 | train=True 108 | ) 109 | train_loader = torch.utils.data.DataLoader( 110 | train_dataset, 111 | batch_size=args.batch_size, 112 | shuffle=True, 113 | num_workers=args.workers, 114 | pin_memory=True, 115 | drop_last=True 116 | ) 117 | # Creat Contrastive Learning Model 118 | model = ContrastiveLearningModel( 119 | radius=args.radius, 120 | nsamples=args.nsamples, 121 | representation_dim=args.representation_dim, 122 | temperature=args.temperature, 123 | pretraining=True 124 | ) 125 | # Distributed model 126 | if torch.cuda.device_count() > 1: 127 | model = nn.DataParallel(model) 128 | model.to(device) 129 | 130 | optimizer = torch.optim.Adam( 131 | model.parameters(), args.lr, weight_decay=args.weight_decay 132 | ) 133 | # param_groups = add_weight_decay(model, weight_decay=args.weight_decay) 134 | # optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) 135 | # optimizer = torch.optim.SGD(model.parameters(), args.lr, 136 | # momentum=args.momentum, 137 | # weight_decay=args.weight_decay 138 | # ) 139 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 140 | # optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1 141 | # ) 142 | 143 | warmup_iters = args.lr_warmup_epochs * len(train_loader) 144 | epochs_iters = args.epochs * len(train_loader) 145 | lr_scheduler = utils.WarmupCosineLR(optimizer, 146 | T_max=epochs_iters, 147 | warmup_iters=warmup_iters, 148 | last_epoch=-1 149 | ) 150 | 151 | clip_loss = nn.CrossEntropyLoss() 152 | 153 | if args.resume: 154 | if os.path.isfile(args.resume): 155 | logger.info(("===> Loading checkpoint for resume '{}'".format(args.resume))) 156 | checkpoint = torch.load(args.resume, map_location='cpu') 157 | model.load_state_dict(checkpoint['model']) 158 | optimizer.load_state_dict(checkpoint['optimizer']) 159 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 160 | args.start_epoch = checkpoint['epoch'] + 1 161 | logger.info(("===> Loaded checkpoint with epoch {}".format(checkpoint['epoch']))) 162 | else: 163 | logger.info(("===> There is no checkpoint at '{}'".format(args.resume))) 164 | 165 | start_time = time.time() 166 | for epoch in range(args.start_epoch, args.epochs): 167 | train_loss = train(model, optimizer, lr_scheduler, train_loader, device, 168 | epoch, args.print_freq, logger, clip_loss 169 | ) 170 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 171 | tf_writer.add_scalar('loss/train', train_loss, epoch) 172 | tf_writer.flush() 173 | 174 | checkpoint = { 175 | 'model': model.state_dict(), 176 | 'optimizer': optimizer.state_dict(), 177 | 'lr_scheduler': lr_scheduler.state_dict(), 178 | 'epoch': epoch, 179 | 'args': args 180 | } 181 | torch.save( 182 | checkpoint, 183 | os.path.join(log_dir, 'checkpoint_{}.pth'.format(epoch)) 184 | ) 185 | logger.info('====================================') 186 | 187 | total_time = time.time() - start_time 188 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 189 | logger.info(('Training time {}'.format(total_time_str))) 190 | 191 | 192 | def parse_args(): 193 | import argparse 194 | parser = argparse.ArgumentParser(description='PyTorch SimCLR') 195 | 196 | parser.add_argument('--data-path', default='/data/MSRAction', metavar='DIR', help='path to dataset') 197 | parser.add_argument('--data-meta', default='datasets/MSRAction_all.list', help='dataset') 198 | parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') 199 | parser.add_argument('--lr-warmup-epochs', default=5, type=int, help='number of warmup epochs') 200 | parser.add_argument('-b', '--batch-size', default=80, type=int, metavar='N', help='batch size') 201 | parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, metavar='LR', help='initial learning rate', dest='lr') 202 | parser.add_argument('--temperature', default=0.01, type=float, help='softmax temperature (default: 0.07)') 203 | parser.add_argument('--representation-dim', default=1024, type=int, metavar='N', help='representation dim') 204 | 205 | parser.add_argument('--sub-clips', default=4, type=int, metavar='N', help='number of sub clips') 206 | parser.add_argument('--radius', default=0.3, type=float, help='radius for the ball query') 207 | parser.add_argument('--nsamples', default=9, type=int, help='number of neighbors for the ball query') 208 | parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') 209 | parser.add_argument('--clip-stride', default=1, type=int, metavar='N', help='number of steps between clips') 210 | parser.add_argument('--frame-stride', default=1, type=int, metavar='N', help='number of steps between clips') 211 | parser.add_argument('--num-points', default=1024, type=int, metavar='N', help='number of points per frame') 212 | 213 | parser.add_argument('--model', default='MSR', type=str, help='model') 214 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', help='number of data loading workers (default: 32)') 215 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 216 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 217 | parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') 218 | parser.add_argument('--print-freq', default=200, type=int, help='Log every n steps') 219 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') 220 | parser.add_argument('--log-dir', default='log_ssl/', type=str, help='path where to save') 221 | parser.add_argument('--resume', default='', help='resume from checkpoint') 222 | 223 | args = parser.parse_args() 224 | 225 | return args 226 | 227 | 228 | if __name__ == "__main__": 229 | args = parse_args() 230 | main(args) 231 | -------------------------------------------------------------------------------- /0-pretrain-msr.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import os 4 | import time 5 | import sys 6 | import random 7 | import numpy as np 8 | from tensorboardX import SummaryWriter 9 | 10 | import torch 11 | import torch.utils.data 12 | import torch.distributed as dist 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | import utils 17 | from logger import setup_logger 18 | from data_aug.CLR_MSR import CLRMSRSubject 19 | from models.CLR_Model import ContrastiveLearningModel 20 | from timm.scheduler import CosineLRScheduler 21 | 22 | 23 | def train(model, optimizer, lr_scheduler, data_loader, 24 | device, epoch, print_freq, logger, criterion_global): 25 | 26 | batch_time = utils.AverageMeter() 27 | data_time = utils.AverageMeter() 28 | losses = utils.AverageMeter() 29 | top1 = utils.AverageMeter() 30 | top5 = utils.AverageMeter() 31 | 32 | model.train() 33 | for i, (clips, index) in enumerate(data_loader): 34 | start_time = time.time() 35 | 36 | clips = clips.to(device) # [B, S, L, N, 3] 37 | 38 | loss, acc1, acc5, view1_global, erase_global, regression_global = model(clips) 39 | loss, acc1, acc5 = loss.mean(), acc1.mean(0), acc5.mean(0) 40 | 41 | view11 = torch.matmul(view1_global, regression_global.transpose(0,1)) # [B, B] 42 | 43 | view12 = view1_global.unsqueeze(1) * erase_global # [B, H, C] 44 | view12 = torch.sum(view12, dim=-1, keepdim=False) # [B, H] 45 | 46 | view_score = torch.cat(tensors=(view11, view12), dim=1) # [B, B+H] 47 | view_score = view_score / 0.1 48 | 49 | target_view_score = torch.arange(view_score.shape[0]).to(device) 50 | loss_global = criterion_global(view_score, target_view_score) 51 | 52 | loss = loss + loss_global 53 | 54 | batch_size = clips.shape[0] 55 | lr_ = optimizer.param_groups[-1]["lr"] 56 | 57 | losses.update(loss.item(), batch_size) 58 | top1.update(acc1.item(), batch_size) 59 | top5.update(acc5.item(), batch_size) 60 | 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | lr_scheduler.step() 65 | 66 | batch_time.update(time.time() - start_time) 67 | 68 | if i % print_freq == 0: 69 | logger.info(('Epoch:[{0}][{1}/{2}]\t' 70 | 'lr:{lr:.5f}\t' 71 | 'Loss:{loss.val:.3f} ({loss.avg:.3f})\t' 72 | 'Top1:{top1.val:.2f} ({top1.avg:.2f})'.format( 73 | epoch, i, len(data_loader), lr=lr_, loss=losses, top1=top1))) 74 | 75 | return losses.avg, top1.avg, top5.avg 76 | 77 | 78 | def main(args): 79 | 80 | # Fix the seed 81 | random.seed(args.seed) 82 | np.random.seed(args.seed) 83 | torch.manual_seed(args.seed) 84 | torch.cuda.manual_seed(args.seed) 85 | torch.cuda.manual_seed_all(args.seed) 86 | torch.backends.cudnn.deterministic = True 87 | torch.backends.cudnn.benchmark = False 88 | 89 | device = torch.device("cuda") 90 | 91 | # Check folders and setup logger 92 | log_dir = os.path.join(args.log_dir, args.model) 93 | utils.mkdir(log_dir) 94 | 95 | with open(os.path.join(log_dir, 'args.txt'), 'w') as f: 96 | f.write(str(args)) 97 | 98 | logger = setup_logger(output=log_dir, distributed_rank=0, name=args.model) 99 | tf_writer = SummaryWriter(log_dir=log_dir) 100 | 101 | # Data loading code 102 | train_dataset = CLRMSRSubject( 103 | root=args.data_path, 104 | meta=args.data_meta, 105 | frames_per_clip=args.clip_len, 106 | step_between_clips=args.clip_stride, 107 | step_between_frames=args.frame_stride, 108 | num_points=args.num_points, 109 | sub_clips=args.sub_clips, 110 | train=True 111 | ) 112 | train_loader = torch.utils.data.DataLoader( 113 | train_dataset, 114 | batch_size=args.batch_size, 115 | shuffle=True, 116 | num_workers=args.workers, 117 | pin_memory=True, 118 | drop_last=True 119 | ) 120 | # Creat Contrastive Learning Model 121 | model = ContrastiveLearningModel( 122 | radius=args.radius, 123 | nsamples=args.nsamples, 124 | representation_dim=args.representation_dim, 125 | temperature=args.temperature, 126 | pretraining=True 127 | ) 128 | # Distributed model 129 | if torch.cuda.device_count() > 1: 130 | model = nn.DataParallel(model) 131 | model.to(device) 132 | 133 | criterion_global = nn.CrossEntropyLoss() 134 | 135 | # optimizer = torch.optim.SGD(model.parameters(), args.lr, 136 | # momentum=args.momentum, 137 | # weight_decay=args.weight_decay 138 | # ) 139 | # param_groups = utils.add_weight_decay(model, weight_decay=args.weight_decay) 140 | # optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) 141 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 142 | 143 | # lr_scheduler = CosineLRScheduler(optimizer, 144 | # t_initial=args.epochs, t_mul=1, lr_min=1e-6, decay_rate=0.1, 145 | # warmup_lr_init=1e-6, warmup_t=args.lr_warmup_epochs, cycle_limit=1, t_in_epochs=True 146 | # ) 147 | warmup_iters = args.lr_warmup_epochs * len(train_loader) 148 | epochs_iters = args.epochs * len(train_loader) 149 | lr_scheduler = utils.WarmupCosineLR(optimizer, 150 | T_max=epochs_iters, 151 | warmup_iters=warmup_iters, 152 | last_epoch=-1 153 | ) 154 | 155 | if args.resume: 156 | if os.path.isfile(args.resume): 157 | logger.info(("===> Loading checkpoint for resume '{}'".format(args.resume))) 158 | checkpoint = torch.load(args.resume, map_location='cpu') 159 | model.load_state_dict(checkpoint['model']) 160 | optimizer.load_state_dict(checkpoint['optimizer']) 161 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 162 | args.start_epoch = checkpoint['epoch'] + 1 163 | logger.info(("===> Loaded checkpoint with epoch {}".format(checkpoint['epoch']))) 164 | else: 165 | logger.info(("===> There is no checkpoint at '{}'".format(args.resume))) 166 | 167 | start_time = time.time() 168 | for epoch in range(args.start_epoch, args.epochs): 169 | train_loss, train_top1, train_top5 = train(model, optimizer, lr_scheduler, train_loader, device, 170 | epoch, args.print_freq, logger, criterion_global 171 | ) 172 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 173 | tf_writer.add_scalar('loss/train', train_loss, epoch) 174 | tf_writer.add_scalar('acc/train_top1', train_top1, epoch) 175 | tf_writer.add_scalar('acc/train_top5', train_top5, epoch) 176 | tf_writer.flush() 177 | 178 | checkpoint = { 179 | 'model': model.state_dict(), 180 | 'optimizer': optimizer.state_dict(), 181 | 'lr_scheduler': lr_scheduler.state_dict(), 182 | 'epoch': epoch, 183 | 'args': args 184 | } 185 | torch.save( 186 | checkpoint, 187 | os.path.join(log_dir, 'checkpoint_{}.pth'.format(epoch)) 188 | ) 189 | logger.info('====================================') 190 | 191 | total_time = time.time() - start_time 192 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 193 | logger.info(('Training time {}'.format(total_time_str))) 194 | 195 | 196 | def parse_args(): 197 | import argparse 198 | parser = argparse.ArgumentParser(description='PointCMP') 199 | parser.add_argument('--data-path', default='/data/MSRAction', metavar='DIR', help='path to dataset') 200 | parser.add_argument('--data-meta', default='datasets/MSRAction_all.list', help='dataset') 201 | parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') 202 | parser.add_argument('--lr-warmup-epochs', default=5, type=int, help='number of warmup epochs') 203 | parser.add_argument('-b', '--batch-size', default=96, type=int, metavar='N', help='batch size') 204 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, metavar='LR', help='initial learning rate', dest='lr') 205 | parser.add_argument('--temperature', default=0.01, type=float, help='softmax temperature (default: 0.07)') 206 | parser.add_argument('--representation-dim', default=1024, type=int, metavar='N', help='representation dim') 207 | 208 | parser.add_argument('--sub-clips', default=4, type=int, metavar='N', help='number of sub clips') 209 | parser.add_argument('--radius', default=0.3, type=float, help='radius for the ball query') 210 | parser.add_argument('--nsamples', default=9, type=int, help='number of neighbors for the ball query') 211 | parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') 212 | parser.add_argument('--clip-stride', default=1, type=int, metavar='N', help='number of steps between clips') 213 | parser.add_argument('--frame-stride', default=1, type=int, metavar='N', help='number of steps between clips') 214 | parser.add_argument('--num-points', default=1024, type=int, metavar='N', help='number of points per frame') 215 | 216 | parser.add_argument('--model', default='MSR', type=str, help='model') 217 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', help='number of data loading workers (default: 32)') 218 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 219 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 220 | parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') 221 | parser.add_argument('--print-freq', default=200, type=int, help='Log every n steps') 222 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') 223 | parser.add_argument('--log-dir', default='log_ssl/', type=str, help='path where to save') 224 | parser.add_argument('--resume', default='', help='resume from checkpoint') 225 | 226 | args = parser.parse_args() 227 | return args 228 | 229 | 230 | if __name__ == "__main__": 231 | args = parse_args() 232 | main(args) 233 | -------------------------------------------------------------------------------- /modules/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import pytorch_utils as pt_utils 18 | import sys 19 | 20 | try: 21 | import builtins 22 | except: 23 | import __builtin__ as builtins 24 | 25 | try: 26 | import pointnet2._ext as _ext 27 | except ImportError: 28 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 29 | raise ImportError( 30 | "Could not import _ext module.\n" 31 | "Please see the setup instructions in the README: " 32 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 33 | ) 34 | 35 | if False: 36 | # Workaround for type hints without depending on the `typing` module 37 | from typing import * 38 | 39 | 40 | class RandomDropout(nn.Module): 41 | def __init__(self, p=0.5, inplace=False): 42 | super(RandomDropout, self).__init__() 43 | self.p = p 44 | self.inplace = inplace 45 | 46 | def forward(self, X): 47 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 48 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 49 | 50 | 51 | class FurthestPointSampling(Function): 52 | @staticmethod 53 | def forward(ctx, xyz, npoint): 54 | # type: (Any, torch.Tensor, int) -> torch.Tensor 55 | r""" 56 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 57 | minimum distance 58 | 59 | Parameters 60 | ---------- 61 | xyz : torch.Tensor 62 | (B, N, 3) tensor where N > npoint 63 | npoint : int32 64 | number of features in the sampled set 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | (B, npoint) tensor containing the set 70 | """ 71 | fps_inds = _ext.furthest_point_sampling(xyz, npoint) 72 | ctx.mark_non_differentiable(fps_inds) 73 | return fps_inds 74 | 75 | @staticmethod 76 | def backward(xyz, a=None): 77 | return None, None 78 | 79 | 80 | furthest_point_sample = FurthestPointSampling.apply 81 | 82 | 83 | class GatherOperation(Function): 84 | @staticmethod 85 | def forward(ctx, features, idx): 86 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 87 | r""" 88 | 89 | Parameters 90 | ---------- 91 | features : torch.Tensor 92 | (B, C, N) tensor 93 | 94 | idx : torch.Tensor 95 | (B, npoint) tensor of the features to gather 96 | 97 | Returns 98 | ------- 99 | torch.Tensor 100 | (B, C, npoint) tensor 101 | """ 102 | 103 | _, C, N = features.size() 104 | 105 | ctx.for_backwards = (idx, C, N) 106 | 107 | return _ext.gather_points(features, idx) 108 | 109 | @staticmethod 110 | def backward(ctx, grad_out): 111 | idx, C, N = ctx.for_backwards 112 | 113 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 114 | return grad_features, None 115 | 116 | 117 | gather_operation = GatherOperation.apply 118 | 119 | 120 | class ThreeNN(Function): 121 | @staticmethod 122 | def forward(ctx, unknown, known): 123 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 124 | r""" 125 | Find the three nearest neighbors of unknown in known 126 | Parameters 127 | ---------- 128 | unknown : torch.Tensor 129 | (B, n, 3) tensor of known features 130 | known : torch.Tensor 131 | (B, m, 3) tensor of unknown features 132 | 133 | Returns 134 | ------- 135 | dist : torch.Tensor 136 | (B, n, 3) l2 distance to the three nearest neighbors 137 | idx : torch.Tensor 138 | (B, n, 3) index of 3 nearest neighbors 139 | """ 140 | dist2, idx = _ext.three_nn(unknown, known) 141 | 142 | return torch.sqrt(dist2), idx 143 | 144 | @staticmethod 145 | def backward(ctx, a=None, b=None): 146 | return None, None 147 | 148 | 149 | three_nn = ThreeNN.apply 150 | 151 | 152 | class ThreeInterpolate(Function): 153 | @staticmethod 154 | def forward(ctx, features, idx, weight): 155 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 156 | r""" 157 | Performs weight linear interpolation on 3 features 158 | Parameters 159 | ---------- 160 | features : torch.Tensor 161 | (B, c, m) Features descriptors to be interpolated from 162 | idx : torch.Tensor 163 | (B, n, 3) three nearest neighbors of the target features in features 164 | weight : torch.Tensor 165 | (B, n, 3) weights 166 | 167 | Returns 168 | ------- 169 | torch.Tensor 170 | (B, c, n) tensor of the interpolated features 171 | """ 172 | B, c, m = features.size() 173 | n = idx.size(1) 174 | 175 | ctx.three_interpolate_for_backward = (idx, weight, m) 176 | 177 | return _ext.three_interpolate(features, idx, weight) 178 | 179 | @staticmethod 180 | def backward(ctx, grad_out): 181 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 182 | r""" 183 | Parameters 184 | ---------- 185 | grad_out : torch.Tensor 186 | (B, c, n) tensor with gradients of ouputs 187 | 188 | Returns 189 | ------- 190 | grad_features : torch.Tensor 191 | (B, c, m) tensor with gradients of features 192 | 193 | None 194 | 195 | None 196 | """ 197 | idx, weight, m = ctx.three_interpolate_for_backward 198 | 199 | grad_features = _ext.three_interpolate_grad( 200 | grad_out.contiguous(), idx, weight, m 201 | ) 202 | 203 | return grad_features, None, None 204 | 205 | 206 | three_interpolate = ThreeInterpolate.apply 207 | 208 | 209 | class GroupingOperation(Function): 210 | @staticmethod 211 | def forward(ctx, features, idx): 212 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 213 | r""" 214 | 215 | Parameters 216 | ---------- 217 | features : torch.Tensor 218 | (B, C, N) tensor of features to group 219 | idx : torch.Tensor 220 | (B, npoint, nsample) tensor containing the indicies of features to group with 221 | 222 | Returns 223 | ------- 224 | torch.Tensor 225 | (B, C, npoint, nsample) tensor 226 | """ 227 | B, nfeatures, nsample = idx.size() 228 | _, C, N = features.size() 229 | 230 | ctx.for_backwards = (idx, N) 231 | 232 | return _ext.group_points(features, idx) 233 | 234 | @staticmethod 235 | def backward(ctx, grad_out): 236 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 237 | r""" 238 | 239 | Parameters 240 | ---------- 241 | grad_out : torch.Tensor 242 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 243 | 244 | Returns 245 | ------- 246 | torch.Tensor 247 | (B, C, N) gradient of the features 248 | None 249 | """ 250 | idx, N = ctx.for_backwards 251 | 252 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 253 | 254 | return grad_features, None 255 | 256 | 257 | grouping_operation = GroupingOperation.apply 258 | 259 | 260 | class BallQuery(Function): 261 | @staticmethod 262 | def forward(ctx, radius, nsample, xyz, new_xyz): 263 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 264 | r""" 265 | 266 | Parameters 267 | ---------- 268 | radius : float 269 | radius of the balls 270 | nsample : int 271 | maximum number of features in the balls 272 | xyz : torch.Tensor 273 | (B, N, 3) xyz coordinates of the features 274 | new_xyz : torch.Tensor 275 | (B, npoint, 3) centers of the ball query 276 | 277 | Returns 278 | ------- 279 | torch.Tensor 280 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 281 | """ 282 | inds = _ext.ball_query(new_xyz, xyz, radius, nsample) 283 | ctx.mark_non_differentiable(inds) 284 | return inds 285 | 286 | @staticmethod 287 | def backward(ctx, a=None): 288 | return None, None, None, None 289 | 290 | 291 | ball_query = BallQuery.apply 292 | 293 | 294 | class QueryAndGroup(nn.Module): 295 | r""" 296 | Groups with a ball query of radius 297 | 298 | Parameters 299 | --------- 300 | radius : float32 301 | Radius of ball 302 | nsample : int32 303 | Maximum number of features to gather in the ball 304 | """ 305 | 306 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 307 | # type: (QueryAndGroup, float, int, bool) -> None 308 | super(QueryAndGroup, self).__init__() 309 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 310 | self.ret_grouped_xyz = ret_grouped_xyz 311 | self.normalize_xyz = normalize_xyz 312 | self.sample_uniformly = sample_uniformly 313 | self.ret_unique_cnt = ret_unique_cnt 314 | if self.ret_unique_cnt: 315 | assert(self.sample_uniformly) 316 | 317 | def forward(self, xyz, new_xyz, features=None): 318 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 319 | r""" 320 | Parameters 321 | ---------- 322 | xyz : torch.Tensor 323 | xyz coordinates of the features (B, N, 3) 324 | new_xyz : torch.Tensor 325 | centriods (B, npoint, 3) 326 | features : torch.Tensor 327 | Descriptors of the features (B, C, N) 328 | 329 | Returns 330 | ------- 331 | new_features : torch.Tensor 332 | (B, 3 + C, npoint, nsample) tensor 333 | """ 334 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 335 | 336 | if self.sample_uniformly: 337 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 338 | for i_batch in range(idx.shape[0]): 339 | for i_region in range(idx.shape[1]): 340 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 341 | num_unique = unique_ind.shape[0] 342 | unique_cnt[i_batch, i_region] = num_unique 343 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 344 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 345 | idx[i_batch, i_region, :] = all_ind 346 | 347 | 348 | xyz_trans = xyz.transpose(1, 2).contiguous() 349 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 350 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 351 | if self.normalize_xyz: 352 | grouped_xyz /= self.radius 353 | 354 | if features is not None: 355 | grouped_features = grouping_operation(features, idx) 356 | if self.use_xyz: 357 | new_features = torch.cat( 358 | [grouped_xyz, grouped_features], dim=1 359 | ) # (B, C + 3, npoint, nsample) 360 | else: 361 | new_features = grouped_features 362 | else: 363 | assert ( 364 | self.use_xyz 365 | ), "Cannot have not features and not use xyz as a feature!" 366 | new_features = grouped_xyz 367 | 368 | ret = [new_features] 369 | if self.ret_grouped_xyz: 370 | ret.append(grouped_xyz) 371 | if self.ret_unique_cnt: 372 | ret.append(unique_cnt) 373 | if len(ret) == 1: 374 | return ret[0] 375 | else: 376 | return tuple(ret) 377 | 378 | 379 | class GroupAll(nn.Module): 380 | r""" 381 | Groups all features 382 | 383 | Parameters 384 | --------- 385 | """ 386 | 387 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 388 | # type: (GroupAll, bool) -> None 389 | super(GroupAll, self).__init__() 390 | self.use_xyz = use_xyz 391 | 392 | def forward(self, xyz, new_xyz, features=None): 393 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 394 | r""" 395 | Parameters 396 | ---------- 397 | xyz : torch.Tensor 398 | xyz coordinates of the features (B, N, 3) 399 | new_xyz : torch.Tensor 400 | Ignored 401 | features : torch.Tensor 402 | Descriptors of the features (B, C, N) 403 | 404 | Returns 405 | ------- 406 | new_features : torch.Tensor 407 | (B, C + 3, 1, N) tensor 408 | """ 409 | 410 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 411 | if features is not None: 412 | grouped_features = features.unsqueeze(2) 413 | if self.use_xyz: 414 | new_features = torch.cat( 415 | [grouped_xyz, grouped_features], dim=1 416 | ) # (B, 3 + C, 1, N) 417 | else: 418 | new_features = grouped_features 419 | else: 420 | new_features = grouped_xyz 421 | 422 | if self.ret_grouped_xyz: 423 | return new_features, grouped_xyz 424 | else: 425 | return new_features 426 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import time 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import errno 9 | from bisect import bisect_right 10 | import os 11 | import math 12 | import shutil 13 | import yaml 14 | 15 | 16 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 17 | torch.save(state, filename) 18 | if is_best: 19 | shutil.copyfile(filename, 'model_best.pth.tar') 20 | 21 | 22 | def save_config_file(model_checkpoints_folder, args): 23 | if not os.path.exists(model_checkpoints_folder): 24 | os.makedirs(model_checkpoints_folder) 25 | with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile: 26 | yaml.dump(args, outfile, default_flow_style=False) 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """Computes the accuracy over the k top predictions for the specified values of k""" 31 | with torch.no_grad(): 32 | maxk = max(topk) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(maxk, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | 46 | def eva_accuracy(output, target, topk=(1,)): 47 | """Computes the accuracy over the k top predictions for the specified values of k""" 48 | with torch.no_grad(): 49 | maxk = max(topk) 50 | batch_size = target.size(0) 51 | 52 | _, pred = output.topk(maxk, 1, True, True) 53 | pred = pred.t() 54 | correct = pred.eq(target[None]) 55 | 56 | res = [] 57 | for k in topk: 58 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 59 | res.append(correct_k * (100.0 / batch_size)) 60 | return res 61 | 62 | 63 | def reduce_tensor(tensor): 64 | rt = tensor.clone() 65 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 66 | rt /= dist.get_world_size() 67 | return rt 68 | 69 | 70 | class AverageMeter(object): 71 | """Computes and stores the average and current value""" 72 | def __init__(self): 73 | self.reset() 74 | 75 | def reset(self): 76 | self.val = 0 77 | self.avg = 0 78 | self.sum = 0 79 | self.count = 0 80 | 81 | def update(self, val, n=1): 82 | self.val = val 83 | self.sum += val * n 84 | self.count += n 85 | self.avg = self.sum / self.count 86 | 87 | 88 | class SmoothedValue(object): 89 | """Track a series of values and provide access to smoothed values over a 90 | window or the global series average. 91 | """ 92 | 93 | def __init__(self, window_size=20, fmt=None): 94 | if fmt is None: 95 | fmt = "{median:.4f} ({global_avg:.4f})" 96 | self.deque = deque(maxlen=window_size) 97 | self.total = 0.0 98 | self.count = 0 99 | self.fmt = fmt 100 | 101 | def update(self, value, n=1): 102 | self.deque.append(value) 103 | self.count += n 104 | self.total += value * n 105 | 106 | def synchronize_between_processes(self): 107 | """ 108 | Warning: does not synchronize the deque! 109 | """ 110 | if not is_dist_avail_and_initialized(): 111 | return 112 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 113 | dist.barrier() 114 | dist.all_reduce(t) 115 | t = t.tolist() 116 | self.count = int(t[0]) 117 | self.total = t[1] 118 | 119 | @property 120 | def median(self): 121 | d = torch.tensor(list(self.deque)) 122 | return d.median().item() 123 | 124 | @property 125 | def avg(self): 126 | d = torch.tensor(list(self.deque), dtype=torch.float32) 127 | return d.mean().item() 128 | 129 | @property 130 | def global_avg(self): 131 | return self.total / self.count 132 | 133 | @property 134 | def max(self): 135 | return max(self.deque) 136 | 137 | @property 138 | def value(self): 139 | return self.deque[-1] 140 | 141 | def __str__(self): 142 | return self.fmt.format( 143 | median=self.median, 144 | avg=self.avg, 145 | global_avg=self.global_avg, 146 | max=self.max, 147 | value=self.value) 148 | 149 | 150 | class MetricLogger(object): 151 | def __init__(self, delimiter="\t"): 152 | self.meters = defaultdict(SmoothedValue) 153 | self.delimiter = delimiter 154 | 155 | def update(self, **kwargs): 156 | for k, v in kwargs.items(): 157 | if isinstance(v, torch.Tensor): 158 | v = v.item() 159 | assert isinstance(v, (float, int)) 160 | self.meters[k].update(v) 161 | 162 | def __getattr__(self, attr): 163 | if attr in self.meters: 164 | return self.meters[attr] 165 | if attr in self.__dict__: 166 | return self.__dict__[attr] 167 | raise AttributeError("'{}' object has no attribute '{}'".format( 168 | type(self).__name__, attr)) 169 | 170 | def __str__(self): 171 | loss_str = [] 172 | for name, meter in self.meters.items(): 173 | loss_str.append( 174 | "{}: {}".format(name, str(meter)) 175 | ) 176 | return self.delimiter.join(loss_str) 177 | 178 | def synchronize_between_processes(self): 179 | for meter in self.meters.values(): 180 | meter.synchronize_between_processes() 181 | 182 | def add_meter(self, name, meter): 183 | self.meters[name] = meter 184 | 185 | def log_every(self, iterable, print_freq, header=None): 186 | i = 0 187 | if not header: 188 | header = '' 189 | start_time = time.time() 190 | end = time.time() 191 | iter_time = SmoothedValue(fmt='{avg:.4f}') 192 | data_time = SmoothedValue(fmt='{avg:.4f}') 193 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 194 | if torch.cuda.is_available(): 195 | log_msg = self.delimiter.join([ 196 | header, 197 | '[{0' + space_fmt + '}/{1}]', 198 | 'eta: {eta}', 199 | '{meters}', 200 | 'time: {time}', 201 | 'data: {data}', 202 | 'max mem: {memory:.0f}' 203 | ]) 204 | else: 205 | log_msg = self.delimiter.join([ 206 | header, 207 | '[{0' + space_fmt + '}/{1}]', 208 | 'eta: {eta}', 209 | '{meters}', 210 | 'time: {time}', 211 | 'data: {data}' 212 | ]) 213 | MB = 1024.0 * 1024.0 214 | for obj in iterable: 215 | data_time.update(time.time() - end) 216 | yield obj 217 | iter_time.update(time.time() - end) 218 | if i % print_freq == 0: 219 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 220 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 221 | if torch.cuda.is_available(): 222 | print(log_msg.format( 223 | i, len(iterable), eta=eta_string, 224 | meters=str(self), 225 | time=str(iter_time), data=str(data_time), 226 | memory=torch.cuda.max_memory_allocated() / MB)) 227 | else: 228 | print(log_msg.format( 229 | i, len(iterable), eta=eta_string, 230 | meters=str(self), 231 | time=str(iter_time), data=str(data_time))) 232 | i += 1 233 | end = time.time() 234 | total_time = time.time() - start_time 235 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 236 | print('{} Total time: {}'.format(header, total_time_str)) 237 | 238 | 239 | 240 | def mkdir(path): 241 | try: 242 | os.makedirs(path) 243 | except OSError as e: 244 | if e.errno != errno.EEXIST: 245 | raise 246 | 247 | 248 | def setup_for_distributed(is_master): 249 | """ 250 | This function disables printing when not in master process 251 | """ 252 | import builtins as __builtin__ 253 | builtin_print = __builtin__.print 254 | 255 | def print(*args, **kwargs): 256 | force = kwargs.pop('force', False) 257 | if is_master or force: 258 | builtin_print(*args, **kwargs) 259 | 260 | __builtin__.print = print 261 | 262 | 263 | def is_dist_avail_and_initialized(): 264 | if not dist.is_available(): 265 | return False 266 | if not dist.is_initialized(): 267 | return False 268 | return True 269 | 270 | 271 | def get_world_size(): 272 | if not is_dist_avail_and_initialized(): 273 | return 1 274 | return dist.get_world_size() 275 | 276 | 277 | def get_rank(): 278 | if not is_dist_avail_and_initialized(): 279 | return 0 280 | return dist.get_rank() 281 | 282 | 283 | def is_main_process(): 284 | return get_rank() == 0 285 | 286 | 287 | def save_on_master(*args, **kwargs): 288 | if is_main_process(): 289 | torch.save(*args, **kwargs) 290 | 291 | 292 | def init_distributed_mode(args): 293 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 294 | args.rank = int(os.environ["RANK"]) 295 | args.world_size = int(os.environ['WORLD_SIZE']) 296 | args.gpu = int(os.environ['LOCAL_RANK']) 297 | elif 'SLURM_PROCID' in os.environ: 298 | args.rank = int(os.environ['SLURM_PROCID']) 299 | args.gpu = args.rank % torch.cuda.device_count() 300 | elif hasattr(args, "rank"): 301 | pass 302 | else: 303 | print('Not using distributed mode') 304 | args.distributed = False 305 | return 306 | 307 | args.distributed = True 308 | 309 | torch.cuda.set_device(args.gpu) 310 | args.dist_backend = 'nccl' 311 | print('| distributed init (rank {}): {}'.format( 312 | args.rank, args.dist_url), flush=True) 313 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 314 | world_size=args.world_size, rank=args.rank) 315 | setup_for_distributed(args.rank == 0) 316 | 317 | 318 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 319 | def __init__( 320 | self, 321 | optimizer, 322 | milestones, 323 | gamma=0.1, 324 | warmup_factor=1.0 / 3, 325 | warmup_iters=5, 326 | warmup_method="linear", 327 | last_epoch=-1, 328 | ): 329 | if not milestones == sorted(milestones): 330 | raise ValueError( 331 | "Milestones should be a list of" " increasing integers. Got {}", 332 | milestones, 333 | ) 334 | 335 | if warmup_method not in ("constant", "linear"): 336 | raise ValueError( 337 | "Only 'constant' or 'linear' warmup_method accepted" 338 | "got {}".format(warmup_method) 339 | ) 340 | self.milestones = milestones 341 | self.gamma = gamma 342 | self.warmup_factor = warmup_factor 343 | self.warmup_iters = warmup_iters 344 | self.warmup_method = warmup_method 345 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 346 | 347 | def get_lr(self): 348 | warmup_factor = 1 349 | if self.last_epoch < self.warmup_iters: 350 | if self.warmup_method == "constant": 351 | warmup_factor = self.warmup_factor 352 | elif self.warmup_method == "linear": 353 | alpha = float(self.last_epoch) / self.warmup_iters 354 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 355 | return [ 356 | base_lr * 357 | warmup_factor * 358 | self.gamma ** bisect_right(self.milestones, self.last_epoch) 359 | for base_lr in self.base_lrs 360 | ] 361 | 362 | 363 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 364 | def __init__( 365 | self, 366 | optimizer, 367 | T_max, 368 | cycles=0.5, 369 | eta_min=0, 370 | warmup_factor=1.0 / 3, 371 | warmup_iters=5, 372 | warmup_method="linear", 373 | last_epoch=-1, 374 | ): 375 | 376 | if warmup_method not in ("constant", "linear"): 377 | raise ValueError( 378 | "Only 'constant' or 'linear' warmup_method accepted" 379 | "got {}".format(warmup_method) 380 | ) 381 | self.T_max = T_max 382 | self.cycles = cycles 383 | self.eta_min = eta_min 384 | self.warmup_factor = warmup_factor 385 | self.warmup_iters = warmup_iters 386 | self.warmup_method = warmup_method 387 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 388 | 389 | def get_lr(self): 390 | warmup_factor = 1 391 | if self.last_epoch < self.warmup_iters: 392 | cosine_factor = 1 393 | if self.warmup_method == "constant": 394 | warmup_factor = self.warmup_factor 395 | elif self.warmup_method == "linear": 396 | alpha = float(self.last_epoch) / self.warmup_iters 397 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 398 | else: 399 | progress = (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters) 400 | cosine_factor = self.cycles * (1 + math.cos(math.pi * progress)) 401 | 402 | return [ 403 | base_lr * warmup_factor * cosine_factor for base_lr in self.base_lrs 404 | ] 405 | 406 | 407 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 408 | decay = [] 409 | no_decay = [] 410 | for name, param in model.module.named_parameters(): 411 | if not param.requires_grad: 412 | continue # frozen weights 413 | if len(param.shape) == 1 or name.endswith(".bias") or 'token' in name or name in skip_list: 414 | # print(name) 415 | no_decay.append(param) 416 | else: 417 | decay.append(param) 418 | return [ 419 | {'params': no_decay, 'weight_decay': 0.}, 420 | {'params': decay, 'weight_decay': weight_decay}] 421 | 422 | if __name__=='__main__': 423 | output = torch.tensor([[1,2],[4,3]]) 424 | target = torch.tensor([1,1]) 425 | test = accuracy(output, target, topk=(1,2)) 426 | print(test) 427 | print(type(test[0])) -------------------------------------------------------------------------------- /1-linear-msr.py: -------------------------------------------------------------------------------- 1 | # Largely contributed by https://github.com/hehefan/Point-Spatio-Temporal-Convolution 2 | 3 | from __future__ import print_function 4 | import datetime 5 | import os 6 | import time 7 | import sys 8 | import random 9 | import numpy as np 10 | from tensorboardX import SummaryWriter 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 13 | # 4 GPUs 24 batch for linear probing 14 | 15 | import torch 16 | import torch.utils.data 17 | import torch.distributed as dist 18 | from torch import nn 19 | import torch.nn.functional as F 20 | 21 | import utils 22 | from logger import setup_logger 23 | from datasets.msr import MSRAction3D 24 | from models.CLR_Model import ContrastiveLearningModel 25 | from timm.loss import LabelSmoothingCrossEntropy 26 | 27 | 28 | def train(model, criterion, optimizer, lr_scheduler, data_loader, 29 | device, epoch, print_freq, logger): 30 | batch_time = utils.AverageMeter() 31 | data_time = utils.AverageMeter() 32 | losses = utils.AverageMeter() 33 | top1 = utils.AverageMeter() 34 | top5 = utils.AverageMeter() 35 | 36 | model.train() 37 | 38 | for i, (clip, target, _) in enumerate(data_loader): 39 | start_time = time.time() 40 | clip, target = clip.to(device), target.to(device) 41 | output = model(clip) 42 | loss = criterion(output, target) 43 | batch_size = clip.shape[0] 44 | lr_ = optimizer.param_groups[-1]["lr"] 45 | 46 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 47 | 48 | losses.update(loss.item(), batch_size) 49 | top1.update(acc1.item(), batch_size) 50 | top5.update(acc5.item(), batch_size) 51 | 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | lr_scheduler.step() 56 | 57 | batch_time.update(time.time() - start_time) 58 | 59 | if i % print_freq == 0: 60 | logger.info(('Epoch: [{0}][{1}/{2}]\t' 61 | 'lr: {lr:.5f}\t' 62 | 'Batch-Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 63 | 'Loss: {loss.val:.4f} ({loss.avg:.4f})\t' 64 | 'Top1: {top1.val:.3f} ({top1.avg:.3f})\t' 65 | 'Top5: {top5.val:.3f} ({top5.avg:.3f})'.format( 66 | epoch, i, len(data_loader), 67 | lr=lr_, batch_time=batch_time, 68 | loss=losses, top1=top1, top5=top5))) 69 | 70 | return losses.avg, top1.avg, top5.avg 71 | 72 | 73 | def evaluate(model, criterion, data_loader, device, print_freq, logger): 74 | batch_time = utils.AverageMeter() 75 | losses = utils.AverageMeter() 76 | top1 = utils.AverageMeter() 77 | top5 = utils.AverageMeter() 78 | 79 | model.eval() 80 | 81 | video_prob = {} 82 | video_label = {} 83 | with torch.no_grad(): 84 | for i, (clip, target, video_idx) in enumerate(data_loader): 85 | start_time = time.time() 86 | clip = clip.to(device, non_blocking=True) 87 | target = target.to(device, non_blocking=True) 88 | output = model(clip) 89 | loss = criterion(output, target) 90 | batch_size = clip.shape[0] 91 | 92 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 93 | batch_time.update(time.time() - start_time) 94 | 95 | losses.update(loss.item(), batch_size) 96 | top1.update(acc1.item(), batch_size) 97 | top5.update(acc5.item(), batch_size) 98 | 99 | if i % print_freq == 0: 100 | logger.info(('Test: [{0}/{1}]\t' 101 | 'Batch-Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 102 | 'Loss: {loss.val:.4f} ({loss.avg:.4f})\t' 103 | 'Top1: {top1.val:.3f} ({top1.avg:.3f})\t' 104 | 'Top5: {top5.val:.3f} ({top5.avg:.3f})'.format( 105 | i, len(data_loader), batch_time=batch_time, 106 | loss=losses, top1=top1, top5=top5))) 107 | 108 | prob = F.softmax(input=output, dim=1) 109 | 110 | # FIXME need to take into account that the datasets 111 | # could have been padded in distributed setup 112 | target = target.cpu().numpy() 113 | video_idx = video_idx.cpu().numpy() 114 | prob = prob.cpu().numpy() 115 | for i in range(0, batch_size): 116 | idx = video_idx[i] 117 | if idx in video_prob: 118 | video_prob[idx] += prob[i] 119 | else: 120 | video_prob[idx] = prob[i] 121 | video_label[idx] = target[i] 122 | 123 | # video level prediction 124 | video_pred = {k: np.argmax(v) for k, v in video_prob.items()} 125 | pred_correct = [video_pred[k]==video_label[k] for k in video_pred] 126 | total_acc = torch.tensor(np.mean(pred_correct)).to(device) 127 | 128 | class_count = [0] * data_loader.dataset.num_classes 129 | class_correct = [0] * data_loader.dataset.num_classes 130 | 131 | for k, v in video_pred.items(): 132 | label = video_label[k] 133 | class_count[label] += 1 134 | class_correct[label] += (v==label) 135 | class_acc = torch.tensor([c/float(s) for c, s in zip(class_correct, class_count)]).to(device) 136 | 137 | logger.info(('Video-level Total-acc: {:.5f}\t'.format(total_acc.item()))) 138 | logger.info(('Video-level Class-acc: {}'.format(np.round(class_acc.tolist(),3)))) 139 | 140 | return losses.avg, top1.avg, top5.avg, total_acc.item() 141 | 142 | 143 | def main(args): 144 | 145 | # Fix the seed 146 | random.seed(args.seed) 147 | np.random.seed(args.seed) 148 | torch.manual_seed(args.seed) 149 | torch.cuda.manual_seed(args.seed) 150 | torch.cuda.manual_seed_all(args.seed) 151 | torch.backends.cudnn.deterministic = True 152 | torch.backends.cudnn.benchmark = False 153 | 154 | device = torch.device("cuda") 155 | 156 | # Check folders and setup logger 157 | output_dir = os.path.join(args.output_dir, args.model) 158 | log_dir = os.path.join(args.log_dir, args.model) 159 | utils.mkdir(output_dir) 160 | utils.mkdir(log_dir) 161 | 162 | with open(os.path.join(log_dir, 'args.txt'), 'w') as f: 163 | f.write(str(args)) 164 | 165 | logger = setup_logger(output=log_dir, distributed_rank=0, name=args.model) 166 | tf_writer = SummaryWriter(log_dir=log_dir) 167 | 168 | # Data loading code 169 | dataset = MSRAction3D( 170 | root=args.data_path, 171 | meta=args.data_meta, 172 | frames_per_clip=args.clip_len, 173 | step_between_clips=args.clip_stride, 174 | step_between_frames=args.frame_stride, 175 | num_points=args.num_points, 176 | sub_clips=args.sub_clips, 177 | train=True 178 | ) 179 | train_loader = torch.utils.data.DataLoader( 180 | dataset, 181 | batch_size=args.batch_size, 182 | num_workers=args.workers, 183 | shuffle=True, 184 | pin_memory=True, 185 | drop_last=True 186 | ) 187 | dataset_test = MSRAction3D( 188 | root=args.data_path, 189 | meta=args.data_meta, 190 | frames_per_clip=args.clip_len, 191 | step_between_clips=args.clip_stride, 192 | step_between_frames=args.frame_stride, 193 | num_points=args.num_points, 194 | sub_clips=args.sub_clips, 195 | train=False 196 | ) 197 | val_loader = torch.utils.data.DataLoader( 198 | dataset_test, 199 | batch_size=args.batch_size, 200 | num_workers=args.workers, 201 | pin_memory=True 202 | ) 203 | # Creat Contrastive Learning Model 204 | model = ContrastiveLearningModel( 205 | radius=args.radius, 206 | nsamples=args.nsamples, 207 | representation_dim=args.representation_dim, 208 | num_classes=dataset.num_classes, 209 | pretraining=False 210 | ) 211 | # Distributed model 212 | if torch.cuda.device_count() > 1: 213 | model = nn.DataParallel(model) 214 | model.to(device) 215 | 216 | logger.info(("===> Loading checkpoint for finetune '{}'".format(args.finetune))) 217 | checkpoint = torch.load(args.finetune, map_location='cpu') 218 | state_dict = checkpoint['model'] 219 | 220 | for k in list(state_dict.keys()): 221 | if not k.startswith(('module.encoder')): 222 | del state_dict[k] 223 | 224 | log = model.load_state_dict(state_dict, strict=False) 225 | assert log.missing_keys == ['module.fc_out.weight', 'module.fc_out.bias'] 226 | 227 | # freeze all layers but the last fc 228 | for name, param in model.named_parameters(): 229 | if not name.startswith(('module.fc_out')): 230 | param.requires_grad = False 231 | 232 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 233 | assert len(parameters) == 2 234 | logger.info(("===> Loaded checkpoint with epoch {}".format(checkpoint['epoch']))) 235 | 236 | criterion = LabelSmoothingCrossEntropy(smoothing=0.1) 237 | 238 | optimizer = torch.optim.SGD(parameters, lr=args.lr, 239 | momentum=args.momentum, 240 | weight_decay=args.weight_decay 241 | ) 242 | warmup_iters = args.lr_warmup_epochs * len(train_loader) 243 | epochs_iters = args.epochs * len(train_loader) 244 | lr_scheduler = utils.WarmupCosineLR(optimizer, 245 | T_max=epochs_iters, 246 | warmup_iters=warmup_iters, 247 | last_epoch=-1 248 | ) 249 | start_time = time.time() 250 | acc = 0 251 | for epoch in range(args.start_epoch, args.epochs): 252 | train_loss, train_top1, train_top5 = train(model, criterion, optimizer, 253 | lr_scheduler, train_loader, device, 254 | epoch, args.print_freq, logger) 255 | 256 | test_loss, test_top1, test_top5, total_acc = evaluate(model, criterion, val_loader, 257 | device, args.print_freq, logger) 258 | acc = max(acc, total_acc) 259 | 260 | logger.info(("Best total acc: '{}'".format(acc))) 261 | 262 | tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch) 263 | tf_writer.add_scalar('loss/train', train_loss, epoch) 264 | tf_writer.add_scalar('acc/train_top1', train_top1, epoch) 265 | tf_writer.add_scalar('acc/train_top5', train_top5, epoch) 266 | tf_writer.add_scalar('loss/test', test_loss, epoch) 267 | tf_writer.add_scalar('acc/test_top1', test_top1, epoch) 268 | tf_writer.add_scalar('acc/test_top5', test_top5, epoch) 269 | tf_writer.add_scalar('acc/total_acc_best', acc, epoch) 270 | tf_writer.flush() 271 | 272 | checkpoint = { 273 | 'model': model.state_dict(), 274 | 'optimizer': optimizer.state_dict(), 275 | 'lr_scheduler': lr_scheduler.state_dict(), 276 | 'epoch': epoch, 277 | 'args': args} 278 | torch.save( 279 | checkpoint, 280 | os.path.join(output_dir, 'model_{}.pth'.format(epoch))) 281 | 282 | total_time = time.time() - start_time 283 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 284 | logger.info(('Training time {}'.format(total_time_str))) 285 | 286 | 287 | def parse_args(): 288 | import argparse 289 | parser = argparse.ArgumentParser(description='PSTNet Training') 290 | 291 | parser.add_argument('--data-path', default='/data/MSRAction', metavar='DIR', help='path to dataset') 292 | parser.add_argument('--data-meta', default='datasets/MSRAction_all.list', help='dataset') 293 | parser.add_argument('--seed', default=0, type=int, help='random seed') 294 | parser.add_argument('--model', default='MSR', type=str, help='model') 295 | 296 | parser.add_argument('--radius', default=0.3, type=float, help='radius for the ball query') 297 | parser.add_argument('--nsamples', default=9, type=int, help='number of neighbors for the ball query') 298 | parser.add_argument('--clip-len', default=16, type=int, metavar='N', help='number of frames per clip') 299 | parser.add_argument('--sub-clips', default=4, type=int, metavar='N', help='number of sub-clips') 300 | parser.add_argument('--clip-stride', default=1, type=int, metavar='N', help='number of steps between clips') 301 | parser.add_argument('--frame-stride', default=1, type=int, metavar='N', help='number of steps between frames') 302 | # Following PSTNet, when using a small 'clip-len', increasing 'frame-stride' appropriately will help improve accuracy. 303 | parser.add_argument('--num-points', default=2048, type=int, metavar='N', help='number of points per frame') 304 | parser.add_argument('--representation-dim', default=1024, type=int, metavar='N', help='representation dim') 305 | 306 | parser.add_argument('-b', '--batch-size', default=24, type=int) 307 | parser.add_argument('--lr', default=0.015, type=float, help='initial learning rate') 308 | parser.add_argument('--finetune', default='MSR/checkpoint.pth', help='finetune from checkpoint') 309 | 310 | parser.add_argument('--epochs', default=35, type=int, metavar='N', help='number of total epochs to run') 311 | parser.add_argument('--lr-milestones', nargs='+', default=[20, 30], type=int, help='decrease lr on milestones') 312 | parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs') 313 | 314 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', help='number of data loading workers (default: 16)') 315 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 316 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 317 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 318 | parser.add_argument('--print-freq', default=200, type=int, help='print frequency') 319 | parser.add_argument('--output-dir', default='output/', type=str, help='path where to save') 320 | parser.add_argument('--log-dir', default='log/', type=str, help='path where to save') 321 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') 322 | parser.add_argument('--resume', default='', help='resume from checkpoint') 323 | 324 | args = parser.parse_args() 325 | 326 | return args 327 | 328 | 329 | if __name__ == "__main__": 330 | args = parse_args() 331 | main(args) 332 | -------------------------------------------------------------------------------- /datasets/MSRAction_all.list: -------------------------------------------------------------------------------- 1 | a20_s09_e03_sdepth 55 2 | a07_s02_e03_sdepth 48 3 | a20_s08_e01_sdepth 54 4 | a14_s01_e03_sdepth 40 5 | a01_s09_e02_sdepth 31 6 | a19_s06_e02_sdepth 49 7 | a13_s02_e03_sdepth 66 8 | a19_s02_e02_sdepth 34 9 | a19_s05_e03_sdepth 36 10 | a18_s08_e03_sdepth 50 11 | a04_s07_e01_sdepth 36 12 | a18_s03_e03_sdepth 53 13 | a20_s04_e02_sdepth 49 14 | a17_s05_e01_sdepth 26 15 | a18_s03_e01_sdepth 46 16 | a19_s02_e01_sdepth 45 17 | a06_s10_e03_sdepth 24 18 | a11_s02_e01_sdepth 40 19 | a08_s05_e02_sdepth 33 20 | a09_s02_e01_sdepth 51 21 | a17_s08_e03_sdepth 46 22 | a18_s02_e03_sdepth 45 23 | a13_s06_e02_sdepth 31 24 | a10_s05_e02_sdepth 21 25 | a19_s07_e03_sdepth 46 26 | a05_s06_e01_sdepth 60 27 | a10_s06_e01_sdepth 49 28 | a11_s03_e03_sdepth 39 29 | a07_s02_e01_sdepth 37 30 | a04_s09_e03_sdepth 35 31 | a08_s08_e03_sdepth 35 32 | a11_s01_e03_sdepth 40 33 | a10_s10_e03_sdepth 20 34 | a08_s05_e03_sdepth 26 35 | a07_s05_e01_sdepth 26 36 | a07_s08_e03_sdepth 51 37 | a09_s08_e01_sdepth 40 38 | a20_s04_e03_sdepth 49 39 | a17_s02_e01_sdepth 50 40 | a20_s07_e03_sdepth 56 41 | a09_s04_e03_sdepth 48 42 | a03_s07_e01_sdepth 35 43 | a12_s09_e03_sdepth 23 44 | a06_s05_e02_sdepth 28 45 | a03_s06_e03_sdepth 58 46 | a18_s10_e01_sdepth 33 47 | a05_s03_e01_sdepth 34 48 | a18_s07_e03_sdepth 51 49 | a08_s07_e02_sdepth 39 50 | a02_s07_e03_sdepth 48 51 | a04_s08_e01_sdepth 52 52 | a01_s07_e03_sdepth 44 53 | a07_s06_e01_sdepth 47 54 | a13_s06_e03_sdepth 37 55 | a11_s10_e03_sdepth 25 56 | a08_s06_e03_sdepth 45 57 | a07_s10_e01_sdepth 33 58 | a17_s10_e01_sdepth 31 59 | a09_s07_e02_sdepth 35 60 | a03_s05_e01_sdepth 30 61 | a05_s07_e01_sdepth 35 62 | a01_s01_e01_sdepth 54 63 | a02_s09_e03_sdepth 37 64 | a19_s05_e01_sdepth 31 65 | a13_s08_e03_sdepth 57 66 | a19_s01_e03_sdepth 50 67 | a11_s04_e03_sdepth 38 68 | a05_s07_e02_sdepth 34 69 | a17_s07_e03_sdepth 45 70 | a04_s05_e01_sdepth 40 71 | a16_s10_e02_sdepth 37 72 | a09_s01_e03_sdepth 49 73 | a03_s08_e01_sdepth 29 74 | a19_s05_e02_sdepth 33 75 | a12_s05_e01_sdepth 16 76 | a01_s07_e01_sdepth 50 77 | a09_s06_e03_sdepth 58 78 | a14_s07_e02_sdepth 31 79 | a13_s10_e02_sdepth 31 80 | a02_s06_e01_sdepth 42 81 | a11_s03_e02_sdepth 44 82 | a04_s06_e03_sdepth 52 83 | a19_s06_e03_sdepth 51 84 | a17_s07_e01_sdepth 45 85 | a03_s10_e03_sdepth 39 86 | a18_s06_e02_sdepth 49 87 | a17_s04_e03_sdepth 50 88 | a01_s02_e03_sdepth 45 89 | a11_s10_e01_sdepth 28 90 | a18_s05_e02_sdepth 38 91 | a01_s09_e01_sdepth 45 92 | a02_s10_e01_sdepth 38 93 | a14_s07_e03_sdepth 30 94 | a15_s01_e02_sdepth 41 95 | a13_s08_e02_sdepth 57 96 | a19_s04_e01_sdepth 71 97 | a02_s08_e01_sdepth 66 98 | a18_s01_e02_sdepth 42 99 | a20_s09_e01_sdepth 50 100 | a01_s06_e01_sdepth 57 101 | a20_s10_e03_sdepth 37 102 | a10_s07_e01_sdepth 28 103 | a08_s05_e01_sdepth 28 104 | a12_s08_e02_sdepth 43 105 | a17_s09_e02_sdepth 37 106 | a04_s08_e03_sdepth 36 107 | a07_s01_e02_sdepth 42 108 | a12_s07_e01_sdepth 39 109 | a15_s07_e01_sdepth 42 110 | a16_s02_e02_sdepth 55 111 | a13_s03_e02_sdepth 52 112 | a11_s07_e02_sdepth 36 113 | a18_s05_e01_sdepth 40 114 | a20_s04_e01_sdepth 45 115 | a14_s07_e01_sdepth 34 116 | a08_s10_e01_sdepth 26 117 | a03_s09_e01_sdepth 34 118 | a06_s03_e01_sdepth 46 119 | a16_s05_e01_sdepth 51 120 | a06_s08_e03_sdepth 35 121 | a14_s10_e01_sdepth 26 122 | a09_s02_e03_sdepth 39 123 | a15_s07_e03_sdepth 38 124 | a07_s03_e02_sdepth 47 125 | a08_s02_e02_sdepth 34 126 | a20_s07_e02_sdepth 54 127 | a10_s03_e02_sdepth 35 128 | a09_s05_e02_sdepth 24 129 | a05_s10_e01_sdepth 27 130 | a15_s08_e02_sdepth 29 131 | a17_s03_e01_sdepth 32 132 | a10_s10_e02_sdepth 18 133 | a12_s08_e03_sdepth 61 134 | a18_s04_e03_sdepth 51 135 | a19_s01_e01_sdepth 59 136 | a19_s03_e02_sdepth 55 137 | a15_s02_e02_sdepth 38 138 | a12_s04_e02_sdepth 31 139 | a08_s03_e03_sdepth 39 140 | a12_s08_e01_sdepth 53 141 | a10_s07_e03_sdepth 39 142 | a14_s05_e03_sdepth 30 143 | a15_s09_e02_sdepth 26 144 | a03_s09_e02_sdepth 41 145 | a07_s03_e01_sdepth 48 146 | a12_s10_e03_sdepth 34 147 | a20_s02_e02_sdepth 50 148 | a09_s10_e03_sdepth 28 149 | a19_s07_e02_sdepth 40 150 | a12_s02_e03_sdepth 47 151 | a04_s02_e02_sdepth 28 152 | a18_s05_e03_sdepth 34 153 | a06_s02_e03_sdepth 49 154 | a06_s09_e02_sdepth 30 155 | a06_s10_e01_sdepth 26 156 | a14_s08_e02_sdepth 31 157 | a12_s03_e03_sdepth 38 158 | a03_s08_e02_sdepth 31 159 | a17_s03_e02_sdepth 34 160 | a10_s05_e01_sdepth 30 161 | a14_s08_e01_sdepth 48 162 | a11_s06_e01_sdepth 42 163 | a10_s08_e01_sdepth 40 164 | a04_s10_e01_sdepth 30 165 | a03_s06_e02_sdepth 34 166 | a06_s07_e02_sdepth 33 167 | a11_s09_e01_sdepth 37 168 | a11_s10_e02_sdepth 26 169 | a15_s07_e02_sdepth 35 170 | a01_s06_e02_sdepth 38 171 | a12_s01_e02_sdepth 36 172 | a20_s08_e02_sdepth 54 173 | a05_s10_e03_sdepth 23 174 | a14_s01_e02_sdepth 35 175 | a08_s03_e02_sdepth 36 176 | a16_s06_e01_sdepth 38 177 | a19_s01_e02_sdepth 54 178 | a16_s02_e01_sdepth 42 179 | a02_s09_e01_sdepth 47 180 | a12_s04_e01_sdepth 44 181 | a04_s03_e01_sdepth 100 182 | a14_s02_e01_sdepth 32 183 | a15_s06_e02_sdepth 56 184 | a13_s04_e01_sdepth 48 185 | a12_s03_e01_sdepth 20 186 | a05_s09_e03_sdepth 37 187 | a13_s02_e01_sdepth 43 188 | a10_s02_e01_sdepth 29 189 | a19_s04_e02_sdepth 51 190 | a02_s01_e02_sdepth 39 191 | a09_s10_e02_sdepth 29 192 | a08_s08_e02_sdepth 31 193 | a10_s01_e03_sdepth 38 194 | a17_s05_e02_sdepth 33 195 | a15_s06_e01_sdepth 50 196 | a14_s08_e03_sdepth 38 197 | a10_s04_e01_sdepth 33 198 | a06_s09_e01_sdepth 36 199 | a13_s06_e01_sdepth 36 200 | a03_s03_e01_sdepth 41 201 | a02_s02_e01_sdepth 44 202 | a12_s07_e02_sdepth 44 203 | a04_s07_e03_sdepth 37 204 | a11_s08_e01_sdepth 64 205 | a09_s03_e02_sdepth 47 206 | a16_s07_e03_sdepth 40 207 | a05_s01_e01_sdepth 44 208 | a07_s07_e02_sdepth 42 209 | a01_s05_e02_sdepth 38 210 | a04_s01_e03_sdepth 23 211 | a13_s05_e01_sdepth 22 212 | a16_s01_e01_sdepth 49 213 | a09_s08_e03_sdepth 41 214 | a01_s08_e03_sdepth 62 215 | a03_s03_e03_sdepth 45 216 | a17_s08_e02_sdepth 44 217 | a04_s09_e02_sdepth 42 218 | a08_s04_e01_sdepth 39 219 | a10_s02_e03_sdepth 42 220 | a05_s09_e02_sdepth 28 221 | a08_s09_e01_sdepth 37 222 | a14_s10_e03_sdepth 22 223 | a08_s07_e03_sdepth 38 224 | a12_s09_e01_sdepth 26 225 | a04_s08_e02_sdepth 25 226 | a14_s04_e02_sdepth 36 227 | a14_s06_e03_sdepth 40 228 | a14_s04_e03_sdepth 34 229 | a17_s02_e03_sdepth 37 230 | a05_s05_e03_sdepth 31 231 | a01_s08_e02_sdepth 67 232 | a04_s10_e02_sdepth 24 233 | a14_s02_e03_sdepth 48 234 | a12_s07_e03_sdepth 51 235 | a04_s06_e01_sdepth 30 236 | a07_s09_e02_sdepth 38 237 | a02_s05_e03_sdepth 29 238 | a06_s02_e01_sdepth 40 239 | a20_s05_e01_sdepth 34 240 | a03_s07_e03_sdepth 36 241 | a14_s10_e02_sdepth 19 242 | a19_s10_e01_sdepth 32 243 | a16_s06_e02_sdepth 42 244 | a16_s10_e03_sdepth 34 245 | a14_s03_e03_sdepth 42 246 | a04_s01_e01_sdepth 31 247 | a01_s05_e03_sdepth 38 248 | a03_s05_e03_sdepth 32 249 | a11_s07_e03_sdepth 40 250 | a15_s08_e03_sdepth 33 251 | a03_s09_e03_sdepth 35 252 | a17_s09_e03_sdepth 38 253 | a13_s09_e01_sdepth 39 254 | a08_s09_e03_sdepth 41 255 | a16_s04_e01_sdepth 41 256 | a13_s01_e03_sdepth 46 257 | a10_s03_e03_sdepth 35 258 | a07_s05_e03_sdepth 35 259 | a18_s07_e01_sdepth 51 260 | a08_s07_e01_sdepth 33 261 | a09_s07_e01_sdepth 34 262 | a03_s02_e03_sdepth 47 263 | a01_s01_e02_sdepth 43 264 | a08_s08_e01_sdepth 41 265 | a16_s05_e02_sdepth 39 266 | a16_s08_e01_sdepth 53 267 | a11_s02_e03_sdepth 34 268 | a07_s10_e03_sdepth 28 269 | a02_s07_e01_sdepth 52 270 | a05_s05_e02_sdepth 34 271 | a08_s03_e01_sdepth 38 272 | a06_s02_e02_sdepth 36 273 | a09_s09_e03_sdepth 40 274 | a08_s04_e03_sdepth 39 275 | a17_s04_e02_sdepth 45 276 | a17_s01_e03_sdepth 57 277 | a05_s03_e02_sdepth 37 278 | a08_s06_e01_sdepth 37 279 | a02_s01_e03_sdepth 37 280 | a12_s03_e02_sdepth 35 281 | a06_s08_e01_sdepth 38 282 | a01_s07_e02_sdepth 49 283 | a18_s08_e01_sdepth 43 284 | a06_s07_e01_sdepth 41 285 | a06_s01_e02_sdepth 34 286 | a03_s03_e02_sdepth 36 287 | a02_s06_e03_sdepth 55 288 | a01_s08_e01_sdepth 58 289 | a18_s02_e02_sdepth 39 290 | a17_s02_e02_sdepth 48 291 | a04_s10_e03_sdepth 24 292 | a19_s08_e01_sdepth 46 293 | a13_s01_e02_sdepth 42 294 | a13_s07_e01_sdepth 33 295 | a18_s02_e01_sdepth 47 296 | a18_s06_e01_sdepth 31 297 | a12_s10_e02_sdepth 29 298 | a10_s03_e01_sdepth 37 299 | a14_s03_e01_sdepth 38 300 | a03_s06_e01_sdepth 68 301 | a08_s06_e02_sdepth 34 302 | a18_s09_e01_sdepth 44 303 | a01_s01_e03_sdepth 43 304 | a07_s01_e03_sdepth 38 305 | a17_s04_e01_sdepth 42 306 | a09_s08_e02_sdepth 37 307 | a05_s03_e03_sdepth 31 308 | a10_s02_e02_sdepth 25 309 | a11_s01_e01_sdepth 44 310 | a03_s08_e03_sdepth 30 311 | a12_s06_e03_sdepth 44 312 | a08_s01_e03_sdepth 31 313 | a20_s10_e01_sdepth 44 314 | a19_s03_e03_sdepth 57 315 | a05_s08_e01_sdepth 40 316 | a04_s01_e02_sdepth 41 317 | a20_s03_e01_sdepth 43 318 | a11_s04_e01_sdepth 35 319 | a15_s01_e01_sdepth 31 320 | a20_s01_e02_sdepth 44 321 | a20_s05_e03_sdepth 36 322 | a02_s03_e02_sdepth 41 323 | a15_s10_e03_sdepth 19 324 | a11_s03_e01_sdepth 41 325 | a03_s05_e02_sdepth 37 326 | a15_s09_e03_sdepth 30 327 | a13_s10_e01_sdepth 32 328 | a10_s09_e01_sdepth 32 329 | a10_s04_e02_sdepth 28 330 | a01_s02_e01_sdepth 55 331 | a15_s10_e01_sdepth 24 332 | a03_s07_e02_sdepth 36 333 | a11_s09_e03_sdepth 48 334 | a10_s06_e02_sdepth 49 335 | a18_s08_e02_sdepth 36 336 | a04_s02_e01_sdepth 45 337 | a06_s03_e03_sdepth 39 338 | a13_s10_e03_sdepth 35 339 | a07_s03_e03_sdepth 51 340 | a17_s10_e02_sdepth 35 341 | a14_s05_e02_sdepth 31 342 | a20_s07_e01_sdepth 71 343 | a06_s06_e02_sdepth 56 344 | a16_s07_e02_sdepth 37 345 | a06_s03_e02_sdepth 37 346 | a09_s06_e01_sdepth 40 347 | a11_s05_e02_sdepth 15 348 | a06_s01_e01_sdepth 39 349 | a08_s02_e01_sdepth 30 350 | a04_s09_e01_sdepth 44 351 | a13_s07_e02_sdepth 32 352 | a06_s09_e03_sdepth 37 353 | a09_s09_e02_sdepth 38 354 | a18_s03_e02_sdepth 13 355 | a01_s10_e02_sdepth 40 356 | a16_s08_e02_sdepth 57 357 | a01_s02_e02_sdepth 31 358 | a05_s08_e02_sdepth 32 359 | a01_s03_e03_sdepth 38 360 | a18_s10_e02_sdepth 33 361 | a05_s06_e02_sdepth 63 362 | a20_s08_e03_sdepth 60 363 | a05_s05_e01_sdepth 26 364 | a03_s01_e01_sdepth 40 365 | a17_s10_e03_sdepth 30 366 | a09_s01_e02_sdepth 42 367 | a17_s01_e01_sdepth 48 368 | a10_s06_e03_sdepth 47 369 | a19_s09_e03_sdepth 37 370 | a03_s10_e02_sdepth 38 371 | a13_s05_e03_sdepth 36 372 | a02_s02_e03_sdepth 46 373 | a15_s08_e01_sdepth 46 374 | a10_s05_e03_sdepth 20 375 | a05_s10_e02_sdepth 21 376 | a16_s04_e03_sdepth 49 377 | a06_s06_e01_sdepth 52 378 | a01_s10_e01_sdepth 45 379 | a20_s06_e03_sdepth 52 380 | a09_s03_e03_sdepth 40 381 | a20_s05_e02_sdepth 34 382 | a01_s05_e01_sdepth 38 383 | a06_s08_e02_sdepth 34 384 | a17_s06_e02_sdepth 38 385 | a11_s09_e02_sdepth 47 386 | a18_s01_e01_sdepth 43 387 | a12_s05_e02_sdepth 19 388 | a08_s10_e02_sdepth 26 389 | a13_s01_e01_sdepth 41 390 | a02_s08_e03_sdepth 62 391 | a16_s01_e03_sdepth 51 392 | a07_s05_e02_sdepth 29 393 | a05_s02_e03_sdepth 46 394 | a07_s04_e02_sdepth 47 395 | a15_s01_e03_sdepth 44 396 | a12_s06_e02_sdepth 52 397 | a09_s05_e03_sdepth 31 398 | a07_s02_e02_sdepth 46 399 | a10_s04_e03_sdepth 27 400 | a02_s10_e02_sdepth 32 401 | a13_s07_e03_sdepth 37 402 | a01_s06_e03_sdepth 52 403 | a09_s04_e01_sdepth 43 404 | a19_s10_e02_sdepth 31 405 | a11_s07_e01_sdepth 47 406 | a16_s08_e03_sdepth 55 407 | a15_s02_e01_sdepth 24 408 | a20_s06_e02_sdepth 44 409 | a16_s04_e02_sdepth 38 410 | a13_s05_e02_sdepth 27 411 | a12_s02_e01_sdepth 33 412 | a08_s02_e03_sdepth 37 413 | a16_s05_e03_sdepth 44 414 | a02_s08_e02_sdepth 60 415 | a15_s06_e03_sdepth 40 416 | a07_s06_e02_sdepth 36 417 | a03_s01_e02_sdepth 35 418 | a07_s09_e01_sdepth 41 419 | a14_s03_e02_sdepth 28 420 | a08_s09_e02_sdepth 33 421 | a12_s05_e03_sdepth 16 422 | a06_s07_e03_sdepth 37 423 | a07_s04_e01_sdepth 45 424 | a14_s06_e01_sdepth 38 425 | a02_s10_e03_sdepth 29 426 | a07_s09_e03_sdepth 41 427 | a02_s02_e02_sdepth 41 428 | a12_s06_e01_sdepth 57 429 | a02_s03_e01_sdepth 38 430 | a14_s09_e01_sdepth 18 431 | a16_s03_e01_sdepth 42 432 | a16_s03_e02_sdepth 47 433 | a20_s06_e01_sdepth 54 434 | a07_s10_e02_sdepth 22 435 | a10_s08_e03_sdepth 29 436 | a02_s09_e02_sdepth 42 437 | a20_s03_e03_sdepth 43 438 | a07_s06_e03_sdepth 67 439 | a09_s04_e02_sdepth 39 440 | a04_s07_e02_sdepth 36 441 | a11_s05_e01_sdepth 21 442 | a17_s05_e03_sdepth 30 443 | a11_s05_e03_sdepth 20 444 | a02_s05_e02_sdepth 31 445 | a09_s07_e03_sdepth 42 446 | a17_s03_e03_sdepth 46 447 | a05_s07_e03_sdepth 42 448 | a17_s09_e01_sdepth 36 449 | a16_s10_e01_sdepth 28 450 | a02_s07_e02_sdepth 46 451 | a04_s06_e02_sdepth 52 452 | a11_s01_e02_sdepth 43 453 | a14_s02_e02_sdepth 28 454 | a06_s05_e01_sdepth 34 455 | a20_s01_e01_sdepth 55 456 | a14_s01_e01_sdepth 46 457 | a01_s10_e03_sdepth 33 458 | a18_s04_e02_sdepth 51 459 | a08_s10_e03_sdepth 30 460 | a08_s04_e02_sdepth 37 461 | a11_s08_e03_sdepth 70 462 | a06_s10_e02_sdepth 29 463 | a18_s06_e03_sdepth 69 464 | a05_s02_e02_sdepth 37 465 | a10_s01_e01_sdepth 36 466 | a02_s05_e01_sdepth 46 467 | a07_s01_e01_sdepth 43 468 | a10_s09_e02_sdepth 28 469 | a09_s06_e02_sdepth 51 470 | a07_s07_e01_sdepth 40 471 | a03_s02_e02_sdepth 40 472 | a06_s01_e03_sdepth 42 473 | a13_s09_e02_sdepth 38 474 | a11_s08_e02_sdepth 72 475 | a14_s06_e02_sdepth 44 476 | a19_s07_e01_sdepth 44 477 | a18_s01_e03_sdepth 51 478 | a09_s02_e02_sdepth 37 479 | a03_s10_e01_sdepth 39 480 | a13_s04_e03_sdepth 38 481 | a10_s10_e01_sdepth 27 482 | a05_s01_e02_sdepth 76 483 | a16_s09_e01_sdepth 34 484 | a19_s10_e03_sdepth 34 485 | a12_s10_e01_sdepth 28 486 | a19_s08_e02_sdepth 50 487 | a11_s06_e02_sdepth 53 488 | a19_s06_e01_sdepth 53 489 | a01_s03_e02_sdepth 33 490 | a12_s04_e03_sdepth 35 491 | a09_s01_e01_sdepth 38 492 | a20_s10_e02_sdepth 38 493 | a17_s07_e02_sdepth 45 494 | a19_s04_e03_sdepth 46 495 | a18_s04_e01_sdepth 48 496 | a12_s01_e01_sdepth 49 497 | a14_s04_e01_sdepth 58 498 | a13_s09_e03_sdepth 255 499 | a13_s08_e01_sdepth 55 500 | a09_s10_e01_sdepth 28 501 | a04_s02_e03_sdepth 38 502 | a12_s01_e03_sdepth 28 503 | a09_s09_e01_sdepth 41 504 | a04_s05_e03_sdepth 27 505 | a16_s06_e03_sdepth 58 506 | a15_s10_e02_sdepth 26 507 | a11_s06_e03_sdepth 54 508 | a16_s09_e02_sdepth 42 509 | a14_s09_e03_sdepth 28 510 | a10_s08_e02_sdepth 33 511 | a12_s02_e02_sdepth 33 512 | a11_s02_e02_sdepth 36 513 | a17_s06_e01_sdepth 47 514 | a07_s08_e01_sdepth 50 515 | a09_s03_e01_sdepth 53 516 | a15_s09_e01_sdepth 32 517 | a06_s05_e03_sdepth 35 518 | a17_s01_e02_sdepth 47 519 | a01_s09_e03_sdepth 33 520 | a04_s05_e02_sdepth 20 521 | a13_s04_e02_sdepth 45 522 | a19_s02_e03_sdepth 40 523 | a04_s03_e02_sdepth 30 524 | a19_s09_e02_sdepth 39 525 | a18_s09_e02_sdepth 34 526 | a19_s03_e01_sdepth 45 527 | a07_s08_e02_sdepth 39 528 | a10_s01_e02_sdepth 29 529 | a08_s01_e02_sdepth 32 530 | a20_s09_e02_sdepth 53 531 | a19_s08_e03_sdepth 57 532 | a17_s06_e03_sdepth 42 533 | a12_s09_e02_sdepth 26 534 | a08_s01_e01_sdepth 30 535 | a16_s03_e03_sdepth 41 536 | a17_s08_e01_sdepth 44 537 | a05_s09_e01_sdepth 24 538 | a14_s09_e02_sdepth 27 539 | a20_s02_e03_sdepth 53 540 | a13_s02_e02_sdepth 30 541 | a02_s01_e01_sdepth 34 542 | a10_s09_e03_sdepth 29 543 | a05_s02_e01_sdepth 44 544 | a03_s01_e03_sdepth 31 545 | a18_s10_e03_sdepth 39 546 | a18_s09_e03_sdepth 39 547 | a02_s03_e03_sdepth 40 548 | a10_s07_e02_sdepth 37 549 | a14_s05_e01_sdepth 27 550 | a03_s02_e01_sdepth 32 551 | a09_s05_e01_sdepth 32 552 | a11_s04_e02_sdepth 37 553 | a19_s09_e01_sdepth 42 554 | a18_s07_e02_sdepth 45 555 | a16_s01_e02_sdepth 48 556 | a16_s09_e03_sdepth 49 557 | a16_s02_e03_sdepth 48 558 | a20_s01_e03_sdepth 53 559 | a13_s03_e03_sdepth 44 560 | a20_s03_e02_sdepth 34 561 | a05_s01_e03_sdepth 33 562 | a20_s02_e01_sdepth 56 563 | a02_s06_e02_sdepth 41 564 | a13_s03_e01_sdepth 44 565 | a05_s08_e03_sdepth 36 566 | a01_s03_e01_sdepth 45 567 | a16_s07_e01_sdepth 40 568 | -------------------------------------------------------------------------------- /models/CLR_Model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import torch.nn.functional as F 9 | import timm 10 | from timm.models.layers import trunc_normal_ 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(BASE_DIR) 15 | sys.path.append(ROOT_DIR) 16 | sys.path.append(os.path.join(ROOT_DIR, 'modules')) 17 | 18 | from Folding import * 19 | from PSTNet import * 20 | from transformer import * 21 | from pst_convolutions import PSTConv 22 | from chamfer_distance import ChamferDistance 23 | import pointnet2_utils 24 | import utils 25 | 26 | 27 | class ContrastiveLearningModel(nn.Module): 28 | def __init__(self, 29 | radius=1.5, 30 | nsamples=3*3, 31 | representation_dim=1024, 32 | num_classes=20, 33 | temperature=0.1, 34 | pretraining=True): 35 | super(ContrastiveLearningModel, self).__init__() 36 | 37 | self.encoder = Encoder(radius=radius, nsamples=nsamples) 38 | 39 | self.pretraining = pretraining 40 | 41 | if self.pretraining: 42 | self.temperature = temperature 43 | self.token_dim = representation_dim 44 | 45 | # regression following P4Transformer 46 | self.emb_relu = False 47 | self.depth = 3 48 | self.heads = 8 49 | self.dim_head =128 50 | self.mlp_dim = 2048 51 | 52 | self.pos_embedding = nn.Conv1d( 53 | in_channels=4, 54 | out_channels=self.token_dim, 55 | kernel_size=1, 56 | stride=1, 57 | padding=0, 58 | bias=True 59 | ) 60 | self.transformer = Transformer( 61 | self.token_dim, 62 | self.depth, 63 | self.heads, 64 | self.dim_head, 65 | self.mlp_dim 66 | ) 67 | self.mlp_head = nn.Sequential( 68 | nn.Linear(self.mlp_dim, self.mlp_dim, bias=False), 69 | nn.BatchNorm1d(self.mlp_dim), 70 | nn.ReLU(inplace=True), 71 | nn.Linear(self.mlp_dim, self.token_dim) 72 | ) 73 | self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.token_dim)) 74 | trunc_normal_(self.mask_token, std=.02) 75 | 76 | self.folding = FoldingDecoder(self.token_dim) 77 | 78 | self.v2_fc = nn.Linear(representation_dim, representation_dim) 79 | # self.v2_fc = nn.Linear(representation_dim, 512) 80 | 81 | self.criterion_local = torch.nn.CrossEntropyLoss() 82 | # self.criterion_dist = ChamferDistance() 83 | else: 84 | self.fc_out = nn.Linear(2048, num_classes) 85 | 86 | 87 | def similarity_aug(self, new_features, Batchsize, Sub_clips, N_out, device): 88 | with torch.no_grad(): 89 | new_features_detach = new_features.clone().detach() 90 | view1_global_detach = torch.mean(input=new_features_detach, dim=-2, keepdim=False) # [B, S, C] 91 | view1_global_detach = torch.max(input=view1_global_detach, dim=-2, keepdim=False)[0] # [B, C] 92 | view1_global_detach = F.normalize(view1_global_detach, dim=-1) 93 | 94 | new_features_detach_norm = F.normalize(new_features_detach, dim=-1) 95 | new_features_detach_norm = new_features_detach_norm.reshape((Batchsize, Sub_clips*N_out, self.token_dim)) # [B, S*N, C] 96 | 97 | mask_list = [] 98 | mask_indx = [] 99 | mask_high_sim = torch.ones((Batchsize, 10, Sub_clips, N_out, self.token_dim), dtype=torch.float32).to(device) # [B, S, N, C] 100 | for bi in range(Batchsize): 101 | # for Token Mask 102 | token_sim_with_global = torch.matmul(new_features_detach_norm[bi], view1_global_detach[bi]) # [S*N] 103 | sort_token_idx = token_sim_with_global.argsort(dim=-1) 104 | high_sim_token_idx = sort_token_idx[int(Sub_clips * N_out * 0.8):] 105 | high_sim_token_clip_idx = (high_sim_token_idx / N_out).int() 106 | 107 | index = -1 108 | sub_i_max = -1 109 | for sub_i in range(Sub_clips): 110 | len_sub_i = len(torch.where(high_sim_token_clip_idx==sub_i)[0]) 111 | if len_sub_i > sub_i_max: 112 | sub_i_max = len_sub_i 113 | index = sub_i 114 | 115 | mask_indx.append(index) 116 | mask = torch.zeros(Sub_clips, dtype=torch.float32) # [S] 117 | mask[index] = 1 # mask 118 | mask_list.append(mask) 119 | 120 | # for Channel Mask 121 | channel_score = new_features_detach_norm[bi] * view1_global_detach[bi] # [S*N, C] 122 | sort_idx = channel_score.argsort(dim=-1) 123 | sort_idx_sort = sort_idx.argsort(dim=-1) 124 | sort_idx_sort = sort_idx_sort.sum(dim=0) 125 | sort_idx_sort = sort_idx_sort.argsort(dim=-1) 126 | high_similarity_idx = sort_idx_sort[int(self.token_dim * 0.8):] 127 | c_idx_num = high_similarity_idx.shape[0] 128 | mask_high_sim[bi,0,:,:,high_similarity_idx] = 0 # mask high sim to 0 129 | 130 | for mci in range(1, mask_high_sim.shape[1]): 131 | c_index = torch.LongTensor(random.sample(range(c_idx_num), int(c_idx_num*0.95))).to(device) 132 | high_sim_idx_random = torch.index_select(high_similarity_idx, 0, c_index) 133 | mask_high_sim[bi,mci,:,:,high_sim_idx_random] = 0 134 | 135 | erase_global = new_features_detach.unsqueeze(1) * mask_high_sim 136 | erase_global = torch.mean(input=erase_global, dim=-2, keepdim=False) 137 | erase_global = torch.max(input=erase_global, dim=-2, keepdim=False)[0] 138 | erase_global = self.v2_fc(erase_global) 139 | erase_global = F.normalize(erase_global, dim=-1) 140 | 141 | return mask_indx, mask_list, erase_global 142 | 143 | 144 | def forward(self, clips): 145 | device = clips.get_device() 146 | 147 | if self.pretraining: 148 | Batchsize, Sub_clips, L_sub_clip, N_point, C_xyz = clips.shape # [B, S, L', N, 3] 149 | clips = clips.reshape((-1, L_sub_clip, N_point, C_xyz)) 150 | 151 | new_xys, new_features = self.encoder(clips) 152 | 153 | new_features = new_features.permute(0, 1, 3, 2) # [B*S, L, N, C] 154 | BS, L_out, N_out, C_ = new_features.shape 155 | new_features = new_features.reshape((-1, C_)) 156 | new_features = self.mlp_head(new_features) 157 | new_features = new_features.reshape((BS, L_out, N_out, new_features.shape[-1])) 158 | 159 | BS, L_out, N_out, C_out = new_features.shape 160 | assert(C_out==self.token_dim) 161 | 162 | new_xys = new_xys.reshape((Batchsize, Sub_clips, L_out, N_out, C_xyz)) # [B, S, L, N, 3] 163 | new_features = new_features.reshape((Batchsize, Sub_clips, L_out, N_out, C_out)) # [B, S, L, N, C] 164 | assert(L_out==1) # By default, only the case where each sub-clip is aggregated into one frame is considered. 165 | 166 | new_xys = torch.squeeze(new_xys, dim=-3).contiguous() # [B, S, N, 3] 167 | new_features = torch.squeeze(new_features, dim=-3).contiguous() # [B, S, N, C] 168 | 169 | view1_global = torch.mean(input=new_features, dim=-2, keepdim=False) 170 | view1_global = torch.max(input=view1_global, dim=-2, keepdim=False)[0] 171 | view1_global = self.v2_fc(view1_global) 172 | view1_global = F.normalize(view1_global, dim=-1) 173 | 174 | # for masking 175 | mask_indx, mask_list, erase_global = self.similarity_aug(new_features, Batchsize, Sub_clips, N_out, device) 176 | 177 | # mask tokens 178 | bool_masked_pos = torch.stack(mask_list).to(device) # [B, S] 179 | mask_token = self.mask_token.expand(Batchsize, Sub_clips, N_out, self.token_dim) # [B, S, N, C] 180 | 181 | w = bool_masked_pos.unsqueeze(-1).unsqueeze(-1).type_as(mask_token) # [B, S, 1, 1] 182 | maksed_input_tokens = new_features * (1 - w) + mask_token * w # [B, S, N, C] 183 | maksed_input_tokens = maksed_input_tokens.reshape((Batchsize, Sub_clips*N_out, C_out)) # [B, S*N, C] 184 | 185 | # regression following P4Transformer 186 | xyzts = [] 187 | xyz_list = torch.split(tensor=new_xys, split_size_or_sections=1, dim=1) # S*[B, 1, N, 3] 188 | xyz_list = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyz_list] # S*[B, N, 3] 189 | for t, xyz in enumerate(xyz_list): 190 | # [B, N, 3] 191 | t = torch.ones((xyz.size()[0], xyz.size()[1], 1), dtype=torch.float32, device=device) * (t+1) 192 | xyzt = torch.cat(tensors=(xyz, t), dim=2) # [B, N, 4] 193 | xyzts.append(xyzt) 194 | xyzts = torch.stack(tensors=xyzts, dim=1) # [B, S, N, 4] 195 | 196 | xyzts = torch.reshape(input=xyzts, shape=(xyzts.shape[0], xyzts.shape[1]*xyzts.shape[2], xyzts.shape[3])) # [B, S*N, 4] 197 | xyzts = self.pos_embedding(xyzts.permute(0, 2, 1)).permute(0, 2, 1) # [B, S*N, C] 198 | 199 | embedding = xyzts + maksed_input_tokens # [B, S*N, C] 200 | 201 | if self.emb_relu: 202 | embedding = self.emb_relu(embedding) 203 | 204 | output = self.transformer(embedding) # [B, S*N, C] 205 | output = output.reshape((Batchsize, Sub_clips, N_out, self.token_dim)) # [B, S, N, C] 206 | 207 | regression_global = output.clone().detach() 208 | regression_global = torch.mean(input=regression_global, dim=-2, keepdim=False) 209 | regression_global = torch.max(input=regression_global, dim=-2, keepdim=False)[0] 210 | regression_global = self.v2_fc(regression_global) 211 | regression_global = F.normalize(regression_global, dim=-1) 212 | 213 | # get labels of local features / xyz 214 | label_mask_feature = [] 215 | label_mask_feature_neg = [] 216 | label_mask_xyz = [] 217 | mask_local_feature = [] 218 | for bi in range(Batchsize): 219 | mask_i = mask_indx[bi] 220 | mask_local_feature.append(output[bi,mask_i,:,:]) # [N, C] 221 | label_mask_xyz.append(new_xys[bi,mask_i,:,:]) # [N, 3] 222 | label_mask_feature.append(new_features[bi,mask_i,:,:]) # [N, C] 223 | 224 | i_dex = np.arange(Sub_clips) 225 | i_dex = np.delete(i_dex, mask_i) 226 | label_mask_feature_neg.append(new_features[bi,i_dex,:,:]) # [S-1, N, C] 227 | 228 | mask_local_feature = torch.stack(tensors=mask_local_feature, dim=0) # [B, N, C] 229 | label_mask_xyz = torch.stack(tensors=label_mask_xyz, dim=0) # [B, N, 3] 230 | label_mask_feature = torch.stack(tensors=label_mask_feature, dim=0) # [B, N, C] 231 | label_mask_feature_neg = torch.stack(tensors=label_mask_feature_neg, dim=0) # [B, S-1, N, C] 232 | 233 | # matching 234 | mask_fold_xyz = self.folding(mask_local_feature) # [B, N, 3] 235 | 236 | dist, idx = pointnet2_utils.three_nn(label_mask_xyz.contiguous(), mask_fold_xyz) # (anchor, neighbor) 237 | dist_recip = 1.0 / (dist + 1e-8) 238 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 239 | weight = dist_recip / norm 240 | 241 | mask_local_feature = mask_local_feature.transpose(1,2) # [B, C, N] 242 | mask_local_feature = pointnet2_utils.three_interpolate(mask_local_feature.contiguous(), idx, weight) # [B, C, N] 243 | mask_local_feature = mask_local_feature.transpose(1,2) # [B, N, C] 244 | 245 | # local similarity matrixes 246 | mask_local_feature = mask_local_feature.reshape((-1,C_out)) # [B*N, C] 247 | mask_local_feature = F.normalize(mask_local_feature, dim=-1) 248 | 249 | label_mask_feature = label_mask_feature.reshape((-1,C_out)) # [B*N, C] 250 | label_mask_feature = F.normalize(label_mask_feature, dim=-1) 251 | 252 | label_mask_feature_neg = label_mask_feature_neg.reshape((-1,C_out)) # [B*(S-1)*N, C] 253 | label_mask_feature_neg = F.normalize(label_mask_feature_neg, dim=-1) 254 | label_mask_feature = torch.cat((label_mask_feature, label_mask_feature_neg), dim=0) # [B*S*N, C] 255 | 256 | score_local = torch.matmul(mask_local_feature, label_mask_feature.transpose(0,1)) # [B*N, B*S*N] 257 | score_local = score_local / self.temperature 258 | 259 | target_sim_local = torch.arange(score_local.size()[0]).to(device) 260 | loss_local = self.criterion_local(score_local, target_sim_local) 261 | 262 | acc1, acc5 = utils.accuracy(score_local, target_sim_local, topk=(1, 5)) 263 | 264 | return loss_local, acc1, acc5, view1_global, erase_global.detach(), regression_global.detach() 265 | 266 | else: 267 | # # If subclip is used like pre-training. 268 | # Batchsize, Sub_clips, L_sub_clip, N_point, C_xyz = clips.shape # [B, S, L, N, 3] 269 | # clips = clips.reshape((-1, L_sub_clip, N_point, C_xyz)) 270 | 271 | new_xys, new_features = self.encoder(clips) 272 | 273 | # # If subclip is used like pre-training. 274 | # BS, L_out, C_out, N_out = new_features.shape 275 | # new_features = new_features.reshape((Batchsize, Sub_clips, L_out, C_out, N_out)) # [B, S, L, C, N] 276 | # assert(L_out==1) # By default, only the case where each sub-clip is aggregated into one frame is considered. 277 | # new_features = torch.squeeze(new_features, dim=-3).contiguous() # [B, S, C, N] 278 | 279 | output = torch.mean(input=new_features, dim=-1, keepdim=False) # [B, S, C] 280 | output = torch.max(input=output, dim=1, keepdim=False)[0] # [B, C] 281 | 282 | # Just for linear probing on MSRAction3D 283 | output = self.fc_out(output) 284 | 285 | return output 286 | -------------------------------------------------------------------------------- /modules/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Pointnet2 layers. 7 | Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch 8 | Extended with the following: 9 | 1. Uniform sampling in each local region (sample_uniformly) 10 | 2. Return sampled points indices to support votenet. 11 | ''' 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import os 17 | import sys 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.append(BASE_DIR) 20 | 21 | import pointnet2_utils 22 | import pytorch_utils as pt_utils 23 | from typing import List 24 | 25 | 26 | class _PointnetSAModuleBase(nn.Module): 27 | 28 | def __init__(self): 29 | super().__init__() 30 | self.npoint = None 31 | self.groupers = None 32 | self.mlps = None 33 | 34 | def forward(self, xyz: torch.Tensor, 35 | features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 36 | r""" 37 | Parameters 38 | ---------- 39 | xyz : torch.Tensor 40 | (B, N, 3) tensor of the xyz coordinates of the features 41 | features : torch.Tensor 42 | (B, N, C) tensor of the descriptors of the the features 43 | 44 | Returns 45 | ------- 46 | new_xyz : torch.Tensor 47 | (B, npoint, 3) tensor of the new features' xyz 48 | new_features : torch.Tensor 49 | (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 50 | """ 51 | 52 | new_features_list = [] 53 | 54 | xyz_flipped = xyz.transpose(1, 2).contiguous() 55 | new_xyz = pointnet2_utils.gather_operation( 56 | xyz_flipped, 57 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 58 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 59 | 60 | for i in range(len(self.groupers)): 61 | new_features = self.groupers[i]( 62 | xyz, new_xyz, features 63 | ) # (B, C, npoint, nsample) 64 | 65 | new_features = self.mlps[i]( 66 | new_features 67 | ) # (B, mlp[-1], npoint, nsample) 68 | new_features = F.max_pool2d( 69 | new_features, kernel_size=[1, new_features.size(3)] 70 | ) # (B, mlp[-1], npoint, 1) 71 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 72 | 73 | new_features_list.append(new_features) 74 | 75 | return new_xyz, torch.cat(new_features_list, dim=1) 76 | 77 | 78 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 79 | r"""Pointnet set abstrction layer with multiscale grouping 80 | 81 | Parameters 82 | ---------- 83 | npoint : int 84 | Number of features 85 | radii : list of float32 86 | list of radii to group with 87 | nsamples : list of int32 88 | Number of samples in each ball query 89 | mlps : list of list of int32 90 | Spec of the pointnet before the global max_pool for each scale 91 | bn : bool 92 | Use batchnorm 93 | """ 94 | 95 | def __init__( 96 | self, 97 | *, 98 | npoint: int, 99 | radii: List[float], 100 | nsamples: List[int], 101 | mlps: List[List[int]], 102 | bn: bool = True, 103 | use_xyz: bool = True, 104 | sample_uniformly: bool = False 105 | ): 106 | super().__init__() 107 | 108 | assert len(radii) == len(nsamples) == len(mlps) 109 | 110 | self.npoint = npoint 111 | self.groupers = nn.ModuleList() 112 | self.mlps = nn.ModuleList() 113 | for i in range(len(radii)): 114 | radius = radii[i] 115 | nsample = nsamples[i] 116 | self.groupers.append( 117 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) 118 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 119 | ) 120 | mlp_spec = mlps[i] 121 | if use_xyz: 122 | mlp_spec[0] += 3 123 | 124 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 125 | 126 | 127 | class PointnetSAModule(PointnetSAModuleMSG): 128 | r"""Pointnet set abstrction layer 129 | 130 | Parameters 131 | ---------- 132 | npoint : int 133 | Number of features 134 | radius : float 135 | Radius of ball 136 | nsample : int 137 | Number of samples in the ball query 138 | mlp : list 139 | Spec of the pointnet before the global max_pool 140 | bn : bool 141 | Use batchnorm 142 | """ 143 | 144 | def __init__( 145 | self, 146 | *, 147 | mlp: List[int], 148 | npoint: int = None, 149 | radius: float = None, 150 | nsample: int = None, 151 | bn: bool = True, 152 | use_xyz: bool = True 153 | ): 154 | super().__init__( 155 | mlps=[mlp], 156 | npoint=npoint, 157 | radii=[radius], 158 | nsamples=[nsample], 159 | bn=bn, 160 | use_xyz=use_xyz 161 | ) 162 | 163 | 164 | class PointnetSAModuleVotes(nn.Module): 165 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 166 | with extra support for returning point indices for getting their GT votes ''' 167 | 168 | def __init__( 169 | self, 170 | *, 171 | mlp: List[int], 172 | npoint: int = None, 173 | radius: float = None, 174 | nsample: int = None, 175 | bn: bool = True, 176 | use_xyz: bool = True, 177 | pooling: str = 'max', 178 | sigma: float = None, # for RBF pooling 179 | normalize_xyz: bool = False, # noramlize local XYZ with radius 180 | sample_uniformly: bool = False, 181 | ret_unique_cnt: bool = False 182 | ): 183 | super().__init__() 184 | 185 | self.npoint = npoint 186 | self.radius = radius 187 | self.nsample = nsample 188 | self.pooling = pooling 189 | self.mlp_module = None 190 | self.use_xyz = use_xyz 191 | self.sigma = sigma 192 | if self.sigma is None: 193 | self.sigma = self.radius/2 194 | self.normalize_xyz = normalize_xyz 195 | self.ret_unique_cnt = ret_unique_cnt 196 | 197 | if npoint is not None: 198 | self.grouper = pointnet2_utils.QueryAndGroup(radius, nsample, 199 | use_xyz=use_xyz, ret_grouped_xyz=True, normalize_xyz=normalize_xyz, 200 | sample_uniformly=sample_uniformly, ret_unique_cnt=ret_unique_cnt) 201 | else: 202 | self.grouper = pointnet2_utils.GroupAll(use_xyz, ret_grouped_xyz=True) 203 | 204 | mlp_spec = mlp 205 | if use_xyz and len(mlp_spec)>0: 206 | mlp_spec[0] += 3 207 | self.mlp_module = pt_utils.SharedMLP(mlp_spec, bn=bn) 208 | 209 | 210 | def forward(self, xyz: torch.Tensor, 211 | features: torch.Tensor = None, 212 | inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 213 | r""" 214 | Parameters 215 | ---------- 216 | xyz : torch.Tensor 217 | (B, N, 3) tensor of the xyz coordinates of the features 218 | features : torch.Tensor 219 | (B, C, N) tensor of the descriptors of the the features 220 | inds : torch.Tensor 221 | (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) 222 | 223 | Returns 224 | ------- 225 | new_xyz : torch.Tensor 226 | (B, npoint, 3) tensor of the new features' xyz 227 | new_features : torch.Tensor 228 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 229 | inds: torch.Tensor 230 | (B, npoint) tensor of the inds 231 | """ 232 | 233 | xyz_flipped = xyz.transpose(1, 2).contiguous() 234 | if inds is None: 235 | inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) 236 | else: 237 | assert(inds.shape[1] == self.npoint) 238 | new_xyz = pointnet2_utils.gather_operation( 239 | xyz_flipped, inds 240 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 241 | 242 | if not self.ret_unique_cnt: 243 | grouped_features, grouped_xyz = self.grouper( 244 | xyz, new_xyz, features 245 | ) # (B, C, npoint, nsample) 246 | else: 247 | grouped_features, grouped_xyz, unique_cnt = self.grouper( 248 | xyz, new_xyz, features 249 | ) # (B, C, npoint, nsample), (B,3,npoint,nsample), (B,npoint) 250 | 251 | new_features = self.mlp_module( 252 | grouped_features 253 | ) # (B, mlp[-1], npoint, nsample) 254 | if self.pooling == 'max': 255 | new_features = F.max_pool2d( 256 | new_features, kernel_size=[1, new_features.size(3)] 257 | ) # (B, mlp[-1], npoint, 1) 258 | elif self.pooling == 'avg': 259 | new_features = F.avg_pool2d( 260 | new_features, kernel_size=[1, new_features.size(3)] 261 | ) # (B, mlp[-1], npoint, 1) 262 | elif self.pooling == 'rbf': 263 | # Use radial basis function kernel for weighted sum of features (normalized by nsample and sigma) 264 | # Ref: https://en.wikipedia.org/wiki/Radial_basis_function_kernel 265 | rbf = torch.exp(-1 * grouped_xyz.pow(2).sum(1,keepdim=False) / (self.sigma**2) / 2) # (B, npoint, nsample) 266 | new_features = torch.sum(new_features * rbf.unsqueeze(1), -1, keepdim=True) / float(self.nsample) # (B, mlp[-1], npoint, 1) 267 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 268 | 269 | if not self.ret_unique_cnt: 270 | return new_xyz, new_features, inds 271 | else: 272 | return new_xyz, new_features, inds, unique_cnt 273 | 274 | class PointnetSAModuleMSGVotes(nn.Module): 275 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 276 | with extra support for returning point indices for getting their GT votes ''' 277 | 278 | def __init__( 279 | self, 280 | *, 281 | mlps: List[List[int]], 282 | npoint: int, 283 | radii: List[float], 284 | nsamples: List[int], 285 | bn: bool = True, 286 | use_xyz: bool = True, 287 | sample_uniformly: bool = False 288 | ): 289 | super().__init__() 290 | 291 | assert(len(mlps) == len(nsamples) == len(radii)) 292 | 293 | self.npoint = npoint 294 | self.groupers = nn.ModuleList() 295 | self.mlps = nn.ModuleList() 296 | for i in range(len(radii)): 297 | radius = radii[i] 298 | nsample = nsamples[i] 299 | self.groupers.append( 300 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, sample_uniformly=sample_uniformly) 301 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 302 | ) 303 | mlp_spec = mlps[i] 304 | if use_xyz: 305 | mlp_spec[0] += 3 306 | 307 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 308 | 309 | def forward(self, xyz: torch.Tensor, 310 | features: torch.Tensor = None, inds: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 311 | r""" 312 | Parameters 313 | ---------- 314 | xyz : torch.Tensor 315 | (B, N, 3) tensor of the xyz coordinates of the features 316 | features : torch.Tensor 317 | (B, C, C) tensor of the descriptors of the the features 318 | inds : torch.Tensor 319 | (B, npoint) tensor that stores index to the xyz points (values in 0-N-1) 320 | 321 | Returns 322 | ------- 323 | new_xyz : torch.Tensor 324 | (B, npoint, 3) tensor of the new features' xyz 325 | new_features : torch.Tensor 326 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 327 | inds: torch.Tensor 328 | (B, npoint) tensor of the inds 329 | """ 330 | new_features_list = [] 331 | 332 | xyz_flipped = xyz.transpose(1, 2).contiguous() 333 | if inds is None: 334 | inds = pointnet2_utils.furthest_point_sample(xyz, self.npoint) 335 | new_xyz = pointnet2_utils.gather_operation( 336 | xyz_flipped, inds 337 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 338 | 339 | for i in range(len(self.groupers)): 340 | new_features = self.groupers[i]( 341 | xyz, new_xyz, features 342 | ) # (B, C, npoint, nsample) 343 | new_features = self.mlps[i]( 344 | new_features 345 | ) # (B, mlp[-1], npoint, nsample) 346 | new_features = F.max_pool2d( 347 | new_features, kernel_size=[1, new_features.size(3)] 348 | ) # (B, mlp[-1], npoint, 1) 349 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 350 | 351 | new_features_list.append(new_features) 352 | 353 | return new_xyz, torch.cat(new_features_list, dim=1), inds 354 | 355 | 356 | class PointnetFPModule(nn.Module): 357 | r"""Propigates the features of one set to another 358 | 359 | Parameters 360 | ---------- 361 | mlp : list 362 | Pointnet module parameters 363 | bn : bool 364 | Use batchnorm 365 | """ 366 | 367 | def __init__(self, *, mlp: List[int], bn: bool = True): 368 | super().__init__() 369 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 370 | 371 | def forward( 372 | self, unknown: torch.Tensor, known: torch.Tensor, 373 | unknow_feats: torch.Tensor, known_feats: torch.Tensor 374 | ) -> torch.Tensor: 375 | r""" 376 | Parameters 377 | ---------- 378 | unknown : torch.Tensor 379 | (B, n, 3) tensor of the xyz positions of the unknown features 380 | known : torch.Tensor 381 | (B, m, 3) tensor of the xyz positions of the known features 382 | unknow_feats : torch.Tensor 383 | (B, C1, n) tensor of the features to be propigated to 384 | known_feats : torch.Tensor 385 | (B, C2, m) tensor of features to be propigated 386 | 387 | Returns 388 | ------- 389 | new_features : torch.Tensor 390 | (B, mlp[-1], n) tensor of the features of the unknown features 391 | """ 392 | 393 | if known is not None: 394 | dist, idx = pointnet2_utils.three_nn(unknown, known) 395 | dist_recip = 1.0 / (dist + 1e-8) 396 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 397 | weight = dist_recip / norm 398 | 399 | interpolated_feats = pointnet2_utils.three_interpolate( 400 | known_feats, idx, weight 401 | ) 402 | else: 403 | interpolated_feats = known_feats.expand( 404 | *known_feats.size()[0:2], unknown.size(1) 405 | ) 406 | 407 | if unknow_feats is not None: 408 | new_features = torch.cat([interpolated_feats, unknow_feats], 409 | dim=1) #(B, C2 + C1, n) 410 | else: 411 | new_features = interpolated_feats 412 | 413 | new_features = new_features.unsqueeze(-1) 414 | new_features = self.mlp(new_features) 415 | 416 | return new_features.squeeze(-1) 417 | 418 | class PointnetLFPModuleMSG(nn.Module): 419 | ''' Modified based on _PointnetSAModuleBase and PointnetSAModuleMSG 420 | learnable feature propagation layer.''' 421 | 422 | def __init__( 423 | self, 424 | *, 425 | mlps: List[List[int]], 426 | radii: List[float], 427 | nsamples: List[int], 428 | post_mlp: List[int], 429 | bn: bool = True, 430 | use_xyz: bool = True, 431 | sample_uniformly: bool = False 432 | ): 433 | super().__init__() 434 | 435 | assert(len(mlps) == len(nsamples) == len(radii)) 436 | 437 | self.post_mlp = pt_utils.SharedMLP(post_mlp, bn=bn) 438 | 439 | self.groupers = nn.ModuleList() 440 | self.mlps = nn.ModuleList() 441 | for i in range(len(radii)): 442 | radius = radii[i] 443 | nsample = nsamples[i] 444 | self.groupers.append( 445 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz, 446 | sample_uniformly=sample_uniformly) 447 | ) 448 | mlp_spec = mlps[i] 449 | if use_xyz: 450 | mlp_spec[0] += 3 451 | 452 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 453 | 454 | def forward(self, xyz2: torch.Tensor, xyz1: torch.Tensor, 455 | features2: torch.Tensor, features1: torch.Tensor) -> torch.Tensor: 456 | r""" Propagate features from xyz1 to xyz2. 457 | Parameters 458 | ---------- 459 | xyz2 : torch.Tensor 460 | (B, N2, 3) tensor of the xyz coordinates of the features 461 | xyz1 : torch.Tensor 462 | (B, N1, 3) tensor of the xyz coordinates of the features 463 | features2 : torch.Tensor 464 | (B, C2, N2) tensor of the descriptors of the the features 465 | features1 : torch.Tensor 466 | (B, C1, N1) tensor of the descriptors of the the features 467 | 468 | Returns 469 | ------- 470 | new_features1 : torch.Tensor 471 | (B, \sum_k(mlps[k][-1]), N1) tensor of the new_features descriptors 472 | """ 473 | new_features_list = [] 474 | 475 | for i in range(len(self.groupers)): 476 | new_features = self.groupers[i]( 477 | xyz1, xyz2, features1 478 | ) # (B, C1, N2, nsample) 479 | new_features = self.mlps[i]( 480 | new_features 481 | ) # (B, mlp[-1], N2, nsample) 482 | new_features = F.max_pool2d( 483 | new_features, kernel_size=[1, new_features.size(3)] 484 | ) # (B, mlp[-1], N2, 1) 485 | new_features = new_features.squeeze(-1) # (B, mlp[-1], N2) 486 | 487 | if features2 is not None: 488 | new_features = torch.cat([new_features, features2], 489 | dim=1) #(B, mlp[-1] + C2, N2) 490 | 491 | new_features = new_features.unsqueeze(-1) 492 | new_features = self.post_mlp(new_features) 493 | 494 | new_features_list.append(new_features) 495 | 496 | return torch.cat(new_features_list, dim=1).squeeze(-1) 497 | 498 | 499 | if __name__ == "__main__": 500 | from torch.autograd import Variable 501 | torch.manual_seed(1) 502 | torch.cuda.manual_seed_all(1) 503 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) 504 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) 505 | 506 | test_module = PointnetSAModuleMSG( 507 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]] 508 | ) 509 | test_module.cuda() 510 | print(test_module(xyz, xyz_feats)) 511 | 512 | for _ in range(1): 513 | _, new_features = test_module(xyz, xyz_feats) 514 | new_features.backward( 515 | torch.cuda.FloatTensor(*new_features.size()).fill_(1) 516 | ) 517 | print(new_features) 518 | print(xyz.grad) 519 | --------------------------------------------------------------------------------