├── LICENSE ├── README.md ├── ckpt ├── argo_gru.pth ├── argo_lstm.pth ├── kitti_gru.pth └── kitti_lstm.pth ├── data ├── argo_loader.py └── kitti_loader.py ├── model ├── PointUtils │ ├── setup.py │ └── src │ │ ├── cuda_utils.h │ │ ├── furthest_point_sampling.cpp │ │ ├── furthest_point_sampling_gpu.cu │ │ ├── furthest_point_sampling_gpu.h │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── interpolate_gpu.h │ │ └── point_utils_api.cpp ├── layers.py ├── models.py └── utils.py ├── requirements.txt ├── scripts ├── eval.sh └── train.sh ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intelligent Sensing, Perception and Computing Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MoNet: Motion-based Point Cloud Prediction Network 2 | 3 | ### Environments 4 | 5 | - PyTorch 1.7.1 6 | - cuda 11.1 7 | - [pytorch3d](https://github.com/facebookresearch/pytorch3d) 8 | - [EMD](https://github.com/daerduoCarey/PyTorchEMD) 9 | 10 | Please run the following commands to install `point_utils` 11 | ``` 12 | cd model/PointUtils 13 | python setup.py install 14 | ``` 15 | 16 | Please check `requirements.txt` for more requirements. 17 | 18 | ### Datasets 19 | The data of the two datasets should be organized as follows: 20 | #### KITTI odometry dataset 21 | ``` 22 | DATA_ROOT 23 | ├── 00 24 | │ ├── velodyne 25 | │ ├── calib.txt 26 | ├── 01 27 | ├── ... 28 | ``` 29 | #### Argoverse dataset 30 | ``` 31 | DATA_ROOT 32 | ├── train1 33 | │ ├── 043aeba7-14e5-3cde-8a5c-639389b6d3a6 34 | | ├──lidar 35 | | ├──poses 36 | | ├──... 37 | │ ├── ... 38 | ├── train2 39 | ├── train3 40 | ├── train4 41 | ├── val 42 | ├── test 43 | ``` 44 | 45 | ### Evaluation 46 | 47 | Please run `eval_kitti.sh/eval_argo.sh` to evaluate the proposed MoNet on the two datasets using the provided pretrained model in `ckpt`. The `ROOT`, `CKPT`, `GPU` and `RNN` should be modified. 48 | 49 | ### Train 50 | 51 | If you want to train the network, please run `train.sh` and reminder to modify the `ROOT`, `CKPT_DIR` and `RUNNAME`. 52 | 53 | Noting that we utilize [wandb](https://www.wandb.com/) to record the training procedure, if you do not want to use it, please drop the `--use_wandb` in `train.sh`. 54 | 55 | ### Citation 56 | If you find this project useful for your work, please consider citing: 57 | ``` 58 | @ARTICLE{Lu_MoNet_2021, 59 | author={Lu, Fan and Chen, Guang and Li, Zhijun and Zhang, Lijun and Liu, Yinlong and Qu, Sanqing and Knoll, Alois}, 60 | journal={IEEE Transactions on Intelligent Transportation Systems}, 61 | title={MoNet: Motion-Based Point Cloud Prediction Network}, 62 | year={2021}, 63 | volume={}, 64 | number={}, 65 | pages={1-11} 66 | } 67 | ``` -------------------------------------------------------------------------------- /ckpt/argo_gru.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/MoNet/573a76ff3aacf14efe828ffeb69fe61d3a1f3df5/ckpt/argo_gru.pth -------------------------------------------------------------------------------- /ckpt/argo_lstm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/MoNet/573a76ff3aacf14efe828ffeb69fe61d3a1f3df5/ckpt/argo_lstm.pth -------------------------------------------------------------------------------- /ckpt/kitti_gru.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/MoNet/573a76ff3aacf14efe828ffeb69fe61d3a1f3df5/ckpt/kitti_gru.pth -------------------------------------------------------------------------------- /ckpt/kitti_lstm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispc-lab/MoNet/573a76ff3aacf14efe828ffeb69fe61d3a1f3df5/ckpt/kitti_lstm.pth -------------------------------------------------------------------------------- /data/argo_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | import pandas as pd 6 | from plyfile import PlyData 7 | from torch.utils.data import Dataset 8 | 9 | class ArgoDataset(Dataset): 10 | 11 | def __init__(self, root, npoints, input_num, pred_num, seqs): 12 | super(ArgoDataset, self).__init__() 13 | 14 | self.root = root 15 | self.seqs = seqs 16 | self.input_num = input_num 17 | self.pred_num = pred_num 18 | self.npoints = npoints 19 | 20 | self.dataset = self.make_dataset() 21 | 22 | def get_cloud(self, filename): 23 | plydata = PlyData.read(filename) 24 | data = plydata.elements[0].data 25 | data_pd = pd.DataFrame(data) 26 | data_np = np.zeros(data_pd.shape, dtype=np.float32) 27 | property_names = data[0].dtype.names 28 | for i, name in enumerate(property_names): 29 | data_np[:,i] = data_pd[name] 30 | pc = data_np[:,:3] 31 | N = pc.shape[0] 32 | if N >= self.npoints: 33 | sample_idx = np.random.choice(N, self.npoints, replace=False) 34 | else: 35 | sample_idx = np.concatenate((np.arange(N), np.random.choice(N, self.npoints-N, replace=True)), axis=-1) 36 | pc = pc[sample_idx, :].astype('float32') 37 | pc = torch.from_numpy(pc).t() 38 | return pc 39 | 40 | def make_dataset(self): 41 | dataset = [] 42 | 43 | for seq in self.seqs: 44 | dirs = os.listdir(os.path.join(self.root, seq)) 45 | dirs = sorted(dirs) 46 | for curr_dir in dirs: 47 | names = os.listdir(os.path.join(self.root, seq, curr_dir, 'lidar')) 48 | names = sorted(names) 49 | max_ind = len(names) 50 | interval = self.input_num + self.pred_num 51 | ini_index = 0 52 | while (ini_index < max_ind - interval): 53 | paths = [] 54 | 55 | for j in range(interval): 56 | curr_path = os.path.join(self.root, seq, curr_dir, 'lidar', names[j+ini_index]) 57 | paths.append(curr_path) 58 | 59 | ini_index += interval 60 | dataset.append(paths) 61 | 62 | return dataset 63 | 64 | def __getitem__(self, index): 65 | paths = self.dataset[index] 66 | 67 | input_pc_list = [] 68 | 69 | for i in range(self.input_num): 70 | input_pc_path = paths[i] 71 | input_pc = self.get_cloud(input_pc_path) 72 | input_pc_list.append(input_pc) 73 | 74 | input_pc = torch.stack(input_pc_list, dim=0) 75 | 76 | output_pc_list = [] 77 | 78 | for i in range(self.input_num, self.input_num+self.pred_num): 79 | output_pc_path = paths[i] 80 | output_pc = self.get_cloud(output_pc_path) 81 | output_pc_list.append(output_pc) 82 | output_pc = torch.stack(output_pc_list, dim=0) 83 | 84 | return input_pc, output_pc 85 | 86 | def __len__(self): 87 | return len(self.dataset) -------------------------------------------------------------------------------- /data/kitti_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import glob 5 | from torch.utils.data import Dataset 6 | 7 | class KittiDataset(Dataset): 8 | ''' 9 | Multi sequence training on Kitti dataset 10 | Parameter: 11 | root: dir of kitti dataset (sequence/) 12 | npoints: number of random sampled points from raw points 13 | input_num: input point cloud number 14 | pred_num: predicted point cloud number 15 | seqs: sequence list 16 | ''' 17 | def __init__(self, root, npoints, input_num, pred_num, seqs): 18 | super(KittiDataset, self).__init__() 19 | 20 | self.root = root 21 | self.seqs = seqs 22 | self.input_num = input_num 23 | self.pred_num = pred_num 24 | self.npoints = npoints 25 | self.dataset = self.make_dataset() 26 | 27 | def make_dataset(self): 28 | dataset = [] 29 | for seq in self.seqs: 30 | dataroot = os.path.join(self.root, seq, 'velodyne') 31 | datapath = glob.glob(os.path.join(dataroot, '*.bin')) 32 | datapath = sorted(datapath) 33 | max_ind = len(datapath) 34 | ini_index = 0 35 | interval = self.input_num + self.pred_num 36 | while (ini_index < max_ind - interval): 37 | paths = [] 38 | for i in range(interval): 39 | curr_path = os.path.join(seq, 'velodyne',datapath[ini_index+i]) 40 | paths.append(curr_path) 41 | ini_index += interval 42 | dataset.append(paths) 43 | return dataset 44 | 45 | def get_cloud(self, filename): 46 | pc = np.fromfile(filename, dtype=np.float32, count=-1).reshape([-1,4]) 47 | N = pc.shape[0] 48 | if N >= self.npoints: 49 | sample_idx = np.random.choice(N, self.npoints, replace=False) 50 | else: 51 | sample_idx = np.concatenate((np.arange(N), np.random.choice(N, self.npoints-N, replace=True)), axis=-1) 52 | pc = pc[sample_idx, :3].astype('float32') 53 | pc = torch.from_numpy(pc).t() 54 | return pc 55 | 56 | def __getitem__(self, index): 57 | paths = self.dataset[index] 58 | 59 | input_pc_list = [] 60 | for i in range(self.input_num): 61 | input_pc_name = paths[i] 62 | input_pc = self.get_cloud(os.path.join(self.root, input_pc_name)) 63 | input_pc_list.append(input_pc) 64 | input_pc = torch.stack(input_pc_list, dim=0) 65 | 66 | output_pc_list = [] 67 | for i in range(self.input_num, self.input_num+self.pred_num): 68 | output_pc_name = paths[i] 69 | output_pc = self.get_cloud(os.path.join(self.root, output_pc_name)) 70 | output_pc_list.append(output_pc) 71 | output_pc = torch.stack(output_pc_list, dim=0) 72 | 73 | return input_pc, output_pc 74 | 75 | def __len__(self): 76 | return len(self.dataset) -------------------------------------------------------------------------------- /model/PointUtils/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='point_utils', 6 | ext_modules=[ 7 | CUDAExtension('point_utils_cuda', [ 8 | 'src/point_utils_api.cpp', 9 | 10 | 'src/furthest_point_sampling.cpp', 11 | 'src/furthest_point_sampling_gpu.cu', 12 | 'src/interpolate.cpp', 13 | 'src/interpolate_gpu.cu', 14 | ], 15 | extra_compile_args={ 16 | 'cxx':['-g'], 17 | 'nvcc': ['-O2'] 18 | }) 19 | ], 20 | cmdclass={'build_ext':BuildExtension} 21 | ) -------------------------------------------------------------------------------- /model/PointUtils/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 256 9 | #define DIVUP(m,n) ((m) / (n)+((m) % (n) > 0)) 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.") 13 | #define CHECK_CONTIGUOUS_CUDA(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | /*** 18 | * calculate proper thread number 19 | * If work_size < TOTAL_THREADS, number = work_size (2^n) 20 | * Else number = TOTAL_THREADS 21 | ***/ 22 | inline int opt_n_threads(int work_size) { 23 | // log2(work_size) 24 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 25 | // 1 * 2^(pow_2) 26 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 27 | } 28 | 29 | #endif -------------------------------------------------------------------------------- /model/PointUtils/src/furthest_point_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "furthest_point_sampling_gpu.h" 7 | 8 | // extern THCState *state; 9 | 10 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 12 | const float *points = points_tensor.data(); 13 | const int *idx = idx_tensor.data(); 14 | float *out = out_tensor.data(); 15 | 16 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 17 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 18 | return 1; 19 | } 20 | 21 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 22 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 23 | 24 | const float *grad_out = grad_out_tensor.data(); 25 | const int *idx = idx_tensor.data(); 26 | float *grad_points = grad_points_tensor.data(); 27 | 28 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 29 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 30 | return 1; 31 | } 32 | 33 | int furthest_point_sampling_wrapper(int b, int n, int m, 34 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 35 | 36 | const float *points = points_tensor.data(); 37 | float *temp = temp_tensor.data(); 38 | int *idx = idx_tensor.data(); 39 | 40 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 41 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 42 | return 1; 43 | } 44 | 45 | int weighted_furthest_point_sampling_wrapper(int b, int n, int m, 46 | at::Tensor points_tensor, at::Tensor weights_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 47 | 48 | const float *points = points_tensor.data(); 49 | const float *weights = weights_tensor.data(); 50 | float *temp = temp_tensor.data(); 51 | int *idx = idx_tensor.data(); 52 | 53 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 54 | weighted_furthest_point_sampling_kernel_launcher(b, n, m, points, weights, temp, idx, stream); 55 | return 1; 56 | } -------------------------------------------------------------------------------- /model/PointUtils/src/furthest_point_sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "furthest_point_sampling_gpu.h" 6 | 7 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 8 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 9 | // points: [B,C,N] 10 | // idx: [B,M] 11 | 12 | int bs_idx = blockIdx.z; 13 | int c_idx = blockIdx.y; 14 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 15 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 16 | // Pointer to current point 17 | out += bs_idx * c * m + c_idx * m + pt_idx; // curr batch + channels + points 18 | idx += bs_idx * m + pt_idx; // curr batch + points 19 | points += bs_idx * c * n + c_idx * n; // batch + channels 20 | out[0] = points[idx[0]]; // curr batch channels -> channel of curr point ? 21 | } 22 | 23 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 24 | const float *points, const int *idx, float *out, cudaStream_t stream) { 25 | // points: [B,C,N] 26 | // idx: [B,npoints] 27 | cudaError_t err; 28 | // dim3 is a type to assign dimension 29 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // DIVUP: npoints/THREADS_PER_BLOCK 30 | dim3 threads(THREADS_PER_BLOCK); // others assign to 1 31 | 32 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 33 | 34 | err = cudaGetLastError(); 35 | if (cudaSuccess != err) { 36 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 37 | exit(-1); 38 | } 39 | } 40 | 41 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, 42 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 43 | // grad_out: [B,C,M] 44 | // idx: [B,M] 45 | int bs_idx = blockIdx.z; 46 | int c_idx = blockIdx.y; 47 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 48 | if (bs_idx > b || c_idx >= c || pt_idx >= m) return; 49 | 50 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 51 | idx += bs_idx * m + pt_idx; 52 | grad_points += bs_idx * c * n + c_idx * n; 53 | 54 | atomicAdd(grad_points + idx[0], grad_out[0]); // assign the grad of indexed value to grad_points 55 | } 56 | 57 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 58 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 59 | // grad_out: [B,C, npoints] 60 | // idx: [B, npoints] 61 | 62 | cudaError_t err; 63 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); 64 | dim3 threads(THREADS_PER_BLOCK); 65 | 66 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 67 | 68 | err = cudaGetLastError(); 69 | if (cudaSuccess != err) { 70 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 71 | exit(-1); 72 | } 73 | } 74 | 75 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2) { 76 | const float v1 = dists[idx1], v2 = dists[idx2]; 77 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 78 | dists[idx1] = max(v1, v2); 79 | dists_i[idx1] = v2 > v1 ? i2 : i1; 80 | } 81 | 82 | // A kernel runs on single thread and the launcher is defined to launch the kernel 83 | // Grid size and block size are all defined in the launcher 84 | template 85 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 86 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 87 | // dataset [B,N,3] 88 | // temp: [B,N] 89 | // idxs: 90 | // All global memory 91 | 92 | if (m <= 0) return; 93 | // assign shared memory 94 | __shared__ float dists[block_size]; 95 | __shared__ int dists_i[block_size]; 96 | 97 | int batch_index = blockIdx.x; 98 | // Point to curr batch (blockIdx of current thread of this kernel) 99 | dataset += batch_index * n * 3; 100 | temp += batch_index * n; 101 | idxs += batch_index * m; 102 | 103 | // threadIdx of current thread 104 | int tid = threadIdx.x; 105 | const int stride = block_size; // number of threads in one block 106 | 107 | int old = 0; 108 | if (threadIdx.x == 0) 109 | idxs[0] = old; // Initialize index 110 | 111 | __syncthreads(); 112 | // for loop m for m sampled points 113 | for (int j = 1; j < m; j++) { 114 | // printf("curr index: %d\n", j); 115 | int besti = 0; 116 | float best = -1; 117 | // Coordinate of last point 118 | float x1 = dataset[old * 3 + 0]; 119 | float y1 = dataset[old * 3 + 1]; 120 | float z1 = dataset[old * 3 + 2]; 121 | // Get global index, parallel calculate distance with multiple blocks 122 | for (int k = tid; k < n; k += stride) { 123 | // calculate distance with the other point 124 | float x2, y2, z2; 125 | x2 = dataset[k * 3 + 0]; 126 | y2 = dataset[k * 3 + 1]; 127 | z2 = dataset[k * 3 + 2]; 128 | 129 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 130 | float d2 = min(d, temp[k]); 131 | temp[k] = d2; // update temp distance 132 | besti = d2 > best ? k : besti; // If d2 > best, besti = k (idx) 133 | best = d2 > best ? d2 : best; // If d2 > best, best = d2 (distance) 134 | } 135 | // dists[tid] stores the largest dist over all blocks for the current threadIdx 136 | dists[tid] = best; 137 | dists_i[tid] = besti; 138 | __syncthreads(); // wait for all threads finishing compute the distance 139 | // calculate the idx of largest distance ? 140 | if (block_size >= 1024) { 141 | if (tid < 512) { 142 | __update(dists, dists_i, tid, tid + 512); 143 | } 144 | __syncthreads(); 145 | } 146 | if (block_size >= 512) { 147 | if (tid < 256) { 148 | __update(dists, dists_i, tid, tid + 256); 149 | } 150 | __syncthreads(); 151 | } 152 | if (block_size >= 256) { 153 | if (tid < 128) { 154 | __update(dists, dists_i, tid, tid + 128); 155 | } 156 | __syncthreads(); 157 | } 158 | if (block_size >= 128) { 159 | if (tid < 64) { 160 | __update(dists, dists_i, tid, tid + 64); 161 | } 162 | __syncthreads(); 163 | } 164 | if (block_size >= 64) { 165 | if (tid < 32) { 166 | __update(dists, dists_i, tid, tid + 32); 167 | } 168 | __syncthreads(); 169 | } 170 | if (block_size >= 32) { 171 | if (tid < 16) { 172 | __update(dists, dists_i, tid, tid + 16); 173 | } 174 | __syncthreads(); 175 | } 176 | if (block_size >= 16) { 177 | if (tid < 8) { 178 | __update(dists, dists_i, tid, tid + 8); 179 | } 180 | __syncthreads(); 181 | } 182 | if (block_size >= 8) { 183 | if (tid < 4) { 184 | __update(dists, dists_i, tid, tid + 4); 185 | } 186 | __syncthreads(); 187 | } 188 | if (block_size >= 4) { 189 | if (tid < 2) { 190 | __update(dists, dists_i, tid, tid + 2); 191 | } 192 | __syncthreads(); 193 | } 194 | if (block_size >= 2) { 195 | if (tid < 1) { 196 | __update(dists, dists_i, tid, tid + 1); 197 | } 198 | __syncthreads(); 199 | } 200 | 201 | // All threads update a single new point (old). 202 | old = dists_i[0]; // update last point index 203 | if (tid == 0) 204 | idxs[j] = old; 205 | } 206 | } 207 | 208 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 209 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 210 | // dataset: [B,N,3] 211 | // tmp: [B,N] 212 | 213 | cudaError_t err; 214 | unsigned int n_threads = opt_n_threads(n); // compute proper thread number 215 | 216 | switch (n_threads) { 217 | // Call kernel functions: Func 218 | // Dg: grid size (how many blocks in the grid) 219 | // Db: block size (how many threads in the block) 220 | // Ns: memory for shared value, default 0 221 | // s: stream 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | err = cudaGetLastError(); 248 | if (cudaSuccess != err) { 249 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 250 | exit(-1); 251 | } 252 | } 253 | 254 | template 255 | __global__ void weighted_furthest_point_sampling_kernel(int b, int n, int m, 256 | const float *__restrict__ dataset, const float *__restrict__ weights, float *__restrict__ temp, int *__restrict__ idxs) { 257 | // dataset: [B,N,3] 258 | // weights: [B,N] 259 | // temp: [B,N] 260 | 261 | if (m <= 0) return; 262 | 263 | __shared__ float dists[block_size]; 264 | __shared__ int dists_i[block_size]; 265 | 266 | int batch_index = blockIdx.x; 267 | dataset += batch_index * n * 3; 268 | weights += batch_index * n; 269 | temp += batch_index * n; 270 | idxs += batch_index * m; 271 | 272 | int tid = threadIdx.x; 273 | const int stride = block_size; 274 | 275 | int old = 0; 276 | if (threadIdx.x == 0) 277 | idxs[0] = old; 278 | 279 | __syncthreads(); 280 | 281 | for (int j = 1; j < m; j++) { 282 | 283 | int besti = 0; 284 | float best = -1; 285 | 286 | float x1 = dataset[old * 3 + 0]; 287 | float y1 = dataset[old * 3 + 1]; 288 | float z1 = dataset[old * 3 + 2]; 289 | 290 | float w1 = weights[old]; 291 | 292 | for (int k = tid; k < n; k += stride) { 293 | float x2, y2, z2, w2; 294 | x2 = dataset[k * 3 + 0]; 295 | y2 = dataset[k * 3 + 1]; 296 | z2 = dataset[k * 3 + 2]; 297 | w2 = weights[k]; 298 | 299 | float d = w2 * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1)); 300 | float d2 = min(d, temp[k]); 301 | temp[k] = d2; 302 | besti = d2 > best ? k : besti; 303 | best = d2 > best ? d2 : best; 304 | } 305 | dists[tid] = best; 306 | dists_i[tid] = besti; 307 | __syncthreads(); 308 | 309 | if (block_size >= 1024) { 310 | if (tid < 512) { 311 | __update(dists, dists_i, tid, tid + 512); 312 | } 313 | __syncthreads(); 314 | } 315 | if (block_size >= 512) { 316 | if (tid < 256) { 317 | __update(dists, dists_i, tid, tid + 256); 318 | } 319 | __syncthreads(); 320 | } 321 | if (block_size >= 256) { 322 | if (tid < 128) { 323 | __update(dists, dists_i, tid, tid + 128); 324 | } 325 | __syncthreads(); 326 | } 327 | if (block_size >= 128) { 328 | if (tid < 64) { 329 | __update(dists, dists_i, tid, tid + 64); 330 | } 331 | __syncthreads(); 332 | } 333 | if (block_size >= 64) { 334 | if (tid < 32) { 335 | __update(dists, dists_i, tid, tid + 32); 336 | } 337 | __syncthreads(); 338 | } 339 | if (block_size >= 32) { 340 | if (tid < 16) { 341 | __update(dists, dists_i, tid, tid + 16); 342 | } 343 | __syncthreads(); 344 | } 345 | if (block_size >= 16) { 346 | if (tid < 8) { 347 | __update(dists, dists_i, tid, tid + 8); 348 | } 349 | __syncthreads(); 350 | } 351 | if (block_size >= 8) { 352 | if (tid < 4) { 353 | __update(dists, dists_i, tid, tid + 4); 354 | } 355 | __syncthreads(); 356 | } 357 | if (block_size >= 4) { 358 | if (tid < 2) { 359 | __update(dists, dists_i, tid, tid + 2); 360 | } 361 | __syncthreads(); 362 | } 363 | if (block_size >= 2) { 364 | if (tid < 1) { 365 | __update(dists, dists_i, tid, tid + 1); 366 | } 367 | __syncthreads(); 368 | } 369 | 370 | // All threads update a single new point (old). 371 | old = dists_i[0]; // update last point index 372 | if (tid == 0) 373 | idxs[j] = old; 374 | } 375 | } 376 | 377 | void weighted_furthest_point_sampling_kernel_launcher(int b, int n, int m, 378 | const float *dataset, const float *weights, float *temp, int *idxs, cudaStream_t stream) { 379 | 380 | cudaError_t err; 381 | unsigned int n_threads = opt_n_threads(n); // compute proper thread numbere 382 | 383 | switch (n_threads) { 384 | // Call kernel functions: Func 385 | // Dg: grid size (how many blocks in the grid) 386 | // Db: block size (how many threads in the block) 387 | // Ns: memory for shared value, default 0 388 | // s: stream 389 | case 1024: 390 | weighted_furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, weights, temp, idxs); break; 391 | case 512: 392 | weighted_furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, weights, temp, idxs); break; 393 | case 256: 394 | weighted_furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, weights, temp, idxs); break; 395 | case 128: 396 | weighted_furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, weights, temp, idxs); break; 397 | case 64: 398 | weighted_furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, weights, temp, idxs); break; 399 | case 32: 400 | weighted_furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, weights, temp, idxs); break; 401 | case 16: 402 | weighted_furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, weights, temp, idxs); break; 403 | case 8: 404 | weighted_furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, weights, temp, idxs); break; 405 | case 4: 406 | weighted_furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, weights, temp, idxs); break; 407 | case 2: 408 | weighted_furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, weights, temp, idxs); break; 409 | case 1: 410 | weighted_furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, weights, temp, idxs); break; 411 | default: 412 | weighted_furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, weights, temp, idxs); 413 | } 414 | err = cudaGetLastError(); 415 | if (cudaSuccess != err) { 416 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 417 | exit(-1); 418 | } 419 | } -------------------------------------------------------------------------------- /model/PointUtils/src/furthest_point_sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _FURTHEST_POINT_SAMPLING_H 2 | #define _FURTHEST_POINT_SAMPLING_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 9 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 10 | 11 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 12 | const float *points, const int *idx, float *out, cudaStream_t stream); 13 | 14 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 15 | at::Tensor grad_out_Tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 16 | 17 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 18 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 19 | 20 | int furthest_point_sampling_wrapper(int b, int n, int m, 21 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 22 | 23 | int weighted_furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor weights_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | void weighted_furthest_point_sampling_kernel_launcher(int b, int n, int m, 30 | const float *dataset, const float *weights, float *temp, int *idxs, cudaStream_t stream); 31 | 32 | #endif -------------------------------------------------------------------------------- /model/PointUtils/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "interpolate_gpu.h" 11 | 12 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 13 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 14 | const float *unknown = unknown_tensor.data(); 15 | const float *known = known_tensor.data(); 16 | float *dist2 = dist2_tensor.data(); 17 | int *idx = idx_tensor.data(); 18 | 19 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 20 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 21 | } 22 | 23 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 24 | at::Tensor points_tensor, 25 | at::Tensor idx_tensor, 26 | at::Tensor weight_tensor, 27 | at::Tensor out_tensor) { 28 | 29 | const float *points = points_tensor.data(); 30 | const float *weight = weight_tensor.data(); 31 | float *out = out_tensor.data(); 32 | const int *idx = idx_tensor.data(); 33 | 34 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 35 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 36 | } 37 | 38 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 39 | at::Tensor grad_out_tensor, 40 | at::Tensor idx_tensor, 41 | at::Tensor weight_tensor, 42 | at::Tensor grad_points_tensor) { 43 | 44 | const float *grad_out = grad_out_tensor.data(); 45 | const float *weight = weight_tensor.data(); 46 | float *grad_points = grad_points_tensor.data(); 47 | const int *idx = idx_tensor.data(); 48 | 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 51 | } -------------------------------------------------------------------------------- /model/PointUtils/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 9 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 10 | 11 | int bs_idx = blockIdx.y; 12 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 13 | 14 | if (bs_idx >= b || pt_idx >= n) return; 15 | 16 | unknown += bs_idx * n * 3 + pt_idx * 3; 17 | known += bs_idx * m * 3; 18 | dist2 += bs_idx * n * 3 + pt_idx * 3; 19 | idx += bs_idx * n * 3 + pt_idx * 3; 20 | 21 | float ux = unknown[0]; 22 | float uy = unknown[1]; 23 | float uz = unknown[2]; 24 | 25 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 26 | int besti1 = 0, besti2 = 0, besti3 = 0; 27 | 28 | for (int k = 0; k < m; ++k) { 29 | float x = known[k * 3 + 0]; 30 | float y = known[k * 3 + 1]; 31 | float z = known[k * 3 + 2]; 32 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 33 | if (d < best1) { 34 | best3 = best2; 35 | besti3 = besti2; 36 | best2 = best1; 37 | besti2 = besti1; 38 | best1 = d; 39 | besti1 = k; 40 | } 41 | else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } 47 | else if (d < best3) { 48 | best3 = d; 49 | besti3 = k; 50 | } 51 | } 52 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 53 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 54 | } 55 | 56 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 57 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 58 | 59 | cudaError_t err; 60 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); 61 | dim3 threads(THREADS_PER_BLOCK); 62 | 63 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 64 | 65 | err = cudaGetLastError(); 66 | if (cudaSuccess != err) { 67 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 68 | exit(-1); 69 | } 70 | } 71 | 72 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 73 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 74 | 75 | int bs_idx = blockIdx.z; 76 | int c_idx = blockIdx.y; 77 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 78 | 79 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 80 | 81 | weight += bs_idx * n * 3 + pt_idx * 3; 82 | points += bs_idx * c * m + c_idx * m; 83 | idx += bs_idx * n * 3 + pt_idx * 3; 84 | out += bs_idx * c * n + c_idx * n; 85 | 86 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 87 | } 88 | 89 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 90 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 91 | // points: [B,C,M] 92 | // idx: [B,N,3] 93 | // weight: [B,N,3] 94 | 95 | cudaError_t err; 96 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); 97 | dim3 threads(THREADS_PER_BLOCK); // points for one thread to proces 98 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 99 | 100 | err = cudaGetLastError(); 101 | if (cudaSuccess != err) { 102 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 103 | exit(-1); 104 | } 105 | } 106 | 107 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 108 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 109 | 110 | int bs_idx = blockIdx.z; 111 | int c_idx = blockIdx.y; 112 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 113 | 114 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 115 | 116 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 117 | weight += bs_idx * n * 3 + pt_idx * 3; 118 | grad_points += bs_idx * c * m + c_idx * m; 119 | idx += bs_idx * n * 3 + pt_idx * 3; 120 | 121 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 122 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 123 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 124 | } 125 | 126 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 127 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 128 | // grad_out: [B,C,N] 129 | // weight: [B,N,3] 130 | 131 | cudaError_t err; 132 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); 133 | dim3 threads(THREADS_PER_BLOCK); 134 | 135 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 136 | 137 | err = cudaGetLastError(); 138 | if (cudaSuccess != err) { 139 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 140 | exit(-1); 141 | } 142 | } -------------------------------------------------------------------------------- /model/PointUtils/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef INTERPOLATE_GPU_H 2 | #define INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 10 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 11 | 12 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 13 | const float *known, float *dist2, int *idx, cudaStream_t stream); 14 | 15 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 16 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 17 | 18 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 19 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 20 | 21 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 22 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 23 | 24 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 25 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 26 | 27 | #endif -------------------------------------------------------------------------------- /model/PointUtils/src/point_utils_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "furthest_point_sampling_gpu.h" 5 | #include "interpolate_gpu.h" 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 9 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 10 | 11 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 12 | m.def("weighted_furthest_point_sampling_wrapper", &weighted_furthest_point_sampling_wrapper, "weighted_furthest_point_sampling_wrapper"); 13 | 14 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 15 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 16 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 17 | } -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from model.utils import furthest_point_sample, gather_operation, three_nn, three_interpolate 7 | from pytorch3d.ops import knn_points, knn_gather 8 | 9 | def knn_group(points1, points2, k): 10 | ''' 11 | For each point in points1, query k-nearest-neighbors from points2, 12 | 13 | Input: 14 | points1: [B,3,M] (anchor points) 15 | points2: [B,3,N] (query points) 16 | Output: 17 | nn_group: [B,4,M,k] 18 | nn_idx: [B,M,k] 19 | ''' 20 | points1 = points1.permute(0,2,1).contiguous() # [B,M,3] 21 | points2 = points2.permute(0,2,1).contiguous() # [B,N,3] 22 | _, nn_idx, nn_points = knn_points(points1, points2, K=k, return_nn=True) 23 | points1_expand = points1.unsqueeze(2).repeat(1,1,k,1) 24 | rela_nn = nn_points - points1_expand # [B,M,k,3] 25 | rela_dist = torch.norm(rela_nn, dim=-1, keepdim=True) # [B,M,k,1] 26 | nn_group = torch.cat((rela_nn, rela_dist), dim=-1) # [B,M,k,4] 27 | nn_group = nn_group.permute(0,3,1,2).contiguous() 28 | return nn_group, nn_idx 29 | 30 | class MotionAlign(nn.Module): 31 | ''' 32 | Input: 33 | points1: [B,3,N] 34 | points2: [B,3,N] 35 | motion1: [B,C,N] 36 | Output: 37 | aligned_motion: [B,C,N] 38 | ''' 39 | def __init__(self, k, feature_size): 40 | super(MotionAlign, self).__init__() 41 | self.k = k 42 | self.feature_size = feature_size 43 | self.mlps = nn.Sequential(nn.Conv2d(self.feature_size+4, int(self.feature_size/2), kernel_size=1, bias=True), 44 | nn.Conv2d(int(self.feature_size/2), int(self.feature_size/4), kernel_size=1, bias=True)) 45 | 46 | def forward(self, points1, points2, motion1): 47 | nn_group, nn_idx = knn_group(points2, points1, self.k) 48 | features2 = knn_gather(motion1.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() 49 | weights = self.mlps(torch.cat([nn_group,features2],dim=1)) 50 | weights = torch.max(weights, dim=1, keepdim=True)[0] 51 | weights = F.softmax(weights, dim=-1) # [B,1,N,k] 52 | aligned_motion = torch.sum(torch.mul(features2, weights.repeat(1,self.feature_size,1,1)),dim=-1) 53 | return aligned_motion 54 | 55 | class MotionGRU(nn.Module): 56 | ''' 57 | Parameters: 58 | k: k nearest neighbors 59 | content_size: content feature size of current frame 60 | motion_size: motion feature size of current frame 61 | hidden_size: output feature size of hidden state 62 | Input: 63 | H0: [B,C,N] (hidden state of last frame) 64 | points0: [B,3,N] (point coordinates of last frame) 65 | points1: [B,3,N] (point coordinates of current frame) 66 | contents1: [B,C1,N] (content features of current frame) 67 | motions1: [B,C2,N] (motion features of current frame) 68 | 69 | Output: 70 | H1: hidden state 71 | ''' 72 | def __init__(self, k, content_size, motion_size, hidden_size): 73 | super(MotionGRU, self).__init__() 74 | self.k = k 75 | self.feature_size = content_size + motion_size 76 | self.hidden_size = hidden_size 77 | 78 | self.mlp_R = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 79 | self.mlp_Z = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 80 | 81 | self.mlp_H1_0 = nn.Sequential(nn.Conv2d(self.hidden_size+4, self.hidden_size, kernel_size=1, bias=True)) 82 | self.mlp_H1_1 = nn.Sequential(nn.Conv1d(self.hidden_size+self.feature_size, self.hidden_size, kernel_size=1, bias=True)) 83 | 84 | def forward(self, H0, points0, points1, contents1, motions1): 85 | 86 | features1 = torch.cat([contents1, motions1], dim=1) 87 | 88 | nn_group, nn_idx = knn_group(points1, points0, self.k) # [B,4+C,N,k] 89 | nn_H0 = knn_gather(H0.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() 90 | features1_expand = features1.unsqueeze(-1).repeat(1,1,1,self.k) 91 | 92 | gate_R = self.mlp_R(torch.cat((nn_group,nn_H0,features1_expand),dim=1)) # [B,C,N,k] 93 | gate_R = torch.sigmoid(torch.max(gate_R, dim=-1, keepdim=False)[0]) # [B,C,N] 94 | 95 | gate_Z = self.mlp_Z(torch.cat((nn_group,nn_H0,features1_expand),dim=1)) # [B,C,N,k] 96 | gate_Z = torch.sigmoid(torch.max(gate_Z, dim=-1, keepdim=False)[0]) # [B,C,N] 97 | 98 | H1_0 = self.mlp_H1_0(torch.cat((nn_group,nn_H0), dim=1)) # [B,C,N,k] 99 | H1_0 = torch.max(H1_0, dim=-1, keepdim=False)[0] # [B,C,N] 100 | 101 | H1_1 = torch.tanh(self.mlp_H1_1(torch.cat((features1,torch.mul(gate_R,H1_0)),dim=1))) 102 | H1 = torch.mul(gate_Z,H1_0) + torch.mul(1.0-gate_Z,H1_1) 103 | 104 | return H1 105 | 106 | class MotionLSTM(nn.Module): 107 | ''' 108 | Parameters: 109 | k: k nearest neighbors 110 | content_size: content feature size of current frame 111 | motion_size: motion feature size of current frame 112 | hidden_size: output feature size of hidden state 113 | Input: 114 | H0: [B,C,N] (hidden state of last frame) 115 | C0: [B,C,N] (cell state of last frame) 116 | points0: [B,3,N] (point coordinates of last frame) 117 | points1: [B,3,N] (point coordinates of current frame) 118 | contents1: [B,C1,N] (content features of current frame) 119 | motions1: [B,C2,N] (motion features of current frame) 120 | 121 | Output: 122 | H1: hidden state 123 | C1: cell state 124 | ''' 125 | def __init__(self, k, content_size, motion_size, hidden_size): 126 | super(MotionLSTM, self).__init__() 127 | self.k = k 128 | self.feature_size = content_size + motion_size 129 | self.hidden_size = hidden_size 130 | 131 | self.mlp_I = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 132 | self.mlp_F = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 133 | self.mlp_O = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 134 | 135 | self.mlp_C0 = nn.Sequential(nn.Conv2d(self.hidden_size+4, self.hidden_size, kernel_size=1, bias=True)) 136 | 137 | self.mlp_C1_1 = nn.Sequential(nn.Conv2d(self.hidden_size+self.feature_size+4, self.hidden_size, kernel_size=1, bias=True)) 138 | 139 | def forward(self, H0, C0, points0, points1, contents1, motions1): 140 | 141 | features1 = torch.cat([contents1, motions1], dim=1) 142 | 143 | nn_group, nn_idx = knn_group(points1, points0, self.k) # [B,4+C,N,k] 144 | nn_H0 = knn_gather(H0.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() # [B,C,N,k] 145 | nn_C0 = knn_gather(C0.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() # [B,C,N,k] 146 | 147 | features1 = features1.unsqueeze(-1).repeat(1,1,1,self.k) 148 | 149 | gate_I = self.mlp_I(torch.cat((nn_group,nn_H0,features1), dim=1)) # [B,C,N,k] 150 | gate_I = torch.sigmoid(torch.max(gate_I, dim=-1, keepdim=False)[0]) # [B,C,N] 151 | 152 | gate_F = self.mlp_F(torch.cat((nn_group,nn_H0,features1), dim=1)) 153 | gate_F = torch.sigmoid(torch.max(gate_F, dim=-1, keepdim=False)[0]) # [B,C,N] 154 | 155 | gate_O = self.mlp_O(torch.cat((nn_group,nn_H0,features1), dim=1)) 156 | gate_O = torch.sigmoid(torch.max(gate_O, dim=-1, keepdim=False)[0]) # [B,C,N] 157 | 158 | C1_0 = self.mlp_C0(torch.cat((nn_group, nn_C0), dim=1)) 159 | C1_0 = torch.max(C1_0, dim=-1, keepdim=False)[0] # [B,C,N] 160 | 161 | C1_1 = self.mlp_C1_1(torch.cat((nn_group,nn_H0,features1), dim=1)) 162 | C1_1 = torch.tanh(torch.max(C1_1, dim=-1, keepdim=False)[0]) # [B,C,N] 163 | 164 | C1 = torch.mul(gate_F, C1_0) + torch.mul(gate_I, C1_1) # [B,C,N] 165 | H1 = torch.mul(gate_O, torch.tanh(C1)) # [B,C,N] 166 | 167 | return H1, C1 168 | 169 | class FurthestPointsSample(nn.Module): 170 | ''' 171 | Furthest point sampling 172 | Parameters: 173 | npoints: number of sampled points 174 | Input: 175 | x: [B,3,N] 176 | Output: 177 | fps_points: [B,3,npoints] 178 | ''' 179 | def __init__(self, npoints): 180 | super(FurthestPointsSample, self).__init__() 181 | self.npoints = npoints 182 | 183 | def forward(self, x): 184 | fps_points_ind = furthest_point_sample(x.permute(0,2,1).contiguous(), self.npoints) 185 | fps_points = gather_operation(x, fps_points_ind) 186 | 187 | return fps_points 188 | 189 | class ContentEncoder(nn.Module): 190 | ''' 191 | Parameters: 192 | npoints: number of sample points 193 | k: k nearest number 194 | in_channels: input feature channels (C_in) 195 | out_channels: output feature channels (C_out) 196 | fps: True/False 197 | knn: True/False 198 | Input: 199 | points: [B,3,N] 200 | features: [B,C_in,N] 201 | Output: 202 | fps_points: [B,3,npoints] 203 | output_features: [B,C_out,npoints] 204 | ''' 205 | def __init__(self, npoints, k, in_channels, out_channels, radius, fps=True, knn=True): 206 | super(ContentEncoder, self).__init__() 207 | self.k = k 208 | self.fps = fps 209 | self.knn = knn 210 | self.furthest_points_sample = FurthestPointsSample(npoints) 211 | self.radius = radius 212 | 213 | layers = [] 214 | out_channels = [in_channels+4,*out_channels] 215 | for i in range(1, len(out_channels)): 216 | layers += [nn.Conv2d(out_channels[i-1], out_channels[i], kernel_size=1, bias=True), 217 | nn.ReLU()] 218 | self.conv = nn.Sequential(*layers) 219 | 220 | def forward(self, points, features): 221 | 222 | fps_points = self.furthest_points_sample(points) # [B,3,npoints] 223 | if self.knn: 224 | nn_group, nn_idx = knn_group(fps_points, points, self.k) # [B,4,npoints,k] 225 | if features is not None: 226 | nn_features = knn_gather(features.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() # [B,C_in,npoints,k] 227 | else: 228 | raise NotImplementedError 229 | 230 | if features is not None: 231 | new_features = torch.cat([nn_group, nn_features], dim=1) # [B,C_in+4,npoints,k] 232 | else: 233 | new_features = nn_group 234 | new_features = self.conv(new_features) # [B,C_out,npoints,k] 235 | out_features = torch.max(new_features, dim=-1, keepdim=False)[0] # [B,C_out,npoints] 236 | 237 | return fps_points, out_features 238 | 239 | class MotionEncoder(nn.Module): 240 | ''' 241 | Parameters: 242 | k: k nearest neighbors 243 | in_channels: input feature channels 244 | out_channels: output feature channels 245 | Input: 246 | points1: [B,3,N] 247 | features1: [B,C,N] 248 | points2: [B,3,N] 249 | features2: [B,C,N] 250 | Output: 251 | motions: [B,C_out,N] 252 | ''' 253 | def __init__(self, k, in_channels, out_channels): 254 | super(MotionEncoder, self).__init__() 255 | 256 | self.k = k 257 | 258 | layers = [] 259 | 260 | out_channels = [2*in_channels+4, *out_channels] 261 | for i in range(1, len(out_channels)): 262 | layers += [nn.Conv2d(out_channels[i-1], out_channels[i], kernel_size=1, bias=True), 263 | nn.ReLU()] 264 | self.conv = nn.Sequential(*layers) 265 | 266 | def forward(self, points1, features1, points2, features2): 267 | 268 | nn_group, nn_idx = knn_group(points1, points2, self.k) 269 | nn_features2 = knn_gather(features2.permute(0,2,1), nn_idx).permute(0,3,1,2).contiguous() 270 | new_features = torch.cat([nn_group, nn_features2, features1.unsqueeze(3).repeat(1,1,1,self.k)],dim=1) # [B,4+C+C,N,k] 271 | new_features = self.conv(new_features) # [B,C_out,N,k] 272 | motions = torch.max(new_features,dim=-1)[0] 273 | 274 | return motions 275 | 276 | class PointNet2FeaturePropagator(nn.Module): 277 | ''' 278 | Parameters: 279 | in_channels1: input feature channels 1 280 | in_channels2: input feature channels 2 281 | out_channels: output feature channels 282 | Input: 283 | xyz: [B,N,3] 284 | xyz_prev: [B,N,3] 285 | features: [B,C,N] 286 | features_prev: [B,C,N] 287 | ''' 288 | 289 | def __init__(self, in_channels1, in_channels2, out_channels, batchnorm=True): 290 | super(PointNet2FeaturePropagator, self).__init__() 291 | 292 | self.layer_dims = out_channels 293 | 294 | unit_pointnets = [] 295 | in_channels = in_channels1 + in_channels2 296 | for out_channel in out_channels: 297 | unit_pointnets.append( 298 | nn.Conv1d(in_channels, out_channel, 1)) 299 | 300 | if batchnorm: 301 | unit_pointnets.append(nn.BatchNorm1d(out_channel)) 302 | 303 | unit_pointnets.append(nn.ReLU()) 304 | in_channel = out_channel 305 | 306 | self.unit_pointnet = nn.Sequential(*unit_pointnets) 307 | 308 | def forward(self, xyz, xyz_prev, features=None, features_prev=None): 309 | """ 310 | Args: 311 | xyz (torch.Tensor): shape = (batch_size, num_points, 3) 312 | The 3D coordinates of each point at current layer, 313 | computed during feature extraction (i.e. set abstraction). 314 | xyz_prev (torch.Tensor|None): shape = (batch_size, num_points_prev, 3) 315 | The 3D coordinates of each point from the previous feature 316 | propagation layer (corresponding to the next layer during 317 | feature extraction). 318 | This value can be None (i.e. for the very first propagator layer). 319 | features (torch.Tensor|None): shape = (batch_size, num_features, num_points) 320 | The features of each point at current layer, 321 | computed during feature extraction (i.e. set abstraction). 322 | features_prev (torch.Tensor|None): shape = (batch_size, num_features_prev, num_points_prev) 323 | The features of each point from the previous feature 324 | propagation layer (corresponding to the next layer during 325 | feature extraction). 326 | Returns: 327 | (torch.Tensor): shape = (batch_size, num_features_out, num_points) 328 | """ 329 | num_points = xyz.shape[1] 330 | if xyz_prev is None: # Very first feature propagation layer 331 | new_features = features_prev.expand( 332 | *(features.shape + [num_points])) 333 | 334 | else: 335 | dist, idx = three_nn(xyz, xyz_prev) 336 | # shape = (batch_size, num_points, 3), (batch_size, num_points, 3) 337 | inverse_dist = 1.0 / (dist + 1e-8) 338 | total_inverse_dist = torch.sum(inverse_dist, dim=2, keepdim=True) 339 | weights = inverse_dist / total_inverse_dist 340 | new_features = three_interpolate(features_prev, idx, weights) 341 | # shape = (batch_size, num_features_prev, num_points) 342 | 343 | if features is not None: 344 | new_features = torch.cat([new_features, features], dim=1) 345 | 346 | return self.unit_pointnet(new_features) -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.layers import MotionEncoder, ContentEncoder, MotionLSTM, MotionGRU, MotionAlign, PointNet2FeaturePropagator 5 | 6 | class MoNet(nn.Module): 7 | ''' 8 | Input: 9 | x: [B,T,3+C,N] 10 | ''' 11 | def __init__(self, args): 12 | super(MoNet, self).__init__() 13 | 14 | self.rnn = args.rnn 15 | self.npoints = args.npoints 16 | self.pred_num = args.pred_num 17 | self.input_num = args.input_num 18 | 19 | self.C1 = 64 20 | self.C2 = 128 21 | self.C3 = 256 22 | 23 | self.content_encoder_1 = ContentEncoder(npoints=int(self.npoints/32), k=32, in_channels=0, \ 24 | out_channels=[int(self.C1/2),int(self.C1/2),self.C1], radius=0.5+1e-6, fps=True, knn=True) 25 | self.content_encoder_2 = ContentEncoder(npoints=int(self.npoints/64), k=16, in_channels=self.C1, \ 26 | out_channels=[self.C1,self.C1,self.C2], radius=1.0+1e-6, fps=True, knn=True) 27 | self.content_encoder_3 = ContentEncoder(npoints=int(self.npoints/128), k=8, in_channels=self.C2, \ 28 | out_channels=[self.C2,self.C2,self.C3], radius=2.0+1e-6, fps=True, knn=True) 29 | 30 | self.motion_encoder_1 = MotionEncoder(16, in_channels=self.C1, \ 31 | out_channels=[self.C1,self.C1,self.C1]) 32 | self.motion_encoder_2 = MotionEncoder(8, in_channels=self.C2, \ 33 | out_channels=[self.C2,self.C2,self.C2]) 34 | self.motion_encoder_3 = MotionEncoder(8, in_channels=self.C3, \ 35 | out_channels=[self.C3,self.C3,self.C3]) 36 | 37 | self.motion_align_1 = MotionAlign(16, self.C1) 38 | self.motion_align_2 = MotionAlign(16, self.C2) 39 | self.motion_align_3 = MotionAlign(16, self.C3) 40 | 41 | if self.rnn == 'LSTM': 42 | self.motion_rnn_1 = MotionLSTM(16, self.C1, self.C1, 2*self.C1) 43 | self.motion_rnn_2 = MotionLSTM(16, self.C2, self.C2, 2*self.C2) 44 | self.motion_rnn_3 = MotionLSTM(16, self.C3, self.C3, 2*self.C3) 45 | elif self.rnn == 'GRU': 46 | self.motion_rnn_1 = MotionGRU(16, self.C1, self.C1, 2*self.C1) 47 | self.motion_rnn_2 = MotionGRU(16, self.C2, self.C2, 2*self.C2) 48 | self.motion_rnn_3 = MotionGRU(16, self.C3, self.C3, 2*self.C3) 49 | else: 50 | raise('Not implemented') 51 | 52 | self.fp2 = PointNet2FeaturePropagator(2*self.C2, 2*self.C3, [2*self.C2], batchnorm=False) 53 | self.fp1 = PointNet2FeaturePropagator(2*self.C1, 2*self.C2, [2*self.C2], batchnorm=False) 54 | self.fp0 = PointNet2FeaturePropagator(0, 2*self.C2, [2*self.C2], batchnorm=False) 55 | 56 | self.classifier1 = nn.Conv1d(in_channels=2*self.C2, out_channels=128, kernel_size=1, bias=False) 57 | self.classifier2 = nn.Conv1d(in_channels=128, out_channels=3, kernel_size=1, bias=False) 58 | 59 | def forward(self, x): 60 | 61 | B = x.shape[0] 62 | T = x.shape[1] 63 | 64 | # Embedding pipeline 65 | 66 | # Content encoder for input point clouds 67 | points_list_0 = [] 68 | 69 | points_list_1 = [] 70 | contents_list_1 = [] 71 | points_list_2 = [] 72 | contents_list_2 = [] 73 | points_list_3 = [] 74 | contents_list_3 = [] 75 | 76 | for idx in range(self.input_num): 77 | points = x[:,idx,:,:].squeeze(1) 78 | points = points[:,:3,:].contiguous() 79 | points_list_0.append(points) 80 | 81 | points_1, contents_1 = self.content_encoder_1(points, None) 82 | points_2, contents_2 = self.content_encoder_2(points_1, contents_1) 83 | points_3, contents_3 = self.content_encoder_3(points_2, contents_2) 84 | 85 | points_list_1.append(points_1) 86 | contents_list_1.append(contents_1) 87 | points_list_2.append(points_2) 88 | contents_list_2.append(contents_2) 89 | points_list_3.append(points_3) 90 | contents_list_3.append(contents_3) 91 | 92 | # Motion encoder for input point clouds 93 | motion_list_1 = [] 94 | motion_list_2 = [] 95 | motion_list_3 = [] 96 | 97 | for idx in range(self.input_num-1): 98 | motions_1 = self.motion_encoder_1(points_list_1[idx],contents_list_1[idx], \ 99 | points_list_1[idx+1], contents_list_1[idx+1]) 100 | motions_2 = self.motion_encoder_2(points_list_2[idx],contents_list_2[idx], \ 101 | points_list_2[idx+1], contents_list_2[idx+1]) 102 | motions_3 = self.motion_encoder_3(points_list_3[idx],contents_list_3[idx], \ 103 | points_list_3[idx+1], contents_list_3[idx+1]) 104 | 105 | motion_list_1.append(motions_1) 106 | motion_list_2.append(motions_2) 107 | motion_list_3.append(motions_3) 108 | 109 | # Initialize states for RNN 110 | if self.rnn == 'GRU': 111 | last_H1 = torch.zeros((B,2*self.C1,int(self.npoints/32)),dtype=torch.float32).cuda() 112 | last_H2 = torch.zeros((B,2*self.C2,int(self.npoints/64)),dtype=torch.float32).cuda() 113 | last_H3 = torch.zeros((B,2*self.C3,int(self.npoints/128)),dtype=torch.float32).cuda() 114 | elif self.rnn == 'LSTM': 115 | last_C1 = torch.zeros((B,2*self.C1,int(self.npoints/32)),dtype=torch.float32).cuda() 116 | last_C2 = torch.zeros((B,2*self.C2,int(self.npoints/64)),dtype=torch.float32).cuda() 117 | last_C3 = torch.zeros((B,2*self.C3,int(self.npoints/128)),dtype=torch.float32).cuda() 118 | last_H1 = torch.zeros((B,2*self.C1,int(self.npoints/32)),dtype=torch.float32).cuda() 119 | last_H2 = torch.zeros((B,2*self.C2,int(self.npoints/64)),dtype=torch.float32).cuda() 120 | last_H3 = torch.zeros((B,2*self.C3,int(self.npoints/128)),dtype=torch.float32).cuda() 121 | else: 122 | raise('Not implemented') 123 | 124 | curr_points_1 = torch.zeros_like(points_list_1[0], dtype=torch.float32).cuda() 125 | last_points_1 = torch.zeros_like(points_list_1[0], dtype=torch.float32).cuda() 126 | curr_points_2 = torch.zeros_like(points_list_2[0], dtype=torch.float32).cuda() 127 | last_points_2 = torch.zeros_like(points_list_2[0], dtype=torch.float32).cuda() 128 | curr_points_3 = torch.zeros_like(points_list_3[0], dtype=torch.float32).cuda() 129 | last_points_3 = torch.zeros_like(points_list_3[0], dtype=torch.float32).cuda() 130 | 131 | for idx in range(self.input_num-1): 132 | 133 | curr_motions_1 = motion_list_1[idx] 134 | curr_contents_1 = contents_list_1[idx] 135 | curr_motions_2 = motion_list_2[idx] 136 | curr_contents_2 = contents_list_2[idx] 137 | curr_motions_3 = motion_list_3[idx] 138 | curr_contents_3 = contents_list_3[idx] 139 | 140 | curr_points_1 = points_list_1[idx] 141 | curr_points_2 = points_list_2[idx] 142 | curr_points_3 = points_list_3[idx] 143 | 144 | if idx == 0: 145 | last_points_1 = torch.zeros_like(points_list_1[0], dtype=torch.float32).cuda() 146 | last_points_2 = torch.zeros_like(points_list_2[0], dtype=torch.float32).cuda() 147 | last_points_3 = torch.zeros_like(points_list_3[0], dtype=torch.float32).cuda() 148 | else: 149 | last_points_1 = points_list_1[idx-1] 150 | last_points_2 = points_list_2[idx-1] 151 | last_points_3 = points_list_3[idx-1] 152 | 153 | if self.rnn == 'LSTM': 154 | 155 | last_H1, last_C1 = self.motion_rnn_1(last_H1, last_C1, last_points_1, curr_points_1, curr_contents_1, curr_motions_1) 156 | last_H2, last_C2 = self.motion_rnn_2(last_H2, last_C2, last_points_2, curr_points_2, curr_contents_2, curr_motions_2) 157 | last_H3, last_C3 = self.motion_rnn_3(last_H3, last_C3, last_points_3, curr_points_3, curr_contents_3, curr_motions_3) 158 | 159 | 160 | elif self.rnn == 'GRU': 161 | last_H1 = self.motion_rnn_1(last_H1, last_points_1, curr_points_1, curr_contents_1, curr_motions_1) 162 | last_H2 = self.motion_rnn_2(last_H2, last_points_2, curr_points_2, curr_contents_2, curr_motions_2) 163 | last_H3 = self.motion_rnn_3(last_H3, last_points_3, curr_points_3, curr_contents_3, curr_motions_3) 164 | 165 | else: 166 | raise('Not implemented') 167 | 168 | # Inference pipeline 169 | 170 | # Initialization for inference 171 | last_points_1 = points_list_1[-2] 172 | last_points_2 = points_list_2[-2] 173 | last_points_3 = points_list_3[-2] 174 | 175 | curr_points_1 = points_list_1[-1] 176 | curr_points_2 = points_list_2[-1] 177 | curr_points_3 = points_list_3[-1] 178 | 179 | last_contents_1 = contents_list_1[-2] 180 | last_contents_2 = contents_list_2[-2] 181 | last_contents_3 = contents_list_3[-2] 182 | 183 | curr_contents_1 = contents_list_1[-1] 184 | curr_contents_2 = contents_list_2[-1] 185 | curr_contents_3 = contents_list_3[-1] 186 | 187 | curr_points_0 = points_list_0[-1] 188 | 189 | pred_points_list = [] 190 | 191 | for idx in range(self.pred_num): 192 | 193 | curr_motions_1 = self.motion_align_1(last_points_1, curr_points_1, curr_motions_1) 194 | curr_motions_2 = self.motion_align_2(last_points_2, curr_points_2, curr_motions_2) 195 | curr_motions_3 = self.motion_align_3(last_points_3, curr_points_3, curr_motions_3) 196 | 197 | if self.rnn == 'LSTM': 198 | last_H1, last_C1 = self.motion_rnn_1(last_H1, last_C1, last_points_1, curr_points_1, curr_contents_1, curr_motions_1) 199 | last_H2, last_C2 = self.motion_rnn_2(last_H2, last_C2, last_points_2, curr_points_2, curr_contents_2, curr_motions_2) 200 | last_H3, last_C3 = self.motion_rnn_3(last_H3, last_C3, last_points_3, curr_points_3, curr_contents_3, curr_motions_3) 201 | 202 | elif self.rnn == 'GRU': 203 | last_H1 = self.motion_rnn_1(last_H1, last_points_1, curr_points_1, curr_contents_1, curr_motions_1) 204 | last_H2 = self.motion_rnn_2(last_H2, last_points_2, curr_points_2, curr_contents_2, curr_motions_2) 205 | last_H3 = self.motion_rnn_3(last_H3, last_points_3, curr_points_3, curr_contents_3, curr_motions_3) 206 | 207 | else: 208 | raise('Not implemented') 209 | 210 | # decoder 211 | l2_feat = self.fp2(curr_points_2.permute(0,2,1).contiguous(), curr_points_3.permute(0,2,1).contiguous(), last_H2, last_H3) 212 | l1_feat = self.fp1(curr_points_1.permute(0,2,1).contiguous(), curr_points_2.permute(0,2,1).contiguous(), last_H1, l2_feat) 213 | l0_feat = self.fp0(curr_points_0.permute(0,2,1).contiguous(), curr_points_1.permute(0,2,1).contiguous(), None, l1_feat) 214 | 215 | pred_flow = self.classifier2(self.classifier1(l0_feat)) 216 | 217 | pred_points = curr_points_0 + pred_flow 218 | pred_points_list.append(pred_points) 219 | 220 | curr_points_0 = pred_points 221 | 222 | last_points_1 = curr_points_1 223 | last_points_2 = curr_points_2 224 | last_points_3 = curr_points_3 225 | last_contents_1 = curr_contents_1 226 | last_contents_2 = curr_contents_2 227 | last_contents_3 = curr_contents_3 228 | 229 | curr_points_1, curr_contents_1 = self.content_encoder_1(curr_points_0, None) 230 | curr_points_2, curr_contents_2 = self.content_encoder_2(curr_points_1, curr_contents_1) 231 | curr_points_3, curr_contents_3 = self.content_encoder_3(curr_points_2, curr_contents_2) 232 | 233 | curr_motions_1 = self.motion_encoder_1(last_points_1, last_contents_1, curr_points_1, curr_contents_1) 234 | curr_motions_2 = self.motion_encoder_2(last_points_2, last_contents_2, curr_points_2, curr_contents_2) 235 | curr_motions_3 = self.motion_encoder_3(last_points_3, last_contents_3, curr_points_3, curr_contents_3) 236 | 237 | return pred_points_list -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable, grad 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | 8 | from typing import Tuple, Union 9 | 10 | import point_utils_cuda 11 | 12 | from pytorch3d.loss import chamfer_distance 13 | from pytorch3d.ops import knn_points, knn_gather 14 | import random 15 | import emd_cuda 16 | 17 | class FurthestPointSampling(Function): 18 | @staticmethod 19 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 20 | ''' 21 | ctx: 22 | xyz: [B,N,3] 23 | npoint: int 24 | ''' 25 | assert xyz.is_contiguous() 26 | 27 | B, N, _ = xyz.size() 28 | output = torch.cuda.IntTensor(B, npoint) 29 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 30 | 31 | point_utils_cuda.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 32 | return output 33 | 34 | @staticmethod 35 | def backward(xyz, a=None): 36 | return None, None 37 | 38 | furthest_point_sample = FurthestPointSampling.apply 39 | 40 | class WeightedFurthestPointSampling(Function): 41 | @staticmethod 42 | def forward(ctx, xyz: torch.Tensor, weights: torch.Tensor, npoint: int) -> torch.Tensor: 43 | ''' 44 | ctx: 45 | xyz: [B,N,3] 46 | weights: [B,N] 47 | npoint: int 48 | ''' 49 | assert xyz.is_contiguous() 50 | assert weights.is_contiguous() 51 | B, N, _ = xyz.size() 52 | output = torch.cuda.IntTensor(B, npoint) 53 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 54 | 55 | point_utils_cuda.weighted_furthest_point_sampling_wrapper(B, N, npoint, xyz, weights, temp, output); 56 | return output 57 | 58 | @staticmethod 59 | def backward(xyz, a=None): 60 | return None, None 61 | 62 | weighted_furthest_point_sample = WeightedFurthestPointSampling.apply 63 | 64 | class GatherOperation(Function): 65 | @staticmethod 66 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 67 | ''' 68 | ctx 69 | features: [B,C,N] 70 | idx: [B,npoint] 71 | ''' 72 | assert features.is_contiguous() 73 | assert idx.is_contiguous() 74 | 75 | B, npoint = idx.size() 76 | _, C, N = features.size() 77 | output = torch.cuda.FloatTensor(B, C, npoint) 78 | 79 | point_utils_cuda.gather_points_wrapper(B, C, N, npoint, features, idx, output) 80 | 81 | ctx.for_backwards = (idx, C, N) 82 | return output 83 | 84 | @staticmethod 85 | def backward(ctx, grad_out): 86 | idx, C, N = ctx.for_backwards 87 | B, npoint = idx.size() 88 | grad_features = Variable(torch.cuda.FloatTensor(B,C,N).zero_()) 89 | grad_out_data = grad_out.data.contiguous() 90 | point_utils_cuda.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 91 | return grad_features, None 92 | 93 | gather_operation = GatherOperation.apply 94 | 95 | class ThreeNN(Function): 96 | 97 | @staticmethod 98 | def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 99 | """ 100 | Find the three nearest neighbors of unknown in known 101 | :param ctx: 102 | :param unknown: (B, N, 3) 103 | :param known: (B, M, 3) 104 | :return: 105 | dist: (B, N, 3) l2 distance to the three nearest neighbors 106 | idx: (B, N, 3) index of 3 nearest neighbors 107 | """ 108 | assert unknown.is_contiguous() 109 | assert known.is_contiguous() 110 | 111 | B, N, _ = unknown.size() 112 | m = known.size(1) 113 | dist2 = torch.cuda.FloatTensor(B, N, 3) 114 | idx = torch.cuda.IntTensor(B, N, 3) 115 | 116 | point_utils_cuda.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 117 | return torch.sqrt(dist2), idx 118 | 119 | @staticmethod 120 | def backward(ctx, a=None, b=None): 121 | return None, None 122 | 123 | three_nn = ThreeNN.apply 124 | 125 | class ThreeInterpolate(Function): 126 | 127 | @staticmethod 128 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 129 | """ 130 | Performs weight linear interpolation on 3 features 131 | :param ctx: 132 | :param features: (B, C, M) Features descriptors to be interpolated from 133 | :param idx: (B, n, 3) three nearest neighbors of the target features in features 134 | :param weight: (B, n, 3) weights 135 | :return: 136 | output: (B, C, N) tensor of the interpolated features 137 | """ 138 | assert features.is_contiguous() 139 | assert idx.is_contiguous() 140 | assert weight.is_contiguous() 141 | 142 | B, c, m = features.size() 143 | n = idx.size(1) 144 | ctx.three_interpolate_for_backward = (idx, weight, m) 145 | output = torch.cuda.FloatTensor(B, c, n) 146 | 147 | point_utils_cuda.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) 148 | return output 149 | 150 | @staticmethod 151 | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 152 | """ 153 | :param ctx: 154 | :param grad_out: (B, C, N) tensor with gradients of outputs 155 | :return: 156 | grad_features: (B, C, M) tensor with gradients of features 157 | None: 158 | None: 159 | """ 160 | idx, weight, m = ctx.three_interpolate_for_backward 161 | B, c, n = grad_out.size() 162 | 163 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 164 | grad_out_data = grad_out.data.contiguous() 165 | 166 | point_utils_cuda.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) 167 | return grad_features, None, None 168 | 169 | three_interpolate = ThreeInterpolate.apply 170 | 171 | def batch_chamfer_distance(pc1, pc2): 172 | ''' 173 | Input: 174 | pc1: [B,3,N] 175 | pc2: [B,3,N] 176 | ''' 177 | pc1 = pc1.permute(0,2,1).contiguous() 178 | pc2 = pc2.permute(0,2,1).contiguous() 179 | dist_batch, _ = chamfer_distance(pc1, pc2, batch_reduction='mean', point_reduction='mean') 180 | return dist_batch 181 | 182 | def multi_frame_chamfer_loss(pc1, pc2_list): 183 | ''' 184 | Calculate chamfer distance consecutive point cloud stream 185 | Input: 186 | pc1: [B,T,3,N] 187 | pc2_list: a list of [B,3,N] 188 | ''' 189 | pred_num = len(pc2_list) 190 | l_total = 0 191 | for i in range(pred_num): 192 | curr_pc1 = pc1[:,i,:,:].squeeze(1).contiguous() 193 | curr_pc2 = pc2_list[i] 194 | 195 | curr_chamfer_dist = batch_chamfer_distance(curr_pc1, curr_pc2) 196 | l_total += curr_chamfer_dist 197 | l_chamfer = l_total/pred_num 198 | return l_chamfer 199 | 200 | class EarthMoverDistanceFunction(torch.autograd.Function): 201 | @staticmethod 202 | def forward(ctx, xyz1, xyz2): 203 | xyz1 = xyz1.contiguous() 204 | xyz2 = xyz2.contiguous() 205 | assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently." 206 | match = emd_cuda.approxmatch_forward(xyz1, xyz2) 207 | cost = emd_cuda.matchcost_forward(xyz1, xyz2, match) 208 | ctx.save_for_backward(xyz1, xyz2, match) 209 | return cost 210 | 211 | @staticmethod 212 | def backward(ctx, grad_cost): 213 | xyz1, xyz2, match = ctx.saved_tensors 214 | grad_cost = grad_cost.contiguous() 215 | grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match) 216 | return grad_xyz1, grad_xyz2 217 | 218 | def earth_mover_distance(xyz1, xyz2, transpose=True): 219 | """Earth Mover Distance (Approx) 220 | 221 | Args: 222 | xyz1 (torch.Tensor): (b, 3, n1) 223 | xyz2 (torch.Tensor): (b, 3, n1) 224 | transpose (bool): whether to transpose inputs as it might be BCN format. 225 | Extensions only support BNC format. 226 | 227 | Returns: 228 | cost (torch.Tensor): (b) 229 | 230 | """ 231 | if xyz1.dim() == 2: 232 | xyz1 = xyz1.unsqueeze(0) 233 | if xyz2.dim() == 2: 234 | xyz2 = xyz2.unsqueeze(0) 235 | if transpose: 236 | xyz1 = xyz1.transpose(1, 2) 237 | xyz2 = xyz2.transpose(1, 2) 238 | cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) 239 | return cost 240 | 241 | def EMD(pc1, pc2): 242 | ''' 243 | Input: 244 | pc1: [1,3,M] 245 | pc2: [1,3,M] 246 | Ret: 247 | d: torch.float32 248 | ''' 249 | pc1 = pc1.permute(0,2,1).contiguous() 250 | pc2 = pc2.permute(0,2,1).contiguous() 251 | d = earth_mover_distance(pc1, pc2, transpose=False) 252 | d = torch.mean(d)/pc1.shape[1] 253 | return d 254 | 255 | def set_seed(seed): 256 | ''' 257 | Set random seed for torch, numpy and python 258 | ''' 259 | random.seed(seed) 260 | np.random.seed(seed) 261 | torch.manual_seed(seed) 262 | if torch.cuda.is_available(): 263 | torch.cuda.manual_seed(seed) 264 | torch.cuda.manual_seed_all(seed) 265 | 266 | torch.backends.cudnn.benchmark=False 267 | torch.backends.cudnn.deterministic=True -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | plyfile 4 | tqdm 5 | wandb -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | python eval.py --batch_size 1 --seed 1 --gpu GPU \ 2 | --root DATA_ROOT --npoints 16384 --rnn RNN --input_num 5 --pred_num 5 \ 3 | --dataset DATASET --ckpt CKPT -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python train.py --batch_size 4 --epochs 100 --lr 0.001 --seed 1 --gpu GPU \ 2 | --root DATA_ROOT --npoints 16384 --rnn RNN --input_num 5 --pred_num 5 --dataset DATASET \ 3 | --ckpt_dir ./ckpt --multi_gpu --runname RUNNAME --wandb_dir ../wandb --use_wandb -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | 7 | from data.kitti_loader import KittiDataset 8 | from data.argo_loader import ArgoDataset 9 | 10 | from model.models import MoNet 11 | 12 | from model.utils import batch_chamfer_distance, multi_frame_chamfer_loss, set_seed, EMD 13 | 14 | from tqdm import tqdm 15 | import argparse 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='MoNet') 20 | 21 | parser.add_argument('--batch_size', type=int, default=1) 22 | parser.add_argument('--seed', type=int, default=1) 23 | parser.add_argument('--gpu', type=str, default='1') 24 | parser.add_argument('--multi_gpu', action='store_true') 25 | parser.add_argument('--root', type=str, default='') 26 | parser.add_argument('--npoints', type=int, default=16384) 27 | parser.add_argument('--use_wandb', action='store_true') 28 | parser.add_argument('--runname', type=str, default='') 29 | parser.add_argument('--rnn', type=str, default='', help='LSTM/GRU') 30 | parser.add_argument('--pred_num', type=int, default=5) 31 | parser.add_argument('--input_num', type=int, default=5) 32 | parser.add_argument('--dataset', type=str, default='kitti') 33 | parser.add_argument('--ckpt', type=str, default='') 34 | 35 | return parser.parse_args() 36 | 37 | def test(args): 38 | if args.dataset == 'kitti': 39 | test_seqs = ['08','09','10'] 40 | test_dataset = KittiDataset(args.root, args.npoints, args.input_num, args.pred_num, test_seqs) 41 | elif args.dataset == 'argoverse': 42 | test_seqs = ['test'] 43 | test_dataset = ArgoDataset(args.root, args.npoints, args.input_num, args.pred_num, test_seqs) 44 | else: 45 | raise('Not implemented') 46 | 47 | test_loader = DataLoader(test_dataset, 48 | batch_size=args.batch_size, 49 | num_workers=4, 50 | shuffle=False, 51 | pin_memory=True, 52 | drop_last=True) 53 | 54 | net = MoNet(args) 55 | net.cuda() 56 | net.load_state_dict(torch.load(args.ckpt)) 57 | 58 | count = 0 59 | l_chamfer_list = [0.0] * args.pred_num 60 | l_emd_list = [0.0] * args.pred_num 61 | 62 | pbar = tqdm(enumerate(test_loader)) 63 | 64 | net.eval() 65 | 66 | with torch.no_grad(): 67 | for i, data in pbar: 68 | input_pc, output_pc = data 69 | input_pc = input_pc.cuda() 70 | output_pc = output_pc.cuda() 71 | 72 | pred_pc = net(input_pc) 73 | 74 | for t in range(args.pred_num): 75 | l_chamfer_one = batch_chamfer_distance(output_pc[:,t,:3,:].squeeze(1), pred_pc[t]) 76 | l_emd_one = EMD(output_pc[:,t,:3,:].squeeze(1), pred_pc[t]) 77 | l_chamfer_list[t] += l_chamfer_one.item() 78 | l_emd_list[t] += l_emd_one.item() 79 | l_emd_list[t] += 0.0 80 | 81 | count += 1 82 | 83 | for t in range(args.pred_num): 84 | l_chamfer_list[t] = l_chamfer_list[t]/count 85 | l_emd_list[t] = l_emd_list[t]/count 86 | 87 | print('Chamfer Distance:', l_chamfer_list) 88 | print('Average Chamfer Distance:', np.mean(np.array(l_chamfer_list))) 89 | print('Earth Mover Distance', l_emd_list) 90 | print('Average Earth Mover Distance', np.mean(np.array(l_emd_list))) 91 | 92 | if __name__ == '__main__': 93 | args = parse_args() 94 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 95 | set_seed(args.seed) 96 | 97 | test(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import StepLR 8 | 9 | from data.kitti_loader import KittiDataset 10 | from data.argo_loader import ArgoDataset 11 | 12 | from model.models import MoNet 13 | 14 | from model.utils import batch_chamfer_distance, multi_frame_chamfer_loss, set_seed 15 | 16 | from tqdm import tqdm 17 | import argparse 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='MoNet') 21 | 22 | parser.add_argument('--batch_size', type=int, default=2) 23 | parser.add_argument('--epochs', type=int, default=100) 24 | parser.add_argument('--lr', type=float, default=0.001) 25 | parser.add_argument('--seed', type=int, default=1) 26 | parser.add_argument('--gpu', type=str, default='1') 27 | parser.add_argument('--multi_gpu', action='store_true') 28 | parser.add_argument('--root', type=str, default='') 29 | parser.add_argument('--npoints', type=int, default=16384) 30 | parser.add_argument('--use_wandb', action='store_true') 31 | parser.add_argument('--runname', type=str, default='') 32 | parser.add_argument('--rnn', type=str, default='', help='LSTM/GRU') 33 | parser.add_argument('--pred_num', type=int, default=5) 34 | parser.add_argument('--input_num', type=int, default=5) 35 | parser.add_argument('--dataset', type=str, default='kitti') 36 | parser.add_argument('--ckpt_dir', type=str, default='') 37 | parser.add_argument('--wandb_dir', type=str, default='') 38 | 39 | return parser.parse_args() 40 | 41 | def validation(args, net): 42 | if args.dataset == 'kitti': 43 | val_seqs = ['06','07'] 44 | val_dataset = KittiDataset(args.root, args.npoints, args.input_num, args.pred_num, val_seqs) 45 | elif args.dataset == 'argoverse': 46 | val_seqs = ['val'] 47 | val_dataset = ArgoDataset(args.root, args.npoints, args.input_num, args.pred_num, val_seqs) 48 | else: 49 | raise('Not implemented') 50 | val_loader = DataLoader(val_dataset, 51 | batch_size=args.batch_size, 52 | num_workers=4, 53 | shuffle=True, 54 | pin_memory=True, 55 | drop_last=True) 56 | net.eval() 57 | 58 | total_val_loss = 0 59 | count = 0 60 | pbar = tqdm(enumerate(val_loader)) 61 | with torch.no_grad(): 62 | for i, data in pbar: 63 | input_pc, output_pc = data 64 | input_pc = input_pc.cuda() 65 | output_pc = output_pc.cuda() 66 | 67 | pred_pc = net(input_pc) 68 | 69 | loss = multi_frame_chamfer_loss(output_pc[:,:,:3,:], pred_pc) 70 | total_val_loss += loss.item() 71 | count += 1 72 | 73 | total_val_loss = total_val_loss/count 74 | return total_val_loss 75 | 76 | def train(args): 77 | if args.dataset == 'kitti': 78 | train_seqs = ['00','01','02','03','04','05'] 79 | train_dataset = KittiDataset(args.root, args.npoints, args.input_num, args.pred_num, train_seqs) 80 | elif args.dataset == 'argoverse': 81 | train_seqs = ['train1', 'train2', 'train3', 'train4'] 82 | train_dataset = ArgoDataset(args.root, args.npoints, args.input_num, args.pred_num, train_seqs) 83 | else: 84 | raise('Not implemented') 85 | 86 | train_loader = DataLoader(train_dataset, 87 | batch_size=args.batch_size, 88 | num_workers=4, 89 | shuffle=True, 90 | pin_memory=True, 91 | drop_last=True) 92 | 93 | net = MoNet(args) 94 | 95 | if args.use_wandb: 96 | wandb.watch(net) 97 | if args.multi_gpu: 98 | net = torch.nn.DataParallel(net) 99 | net.cuda() 100 | 101 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 102 | scheduler = StepLR(optimizer, step_size=10, gamma=0.5) 103 | 104 | best_train_loss = float('inf') 105 | best_val_loss = float('inf') 106 | best_train_epoch = 0 107 | best_val_epoch = 0 108 | 109 | for epoch in tqdm(range(args.epochs)): 110 | 111 | net.train() 112 | count = 0 113 | total_loss = 0 114 | pbar = tqdm(enumerate(train_loader)) 115 | 116 | for i, data in pbar: 117 | input_pc, output_pc = data 118 | input_pc = input_pc.cuda() 119 | output_pc = output_pc.cuda() 120 | 121 | optimizer.zero_grad() 122 | pred_pc = net(input_pc) 123 | 124 | loss = multi_frame_chamfer_loss(output_pc[:,:,:3,:], pred_pc) 125 | loss.backward() 126 | torch.nn.utils.clip_grad_norm_(net.parameters(),max_norm=5.0) 127 | optimizer.step() 128 | 129 | count += 1 130 | total_loss += loss.item() 131 | 132 | if i % 10 == 0: 133 | pbar.set_description('Train Epoch:{}[{}/{}({:.0f}%)]\tLoss: {:.6f}'.format( 134 | epoch+1, i, len(train_loader), 100. * i/len(train_loader), loss.item() 135 | )) 136 | 137 | total_loss = total_loss/count 138 | total_val_loss = validation(args, net) 139 | 140 | if args.use_wandb: 141 | wandb.log({"train loss":total_loss, "val loss":total_val_loss}) 142 | 143 | print('\n Epoch {} finished. Training loss: {:.4f} Valiadation loss: {:.4f}'.\ 144 | format(epoch+1, total_loss, total_val_loss)) 145 | 146 | ckpt_dir = os.path.join(args.ckpt_dir, 'ckpt_'+args.runname) 147 | if not os.path.exists(ckpt_dir): 148 | os.makedirs(ckpt_dir) 149 | 150 | if total_loss < best_train_loss: 151 | if args.multi_gpu: 152 | torch.save(net.module.state_dict(), os.path.join(ckpt_dir, 'best_train.pth')) 153 | else: 154 | torch.save(net.state_dict(), os.path.join(ckpt_dir, 'best_train.pth')) 155 | best_train_loss = total_loss 156 | best_train_epoch = epoch + 1 157 | 158 | if total_val_loss < best_val_loss: 159 | if args.multi_gpu: 160 | torch.save(net.module.state_dict(), os.path.join(ckpt_dir, 'best_val.pth')) 161 | else: 162 | torch.save(net.state_dict(), os.path.join(ckpt_dir, 'best_val.pth')) 163 | best_val_loss = total_val_loss 164 | best_val_epoch = epoch + 1 165 | 166 | print('Best train epoch: {} Best train loss: {:.4f} Best val epoch: {} Best val loss: {:.4f}'.format( 167 | best_train_epoch, best_train_loss, best_val_epoch, best_val_loss 168 | )) 169 | scheduler.step() 170 | 171 | if __name__ == '__main__': 172 | args = parse_args() 173 | 174 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 175 | set_seed(args.seed) 176 | 177 | if args.use_wandb: 178 | import wandb 179 | wandb.init(config=args, project='MoNet', name=args.dataset+'_'+args.runname, dir=args.wandb_dir) 180 | train(args) --------------------------------------------------------------------------------