├── .gitattributes ├── src ├── .DS_Store ├── models │ ├── .DS_Store │ ├── __pycache__ │ │ ├── VAE.cpython-39.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── point_net.cpython-39.pyc │ │ ├── AutoEncoder.cpython-39.pyc │ │ ├── point_net_loss.cpython-39.pyc │ │ ├── PointCloudDecoder.cpython-39.pyc │ │ ├── PointCloudEncoder.cpython-39.pyc │ │ ├── PointNetDecoder.cpython-39.pyc │ │ ├── PointNetEncoder.cpython-39.pyc │ │ └── shapenet_dataset.cpython-39.pyc │ ├── utils.py │ ├── AutoEncoder.py │ ├── GAN.py │ ├── VAE.py │ ├── PointCloudEncoder.py │ ├── PointNetDecoder.py │ ├── PointNetEncoder.py │ └── PointCloudDecoder.py ├── utils │ ├── __pycache__ │ │ ├── train.cpython-39.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── create.cpython-39.pyc │ │ ├── visual.cpython-39.pyc │ │ ├── graph_dataset.cpython-39.pyc │ │ ├── graph_model.cpython-39.pyc │ │ ├── calculate_loss.cpython-39.pyc │ │ └── graph_construct.cpython-39.pyc │ ├── utils.py │ ├── calculate_loss.py │ └── train.py ├── data │ ├── __pycache__ │ │ ├── dataset.cpython-39.pyc │ │ ├── conversion.cpython-39.pyc │ │ └── shapenet_dataset.cpython-39.pyc │ ├── dataset.py │ └── shapenet_dataset.py └── evaluation │ ├── __pycache__ │ └── evaluate.cpython-39.pyc │ └── evaluate.py ├── models ├── .DS_Store └── checkpoints │ └── .DS_Store ├── LICENSE ├── README.md └── run.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /src/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/.DS_Store -------------------------------------------------------------------------------- /models/checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/models/checkpoints/.DS_Store -------------------------------------------------------------------------------- /src/models/__pycache__/VAE.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/VAE.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/train.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/train.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/data/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/create.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/create.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/visual.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/visual.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/conversion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/data/__pycache__/conversion.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/point_net.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/point_net.cpython-39.pyc -------------------------------------------------------------------------------- /src/evaluation/__pycache__/evaluate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/evaluation/__pycache__/evaluate.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/AutoEncoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/AutoEncoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/graph_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/graph_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/graph_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/graph_model.cpython-39.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/shapenet_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/data/__pycache__/shapenet_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/point_net_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/point_net_loss.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/calculate_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/calculate_loss.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/graph_construct.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/utils/__pycache__/graph_construct.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/PointCloudDecoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/PointCloudDecoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/PointCloudEncoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/PointCloudEncoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/PointNetDecoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/PointNetDecoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/PointNetEncoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/PointNetEncoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/__pycache__/shapenet_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mertyigit/PointNet-VAE/HEAD/src/models/__pycache__/shapenet_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3 3 | 4 | def PointsTo3DShape(points): 5 | pcd = o3.geometry.PointCloud() 6 | pcd.points = o3.utility.Vector3dVector(points) 7 | 8 | return o3.visualization.draw_plotly([pcd]) 9 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3 3 | 4 | 5 | def PointsTo3DShape(points): 6 | pcd = o3.geometry.PointCloud() 7 | pcd.points = o3.utility.Vector3dVector(points) 8 | 9 | return o3.visualization.draw_plotly([pcd]) 10 | 11 | 12 | def VisualizeEmbedding(embedding): 13 | import torch 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | 17 | fig = plt.figure(figsize=(8, 6)) 18 | 19 | # Generate a 2D tensor 20 | tensor = embedding 21 | 22 | 23 | # Create a figure object with custom size 24 | 25 | # Plot the heatmap using Seaborn 26 | sns.heatmap(tensor, cmap='coolwarm', cbar=True) 27 | plt.title('Tensor Heatmap') 28 | 29 | plt.show() -------------------------------------------------------------------------------- /src/models/AutoEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | # Define the VAE model 7 | class AutoEncoder(nn.Module): 8 | def __init__(self, encoder, decoder, device, latent_dim): 9 | super(AutoEncoder, self).__init__() 10 | self.encoder = encoder 11 | self.decoder = decoder 12 | self.latent_dim = latent_dim 13 | self.device = device 14 | 15 | 16 | def forward(self, x): 17 | # Encode the input to get mu and logvar 18 | #embeddings = self.encoder(x) 19 | embeddings, _, _ = self.encoder(x) ## If PointNet Encoder is used 20 | reconstructed_x = self.decoder(embeddings) 21 | 22 | return reconstructed_x -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mert Sengul 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/GAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | # Define the VAE model 7 | class VAE(nn.Module): 8 | def __init__(self, encoder, decoder, device, latent_dim): 9 | super(VAE, self).__init__() 10 | self.encoder = encoder 11 | self.decoder = decoder 12 | self.latent_dim = latent_dim 13 | self.device = device 14 | 15 | # Define the layers for the mean and log-variance vectors 16 | self.fc_mu = nn.Linear(self.latent_dim, self.latent_dim) 17 | self.fc_logvar = nn.Linear(self.latent_dim, self.latent_dim) 18 | 19 | # self.fc_decode = nn.Linear(self.latent_dim, self.encoder.num_global_feats) 20 | 21 | def reparameterize(self, mu, logvar): 22 | std = torch.exp(0.5 * logvar) 23 | 24 | eps = torch.randn_like(std, device=self.device) 25 | 26 | z = mu + eps * std 27 | return z 28 | 29 | def forward(self, x): 30 | # Encode the input to get mu and logvar 31 | global_features, _, _ = self.encoder(x) ## If PointNet Encoder 32 | #global_features = self.encoder(x) ## If Convolution Encoder 33 | mu = self.fc_mu(global_features) 34 | logvar = self.fc_logvar(global_features) 35 | 36 | # Reparameterize and sample from the latent space 37 | z = self.reparameterize(mu, logvar) 38 | z.to(self.device) 39 | # Decode the sampled z to generate output 40 | 41 | reconstructed_x = self.decoder(z) 42 | 43 | return reconstructed_x, mu, logvar -------------------------------------------------------------------------------- /src/models/VAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | # Define the VAE model 7 | class VAE(nn.Module): 8 | def __init__(self, encoder, decoder, device, latent_dim): 9 | super(VAE, self).__init__() 10 | self.encoder = encoder 11 | self.decoder = decoder 12 | self.latent_dim = latent_dim 13 | self.device = device 14 | 15 | # Define the layers for the mean and log-variance vectors 16 | self.fc_mu = nn.Linear(self.latent_dim, self.latent_dim) 17 | self.fc_logvar = nn.Linear(self.latent_dim, self.latent_dim) 18 | 19 | # self.fc_decode = nn.Linear(self.latent_dim, self.encoder.num_global_feats) 20 | 21 | def reparameterize(self, mu, logvar): 22 | std = torch.exp(0.5 * logvar) 23 | 24 | eps = torch.randn_like(std, device=self.device) 25 | 26 | z = mu + eps * std 27 | return z 28 | 29 | def forward(self, x): 30 | # Encode the input to get mu and logvar 31 | global_features, _, _ = self.encoder(x) ## If PointNet Encoder 32 | #global_features = self.encoder(x) ## If Convolution Encoder 33 | mu = self.fc_mu(global_features) 34 | mu = self.fc_mu(mu) 35 | mu = self.fc_mu(mu) 36 | 37 | logvar = self.fc_logvar(global_features) 38 | logvar = self.fc_logvar(logvar) 39 | logvar = self.fc_logvar(logvar) 40 | 41 | # Reparameterize and sample from the latent space 42 | z = self.reparameterize(mu, logvar) 43 | z.to(self.device) 44 | # Decode the sampled z to generate output 45 | 46 | reconstructed_x = self.decoder(z) 47 | 48 | return reconstructed_x, mu, logvar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational PointNet Encoder-Decoder for 3D Point Cloud Data 2 | 3 | Welcome to the Variational PointNet Encoder-Decoder repository! This project represents my endeavor to create a solution for processing and generating 3D point cloud data. It offers a range of encoding and decoding options to suit various application needs. 4 | 5 | [![Python](https://img.shields.io/badge/Python-3.9%2B-blue)](https://www.python.org/) 6 | [![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-orange)](https://pytorch.org/) 7 | [![GitHub License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) 8 | 9 | ## Encoding Options 10 | 11 | **1. PointNet++** 12 | - The PointNet++ encoder is a robust and widely-used choice for extracting features from 3D point cloud data. 13 | 14 | **2. Convolutional Encoders** 15 | - My implementation includes convolutional encoders, designed to leverage spatial dependencies within the point cloud. 16 | 17 | **3. Graph Neural Network (GNN) Encoding (Future Plan)** 18 | - Stay tuned for upcoming GNN encoding capabilities, which will further enhance your 3D data processing options. 19 | 20 | ## Decoding Options 21 | 22 | **1. Deconvoluting Decoder** 23 | - The deconvoluting decoder is a powerful tool for reconstructing 3D point clouds from latent representations. 24 | 25 | **2. Multi-Layer Perceptron (MLP) Decoder** 26 | - My repository includes an MLP decoder, offering flexibility and efficiency in point cloud generation. 27 | 28 | **3. Transformer-based Autoregressive Decoder (Future Plan)** 29 | - I am actively working on a state-of-the-art transformer-based autoregressive decoder for high-quality point cloud generation. Stay tuned for updates! 30 | 31 | ## Get Started 32 | 33 | To get started with my advanced Variational PointNet Encoder-Decoder, please refer to the documentation and instructions in the [Wiki](link-to-wiki) section. (under development) 34 | 35 | ## Contribution and Support 36 | 37 | For any issues or questions, feel free to open an [Issue](https://github.com/mertyigit/PointNet-VAE/issues) 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/utils/calculate_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Chamfer Loss code was adapted from Meta PyTorch3D package. 3 | ''' 4 | 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): 13 | """ 14 | Compute the pairwise distance_tensor matrix between a and b which both have size [m, n, d]. The result is a tensor of 15 | size [m, n, n] whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :]. 16 | :param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] 17 | :param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] 18 | :param p: Norm to use for the distance_tensor 19 | :return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch. 20 | """ 21 | 22 | if len(a.shape) != 3: 23 | raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape) 24 | if len(b.shape) != 3: 25 | raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) 26 | return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) 27 | 28 | def chamfer(a, b): 29 | """ 30 | Compute the chamfer distance between two sets of vectors, a, and b 31 | :param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d] 32 | :param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d] 33 | :return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry 34 | """ 35 | M = pairwise_distances(a, b) 36 | dist1 = torch.mean(torch.sqrt(M.min(1)[0])) 37 | dist2 = torch.mean(torch.sqrt(M.min(2)[0])) 38 | return (dist1 + dist2) / 2.0 39 | 40 | 41 | def chamfer_distance(template: torch.Tensor, source: torch.Tensor): 42 | try: 43 | from .cuda.chamfer_distance import ChamferDistance 44 | cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source) 45 | cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1)) 46 | cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0)) 47 | chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0 48 | except: 49 | chamfer_loss = chamfer(template, source) 50 | return chamfer_loss 51 | 52 | 53 | class ChamferDistanceLoss(nn.Module): 54 | def __init__(self): 55 | super(ChamferDistanceLoss, self).__init__() 56 | 57 | def forward(self, template, source): 58 | return chamfer_distance(template, source) -------------------------------------------------------------------------------- /src/models/PointCloudEncoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains classed comprising Point Net Architecture. Usage for each class can 3 | be found in main() at the bottom. 4 | 5 | TO use: Import Classification and Segmentation classes into desired script 6 | 7 | 8 | 9 | NOTE: 10 | This architecture does not cover Part Segmentation. Per the Point Net paper 11 | that is a different architecture and is not implemented here. 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | # ============================================================================ 20 | class PointCloudEncoder(nn.Module): 21 | def __init__(self, latent_dim, num_point=2500, point_dim=3, bn_decay=0.5): 22 | self.num_point = num_point 23 | self.point_dim = point_dim 24 | self.latent_dim = latent_dim 25 | super(PointCloudEncoder, self).__init__() 26 | 27 | self.conv1 = nn.Conv2d(1, 64, kernel_size=(1, point_dim), padding=0) 28 | self.bn1 = nn.BatchNorm2d(64) 29 | 30 | self.conv2 = nn.Conv2d(64, 64, kernel_size=(1, 1), padding=0) 31 | self.bn2 = nn.BatchNorm2d(64) 32 | 33 | self.conv3 = nn.Conv2d(64, 64, kernel_size=(1, 1), padding=0) 34 | self.bn3 = nn.BatchNorm2d(64) 35 | 36 | self.conv4 = nn.Conv2d(64, 128, kernel_size=(1, 1), padding=0) 37 | self.bn4 = nn.BatchNorm2d(128) 38 | 39 | self.conv5 = nn.Conv2d(128, self.latent_dim, kernel_size=(1, 1), padding=0) 40 | self.bn5 = nn.BatchNorm2d(self.latent_dim) 41 | 42 | self.fc1 = nn.Linear(self.latent_dim, self.latent_dim) 43 | self.bn_fc1 = nn.BatchNorm1d(self.latent_dim) 44 | 45 | 46 | def forward(self, x): 47 | # Encoder 48 | x = nn.functional.relu(self.bn1(self.conv1(x))) 49 | x = nn.functional.relu(self.bn2(self.conv2(x))) 50 | point_feat = nn.functional.relu(self.bn3(self.conv3(x))) 51 | x = nn.functional.relu(self.bn4(self.conv4(point_feat))) 52 | x = nn.functional.relu(self.bn5(self.conv5(x))) 53 | x = F.max_pool2d(x, kernel_size=(self.num_point, 1), padding=0) 54 | x = x.view(x.size(0), -1) 55 | x = nn.functional.relu(self.bn_fc1(self.fc1(x))) 56 | 57 | return x, point_feat, None 58 | 59 | 60 | # Test 61 | def main(): 62 | test_data = torch.rand(32, 3, 2500) 63 | 64 | ## test T-net 65 | tnet = Tnet(dim=3) 66 | transform = tnet(test_data) 67 | print(f'T-net output shape: {transform.shape}') 68 | 69 | ## test backbone 70 | pointfeat = PointNetBackbone(local_feat=False) 71 | out, _, _ = pointfeat(test_data) 72 | print(f'Global Features shape: {out.shape}') 73 | 74 | pointfeat = PointNetBackbone(local_feat=True) 75 | out, _, _ = pointfeat(test_data) 76 | print(f'Combined Features shape: {out.shape}') 77 | 78 | # test on single batch (should throw error if there is an issue) 79 | pointfeat = PointNetBackbone(local_feat=True).eval() 80 | out, _, _ = pointfeat(test_data[0, :, :].unsqueeze(0)) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | 86 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import Tensor 4 | from pathlib import Path 5 | from typing import List, Optional, Sequence, Union, Any, Callable 6 | from torchvision.datasets.folder import default_loader 7 | from pytorch_lightning import LightningDataModule 8 | from torch.utils.data import DataLoader, Dataset 9 | from torchvision import transforms 10 | import zipfile 11 | 12 | import torch_geometric.transforms as T 13 | from torch_geometric.datasets import ModelNet 14 | from torch_geometric.loader import DataLoader 15 | from torch_geometric.nn import MLP, fps, global_max_pool, radius 16 | from torch_geometric.nn.conv import PointConv 17 | 18 | 19 | # Custom Data Class for future applications 20 | class MyDataset(Dataset): 21 | def __init__(self): 22 | pass 23 | 24 | 25 | def __len__(self): 26 | pass 27 | 28 | def __getitem__(self, idx): 29 | pass 30 | 31 | 32 | class DataModelNet(): 33 | """ 34 | PyTorch Lightning data module 35 | 36 | Args: 37 | data_dir: 38 | train_batch_size: 39 | val_batch_size: 40 | patch_size: 41 | num_workers: 42 | pin_memory: 43 | """ 44 | 45 | def __init__( 46 | self, 47 | data_path: str, 48 | pre_transform, 49 | train_batch_size: int = 8, 50 | val_batch_size: int = 8, 51 | train_num_points: int = 1024, 52 | val_num_points: int = 1024, 53 | **kwargs, 54 | ): 55 | super().__init__() 56 | 57 | self.data_path = data_path 58 | self.train_batch_size = train_batch_size 59 | self.val_batch_size = val_batch_size 60 | self.train_num_points = train_num_points 61 | self.val_num_points = val_num_points 62 | self.pre_transform = pre_transform 63 | 64 | def setup(self, stage: Optional[str] = None) -> None: 65 | self.train_dataset = ModelNet( 66 | root=self.data_path, 67 | name=self.data_path[-2:], 68 | train=True, 69 | pre_transform=self.pre_transform, 70 | transform=T.SamplePoints(self.train_num_points), 71 | ) 72 | 73 | self.val_dataset = ModelNet( 74 | root=self.data_path, 75 | name=self.data_path[-2:], 76 | train=False, 77 | pre_transform=self.pre_transform, 78 | transform=T.SamplePoints(self.val_num_points), 79 | ) 80 | 81 | def train_dataloader(self) -> DataLoader: 82 | return DataLoader( 83 | self.train_dataset, 84 | batch_size=self.train_batch_size, 85 | shuffle=True, 86 | ) 87 | 88 | def val_dataloader(self) -> DataLoader: 89 | return DataLoader( 90 | self.val_dataset, 91 | batch_size=self.val_batch_size, 92 | shuffle=False, 93 | ) 94 | 95 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Train Validate Evaluation Script 3 | ''' 4 | 5 | import os 6 | import sys 7 | import re 8 | from glob import glob 9 | import time 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchmetrics 16 | from torchmetrics.classification import MulticlassMatthewsCorrCoef 17 | from torch.utils.data import DataLoader 18 | import torch.optim as optim 19 | from torch.nn.functional import kl_div 20 | 21 | 22 | import open3d as o3 23 | 24 | import yaml 25 | import argparse 26 | 27 | import torch_geometric.transforms as T 28 | from torch_geometric.datasets import ModelNet 29 | from torch_geometric.loader import DataLoader 30 | from torch_geometric.nn import MLP, fps, global_max_pool, radius 31 | from torch_geometric.nn.conv import PointConv 32 | 33 | 34 | from src.models.PointNetEncoder import PointNetBackbone 35 | from src.utils.calculate_loss import ChamferDistanceLoss 36 | from src.models.PointCloudEncoder import PointCloudEncoder 37 | from src.models.PointCloudDecoder import PointCloudDecoder, PointCloudDecoderSelf, PointCloudDecoderMLP 38 | from src.models.AutoEncoder import AutoEncoder 39 | from src.models.VAE import VAE 40 | from src.data.dataset import DataModelNet 41 | from src.utils.utils import * 42 | from src.utils.train import Trainer 43 | 44 | from tqdm import tqdm 45 | 46 | import matplotlib as mpl 47 | import matplotlib.pyplot as plt 48 | 49 | ### HERE ARGS ### 50 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 51 | parser.add_argument('--config', '-c', 52 | dest="filename", 53 | metavar='FILE', 54 | help = 'path to the config file', 55 | default='configs/vae.yaml') 56 | 57 | args = parser.parse_args() 58 | with open(args.filename, 'r') as file: 59 | try: 60 | config = yaml.safe_load(file) 61 | except yaml.YAMLError as exc: 62 | print(exc) 63 | 64 | ################# 65 | 66 | ### REPRODUCIBILITY ### 67 | torch.seed = config['trainer_parameters']['manual_seed'] 68 | ####################### 69 | 70 | 71 | print('MPS is build: {}'.format(torch.backends.mps.is_built())) 72 | print('MPS Availability: {}'.format(torch.backends.mps.is_available())) 73 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' 74 | print('Device is set to :{}'.format(DEVICE)) 75 | 76 | 77 | # model hyperparameters 78 | EPOCHS = config['trainer_parameters']['epochs'] 79 | LR = config['trainer_parameters']['lr'] 80 | LATENT_DIM = config['model_parameters']['latent_dim'] 81 | DEVICE = config['model_parameters']['device'] 82 | 83 | #### LOAD DATA #### 84 | data = DataModelNet(**config["data_parameters"]) 85 | data.setup() 86 | train_dataloader = data.train_dataloader() 87 | val_dataloader = data.val_dataloader() 88 | ################### 89 | 90 | encoder = PointCloudEncoder(latent_dim=LATENT_DIM, num_point=config['data_parameters']['train_num_points']).to(DEVICE) 91 | #encoder = PointNetBackbone(num_points=NUM_POINTS, num_global_feats=LATENT_DIM, local_feat=False).to(DEVICE) 92 | decoder = PointCloudDecoderMLP(latent_dim=LATENT_DIM, num_hidden=3, num_point=config['data_parameters']['train_num_points']).to(DEVICE) 93 | #autoencoder = AutoEncoder(encoder, decoder, device=DEVICE, latent_dim=LATENT_DIM).to(DEVICE) 94 | vae = VAE(encoder, decoder, device=DEVICE, latent_dim=LATENT_DIM).to(DEVICE) 95 | 96 | model_run = Trainer(model=vae, 97 | criterion=ChamferDistanceLoss(), 98 | optimizer=optim.Adam(vae.parameters(), config['trainer_parameters']['lr']), 99 | **config['model_parameters'] 100 | ) 101 | 102 | model_run.fit(train_dataloader, val_dataloader, EPOCHS) 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /src/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Developed from scrtach by Mert Sengul. 3 | Please cite the repo if you readapt. 4 | ''' 5 | 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | import os 10 | 11 | class Evaluater: 12 | ''' 13 | Evaluater object. 14 | ''' 15 | def __init__( 16 | self, 17 | model, 18 | criterion, 19 | encoder_type, 20 | model_type, 21 | checkpoint, 22 | device, 23 | ): 24 | 25 | super().__init__() 26 | 27 | self.model = model 28 | self.criterion = criterion 29 | self.encoder_type = encoder_type 30 | self.model_type = model_type 31 | self.checkpoint = checkpoint 32 | self.device = device 33 | 34 | 35 | def evaluate(self, holdout_loader): 36 | # evaluate 37 | eval_loss, eval_rc_loss, eval_kl_loss = self._evaluate(holdout_loader) 38 | print('Loss: {} - Reconst Loss: {} - KL Loss: {}'.format(eval_loss, eval_rc_loss, eval_kl_loss)) 39 | 40 | def evaluate_data(self, data): 41 | _loss = [] 42 | _rc_loss = [] 43 | _kl_loss = [] 44 | kl_divergence = torch.zeros(1) 45 | 46 | self.model.load_state_dict(torch.load(self.checkpoint, map_location=self.device)) 47 | 48 | # put model in evaluation mode 49 | self.model.eval() 50 | 51 | with torch.no_grad(): 52 | 53 | points, target, batch_size = self._sanitizer(data) # No need to return 54 | 55 | points = points.to(self.device) 56 | target = target.to(self.device) 57 | 58 | if self.model_type == 'VAE': 59 | reconstructed_x, mu, logvar = self.model(points) 60 | kl_divergence = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) 61 | 62 | elif self.model_type == 'AutoEncoder': 63 | reconstructed_x, _, _ = self.model(points) 64 | 65 | if self.encoder_type == 'ConvolutionEncoder': 66 | loss_reconstruction = self.criterion(reconstructed_x, points.squeeze(1)) 67 | 68 | elif self.encoder_type == 'PointNetEncoder': 69 | loss_reconstruction = self.criterion(reconstructed_x, points.transpose(2, 1)) 70 | 71 | loss = loss_reconstruction + kl_divergence 72 | 73 | epoch_loss = loss.item() 74 | rc_loss = loss_reconstruction.item() 75 | kl_loss = kl_divergence.item() 76 | 77 | print('Loss: {} - Reconst Loss: {} - KL Loss: {}'.format(epoch_loss, rc_loss, kl_loss)) 78 | return points, reconstructed_x 79 | 80 | def _evaluate(self, loader): 81 | _loss = [] 82 | _rc_loss = [] 83 | _kl_loss = [] 84 | kl_divergence = torch.zeros(1) 85 | 86 | self.model.load_state_dict(torch.load(self.checkpoint, map_location=self.device)) 87 | 88 | # put model in evaluation mode 89 | self.model.eval() 90 | 91 | with torch.no_grad(): 92 | for i, data in tqdm(enumerate(loader)): 93 | points, target, batch_size = self._sanitizer(data) # No need to return 94 | 95 | points = points.to(self.device) 96 | target = target.to(self.device) 97 | 98 | if self.model_type == 'VAE': 99 | reconstructed_x, mu, logvar = self.model(points) 100 | kl_divergence = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) 101 | 102 | elif self.model_type == 'AutoEncoder': 103 | reconstructed_x, _, _ = self.model(points) 104 | 105 | if self.encoder_type == 'ConvolutionEncoder': 106 | loss_reconstruction = self.criterion(reconstructed_x, points.squeeze(1)) 107 | elif self.encoder_type == 'PointNetEncoder': 108 | loss_reconstruction = self.criterion(reconstructed_x, points.transpose(2, 1)) 109 | 110 | loss = loss_reconstruction + kl_divergence 111 | 112 | _loss.append(loss.item()) 113 | _rc_loss.append(loss_reconstruction.item()) 114 | _kl_loss.append(kl_divergence.item()) 115 | 116 | epoch_loss = np.mean(_loss) 117 | rc_loss = np.mean(_rc_loss) 118 | kl_loss = np.mean(_kl_loss) 119 | 120 | return epoch_loss, rc_loss, kl_loss 121 | 122 | 123 | 124 | def _sanitizer(self, data): 125 | ### Preparate the 3D cloud for encoder ### 126 | 127 | batch_size = data.y.shape[0] 128 | 129 | if self.encoder_type == 'ConvolutionEncoder': 130 | points = torch.stack([data[idx].pos for idx in range(batch_size)]).unsqueeze(1) ## If Convolution Encoder 131 | elif self.encoder_type == 'PointNetEncoder': 132 | points = torch.stack([data[idx].pos for idx in range(batch_size)]).transpose(2, 1) ## If PointNet Encoder 133 | 134 | targets = data.y 135 | 136 | return points, targets, batch_size -------------------------------------------------------------------------------- /src/models/PointNetDecoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains classed comprising Point Net Architecture. Usage for each class can 3 | be found in main() at the bottom. 4 | 5 | TO use: Import Classification and Segmentation classes into desired script 6 | 7 | 8 | 9 | NOTE: 10 | This architecture does not cover Part Segmentation. Per the Point Net paper 11 | that is a different architecture and is not implemented here. 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | # ============================================================================ 20 | # T-net (Spatial Transformer Network) 21 | class Tnet(nn.Module): 22 | ''' T-Net learns a Transformation matrix with a specified dimension ''' 23 | def __init__(self, dim, num_points=2500): 24 | super(Tnet, self).__init__() 25 | 26 | # dimensions for transform matrix 27 | self.dim = dim 28 | 29 | self.conv1 = nn.Conv1d(dim, 64, kernel_size=1) 30 | self.conv2 = nn.Conv1d(64, 128, kernel_size=1) 31 | self.conv3 = nn.Conv1d(128, 1024, kernel_size=1) 32 | 33 | self.linear1 = nn.Linear(1024, 512) 34 | self.linear2 = nn.Linear(512, 256) 35 | self.linear3 = nn.Linear(256, dim**2) 36 | 37 | self.bn1 = nn.BatchNorm1d(64) 38 | self.bn2 = nn.BatchNorm1d(128) 39 | self.bn3 = nn.BatchNorm1d(1024) 40 | self.bn4 = nn.BatchNorm1d(512) 41 | self.bn5 = nn.BatchNorm1d(256) 42 | 43 | self.max_pool = nn.MaxPool1d(kernel_size=num_points) 44 | 45 | 46 | def forward(self, x): 47 | bs = x.shape[0] 48 | 49 | # pass through shared MLP layers (conv1d) 50 | x = self.bn1(F.relu(self.conv1(x))) 51 | x = self.bn2(F.relu(self.conv2(x))) 52 | x = self.bn3(F.relu(self.conv3(x))) 53 | 54 | # max pool over num points 55 | x = self.max_pool(x).view(bs, -1) 56 | 57 | # pass through MLP 58 | x = self.bn4(F.relu(self.linear1(x))) 59 | x = self.bn5(F.relu(self.linear2(x))) 60 | x = self.linear3(x) 61 | 62 | # initialize identity matrix 63 | iden = torch.eye(self.dim, requires_grad=True).repeat(bs, 1, 1) 64 | if x.is_cuda: 65 | iden = iden.cuda() 66 | elif x.is_mps: 67 | iden = iden.to(torch.device('mps')) 68 | x = x.view(-1, self.dim, self.dim) + iden 69 | 70 | return x 71 | 72 | 73 | # ============================================================================ 74 | # Point Net Backbone (main Architecture) 75 | class PointNetDecoder(nn.Module): 76 | def __init__(self, num_points=2500, num_global_feats=1024, local_feat=True): 77 | super(PointNetDecoder, self).__init__() 78 | 79 | self.num_points = num_points 80 | self.num_global_feats = num_global_feats 81 | self.local_feat = local_feat 82 | 83 | # Transposed convolution layers 84 | self.deconv1 = nn.ConvTranspose1d(self.num_global_feats, 128, kernel_size=1) 85 | self.deconv2 = nn.ConvTranspose1d(128, 64, kernel_size=1) 86 | self.deconv3 = nn.ConvTranspose1d(64, 64, kernel_size=1) 87 | self.deconv4 = nn.ConvTranspose1d(64, 64, kernel_size=1) 88 | self.deconv5 = nn.ConvTranspose1d(64, 3, kernel_size=1) 89 | 90 | # Transformation layer in the decoder (similar to T-Net) 91 | self.tnet_decoder = Tnet(dim=128, num_points=num_points) 92 | 93 | # Batch normalization for deconv layers 94 | self.bn1 = nn.BatchNorm1d(128) 95 | self.bn2 = nn.BatchNorm1d(64) 96 | self.bn3 = nn.BatchNorm1d(64) 97 | self.bn4 = nn.BatchNorm1d(64) 98 | 99 | def forward(self, x): 100 | bs = x.shape[0] 101 | 102 | # Depending on whether local features were concatenated or not 103 | if self.local_feat: 104 | global_features = x[:, :self.num_global_feats, :] 105 | else: 106 | global_features = x 107 | 108 | # Expand the global features to match the shape of the local features 109 | global_features = global_features.unsqueeze(-1).repeat(1, 1, self.num_points) 110 | 111 | # Pass through the transformation layer (T-Net) in the decoder 112 | transformation_matrix = self.tnet_decoder(global_features) 113 | 114 | # Apply the transformation matrix to the input (similar to your encoder) 115 | x = torch.bmm(x.transpose(2, 1), transformation_matrix).transpose(2, 1) 116 | 117 | # Continue with the decoder layers 118 | x = self.bn1(F.relu(self.deconv1(x))) 119 | x = self.bn2(F.relu(self.deconv2(x))) 120 | x = self.bn3(F.relu(self.deconv3(x))) 121 | x = self.bn4(F.relu(self.deconv4(x))) 122 | 123 | # Final layer to generate point coordinates 124 | x = self.deconv5(x) 125 | 126 | return x 127 | 128 | 129 | # Test 130 | def main(): 131 | test_data = torch.rand(32, 3, 2500) 132 | 133 | ## test T-net 134 | tnet = Tnet(dim=3) 135 | transform = tnet(test_data) 136 | print(f'T-net output shape: {transform.shape}') 137 | 138 | ## test backbone 139 | pointfeat = PointNetBackbone(local_feat=False) 140 | out, _, _ = pointfeat(test_data) 141 | print(f'Global Features shape: {out.shape}') 142 | 143 | pointfeat = PointNetBackbone(local_feat=True) 144 | out, _, _ = pointfeat(test_data) 145 | print(f'Combined Features shape: {out.shape}') 146 | 147 | # test on single batch (should throw error if there is an issue) 148 | pointfeat = PointNetBackbone(local_feat=True).eval() 149 | out, _, _ = pointfeat(test_data[0, :, :].unsqueeze(0)) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /src/utils/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Developed from scrtach by Mert Sengul. 3 | Please cite the repo if you readapt. 4 | ''' 5 | 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | import os 10 | 11 | class Trainer: 12 | ''' 13 | Trainer object. 14 | ''' 15 | def __init__( 16 | self, 17 | model, 18 | criterion, 19 | optimizer, 20 | encoder_type, 21 | model_type, 22 | checkpoint, 23 | experiment, 24 | device, 25 | latent_dim, 26 | kl_loss_weight=None, 27 | 28 | ): 29 | 30 | super().__init__() 31 | 32 | self.model = model 33 | self.criterion = criterion 34 | self.optimizer = optimizer 35 | self.encoder_type = encoder_type 36 | self.model_type = model_type 37 | self.checkpoint = checkpoint 38 | self.experiment = experiment 39 | self.device = device 40 | self.latent_dim = latent_dim 41 | self.kl_loss_weight = kl_loss_weight 42 | 43 | def fit(self, train_loader, val_loader, epochs): 44 | for epoch in tqdm(range(epochs)): 45 | # train 46 | train_loss, train_rc_loss, train_kl_loss = self._train(train_loader) 47 | print('Epoch: {} - Loss: {} - Reconst Loss: {} - KL Loss: {}'.format(epoch, train_loss, train_rc_loss, train_kl_loss)) 48 | 49 | # validate 50 | val_loss, val_rc_loss, val_kl_loss = self._validate(val_loader) 51 | print('Epoch: {} - Loss: {} - Reconst Loss: {} - KL Loss: {}'.format(epoch, val_loss, val_rc_loss, val_kl_loss)) 52 | 53 | #save model state 54 | self._save_checkpoint(train_loss, val_loss, epoch) 55 | 56 | def _save_checkpoint(self, train_loss, val_loss, epoch): 57 | path = '{}/{}'.format(self.checkpoint, self.experiment) 58 | if not os.path.isdir(path): 59 | os.mkdir(path) 60 | 61 | torch.save(self.model.state_dict(), '{}/checkpoint_{}.pth'.format(path, epoch)) 62 | 63 | def _sanitizer(self, data): 64 | ### Preparate the 3D cloud for encoder ### 65 | 66 | batch_size = data.y.shape[0] 67 | 68 | if self.encoder_type == 'ConvolutionEncoder': 69 | points = torch.stack([data[idx].pos for idx in range(batch_size)]).unsqueeze(1) ## If Convolution Encoder 70 | elif self.encoder_type == 'PointNetEncoder': 71 | points = torch.stack([data[idx].pos for idx in range(batch_size)]).transpose(2, 1) ## If PointNet Encoder 72 | 73 | targets = data.y 74 | 75 | return points, targets, batch_size 76 | 77 | 78 | def _train(self, loader): 79 | # put model in train mode 80 | _loss = [] 81 | _rc_loss = [] 82 | _kl_loss = [] 83 | kl_divergence = torch.zeros(1).to(self.device) 84 | self.model.to(self.device) 85 | self.model.train() 86 | 87 | for i, data in tqdm(enumerate(loader)): 88 | self.optimizer.zero_grad() 89 | 90 | points, target, batch_size = self._sanitizer(data) # No need to return 91 | 92 | points = points.to(self.device) 93 | target = target.to(self.device) 94 | 95 | if self.model_type == 'VAE': 96 | reconstructed_x, mu, logvar = self.model(points) 97 | kl_divergence = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) 98 | 99 | elif self.model_type == 'AutoEncoder': 100 | reconstructed_x, _, _ = self.model(points) 101 | 102 | else: 103 | print('The model stype is not known!') 104 | 105 | if self.encoder_type == 'ConvolutionEncoder': 106 | loss_reconstruction = self.criterion(reconstructed_x, points.squeeze(1)) 107 | elif self.encoder_type == 'PointNetEncoder': 108 | loss_reconstruction = self.criterion(reconstructed_x, points.transpose(2, 1)) 109 | 110 | loss = loss_reconstruction + self.kl_loss_weight * kl_divergence 111 | loss.backward() 112 | self.optimizer.step() 113 | 114 | _loss.append(loss.item()) 115 | _rc_loss.append(loss_reconstruction.item()) 116 | _kl_loss.append(kl_divergence.item()) 117 | 118 | epoch_loss = np.mean(_loss) 119 | rc_loss = np.mean(_rc_loss) 120 | kl_loss = np.mean(_kl_loss) 121 | 122 | return epoch_loss, rc_loss, kl_loss 123 | 124 | def _validate(self, loader): 125 | _loss = [] 126 | _rc_loss = [] 127 | _kl_loss = [] 128 | kl_divergence = torch.zeros(1) 129 | # put model in evaluation mode 130 | self.model.eval() 131 | 132 | with torch.no_grad(): 133 | for i, data in tqdm(enumerate(loader)): 134 | points, target, batch_size = self._sanitizer(data) # No need to return 135 | 136 | points = points.to(self.device) 137 | target = target.to(self.device) 138 | 139 | if self.model_type == 'VAE': 140 | reconstructed_x, mu, logvar = self.model(points) 141 | kl_divergence = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) 142 | 143 | elif self.model_type == 'AutoEncoder': 144 | reconstructed_x, _, _ = self.model(points) 145 | 146 | if self.encoder_type == 'ConvolutionEncoder': 147 | loss_reconstruction = self.criterion(reconstructed_x, points.squeeze(1)) 148 | elif self.encoder_type == 'PointNetEncoder': 149 | loss_reconstruction = self.criterion(reconstructed_x, points.transpose(2, 1)) 150 | 151 | loss = loss_reconstruction + kl_divergence 152 | 153 | _loss.append(loss.item()) 154 | _rc_loss.append(loss_reconstruction.item()) 155 | _kl_loss.append(kl_divergence.item()) 156 | 157 | epoch_loss = np.mean(_loss) 158 | rc_loss = np.mean(_rc_loss) 159 | kl_loss = np.mean(_kl_loss) 160 | 161 | return epoch_loss, rc_loss, kl_loss 162 | 163 | 164 | -------------------------------------------------------------------------------- /src/models/PointNetEncoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains classed comprising Point Net Architecture. Usage for each class can 3 | be found in main() at the bottom. 4 | 5 | TO use: Import Classification and Segmentation classes into desired script 6 | 7 | 8 | 9 | NOTE: 10 | This architecture does not cover Part Segmentation. Per the Point Net paper 11 | that is a different architecture and is not implemented here. 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | # ============================================================================ 20 | # T-net (Spatial Transformer Network) 21 | class Tnet(nn.Module): 22 | ''' T-Net learns a Transformation matrix with a specified dimension ''' 23 | def __init__(self, dim, num_points=2500): 24 | super(Tnet, self).__init__() 25 | 26 | # dimensions for transform matrix 27 | self.dim = dim 28 | 29 | self.conv1 = nn.Conv1d(dim, 64, kernel_size=1) 30 | self.conv2 = nn.Conv1d(64, 128, kernel_size=1) 31 | self.conv3 = nn.Conv1d(128, 1024, kernel_size=1) 32 | 33 | self.linear1 = nn.Linear(1024, 512) 34 | self.linear2 = nn.Linear(512, 256) 35 | self.linear3 = nn.Linear(256, dim**2) 36 | 37 | self.bn1 = nn.BatchNorm1d(64) 38 | self.bn2 = nn.BatchNorm1d(128) 39 | self.bn3 = nn.BatchNorm1d(1024) 40 | self.bn4 = nn.BatchNorm1d(512) 41 | self.bn5 = nn.BatchNorm1d(256) 42 | 43 | self.max_pool = nn.MaxPool1d(kernel_size=num_points) 44 | 45 | 46 | def forward(self, x): 47 | bs = x.shape[0] 48 | 49 | # pass through shared MLP layers (conv1d) 50 | x = self.bn1(F.relu(self.conv1(x))) 51 | x = self.bn2(F.relu(self.conv2(x))) 52 | x = self.bn3(F.relu(self.conv3(x))) 53 | 54 | # max pool over num points 55 | x = self.max_pool(x).view(bs, -1) 56 | 57 | # pass through MLP 58 | x = self.bn4(F.relu(self.linear1(x))) 59 | x = self.bn5(F.relu(self.linear2(x))) 60 | x = self.linear3(x) 61 | 62 | # initialize identity matrix 63 | iden = torch.eye(self.dim, requires_grad=True).repeat(bs, 1, 1) 64 | if x.is_cuda: 65 | iden = iden.cuda() 66 | elif x.is_mps: 67 | iden = iden.to(torch.device('mps')) 68 | x = x.view(-1, self.dim, self.dim) + iden 69 | 70 | return x 71 | 72 | 73 | # ============================================================================ 74 | # Point Net Backbone (main Architecture) 75 | class PointNetBackbone(nn.Module): 76 | ''' 77 | This is the main portion of Point Net before the classification and segmentation heads. 78 | The main function of this network is to obtain the local and global point features, 79 | which can then be passed to each of the heads to perform either classification or 80 | segmentation. The forward pass through the backbone includes both T-nets and their 81 | transformations, the shared MLPs, and the max pool layer to obtain the global features. 82 | 83 | The forward function either returns the global or combined (local and global features) 84 | along with the critical point index locations and the feature transformation matrix. The 85 | feature transformation matrix is used for a regularization term that will help it become 86 | orthogonal. (i.e. a rigid body transformation is an orthogonal transform and we would like 87 | to maintain orthogonality in high dimensional space). "An orthogonal transformations preserves 88 | the lengths of vectors and angles between them" 89 | ''' 90 | def __init__(self, num_points=2500, num_global_feats=1024, local_feat=True): 91 | ''' Initializers: 92 | num_points - number of points in point cloud 93 | num_global_feats - number of Global Features for the main 94 | Max Pooling layer 95 | local_feat - if True, forward() returns the concatenation 96 | of the local and global features 97 | ''' 98 | super(PointNetBackbone, self).__init__() 99 | 100 | # if true concat local and global features 101 | self.num_points = num_points 102 | self.num_global_feats = num_global_feats 103 | self.local_feat = local_feat 104 | 105 | # Spatial Transformer Networks (T-nets) 106 | self.tnet1 = Tnet(dim=3, num_points=num_points) 107 | self.tnet2 = Tnet(dim=64, num_points=num_points) 108 | 109 | # shared MLP 1 110 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1) 111 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1) 112 | 113 | # shared MLP 2 114 | self.conv3 = nn.Conv1d(64, 64, kernel_size=1) 115 | self.conv4 = nn.Conv1d(64, 128, kernel_size=1) 116 | self.conv5 = nn.Conv1d(128, self.num_global_feats, kernel_size=1) 117 | 118 | # batch norms for both shared MLPs 119 | self.bn1 = nn.BatchNorm1d(64) 120 | self.bn2 = nn.BatchNorm1d(64) 121 | self.bn3 = nn.BatchNorm1d(64) 122 | self.bn4 = nn.BatchNorm1d(128) 123 | self.bn5 = nn.BatchNorm1d(self.num_global_feats) 124 | 125 | # max pool to get the global features 126 | self.max_pool = nn.MaxPool1d(kernel_size=num_points, return_indices=True) 127 | 128 | 129 | def forward(self, x): 130 | 131 | # get batch size 132 | bs = x.shape[0] 133 | 134 | # pass through first Tnet to get transform matrix 135 | A_input = self.tnet1(x) 136 | 137 | # perform first transformation across each point in the batch 138 | x = torch.bmm(x.transpose(2, 1), A_input).transpose(2, 1) 139 | 140 | # pass through first shared MLP 141 | x = self.bn1(F.relu(self.conv1(x))) 142 | x = self.bn2(F.relu(self.conv2(x))) 143 | 144 | # get feature transform 145 | A_feat = self.tnet2(x) 146 | 147 | # perform second transformation across each (64 dim) feature in the batch 148 | x = torch.bmm(x.transpose(2, 1), A_feat).transpose(2, 1) 149 | 150 | # store local point features for segmentation head 151 | local_features = x.clone() 152 | 153 | # pass through second MLP 154 | x = self.bn3(F.relu(self.conv3(x))) 155 | x = self.bn4(F.relu(self.conv4(x))) 156 | x = self.bn5(F.relu(self.conv5(x))) 157 | 158 | # get global feature vector and critical indexes 159 | global_features, critical_indexes = self.max_pool(x) 160 | global_features = global_features.view(bs, -1) 161 | critical_indexes = critical_indexes.view(bs, -1) 162 | 163 | if self.local_feat: 164 | features = torch.cat((local_features, 165 | global_features.unsqueeze(-1).repeat(1, 1, self.num_points)), 166 | dim=1) 167 | 168 | return features, critical_indexes, A_feat 169 | 170 | else: 171 | return global_features, critical_indexes, A_feat 172 | 173 | 174 | # Test 175 | def main(): 176 | test_data = torch.rand(32, 3, 2500) 177 | 178 | ## test T-net 179 | tnet = Tnet(dim=3) 180 | transform = tnet(test_data) 181 | print(f'T-net output shape: {transform.shape}') 182 | 183 | ## test backbone 184 | pointfeat = PointNetBackbone(local_feat=False) 185 | out, _, _ = pointfeat(test_data) 186 | print(f'Global Features shape: {out.shape}') 187 | 188 | pointfeat = PointNetBackbone(local_feat=True) 189 | out, _, _ = pointfeat(test_data) 190 | print(f'Combined Features shape: {out.shape}') 191 | 192 | # test on single batch (should throw error if there is an issue) 193 | pointfeat = PointNetBackbone(local_feat=True).eval() 194 | out, _, _ = pointfeat(test_data[0, :, :].unsqueeze(0)) 195 | 196 | 197 | if __name__ == '__main__': 198 | main() 199 | 200 | -------------------------------------------------------------------------------- /src/data/shapenet_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Dataset for shapenet provides functionality for both Classification and Segmentation 3 | 4 | can be downloaded in Colab using the following lines 5 | # !wget -nv https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate 6 | # !unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 7 | # !rm shapenetcore_partanno_segmentation_benchmark_v0.zip 8 | 9 | This dataloader is based on: 10 | https://github.com/intel-isl/Open3D-PointNet 11 | 12 | ''' 13 | 14 | import os 15 | import json 16 | import numpy as np 17 | import open3d as o3 18 | from PIL import Image 19 | import torch 20 | from torch.utils.data import Dataset 21 | 22 | class ShapenetDataset(Dataset): 23 | 24 | def __init__(self, root, split, npoints=2500, classification=False, class_choice=None, 25 | image=False, normalize=True): 26 | 27 | self.root = root 28 | self.split = split.lower() 29 | self.npoints = npoints 30 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 31 | self.cat = {} 32 | self.classification = classification 33 | self.image = image 34 | self.normalize = normalize 35 | 36 | # Open the Category File and Map Folders to Categories 37 | with open(self.catfile, 'r') as f: 38 | for line in f: 39 | ls = line.strip().split() 40 | self.cat[ls[0]] = ls[1] 41 | 42 | # select specific categories from the dataset. 43 | # ex: Call in parameters "class_choice=["Airplane"]. 44 | if not class_choice is None: 45 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 46 | 47 | # for each category, assign the point, segmentation, and image. 48 | self.meta = {} 49 | for item in self.cat: 50 | self.meta[item] = [] 51 | dir_point = os.path.join(self.root, self.cat[item], 'points') 52 | dir_seg = os.path.join(self.root, self.cat[item], 'points_label') 53 | dir_seg_img = os.path.join(self.root, self.cat[item], 'seg_img') 54 | 55 | # get train, valid, test splits from json files 56 | if self.split == 'train': 57 | split_file = os.path.join(self.root, 58 | r'train_test_split/shuffled_train_file_list.json') 59 | elif self.split == 'test': 60 | split_file = os.path.join(self.root, 61 | r'train_test_split/shuffled_test_file_list.json') 62 | elif (self.split == 'valid') or (self.split == 'val'): 63 | split_file = os.path.join(self.root, 64 | r'train_test_split/shuffled_val_file_list.json') 65 | 66 | with open(split_file, 'r') as f: 67 | split_data = json.load(f) 68 | 69 | # get point cloud file (.pts) names for current split 70 | pts_names = [] 71 | for token in split_data: 72 | if self.cat[item] in token: 73 | pts_names.append(token.split('/')[-1] + '.pts') 74 | 75 | 76 | # FOR EVERY POINT CLOUD FILE 77 | for fn in pts_names: 78 | token = (os.path.splitext(os.path.basename(fn))[0]) 79 | # add point cloud, segmentations, and image to class metadata dict 80 | self.meta[item].append((os.path.join(dir_point, token + '.pts'), 81 | os.path.join(dir_seg, token + '.seg'), 82 | os.path.join(dir_seg_img, token + '.png'))) 83 | 84 | # create list containing (item, points, segmentation points, segmentation image) 85 | self.datapath = [] 86 | for item in self.cat: 87 | for fn in self.meta[item]: 88 | self.datapath.append((item, fn[0], fn[1], fn[2])) 89 | 90 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 91 | 92 | self.num_seg_classes = 0 93 | if not self.classification: # Take the Segmentation Labels 94 | for i in range(len(self.datapath)//50): 95 | # get number of seg classes in current item 96 | l = len(np.unique(np.loadtxt(self.datapath[i][2]).astype(np.uint8))) 97 | if l > self.num_seg_classes: 98 | self.num_seg_classes = l 99 | #print(self.num_seg_classes) 100 | 101 | 102 | def __getitem__(self, index): 103 | ''' 104 | Each element has the format "class, points, segmentation labels, segmentation image" 105 | ''' 106 | # Get one Element 107 | fn = self.datapath[index] 108 | 109 | # get its Class 110 | cls_ = self.classes[fn[0]] 111 | 112 | # Read the Point Cloud 113 | point_set = np.asarray(o3.io.read_point_cloud(fn[1], format='xyz').points,dtype=np.float32) 114 | 115 | # Read the Segmentation Data 116 | seg = np.loadtxt(fn[2]).astype(np.int64) 117 | 118 | # Read the Segmentation Image 119 | image = Image.open(fn[3]) 120 | 121 | # down sample the pont cloud 122 | if len(seg) > self.npoints: 123 | choice = np.random.choice(len(seg), self.npoints, replace=False) 124 | else: 125 | # case when there are less points than the desired number 126 | choice = np.random.choice(len(seg), self.npoints, replace=True) 127 | 128 | point_set = point_set[choice, :] 129 | seg = seg[choice] 130 | point_set = torch.from_numpy(point_set) 131 | seg = torch.from_numpy(seg) 132 | cls_ = torch.from_numpy(np.array([cls_]).astype(np.int64)) 133 | 134 | # add Gaussian noise to point set if not testing 135 | if self.split != 'test': 136 | # add N(0, 1/100) noise 137 | point_set += torch.randn(point_set.shape)/100 138 | 139 | # add random rotation to the point cloud 140 | point_set = self.random_rotate(point_set) 141 | 142 | # Normalize Point Cloud to (0, 1) 143 | if self.normalize: 144 | point_set = self.normalize_points(point_set) 145 | 146 | if self.classification: 147 | if self.image: 148 | return point_set, cls_, image 149 | else: 150 | return point_set, cls_ 151 | 152 | else: 153 | if self.image: 154 | return point_set, seg, image 155 | else: 156 | return point_set, seg 157 | 158 | 159 | @staticmethod 160 | def random_rotate(points): 161 | ''' randomly rotates point cloud about vertical axis. 162 | Code is commented out to rotate about all axes 163 | ''' 164 | # construct a randomly parameterized 3x3 rotation matrix 165 | # phi = torch.FloatTensor(1).uniform_(-torch.pi, torch.pi) 166 | theta = torch.FloatTensor(1).uniform_(-torch.pi, torch.pi) 167 | # psi = torch.FloatTensor(1).uniform_(-torch.pi, torch.pi) 168 | 169 | # rot_x = torch.Tensor([ 170 | # [1, 0, 0], 171 | # [0, torch.cos(phi), -torch.sin(phi)], 172 | # [0, torch.sin(phi), torch.cos(phi) ]]) 173 | 174 | rot_y = torch.Tensor([ 175 | [torch.cos(theta), 0, torch.sin(theta)], 176 | [0, 1, 0], 177 | [-torch.sin(theta), 0, torch.cos(theta)]]) 178 | 179 | # rot_z = torch.Tensor([ 180 | # [torch.cos(psi), -torch.sin(psi), 0], 181 | # [torch.sin(psi), torch.cos(psi), 0], 182 | # [0, 0, 1]]) 183 | 184 | # rot = torch.matmul(rot_x, torch.matmul(rot_y, rot_z)) 185 | 186 | return torch.matmul(points, rot_y) 187 | 188 | @staticmethod 189 | def normalize_points(points): 190 | ''' Perform min/max normalization on points 191 | Same as: 192 | (x - min(x))/(max(x) - min(x)) 193 | ''' 194 | points = points - points.min(axis=0)[0] 195 | points /= points.max(axis=0)[0] 196 | 197 | return points 198 | 199 | 200 | def __len__(self): 201 | return len(self.datapath) 202 | 203 | -------------------------------------------------------------------------------- /src/models/PointCloudDecoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Contains classed comprising Point Net Architecture. Usage for each class can 3 | be found in main() at the bottom. 4 | 5 | TO use: Import Classification and Segmentation classes into desired script 6 | 7 | 8 | 9 | NOTE: 10 | This architecture does not cover Part Segmentation. Per the Point Net paper 11 | that is a different architecture and is not implemented here. 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | # ============================================================================ 20 | #class PointCloudDecoder(nn.Module): 21 | # def __init__(self, latent_dim, num_hidden, num_point=2500, point_dim=3, bn_decay=0.5): 22 | # self.num_point = num_point 23 | # self.point_dim = point_dim 24 | # self.latent_dim = latent_dim 25 | # 26 | # super(PointCloudDecoder, self).__init__() 27 | # 28 | # self.upconv1 = nn.ConvTranspose2d(int(self.latent_dim/2), 512, kernel_size=(2, 2), stride=(2, 2), padding=0) 29 | # self.bn_upconv1 = nn.BatchNorm2d(512) 30 | # 31 | # self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=0) 32 | # self.bn_upconv2 = nn.BatchNorm2d(256) 33 | # 34 | # self.upconv = nn.ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1) 35 | # self.bn_upconv = nn.BatchNorm2d(256) 36 | # 37 | # self.upconv3 = nn.ConvTranspose2d(256, 256, kernel_size=(4, 5), stride=(2, 3), padding=0) 38 | # self.bn_upconv3 = nn.BatchNorm2d(256) 39 | # 40 | # self.upconv4 = nn.ConvTranspose2d(256, 128, kernel_size=(5, 7), stride=(3, 3), padding=0) 41 | # self.bn_upconv4 = nn.BatchNorm2d(128) 42 | # 43 | # self.upconv5 = nn.ConvTranspose2d(128, 3, kernel_size=(1, 1), stride=(1, 1), padding=0) 44 | # 45 | # def forward(self, x): 46 | # 47 | # # UPCONV Decoder 48 | # x = x.view(x.size(0), -1, 1, 2) 49 | # x = self.bn_upconv1(nn.functional.relu(self.upconv1(x))) 50 | # x = self.bn_upconv2(nn.functional.relu(self.upconv2(x))) 51 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 52 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 53 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 54 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 55 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 56 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 57 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 58 | # x = self.bn_upconv(nn.functional.relu(self.upconv(x))) 59 | # x = self.bn_upconv3(nn.functional.relu(self.upconv3(x))) 60 | # x = self.bn_upconv4(nn.functional.relu(self.upconv4(x))) 61 | # x = self.upconv5(x) 62 | # x = x.view(x.size(0), -1, 3) 63 | # 64 | # return x 65 | 66 | class PointCloudDecoderMLP(nn.Module): 67 | def __init__(self, latent_dim, num_hidden, num_point=2500, point_dim=3, bn_decay=0.5): 68 | self.num_point = num_point 69 | self.point_dim = point_dim 70 | self.latent_dim = latent_dim 71 | 72 | super(PointCloudDecoderMLP, self).__init__() 73 | 74 | self.fc1 = nn.Linear(self.latent_dim, self.latent_dim*2) 75 | self.fc2 = nn.Linear(self.latent_dim*2, self.latent_dim*4) 76 | self.fc3 = nn.Linear(self.latent_dim*4, self.latent_dim*8) 77 | self.fc4 = nn.Linear(self.latent_dim*8, self.latent_dim*16) 78 | self.fc5 = nn.Linear(self.latent_dim*16, self.latent_dim*32) 79 | self.fc6 = nn.Linear(self.latent_dim*32, self.latent_dim*32) 80 | self.fcend = nn.Linear(self.latent_dim*32, int(self.point_dim*self.num_point)) 81 | 82 | def forward(self, x): 83 | # UPCONV Decoder 84 | #x = x.view(x.size(0), self.latent_dim, 1) 85 | x = nn.functional.relu(self.fc1(x)) 86 | x = nn.functional.relu(self.fc2(x)) 87 | x = nn.functional.relu(self.fc3(x)) 88 | x = nn.functional.relu(self.fc4(x)) 89 | x = nn.functional.relu(self.fc5(x)) 90 | x = nn.functional.relu(self.fc6(x)) 91 | x = self.fcend(x) 92 | x = x.reshape(-1, self.num_point, self.point_dim) 93 | return x 94 | 95 | 96 | class PointCloudDecoderSelf(nn.Module): 97 | def __init__(self, latent_dim, num_hidden, num_point=2500, point_dim=3, bn_decay=0.5): 98 | self.num_point = num_point 99 | self.point_dim = point_dim 100 | self.latent_dim = latent_dim 101 | 102 | super(PointCloudDecoderSelf, self).__init__() 103 | 104 | self.upconv1 = nn.ConvTranspose1d(self.latent_dim, self.latent_dim, kernel_size=3, stride=3, padding=1) 105 | self.bn_upconv1 = nn.BatchNorm1d(int(self.latent_dim)) 106 | self.upconv2 = nn.ConvTranspose1d(self.latent_dim, self.latent_dim*2, kernel_size=3, stride=3, padding=1) 107 | self.bn_upconv2 = nn.BatchNorm1d(int(self.latent_dim*2)) 108 | self.upconv3 = nn.ConvTranspose1d(self.latent_dim*2, self.latent_dim*4, kernel_size=3, stride=3, padding=1) 109 | self.bn_upconv3 = nn.BatchNorm1d(int(self.latent_dim*4)) 110 | self.upconv4 = nn.ConvTranspose1d(self.latent_dim*4, self.latent_dim*8, kernel_size=3, stride=3, padding=1) 111 | self.bn_upconv4 = nn.BatchNorm1d(int(self.latent_dim*8)) 112 | self.upconv5 = nn.ConvTranspose1d(self.latent_dim*8, self.latent_dim*16, kernel_size=3, stride=3, padding=1) 113 | self.bn_upconv5 = nn.BatchNorm1d(int(self.latent_dim*16)) 114 | self.upconv6 = nn.ConvTranspose1d(self.latent_dim*16, self.latent_dim*32, kernel_size=3, stride=3, padding=1) 115 | self.bn_upconv6 = nn.BatchNorm1d(int(self.latent_dim*32)) 116 | self.upconv7 = nn.ConvTranspose1d(self.latent_dim*32, self.latent_dim*32, kernel_size=3, stride=3, padding=0) 117 | 118 | def forward(self, x): 119 | # UPCONV Decoder 120 | x = x.view(x.size(0), self.latent_dim, 1) 121 | x = (nn.functional.relu(self.upconv1(x))) 122 | x = (nn.functional.relu(self.upconv2(x))) 123 | x = (nn.functional.relu(self.upconv3(x))) 124 | x = (nn.functional.relu(self.upconv4(x))) 125 | x = (nn.functional.relu(self.upconv5(x))) 126 | x = (nn.functional.relu(self.upconv6(x))) 127 | x = self.upconv7(x) 128 | 129 | return x 130 | 131 | class PointCloudDecoder(nn.Module): 132 | def __init__(self, latent_dim, num_hidden, num_point=2500, point_dim=3, bn_decay=0.5): 133 | self.num_point = num_point 134 | self.point_dim = point_dim 135 | self.latent_dim = latent_dim 136 | 137 | super(PointCloudDecoder, self).__init__() 138 | 139 | self.upconv1 = nn.ConvTranspose2d(int(self.latent_dim/2), int(self.latent_dim), kernel_size=(3, 3), stride=(2, 2), padding=0) 140 | self.bn_upconv1 = nn.BatchNorm2d(int(self.latent_dim)) 141 | 142 | self.upconv2 = nn.ConvTranspose2d(int(self.latent_dim), int(self.latent_dim*4), kernel_size=(3, 3), stride=(2, 2), padding=0) 143 | self.bn_upconv2 = nn.BatchNorm2d(int(self.latent_dim*4)) 144 | 145 | self.upconv3 = nn.ConvTranspose2d(int(self.latent_dim*4), int(self.latent_dim*8), kernel_size=(3, 3), stride=(2, 2), padding=0) 146 | self.bn_upconv3 = nn.BatchNorm2d(int(self.latent_dim*8)) 147 | 148 | self.upconv4 = nn.ConvTranspose2d(int(self.latent_dim*8), 512, kernel_size=(3, 3), stride=(2, 2), padding=0) 149 | self.bn_upconv4 = nn.BatchNorm2d(512) 150 | 151 | self.fcconv = nn.ConvTranspose2d(512, 512, kernel_size=(1, 1), stride=(1, 1), padding=0) 152 | self.bn_fcconv = nn.BatchNorm2d(512) 153 | 154 | self.upconv5 = nn.ConvTranspose2d(512, 3, kernel_size=(1, 1), stride=(1, 1), padding=0) 155 | self.bn_upconv5 = nn.BatchNorm2d(3) 156 | 157 | def forward(self, x): 158 | 159 | # UPCONV Decoder 160 | x = x.view(x.size(0), -1, 1, 2) 161 | x = self.bn_upconv1(nn.functional.relu(self.upconv1(x))) 162 | x = self.bn_upconv2(nn.functional.relu(self.upconv2(x))) 163 | x = self.bn_upconv3(nn.functional.relu(self.upconv3(x))) 164 | x = self.bn_upconv4(nn.functional.relu(self.upconv4(x))) 165 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 166 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 167 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 168 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 169 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 170 | x = self.bn_fcconv(nn.functional.relu(self.fcconv(x))) 171 | x = self.upconv5(x) 172 | x = x.view(x.size(0), -1, 3) 173 | 174 | return x 175 | 176 | 177 | 178 | 179 | # Test 180 | def main(): 181 | test_data = torch.rand(32, 3, 2500) 182 | 183 | ## test T-net 184 | tnet = Tnet(dim=3) 185 | transform = tnet(test_data) 186 | print(f'T-net output shape: {transform.shape}') 187 | 188 | ## test backbone 189 | pointfeat = PointNetBackbone(local_feat=False) 190 | out, _, _ = pointfeat(test_data) 191 | print(f'Global Features shape: {out.shape}') 192 | 193 | pointfeat = PointNetBackbone(local_feat=True) 194 | out, _, _ = pointfeat(test_data) 195 | print(f'Combined Features shape: {out.shape}') 196 | 197 | # test on single batch (should throw error if there is an issue) 198 | pointfeat = PointNetBackbone(local_feat=True).eval() 199 | out, _, _ = pointfeat(test_data[0, :, :].unsqueeze(0)) 200 | 201 | 202 | if __name__ == '__main__': 203 | main() 204 | 205 | --------------------------------------------------------------------------------