├── configs ├── config_closed_splines.yml ├── config_open_splines.yml ├── config_parsenet.yml ├── config_parsenet_e2e.yml ├── config_parsenet_normals.yml ├── config_test_closed_splines.yml ├── config_test_open_splines.yml ├── config_test_parsenet.yml └── config_test_parsenet_normals.yml ├── download_dataset.sh ├── environment.yml ├── generate_predictions.py ├── images └── parsenet-gallery.jpg ├── original_complete_environment.yml ├── read_config.py ├── readme.md ├── readme_data.md ├── render_options.json ├── src ├── PointNet.py ├── VisUtils.py ├── __init__.py ├── approximation.py ├── augment_utils.py ├── bezier.py ├── color_utils.py ├── curve_utils.py ├── data_utils.py ├── dataset.py ├── dataset_segments.py ├── eval_utils.py ├── fitting_optimization.py ├── fitting_utils.py ├── guard.py ├── loss.py ├── mean_shift.py ├── model.py ├── primitive_forward.py ├── primitives.py ├── residual_utils.py ├── segment_loss.py ├── segment_utils.py ├── test_fitting_utils.py ├── test_utils.py └── utils.py ├── test.py ├── test_closed_control_points.py ├── test_open_splines.py ├── train_closed_control_points.py ├── train_open_splines.py ├── train_parsenet.py └── train_parsenet_e2e.py /configs/config_closed_splines.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "train_closed_spline_{}_{}_{}_bt_{}_lr_{}_trsz_{}_tsz_{}_wght_{}" 5 | 6 | # Dataset path 7 | dataset = "data/spline/closed_splines.h5" 8 | 9 | # path to the pre-trained model 10 | pretrain_model_path = "" 11 | 12 | normals = False 13 | 14 | proportion = 1.0 15 | 16 | num_train=28000 17 | num_val=3000 18 | num_test=3000 19 | 20 | num_points=700 21 | 22 | loss_weight=0.9 23 | 24 | batch_size = 36 25 | 26 | num_epochs = 150 27 | grid_size = 20 28 | optim = adam 29 | accum = 4 30 | 31 | # Learing rate 32 | lr = 0.001 33 | 34 | # Whether to schedule the learning rate or not 35 | lr_sch = True 36 | 37 | # Number of epochs to wait before decaying the learning rate. 38 | patience = 8 39 | 40 | mode = 1 -------------------------------------------------------------------------------- /configs/config_open_splines.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | 5 | model_path = "temp_{}_{}_{}_bt_{}_lr_{}_trsz_{}_tsz_{}_wght_{}" 6 | 7 | # Dataset path 8 | dataset = "data/spline/open_splines.h5" 9 | 10 | # path to the pre-trained model 11 | pretrain_model_path = "" 12 | 13 | # Whether to input the normals or not 14 | normals = False 15 | 16 | proportion = 1.0 17 | 18 | # number of training instance 19 | num_train=3200 20 | num_val=3000 21 | num_test=3000 22 | num_points=700 23 | loss_weight=0.9 24 | 25 | batch_size = 36 26 | 27 | num_epochs = 150 28 | grid_size = 20 29 | 30 | optim = adam 31 | 32 | accum = 4 33 | 34 | # Learing rate 35 | lr = 0.001 36 | 37 | # Whether to schedule the learning rate or not 38 | lr_sch = True 39 | 40 | # Number of epochs to wait before decaying the learning rate. 41 | patience = 8 42 | 43 | mode = 0 -------------------------------------------------------------------------------- /configs/config_parsenet.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "train_parsenet_{}_lr_{}_trsz_{}_tsz_{}_wght_{}_mode_{}" 5 | 6 | # Dataset path 7 | dataset = "" 8 | 9 | # Whether to load a pretrained model or not 10 | preload_model = False 11 | 12 | # pre-trained model path 13 | pretrain_model_path = "" 14 | 15 | # Whether to input the normals or not 16 | normals = False 17 | 18 | proportion = 1.0 19 | 20 | # number of training instance 21 | num_train=24000 22 | num_val=4000 23 | num_test=4000 24 | num_points=1000 25 | loss_weight=100 26 | 27 | num_epochs = 100 28 | grid_size = 20 29 | 30 | # batch size, based on the GPU memory 31 | batch_size = 8 32 | 33 | # Optimization 34 | optim = adam 35 | 36 | # Epsilon for the RL training, not applicable in Supervised training 37 | accum = 4 38 | 39 | # l2 Weight decay 40 | weight_decay = 0.0 41 | 42 | # dropout for Decoder network 43 | dropout = 0.2 44 | 45 | # Learing rate 46 | lr = 0.01 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | mode = 0 58 | -------------------------------------------------------------------------------- /configs/config_parsenet_e2e.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "train_parsenet_e2e_{}_lr_{}_trsz_{}_tsz_{}_wght_{}_mode_{}" 5 | 6 | # Dataset path 7 | dataset = "" 8 | 9 | # Whether to load a pretrained model or not 10 | preload_model = False 11 | 12 | # pre-trained model path 13 | pretrain_model_path = "parsenet_with_normals.pth" 14 | 15 | # Whether to input the normals or not 16 | normals = True 17 | 18 | proportion = 1.0 19 | 20 | # number of training instance 21 | num_train=24000 22 | num_val=4000 23 | num_test=4000 24 | num_points=1000 25 | loss_weight=100 26 | 27 | num_epochs = 100 28 | grid_size = 20 29 | 30 | # batch size, based on the GPU memory 31 | batch_size = 1 32 | 33 | # Optimization 34 | optim = adam 35 | 36 | # Epsilon for the RL training, not applicable in Supervised training 37 | accum = 4 38 | 39 | # l2 Weight decay 40 | weight_decay = 0.0 41 | 42 | # dropout for Decoder network 43 | dropout = 0.2 44 | 45 | # Learing rate 46 | lr = 0.0001 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | mode = 5 58 | -------------------------------------------------------------------------------- /configs/config_parsenet_normals.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "temp_{}_lr_{}_trsz_{}_tsz_{}_wght_{}_mode_{}" 5 | 6 | # Dataset path 7 | dataset = "" 8 | 9 | # Whether to load a pretrained model or not 10 | preload_model = False 11 | 12 | # pre-trained model path 13 | pretrain_model_path = "" 14 | 15 | # Whether to input the normals or not 16 | normals = True 17 | 18 | proportion = 1.0 19 | 20 | # number of training instance 21 | num_train=24000 22 | num_val=4000 23 | num_test=4000 24 | num_points=1000 25 | loss_weight=100 26 | 27 | num_epochs = 100 28 | grid_size = 20 29 | 30 | # batch size, based on the GPU memory 31 | batch_size = 8 32 | 33 | # Optimization 34 | optim = adam 35 | 36 | # Epsilon for the RL training, not applicable in Supervised training 37 | accum = 4 38 | 39 | # l2 Weight decay 40 | weight_decay = 0.0 41 | 42 | # dropout for Decoder network 43 | dropout = 0.2 44 | 45 | # Learing rate 46 | lr = 0.01 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | mode = 5 58 | -------------------------------------------------------------------------------- /configs/config_test_closed_splines.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "temp_{}_{}_{}_bt_{}_lr_{}_trsz_{}_tsz_{}_wght_{}" 5 | 6 | # Dataset path 7 | dataset = "data/spline/closed_splines.h5" 8 | 9 | # path to the pre-trained model 10 | pretrain_model_path = "closed_spline.pth" 11 | 12 | normals = False 13 | 14 | proportion = 1.0 15 | 16 | num_train=28000 17 | num_val=3000 18 | num_test=3000 19 | 20 | num_points=700 21 | 22 | loss_weight=0.9 23 | 24 | batch_size = 1 25 | 26 | num_epochs = 150 27 | grid_size = 20 28 | optim = adam 29 | accum = 4 30 | 31 | # Learing rate 32 | lr = 0.001 33 | 34 | # Whether to schedule the learning rate or not 35 | lr_sch = True 36 | 37 | # Number of epochs to wait before decaying the learning rate. 38 | patience = 8 39 | 40 | mode = 1 -------------------------------------------------------------------------------- /configs/config_test_open_splines.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | 5 | model_path = "temp_{}_{}_{}_bt_{}_lr_{}_trsz_{}_tsz_{}_wght_{}" 6 | 7 | # Dataset path 8 | dataset = "data/spline/open_splines.h5" 9 | 10 | # path to the pre-trained model 11 | pretrain_model_path = "open_spline.pth" 12 | 13 | # Whether to input the normals or not 14 | normals = False 15 | 16 | proportion = 1.0 17 | 18 | # number of training instance 19 | num_train=3200 20 | num_val=3000 21 | num_test=3000 22 | num_points=700 23 | loss_weight=0.9 24 | 25 | batch_size = 1 26 | 27 | num_epochs = 150 28 | grid_size = 20 29 | 30 | optim = adam 31 | 32 | accum = 4 33 | 34 | # Learing rate 35 | lr = 0.001 36 | 37 | # Whether to schedule the learning rate or not 38 | lr_sch = True 39 | 40 | # Number of epochs to wait before decaying the learning rate. 41 | patience = 8 42 | 43 | mode = 0 -------------------------------------------------------------------------------- /configs/config_test_parsenet.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "temp_{}_lr_{}_trsz_{}_tsz_{}_wght_{}_mode_{}" 5 | 6 | # Dataset path 7 | dataset = "" 8 | 9 | # Whether to load a pretrained model or not 10 | preload_model = False 11 | 12 | # model name to load for testing the performance of parsenet. 13 | pretrain_model_path = "parsenet_without_normals.pth" 14 | 15 | # Whether to input the normals or not 16 | normals = False 17 | 18 | proportion = 1.0 19 | 20 | # number of training instance 21 | num_train=24000 22 | num_val=4000 23 | num_test=4000 24 | num_points=1000 25 | loss_weight=100 26 | 27 | num_epochs = 10000 28 | grid_size = 20 29 | 30 | # batch size, based on the GPU memory 31 | batch_size = 1 32 | 33 | # Optimization 34 | optim = adam 35 | 36 | # Epsilon for the RL training, not applicable in Supervised training 37 | accum = 4 38 | 39 | # l2 Weight decay 40 | weight_decay = 0.0 41 | 42 | # dropout for Decoder network 43 | dropout = 0.2 44 | 45 | # Learing rate 46 | lr = 0.01 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | mode = 0 58 | -------------------------------------------------------------------------------- /configs/config_test_parsenet_normals.yml: -------------------------------------------------------------------------------- 1 | comment="" 2 | 3 | [train] 4 | model_path = "temp_{}_lr_{}_trsz_{}_tsz_{}_wght_{}_mode_{}" 5 | 6 | # Dataset path 7 | dataset = "" 8 | 9 | # Whether to load a pretrained model or not 10 | preload_model = False 11 | 12 | # model name to load for testing the performance of parsenet. 13 | pretrain_model_path = "parsenet_with_normals.pth" 14 | 15 | # Whether to input the normals or not 16 | normals = True 17 | 18 | proportion = 1.0 19 | 20 | # number of training instance 21 | num_train=24000 22 | num_val=4000 23 | num_test=4000 24 | num_points=1000 25 | loss_weight=100 26 | 27 | num_epochs = 10000 28 | grid_size = 20 29 | 30 | # batch size, based on the GPU memory 31 | batch_size = 1 32 | 33 | # Optimization 34 | optim = adam 35 | 36 | # Epsilon for the RL training, not applicable in Supervised training 37 | accum = 4 38 | 39 | # l2 Weight decay 40 | weight_decay = 0.0 41 | 42 | # dropout for Decoder network 43 | dropout = 0.2 44 | 45 | # Learing rate 46 | lr = 0.01 47 | 48 | # Encoder dropout 49 | encoder_drop = 0.2 50 | 51 | # Whether to schedule the learning rate or not 52 | lr_sch = True 53 | 54 | # Number of epochs to wait before decaying the learning rate. 55 | patience = 8 56 | 57 | mode = 5 58 | -------------------------------------------------------------------------------- /download_dataset.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading dataset" 2 | #wget http://neghvar.cs.umass.edu/public_data/parsenet/data.zip 3 | wget http://neghvar.cs.umass.edu/public_data/parsenet/predictions.h5 4 | echo "unzipping" 5 | #unzip data.zip 6 | mkdir logs 7 | mkdir logs/results 8 | mkdir logs/results/parsenet_with_normals.pth 9 | mkdir logs/results/parsenet_with_normals.pth/results 10 | mv predictions.h5 logs/results/parsenet_with_normals.pth/results/predictions.h5 11 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: parsenet 2 | channels: 3 | - loopbio 4 | - defaults 5 | - pytorch 6 | dependencies: 7 | - lapsolver=1.0.2=py36h4af77d5_0 8 | - python=3.6.6=h6e4f718_2 9 | - pip: 10 | - geomdl==5.2.9 11 | - h5py==2.10.0 12 | - lap==0.4.0 13 | - open3d==0.9.0.0 14 | - scikit-image==0.16.2 15 | - scikit-learn==0.22.1 16 | - scipy==1.4.1 17 | - six==1.13.0 18 | - sklearn==0.0 19 | - tensorboard-logger==0.1.0 20 | - torch==1.2.0 21 | - torchvision==0.4.0 22 | - transforms3d==0.3.1 23 | - trimesh==2.31.38 24 | - pip==20.0.2 -------------------------------------------------------------------------------- /generate_predictions.py: -------------------------------------------------------------------------------- 1 | from open3d import * 2 | import h5py 3 | import sys 4 | import logging 5 | import json 6 | import os 7 | from shutil import copyfile 8 | import numpy as np 9 | import torch.optim as optim 10 | import torch.utils.data 11 | from torch.autograd import Variable 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from src.PointNet import PrimitivesEmbeddingDGCNGn 14 | from matplotlib import pyplot as plt 15 | from src.utils import visualize_uv_maps, visualize_fitted_surface 16 | from src.utils import chamfer_distance 17 | from read_config import Config 18 | from src.utils import fit_surface_sample_points 19 | from src.dataset_segments import Dataset 20 | from torch.utils.data import DataLoader 21 | from src.utils import chamfer_distance 22 | from src.segment_loss import EmbeddingLoss 23 | from src.segment_utils import cluster 24 | import time 25 | from src.segment_loss import ( 26 | EmbeddingLoss, 27 | primitive_loss, 28 | evaluate_miou, 29 | ) 30 | from src.segment_utils import to_one_hot, SIOU_matched_segments 31 | from src.utils import visualize_point_cloud_from_labels, visualize_point_cloud 32 | from src.dataset import generator_iter 33 | from src.mean_shift import MeanShift 34 | from src.segment_utils import SIOU_matched_segments 35 | from src.residual_utils import Evaluation 36 | import time 37 | from src.primitives import SaveParameters 38 | 39 | # Use only one gpu. 40 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 41 | config = Config(sys.argv[1]) 42 | if_normals = config.normals 43 | 44 | userspace = "" 45 | Loss = EmbeddingLoss(margin=1.0) 46 | 47 | if config.mode == 0: 48 | # Just using points for training 49 | model = PrimitivesEmbeddingDGCNGn( 50 | embedding=True, 51 | emb_size=128, 52 | primitives=True, 53 | num_primitives=10, 54 | loss_function=Loss.triplet_loss, 55 | mode=config.mode, 56 | num_channels=3, 57 | ) 58 | elif config.mode == 5: 59 | # Using points and normals for training 60 | model = PrimitivesEmbeddingDGCNGn( 61 | embedding=True, 62 | emb_size=128, 63 | primitives=True, 64 | num_primitives=10, 65 | loss_function=Loss.triplet_loss, 66 | mode=config.mode, 67 | num_channels=6, 68 | ) 69 | 70 | saveparameters = SaveParameters() 71 | 72 | model_bkp = model 73 | model_bkp.l_permute = np.arange(10000) 74 | model = torch.nn.DataParallel(model, device_ids=[0]) 75 | model.cuda() 76 | 77 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 78 | ms = MeanShift() 79 | 80 | dataset = Dataset( 81 | config.batch_size, 82 | config.num_train, 83 | config.num_val, 84 | config.num_test, 85 | normals=True, 86 | primitives=True, 87 | if_train_data=False, 88 | prefix=userspace 89 | ) 90 | 91 | get_test_data = dataset.get_test(align_canonical=True, anisotropic=False, if_normal_noise=True) 92 | 93 | loader = generator_iter(get_test_data, int(1e10)) 94 | get_test_data = iter( 95 | DataLoader( 96 | loader, 97 | batch_size=1, 98 | shuffle=False, 99 | collate_fn=lambda x: x, 100 | num_workers=0, 101 | pin_memory=False, 102 | ) 103 | ) 104 | 105 | os.makedirs(userspace + "logs/results/{}/results/".format(config.pretrain_model_path), exist_ok=True) 106 | 107 | evaluation = Evaluation() 108 | alt_gpu = 0 109 | model.eval() 110 | 111 | iterations = 50 112 | quantile = 0.015 113 | 114 | model.load_state_dict( 115 | torch.load(userspace + "logs/pretrained_models/" + config.pretrain_model_path) 116 | ) 117 | test_res = [] 118 | test_s_iou = [] 119 | test_p_iou = [] 120 | test_g_res = [] 121 | test_s_res = [] 122 | PredictedLabels = [] 123 | PredictedPrims = [] 124 | 125 | for val_b_id in range(config.num_test // config.batch_size - 1): 126 | points_, labels, normals, primitives_ = next(get_test_data)[0] 127 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 128 | normals = torch.from_numpy(normals).cuda() 129 | 130 | # with torch.autograd.detect_anomaly(): 131 | with torch.no_grad(): 132 | if if_normals: 133 | input = torch.cat([points, normals], 2) 134 | embedding, primitives_log_prob, embed_loss = model( 135 | input.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 136 | ) 137 | else: 138 | embedding, primitives_log_prob, embed_loss = model( 139 | points.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 140 | ) 141 | pred_primitives = torch.max(primitives_log_prob[0], 0)[1].data.cpu().numpy() 142 | embedding = torch.nn.functional.normalize(embedding[0].T, p=2, dim=1) 143 | _, _, cluster_ids = evaluation.guard_mean_shift( 144 | embedding, quantile, iterations, kernel_type="gaussian" 145 | ) 146 | weights = to_one_hot(cluster_ids, np.unique(cluster_ids.data.data.cpu().numpy()).shape[ 147 | 0]) 148 | cluster_ids = cluster_ids.data.cpu().numpy() 149 | 150 | s_iou, p_iou, _, _ = SIOU_matched_segments( 151 | labels[0], 152 | cluster_ids, 153 | pred_primitives, 154 | primitives_[0], 155 | weights, 156 | ) 157 | # print(s_iou, p_iou) 158 | PredictedLabels.append(cluster_ids) 159 | PredictedPrims.append(pred_primitives) 160 | if val_b_id == 3: 161 | break 162 | 163 | with h5py.File(userspace + "logs/results/{}/results/".format(config.pretrain_model_path) + "predictions.h5", "w") as hf: 164 | hf.create_dataset(name="seg_id", data=np.stack(PredictedLabels, 0)) 165 | hf.create_dataset(name="pred_primitives", data=np.stack(PredictedPrims, 0)) -------------------------------------------------------------------------------- /images/parsenet-gallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/parsenet-codebase/223ed936b456b3a6cc3259222303b93143c3fee2/images/parsenet-gallery.jpg -------------------------------------------------------------------------------- /original_complete_environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_1.2 2 | channels: 3 | - loopbio 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - ca-certificates=2019.11.27=0 8 | - certifi=2019.11.28=py36_0 9 | - lapsolver=1.0.2=py36h4af77d5_0 10 | - libedit=3.1.20181209=hc058e9b_0 11 | - libffi=3.2.1=hd88cf55_4 12 | - libgcc-ng=9.1.0=hdf63c60_0 13 | - libstdcxx-ng=9.1.0=hdf63c60_0 14 | - ncurses=6.1=he6710b0_1 15 | - openssl=1.0.2u=h7b6447c_0 16 | - python=3.6.6=h6e4f718_2 17 | - readline=7.0=h7b6447c_5 18 | - sqlite=3.30.1=h7b6447c_0 19 | - tk=8.6.8=hbc83047_0 20 | - wheel=0.33.6=py36_0 21 | - xz=5.2.4=h14c3975_4 22 | - zlib=1.2.11=h7b6447c_3 23 | - pip: 24 | - appdirs==1.4.3 25 | - attrs==19.3.0 26 | - backcall==0.1.0 27 | - bleach==3.1.0 28 | - configobj==5.0.6 29 | - cycler==0.10.0 30 | - decorator==4.4.1 31 | - defusedxml==0.6.0 32 | - distlib==0.3.0 33 | - entrypoints==0.3 34 | - filelock==3.0.12 35 | - fuckit==4.8.1 36 | - geomdl==5.2.9 37 | - h5py==2.10.0 38 | - imageio==2.6.1 39 | - importlib-metadata==1.3.0 40 | - importlib-resources==1.4.0 41 | - ipdb==0.12.3 42 | - ipykernel==5.1.3 43 | - ipython==7.11.1 44 | - ipython-genutils==0.2.0 45 | - ipywidgets==7.5.1 46 | - jedi==0.15.2 47 | - jinja2==2.10.3 48 | - joblib==0.14.1 49 | - jsonschema==3.2.0 50 | - jupyter-client==5.3.4 51 | - jupyter-core==4.6.1 52 | - kiwisolver==1.1.0 53 | - lap==0.4.0 54 | - markupsafe==1.1.1 55 | - matplotlib==3.1.2 56 | - mistune==0.8.4 57 | - more-itertools==8.0.2 58 | - nbconvert==5.6.1 59 | - nbformat==4.4.0 60 | - networkx==2.4 61 | - notebook==6.0.2 62 | - numpy==1.18.0 63 | - open3d==0.9.0.0 64 | - pandocfilters==1.4.2 65 | - parso==0.5.2 66 | - pexpect==4.7.0 67 | - pickleshare==0.7.5 68 | - pillow==7.0.0 69 | - pip==20.0.2 70 | - prometheus-client==0.7.1 71 | - prompt-toolkit==3.0.2 72 | - protobuf==3.11.2 73 | - ptyprocess==0.6.0 74 | - pyflakes==2.1.1 75 | - pygments==2.5.2 76 | - pyparsing==2.4.6 77 | - pyrsistent==0.15.6 78 | - python-dateutil==2.8.1 79 | - pywavelets==1.1.1 80 | - pyyaml==5.3 81 | - pyzmq==18.1.1 82 | - scikit-image==0.16.2 83 | - scikit-learn==0.22.1 84 | - scipy==1.4.1 85 | - send2trash==1.5.0 86 | - setuptools==44.0.0 87 | - six==1.13.0 88 | - sklearn==0.0 89 | - tensorboard-logger==0.1.0 90 | - terminado==0.8.3 91 | - testpath==0.4.4 92 | - torch==1.2.0+cu92 93 | - torchvision==0.4.0+cu92 94 | - tornado==6.0.3 95 | - traitlets==4.3.3 96 | - transforms3d==0.3.1 97 | - trimesh==2.31.38 98 | - typing==3.7.4.1 99 | - virtualenv==20.0.18 100 | - wcwidth==0.1.8 101 | - webencodings==0.5.1 102 | - widgetsnbextension==3.5.1 103 | - zipp==0.6.0 104 | prefix: /home/gopalsharma/softwares/anaconda/envs/pytorch_1.2 -------------------------------------------------------------------------------- /read_config.py: -------------------------------------------------------------------------------- 1 | """Defines the configuration to be loaded before running any experiment""" 2 | from configobj import ConfigObj 3 | import string 4 | 5 | 6 | class Config(object): 7 | def __init__(self, filename: string): 8 | """ 9 | Read from a config file 10 | :param filename: name of the file to read from 11 | """ 12 | 13 | self.filename = filename 14 | config = ConfigObj(self.filename) 15 | self.config = config 16 | 17 | # Comments on the experiments running 18 | self.comment = config["comment"] 19 | 20 | # Model name and location to store 21 | self.model_path = config["train"]["model_path"] 22 | 23 | # path to the model 24 | self.pretrain_model_path = config["train"]["pretrain_model_path"] 25 | 26 | # Normals 27 | self.normals = config["train"].as_bool("normals") 28 | 29 | # number of training examples 30 | self.num_train = config["train"].as_int("num_train") 31 | self.num_val = config["train"].as_int("num_val") 32 | self.num_test = config["train"].as_int("num_test") 33 | self.num_points = config["train"].as_int("num_points") 34 | self.grid_size = config["train"].as_int("grid_size") 35 | # Weight to the loss function for stretching 36 | self.loss_weight = config["train"].as_float("loss_weight") 37 | 38 | # dataset 39 | self.dataset_path = config["train"]["dataset"] 40 | 41 | # Proportion of train dataset to use 42 | self.proportion = config["train"].as_float("proportion") 43 | 44 | # Number of epochs to run during training 45 | self.epochs = config["train"].as_int("num_epochs") 46 | 47 | # batch size, based on the GPU memory 48 | self.batch_size = config["train"].as_int("batch_size") 49 | 50 | # Mode of training, 1: supervised, 2: RL 51 | self.mode = config["train"].as_int("mode") 52 | 53 | # Learning rate 54 | self.lr = config["train"].as_float("lr") 55 | 56 | # Number of epochs to wait before decaying the learning rate. 57 | self.patience = config["train"].as_int("patience") 58 | 59 | # Optimizer: RL training -> "sgd" or supervised training -> "adam" 60 | self.optim = config["train"]["optim"] 61 | 62 | # Epsilon for the RL training, not applicable in Supervised training 63 | self.accum = config["train"].as_int("accum") 64 | 65 | # Whether to schedule the learning rate or not 66 | self.lr_sch = config["train"].as_bool("lr_sch") 67 | 68 | def write_config(self, filename): 69 | """ 70 | Write the details of the experiment in the form of a config file. 71 | This will be used to keep track of what experiments are running and 72 | what parameters have been used. 73 | :return: 74 | """ 75 | self.config.filename = filename 76 | self.config.write() 77 | 78 | def get_all_attribute(self): 79 | """ 80 | This function prints all the values of the attributes, just to cross 81 | check whether all the data types are correct. 82 | :return: Nothing, just printing 83 | """ 84 | for attr, value in self.__dict__.items(): 85 | print(attr, value) 86 | 87 | 88 | if __name__ == "__main__": 89 | file = Config("config_synthetic.yml") 90 | print(file.write_config()) 91 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## ParSeNet: A Parametric Surface Fitting Network for 3D Point Clouds 2 | *Authors: [Gopal Sharma](https://hippogriff.github.io/), [Difan Liu](https://people.cs.umass.edu/~dliu/), [Evangelos Kalogerakis](https://people.cs.umass.edu/~kalo/), [Subhransu Maji](https://people.cs.umass.edu/~smaji/), [Siddhartha Chaudhuri](https://www.cse.iitb.ac.in/~sidch/), [Radomír Měch](https://research.adobe.com/person/radomir-mech/)* 3 | 4 | This repository contains codebase for the ParSeNet paper published at ECCV-2020. 5 | 6 | [Paper](https://arxiv.org/pdf/2003.12181.pdf) | [Project Page](https://hippogriff.github.io/parsenet/) 7 | 8 | ![](images/parsenet-gallery.jpg ) 9 | 10 | 11 | 12 | ### Installation 13 | 14 | To install conda environment: 15 | 16 | ```bash 17 | conda env create --force environment.yml -n parsenet 18 | source activate parsenet 19 | ``` 20 | 21 | ------ 22 | 23 | 24 | 25 | ### Dataset 26 | 27 | To dowload the dataset, run: 28 | 29 | ```bash 30 | bash download_dataset.sh 31 | ``` 32 | For data organization, please see `readme_data.md`. 33 | 34 | ------ 35 | 36 | 37 | 38 | ### Experiments 39 | 40 | Experiments are done on Nvidia 1080ti gpus. 41 | 42 | #### SplineNet 43 | 44 | * To train open SplineNet (with 2 gpus): 45 | 46 | ```python 47 | python train_open_splines.py configs/config_open_splines.yml 48 | ``` 49 | 50 | * To test open SplineNet: 51 | 52 | ```python 53 | python test_open_splines.py configs/config_test_open_splines.yml 54 | ``` 55 | 56 | * To train closed SplineNet (with 2 gpus): 57 | 58 | ```python 59 | python train_closed_control_points.py configs/config_closed_splines.yml 60 | ``` 61 | 62 | * To test closed SplineNet: 63 | 64 | ```python 65 | python test_closed_control_points.py configs/config_test_closed_splines.yml 66 | ``` 67 | 68 | 69 | 70 | #### ParSeNet 71 | 72 | - To train ParseNet with only points as input (with 4 gpus): 73 | 74 | ``` 75 | python train_parsenet.py configs/config_parsenet.yml 76 | ``` 77 | 78 | * To train ParseNet with points and normals as input (with 4 gpus): 79 | 80 | ``` 81 | python train_parsenet.py configs/config_parsenet_normals.yml 82 | ``` 83 | 84 | * To train ParseNet in an end to end manner (note that you need to first pretrain the above models), then specify the path to the trained model in `configs/config_parsenet_e2e.yml` (with 2 gpus). Further note that, this part of the training requires dynamic amount of gpu memory because a shape can have variable number of segment and corresponding number of fitting module. Training is done using Nvidia m40 (24 Gb gpu). 85 | 86 | ``` 87 | python train_parsenet_e2e.py configs/config_parsenet_e2e.yml 88 | ``` 89 | 90 | * Testing can be done using `test.py` 91 | ``` 92 | python test.py 0 3998 93 | ``` 94 | ------ 95 | 96 | 97 | 98 | ### Acknowledgements 99 | 100 | 1. This project takes inspiration of designing network architecture from the code base provided by Wang et.al.: https://github.com/WangYueFt/dgcnn 101 | 2. We also thank Sebastian for timely release and advice on ABC dataset: https://deep-geometry.github.io/abc-dataset/ 102 | 103 | ------ 104 | 105 | 106 | 107 | ### Citation 108 | 109 | ``` 110 | @misc{sharma2020parsenet, 111 | title={ParSeNet: A Parametric Surface Fitting Network for 3D Point Clouds}, 112 | author={Gopal Sharma and Difan Liu and Evangelos Kalogerakis and Subhransu Maji and Siddhartha Chaudhuri and Radomír Měch}, 113 | year={2020}, 114 | eprint={2003.12181}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.CV} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /readme_data.md: -------------------------------------------------------------------------------- 1 | To download the dataset: 2 | ``` 3 | bash dowload_dataset.sh 4 | ``` 5 | This will download and unzip the dataset in the data folder and will unzip pre-trained models. 6 | The data directory has the following organization. 7 | 8 | ``` 9 | ├── shapes 10 | │   ├── all_ids.txt 11 | │   ├── face_data_release.zip 12 | │   ├── meshes.zip 13 | │   ├── test_data.h5 14 | │   ├── test_ids.txt 15 | │   ├── train_data.h5 16 | │   ├── train_ids.txt 17 | │   ├── val_data.h5 18 | │   └── val_ids.txt 19 | └── spline 20 | ├── closed_splines.h5 21 | ├── open_splines.h5 22 | └── simple_less_thn_20.zip 23 | ``` 24 | 25 | * `meshes.zip`: contains all the meshes used in the parsenet 26 | experiments. *Note* that these models are taken from ABC dataset. We 27 | pre-processed shapes to separate disconnected meshes into different 28 | meshes. For that reason you will notices that names of the shapes is 29 | of the format `shapeid_index.json`, where `shapeid` is the id of the 30 | model from ABC dataset and `index` is the index of the disconnected 31 | part. 32 | 33 | * `train_data.h5`, `train_data.h5` and `val_data.h5`: contain points, 34 | normals, segment index and primitive type index for each shape. 35 | Please refer to `src/dataset_segments.py` on how to load these h5 36 | files. Note that, for primitive types, there are possible 10 primitives, 37 | for example circle, sphere, plane, cone, cylinder, open spline, closed spline, 38 | revolution, extrusion and `extra`. revolution, extrusion and extra are treated 39 | as b-spline primitives because b-spline can also approximate these patches. Excluding 40 | shapes with these extra surface patches would have resulted in very small dataset. 41 | More specifically: 42 | 1. [0, 6, 7, 9] indices correspond to closed b-spline. 43 | 2. [2, 8] indices correspond to open b-spline. 44 | 3. [1] index corresponds to plane. 45 | 4. [3] corresponds to cone. 46 | 5. [4] corresponds to cylinder. 47 | 6. [5] corresponds to sphere. 48 | 49 | * `face_data_release.zip`: contains txt files for each shape in the 50 | above dataset. Specifically, it contains the segment id and 51 | primitive types for each shape. 52 | 53 | * `train_ids.txt`, `val_ids.txt` and `test_ids.txt`: contains shape ids for 54 | different splits. `all_ids.txt` contains list of ids for all shapes. 55 | 56 | * `closed_splines.h5`: contains points and control points for closed 57 | splines. Please refer to `src/dataset.py` for more details on how to 58 | load points, and splits. `open_splines.h5` is for open splines. 59 | -------------------------------------------------------------------------------- /render_options.json: -------------------------------------------------------------------------------- 1 | { 2 | "background_color" : [ 1, 1, 1 ], 3 | "class_name" : "RenderOption", 4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], 5 | "image_max_depth" : 3000, 6 | "image_stretch_option" : 0, 7 | "interpolation_option" : 0, 8 | "light0_color" : [ 1, 1, 1 ], 9 | "light0_diffuse_power" : 0.26000000000000003, 10 | "light0_position" : [ 0, 0, 2 ], 11 | "light0_specular_power" : 0.01000000000000001, 12 | "light0_specular_shininess" : 1, 13 | "light1_color" : [ 1, 1, 1 ], 14 | "light1_diffuse_power" : 0.26000000000000003, 15 | "light1_position" : [ 0, 0, 2 ], 16 | "light1_specular_power" : 0.01000000000000001, 17 | "light1_specular_shininess" : 1, 18 | "light2_color" : [ 1, 1, 1 ], 19 | "light2_diffuse_power" : 0.26000000000000003, 20 | "light2_position" : [ 0, 0, -2 ], 21 | "light2_specular_power" : 0.01000000000000001, 22 | "light2_specular_shininess" : 1, 23 | "light3_color" : [ 1, 1, 1 ], 24 | "light3_diffuse_power" : 0.26000000000000003, 25 | "light3_position" : [ 0, 0, -2 ], 26 | "light3_specular_power" : 0.01000000000000001, 27 | "light3_specular_shininess" : 1, 28 | "light_ambient_color" : [ 0.2, 0.2, 0.2 ], 29 | "light_on" : on, 30 | "mesh_color_option" : 1, 31 | "mesh_shade_option" : 1, 32 | "mesh_show_back_face" : true, 33 | "mesh_show_wireframe" : true, 34 | "point_color_option" : 9, 35 | "point_show_normal" : false, 36 | "point_size" : 2, 37 | "show_coordinate_frame" : false, 38 | "version_major" : 1, 39 | "version_minor" : 0 40 | } 41 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hippogriff/parsenet-codebase/223ed936b456b3a6cc3259222303b93143c3fee2/src/__init__.py -------------------------------------------------------------------------------- /src/augment_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # TODO cite the original author from where you have taken the code 5 | 6 | 7 | def rotate_point_cloud(batch_data): 8 | """ Randomly rotate the point clouds to augument the dataset 9 | rotation is per shape based along up direction 10 | Input: 11 | BxNx3 array, original batch of point clouds 12 | Return: 13 | BxNx3 array, rotated batch of point clouds 14 | """ 15 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 16 | for k in range(batch_data.shape[0]): 17 | rotation_angle = np.random.uniform() * 2 * np.pi 18 | cosval = np.cos(rotation_angle) 19 | sinval = np.sin(rotation_angle) 20 | rotation_matrix = np.array([[cosval, 0, sinval], 21 | [0, 1, 0], 22 | [-sinval, 0, cosval]]) 23 | shape_pc = batch_data[k, ...] 24 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 25 | return rotated_data.astype(np.float32) 26 | 27 | 28 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 29 | """ Rotate the point cloud along up direction with certain angle. 30 | Input: 31 | BxNx3 array, original batch of point clouds 32 | Return: 33 | BxNx3 array, rotated batch of point clouds 34 | """ 35 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 36 | for k in range(batch_data.shape[0]): 37 | # rotation_angle = np.random.uniform() * 2 * np.pi 38 | cosval = np.cos(rotation_angle) 39 | sinval = np.sin(rotation_angle) 40 | rotation_matrix = np.array([[cosval, 0, sinval], 41 | [0, 1, 0], 42 | [-sinval, 0, cosval]]) 43 | shape_pc = batch_data[k, ...] 44 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 45 | return rotated_data.astype(np.float32) 46 | 47 | 48 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.30): 49 | """ Randomly perturb the point clouds by small rotations 50 | Input: 51 | BxNx3 array, original batch of point clouds 52 | Return: 53 | BxNx3 array, rotated batch of point clouds 54 | """ 55 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 56 | for k in range(batch_data.shape[0]): 57 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 58 | Rx = np.array([[1, 0, 0], 59 | [0, np.cos(angles[0]), -np.sin(angles[0])], 60 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 61 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 62 | [0, 1, 0], 63 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 64 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 65 | [np.sin(angles[2]), np.cos(angles[2]), 0], 66 | [0, 0, 1]]) 67 | R = np.dot(Rz, np.dot(Ry, Rx)) 68 | shape_pc = batch_data[k, ...] 69 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 70 | return rotated_data.astype(np.float32) 71 | 72 | 73 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 74 | """ Randomly jitter points. jittering is per point. 75 | Input: 76 | BxNx3 array, original batch of point clouds 77 | Return: 78 | BxNx3 array, jittered batch of point clouds 79 | """ 80 | B, N, C = batch_data.shape 81 | assert (clip > 0) 82 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip) 83 | jittered_data += batch_data 84 | return jittered_data.astype(np.float32) 85 | 86 | 87 | def shift_point_cloud(batch_data, shift_range=0.1): 88 | """ Randomly shift point cloud. Shift is per point cloud. 89 | Input: 90 | BxNx3 array, original batch of point clouds 91 | Return: 92 | BxNx3 array, shifted batch of point clouds 93 | """ 94 | B, N, C = batch_data.shape 95 | shifts = np.random.uniform(-shift_range, shift_range, (B, 3)) 96 | for batch_index in range(B): 97 | batch_data[batch_index, :, :] += shifts[batch_index, :] 98 | return batch_data.astype(np.float32) 99 | 100 | 101 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.2): 102 | """ Randomly scale the point cloud. Scale is per point cloud. 103 | Input: 104 | BxNx3 array, original batch of point clouds 105 | Return: 106 | BxNx3 array, scaled batch of point clouds 107 | """ 108 | B, N, C = batch_data.shape 109 | scales = np.random.uniform(scale_low, scale_high, B) 110 | for batch_index in range(B): 111 | batch_data[batch_index, :, :] *= scales[batch_index] 112 | return batch_data 113 | 114 | 115 | class Augment: 116 | def __init__(self, ): 117 | pass 118 | 119 | def augment(self, batch_data): 120 | if np.random.random() > 0.7: 121 | batch_data = rotate_perturbation_point_cloud(batch_data) 122 | if np.random.random() > 0.7: 123 | batch_data = jitter_point_cloud(batch_data) 124 | if np.random.random() > 0.7: 125 | batch_data = shift_point_cloud(batch_data, 0.05) 126 | if np.random.random() > 0.7: 127 | batch_data = random_scale_point_cloud(batch_data) 128 | return batch_data 129 | -------------------------------------------------------------------------------- /src/bezier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import comb 3 | 4 | 5 | def bernstein_polynomial(n): 6 | """ 7 | n: degree of the polynomial 8 | """ 9 | N = np.ones(n + 1, dtype=np.int32) * n 10 | K = np.arange(n + 1) 11 | basis = comb(N, K) 12 | return basis.reshape((1, n + 1)) 13 | 14 | 15 | def bernstein_tensor(t, basis): 16 | """ 17 | t: L x 1 18 | basis: 1 x n + 1 19 | """ 20 | n = basis.shape[1] - 1 21 | T = [] 22 | for i in range(n + 1): 23 | T.append((t ** i) * ((1.0 - t) ** (n - i))) 24 | T = np.concatenate(T, 1) 25 | basis_tensor = T * basis 26 | return basis_tensor 27 | 28 | 29 | basis = bernstein_polynomial(3) 30 | t = np.random.random((100, 1)) 31 | basis_u = bernstein_tensor(t, basis) 32 | 33 | t = np.random.random((100, 1)) 34 | basis_v = bernstein_tensor(t, basis) 35 | 36 | p = np.array([[0, 0, 0], 37 | [0.33, 0, 0.5], 38 | [0.66, 0, 0.5], 39 | [1, 0, 0], 40 | 41 | [0, 0.33, 0.5], 42 | [0.33, 0.33, 1], 43 | [0.66, 0.33, 1], 44 | [1, 0.33, 0.5], 45 | 46 | [0, 0.66, -0.5], 47 | [0.33, 0.66, -1], 48 | [0.66, 0.66, -1], 49 | [1, 0.66, -0.5], 50 | 51 | [0, 1, 0], 52 | [0.33, 1, 0.5], 53 | [0.66, 1, 0.5], 54 | [1, 1, 0]]) 55 | 56 | cp = p.reshape((4, 4, 3)) 57 | 58 | points = [] 59 | for i in range(3): 60 | points.append(np.matmul(np.matmul(basis_u, cp[:, :, i]), np.transpose(basis_v))) 61 | 62 | points = np.stack(points, 2) 63 | -------------------------------------------------------------------------------- /src/color_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | 6 | 7 | # initialize the weighs of the network for Convolutional layers and batchnorm layers 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | m.weight.data.normal_(1.0, 0.02) 14 | m.bias.data.fill_(0) 15 | 16 | 17 | def adjust_learning_rate(optimizer, epoch, phase): 18 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 19 | if epoch % phase == (phase - 1): 20 | for param_group in optimizer.param_groups: 21 | param_group["lr"] = param_group["lr"] / 10.0 22 | 23 | 24 | class AverageValueMeter(object): 25 | """Computes and stores the average and current value""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0.0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | CHUNK_SIZE = 150 44 | lenght_line = 60 45 | 46 | 47 | def my_get_n_random_lines(path, n=5): 48 | MY_CHUNK_SIZE = lenght_line * (n + 2) 49 | lenght = os.stat(path).st_size 50 | with open(path, "r") as file: 51 | file.seek(random.randint(400, lenght - MY_CHUNK_SIZE)) 52 | chunk = file.read(MY_CHUNK_SIZE) 53 | lines = chunk.split(os.linesep) 54 | return lines[1: n + 1] 55 | 56 | 57 | def get_random_color(pastel_factor=0.5): 58 | return [ 59 | (x + pastel_factor) / (1.0 + pastel_factor) 60 | for x in [random.uniform(0, 1.0) for i in [1, 2, 3]] 61 | ] 62 | 63 | 64 | def color_distance(c1, c2): 65 | return sum([abs(x[0] - x[1]) for x in zip(c1, c2)]) 66 | 67 | 68 | def generate_new_color(existing_colors, pastel_factor=0.5): 69 | max_distance = None 70 | best_color = None 71 | for i in range(0, 100): 72 | color = get_random_color(pastel_factor=pastel_factor) 73 | if not existing_colors: 74 | return color 75 | best_distance = min([color_distance(color, c) for c in existing_colors]) 76 | if not max_distance or best_distance > max_distance: 77 | max_distance = best_distance 78 | best_color = color 79 | return best_color 80 | 81 | 82 | # Example: 83 | def get_colors(num_colors=10): 84 | colors = [] 85 | for i in range(0, num_colors): 86 | colors.append(generate_new_color(colors, pastel_factor=0.9)) 87 | for i in range(0, num_colors): 88 | for j in range(0, 3): 89 | colors[i][j] = int(colors[i][j] * 256) 90 | colors[i].append(255) 91 | return colors 92 | 93 | 94 | # CODE from 3D R2N2 95 | def image_transform(img, crop_x, crop_y, crop_loc=None, color_tint=None): 96 | """ 97 | Takes numpy.array img 98 | """ 99 | 100 | # Slight translation 101 | if not crop_loc: 102 | crop_loc = [np.random.randint(0, crop_y), np.random.randint(0, crop_x)] 103 | 104 | if crop_loc: 105 | cr, cc = crop_loc 106 | height, width, _ = img.shape 107 | img_h = height - crop_y 108 | img_w = width - crop_x 109 | img = img[cr: cr + img_h, cc: cc + img_w] 110 | # depth = depth[cr:cr+img_h, cc:cc+img_w] 111 | 112 | if np.random.rand() > 0.5: 113 | img = img[:, ::-1, ...] 114 | 115 | return img 116 | 117 | 118 | def crop_center(im, new_height, new_width): 119 | height = im.shape[0] # Get dimensions 120 | width = im.shape[1] 121 | left = (width - new_width) // 2 122 | top = (height - new_height) // 2 123 | right = (width + new_width) // 2 124 | bottom = (height + new_height) // 2 125 | return im[top:bottom, left:right] 126 | 127 | 128 | def add_random_color_background(im, color_range): 129 | r, g, b = [ 130 | np.random.randint(color_range[i][0], color_range[i][1] + 1) for i in range(3) 131 | ] 132 | if isinstance(im, Image.Image): 133 | im = np.array(im) 134 | 135 | if im.shape[2] > 3: 136 | # If the image has the alpha channel, add the background 137 | alpha = (np.expand_dims(im[:, :, 3], axis=2) == 0).astype(np.float) 138 | im = im[:, :, :3] 139 | bg_color = np.array([[[r, g, b]]]) 140 | im = alpha * bg_color + (1 - alpha) * im 141 | 142 | return im 143 | 144 | 145 | def preprocess_img(im, train=True): 146 | # add random background 147 | # im = add_random_color_background(im, cfg.TRAIN.NO_BG_COLOR_RANGE if train else 148 | # cfg.TEST.NO_BG_COLOR_RANGE) 149 | 150 | # If the image has alpha channel, remove it. 151 | im_rgb = np.array(im)[:, :, :3].astype(np.float32) 152 | if train: 153 | t_im = image_transform(im_rgb, 17, 17) 154 | else: 155 | t_im = crop_center(im_rgb, 224, 224) 156 | 157 | # Scale image 158 | t_im = t_im / 255.0 159 | 160 | return t_im 161 | 162 | 163 | if __name__ == "__main__": 164 | # To make your color choice reproducible, uncomment the following line: 165 | # random.seed(10) 166 | 167 | colors = get_colors(10) 168 | print("Your colors:", colors) 169 | -------------------------------------------------------------------------------- /src/curve_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script contains utility function to draw surfaces 3 | """ 4 | 5 | import numpy as np 6 | from geomdl import BSpline, NURBS 7 | from geomdl import fitting 8 | from geomdl import multi 9 | from geomdl.visualization import VisMPL 10 | from matplotlib import cm 11 | 12 | 13 | class DrawSurfs: 14 | def __init__(self): 15 | """ 16 | Given surfaces from features files from ABC dataset, 17 | load it into geomdl object or samples points on the surfaces 18 | of primitives, depending on the case. Defines utility to sample 19 | points form the surface of splines and primitives. 20 | """ 21 | self.function_dict = { 22 | "Sphere": self.draw_sphere, 23 | "BSpline": self.draw_nurbspatch, 24 | "Cylinder": self.draw_cylinder, 25 | "Cone": self.draw_cone, 26 | "Torus": self.draw_torus, 27 | "Plane": self.draw_plane, 28 | } 29 | 30 | def load_shape(self, shape): 31 | """ 32 | Takes a list containing surface in feature file format, and returns 33 | a list of sampled points on the surface of primitive/splines. 34 | """ 35 | Points = [] 36 | for surf in shape: 37 | function = self.function_dict[surf["type"]] 38 | points = function(surf) 39 | Points.append(points) 40 | Points = np.concatenate(Points, 0) 41 | return Points 42 | 43 | def draw_plane(self, surf): 44 | l = np.array(surf["location"]) 45 | x = np.array(surf["x_axis"]) 46 | y = np.array(surf["y_axis"]) 47 | parameters = np.array(surf["vert_parameters"]) 48 | u_min, v_min = np.min(parameters, 0) 49 | u_max, v_max = np.max(parameters, 0) 50 | u, v = np.meshgrid(np.arange(u_min, u_max, 0.1), np.arange(v_min, v_max, 0.1)) 51 | plane = ( 52 | l 53 | + np.expand_dims(u.flatten(), 1) * x.reshape((1, 3)) 54 | + np.expand_dims(v.flatten(), 1) * y.reshape((1, 3)) 55 | ) 56 | return plane 57 | 58 | def draw_cylinder(self, surf): 59 | l = np.array(surf["location"]) 60 | x = np.array(surf["x_axis"]).reshape((1, 3)) 61 | y = np.array(surf["y_axis"]).reshape((1, 3)) 62 | z = np.array(surf["z_axis"]).reshape((1, 3)) 63 | r = np.array(surf["radius"]) 64 | parameters = np.array(surf["vert_parameters"]) 65 | u_min, v_min = np.min(parameters, 0) 66 | u_max, v_max = np.max(parameters, 0) 67 | u, v = np.meshgrid(np.arange(0, 3.14 * 2, 0.1), np.arange(v_min, v_max, 0.1)) 68 | u = np.expand_dims(u.flatten(), 1) 69 | v = np.expand_dims(v.flatten(), 1) 70 | temp = np.cos(u) * r * x 71 | cylinder = l + np.cos(u) * r * x + np.sin(u) * r * y + v * z 72 | return cylinder 73 | 74 | def draw_sphere(self, surf): 75 | l = np.array(surf["location"]) 76 | x = np.array(surf["x_axis"]).reshape((1, 3)) 77 | y = np.array(surf["y_axis"]).reshape((1, 3)) 78 | r = np.array(surf["radius"]) 79 | z = np.cross(x, y) 80 | parameters = np.array(surf["vert_parameters"]) 81 | u_min, v_min = np.min(parameters, 0) 82 | u_max, v_max = np.max(parameters, 0) 83 | u, v = np.meshgrid(np.arange(u_min, u_max, 0.3), np.arange(v_min, v_max, 0.3)) 84 | u = np.expand_dims(u.flatten(), 1) 85 | v = np.expand_dims(v.flatten(), 1) 86 | 87 | sphere = l + r * np.cos(v) * (np.cos(u) * x + np.sin(u) * y) + r * np.sin(v) * z 88 | return sphere 89 | 90 | def draw_cone(self, surf): 91 | l = np.array(surf["location"]) 92 | x = np.array(surf["x_axis"]).reshape((1, 3)) 93 | y = np.array(surf["y_axis"]).reshape((1, 3)) 94 | z = np.array(surf["z_axis"]).reshape((1, 3)) 95 | r = np.array(surf["radius"]) 96 | a = np.array(surf["angle"]) 97 | 98 | parameters = np.array(surf["vert_parameters"]) 99 | u_min, v_min = np.min(parameters, 0) 100 | u_max, v_max = np.max(parameters, 0) 101 | u, v = np.meshgrid(np.arange(u_min, u_max, 0.1), np.arange(v_min, v_max, 0.1)) 102 | u = np.expand_dims(u.flatten(), 1) 103 | v = np.expand_dims(v.flatten(), 1) 104 | 105 | cone = ( 106 | l 107 | + (r + v * np.sin(a)) * (np.cos(u) * x + np.sin(u) * y) 108 | + v * np.cos(a) * z 109 | ) 110 | return cone 111 | 112 | def draw_torus(self, surf): 113 | l = np.array(surf["location"]) 114 | x = np.array(surf["x_axis"]).reshape((1, 3)) 115 | y = np.array(surf["y_axis"]).reshape((1, 3)) 116 | z = np.array(surf["z_axis"]).reshape((1, 3)) 117 | r_max = np.array(surf["max_radius"]) 118 | r_min = np.array(surf["min_radius"]) 119 | 120 | parameters = np.array(data["surfaces"][5]["vert_parameters"]) 121 | u_min, v_min = np.min(parameters, 0) 122 | u_max, v_max = np.max(parameters, 0) 123 | u, v = np.meshgrid(np.arange(u_min, u_max, 0.3), np.arange(v_min, v_max, 0.3)) 124 | u = np.expand_dims(u.flatten(), 1) 125 | v = np.expand_dims(v.flatten(), 1) 126 | cone = ( 127 | l 128 | + (r_max + r_min * np.cos(v)) * (np.cos(u) * x + np.sin(u) * y) 129 | + (r_min) * np.sin(v) * z 130 | ) 131 | return cone 132 | 133 | def load_spline_curve(self, spline): 134 | crv = BSpline.Curve() 135 | crv.degree = spline["degree"] 136 | crv.ctrlpts = spline["poles"] 137 | crv.knotvector = spline["knots"] 138 | return crv 139 | 140 | def load_spline_surf(self, spline): 141 | # Create a BSpline surface 142 | if spline["v_rational"] or spline["u_rational"]: 143 | surf = NURBS.Surface() 144 | control_points = np.array(spline["poles"]) 145 | size_u, size_v = control_points.shape[0], control_points.shape[1] 146 | 147 | # Set degrees 148 | surf.degree_u = spline["u_degree"] 149 | surf.degree_v = spline["v_degree"] 150 | 151 | # Set control points 152 | surf.ctrlpts2d = np.concatenate([control_points, 153 | np.ones((size_u, size_v, 1))], 2).tolist() 154 | surf.knotvector_v = spline["v_knots"] 155 | surf.knotvector_u = spline["u_knots"] 156 | 157 | weights = spline["weights"] 158 | l = [] 159 | for i in weights: 160 | l += i 161 | surf.weights = l 162 | return surf 163 | 164 | else: 165 | surf = BSpline.Surface() 166 | 167 | # Set degrees 168 | surf.degree_u = spline["u_degree"] 169 | surf.degree_v = spline["v_degree"] 170 | 171 | # Set control points 172 | surf.ctrlpts2d = spline["poles"] 173 | 174 | # Set knot vectors 175 | surf.knotvector_u = spline["u_knots"] 176 | surf.knotvector_v = spline["v_knots"] 177 | return surf 178 | 179 | def draw_nurbspatch(self, surf): 180 | surf = self.load_spline_surf(surf) 181 | return surf.evalpts 182 | 183 | def vis_spline_curve(self, crv): 184 | crv.vis = VisMPL.VisCurve3D() 185 | crv.render() 186 | 187 | def vis_spline_surf(self, surf): 188 | surf.vis = VisMPL.VisSurface() 189 | surf.render() 190 | 191 | def vis_multiple_spline_surf(self, surfs): 192 | mcrv = multi.SurfaceContainer([surf, surf1]) 193 | mcrv.vis = VisMPL.VisSurface() 194 | mcrv.render() 195 | 196 | def sample_points_bspline_surface(self, spline, N): 197 | parameters = np.random.random((N, 2)) 198 | points = spline.evaluate_list(parameters) 199 | return np.array(points) 200 | 201 | def regular_parameterization(self, grid_u, grid_v): 202 | nx, ny = (grid_u, grid_v) 203 | x = np.linspace(0, 1, nx) 204 | y = np.linspace(0, 1, ny) 205 | xv, yv = np.meshgrid(x, y) 206 | xv = np.expand_dims(xv.transpose().flatten(), 1) 207 | yv = np.expand_dims(yv.transpose().flatten(), 1) 208 | parameters = np.concatenate([xv, yv], 1) 209 | return parameters 210 | 211 | def boundary_parameterization(self, grid_u): 212 | u = np.arange(grid_u) 213 | zeros = np.zeros(grid_u) 214 | ones = np.ones(grid_u) 215 | 216 | parameters = [np.stack([zeros, u], 1)] 217 | parameters += [np.stack([np.arange(1, grid_u), np.zeros(grid_u - 1)], 1)] 218 | parameters += [np.stack([np.arange(1, grid_u), np.ones(grid_u - 1) * (grid_u - 1)], 1)] 219 | parameters += [np.stack([np.ones(grid_u - 2) * (grid_u - 1), np.arange(1, grid_u - 1)], 1)] 220 | parameters = np.concatenate(parameters, 0) 221 | return parameters / (grid_u - 1) 222 | 223 | 224 | class PlotSurface: 225 | def __init__(self, abstract_class="vtk"): 226 | self.abstract_class = abstract_class 227 | if abstract_class == "plotly": 228 | from geomdl.visualization.VisPlotly import VisSurface 229 | elif abstract_class == "vtk": 230 | from geomdl.visualization.VisVTK import VisSurface 231 | self.VisSurface = VisSurface 232 | 233 | def plot(self, surf, colormap=None): 234 | surf.vis = self.VisSurface() 235 | if colormap: 236 | surf.render(colormap=cm.cool) 237 | else: 238 | surf.render() 239 | 240 | 241 | def fit_surface(points, size_u, size_v, degree_u=3, degree_v=3, regular_grids=False): 242 | fitted_surface = fitting.approximate_surface( 243 | points, 244 | size_u=size_u, 245 | size_v=size_v, 246 | degree_u=degree_u, 247 | degree_v=degree_v, 248 | ctrlpts_size_u=10, 249 | ctrlpts_size_v=10, 250 | ) 251 | 252 | if regular_grids: 253 | parameters = regular_parameterization(25, 25) 254 | else: 255 | parameters = np.random.random((3000, 2)) 256 | fitted_points = fitted_surface.evaluate_list(parameters) 257 | return fitted_surface, fitted_points 258 | 259 | 260 | def regular_parameterization(grid_u, grid_v): 261 | nx, ny = (grid_u, grid_v) 262 | x = np.linspace(0, 1, nx) 263 | y = np.linspace(0, 1, ny) 264 | xv, yv = np.meshgrid(x, y) 265 | xv = np.expand_dims(xv.transpose().flatten(), 1) 266 | yv = np.expand_dims(yv.transpose().flatten(), 1) 267 | parameters = np.concatenate([xv, yv], 1) 268 | return parameters 269 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_stats(data, max_surfaces, max_control_points): 5 | # see if the number of surfaces are less than the threshold 6 | if len(data) > max_surfaces: 7 | return [0, None] 8 | 9 | contain_spline = False 10 | types = [] 11 | for surf in data: 12 | types.append(surf["type"]) 13 | if "BSpline" in types: 14 | contain_spline = True 15 | else: 16 | return [0, None] 17 | 18 | # remove extra meta data 19 | for d in data: 20 | for key in ["vert_parameters", "face_indices", "coefficients", "vert_indices"]: 21 | if key in d.keys(): 22 | del d[key] 23 | 24 | new_data = [] 25 | for index, surf in enumerate(data): 26 | # removing the unnecessary information about exact poles and keeping just the counts 27 | new_data.append(surf) 28 | if surf["type"] == "BSpline": 29 | new_data[-1]["poles"] = np.array(surf["poles"]).shape 30 | new_data[-1]["u_knots"] = np.array(surf["u_knots"]).shape 31 | new_data[-1]["v_knots"] = np.array(surf["v_knots"]).shape 32 | new_data[-1]["weights"] = np.array(surf["weights"]).shape 33 | 34 | ctrl_p_shape = [] 35 | for surf in data: 36 | if surf["type"] == "BSpline": 37 | ctrl_p_shape.append(np.array(surf["weights"]).reshape(1, 2)) 38 | 39 | ctrl_p_shape = np.concatenate(ctrl_p_shape, 0) 40 | valid_splines = np.sum((ctrl_p_shape < max_control_points)) 41 | 42 | valid_shapes = False 43 | if valid_splines == ctrl_p_shape.shape[0] * ctrl_p_shape.shape[1]: 44 | valid_shapes = True 45 | 46 | return [valid_shapes, new_data] 47 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | from src.augment_utils import Augment 6 | from src.curve_utils import DrawSurfs 7 | 8 | augment = Augment() 9 | 10 | EPS = np.finfo(np.float32).eps 11 | 12 | 13 | class generator_iter(Dataset): 14 | """This is a helper function to be used in the parallel data loading using Pytorch 15 | DataLoader class""" 16 | 17 | def __init__(self, generator, train_size): 18 | self.generator = generator 19 | self.train_size = train_size 20 | 21 | def __len__(self): 22 | return self.train_size 23 | 24 | def __getitem__(self, idx): 25 | return next(self.generator) 26 | 27 | 28 | class DataSetControlPointsPoisson: 29 | def __init__(self, path, batch_size, size_u=20, size_v=20, splits={}, closed=False): 30 | """ 31 | :param path: path to h5py file that stores the dataset 32 | :param batch_size: batch size 33 | :param num_points: number of 34 | :param size_u: 35 | :param size_v: 36 | :param splits: 37 | """ 38 | self.path = path 39 | self.batch_size = batch_size 40 | self.size_u = size_u 41 | self.size_v = size_v 42 | all_files = [] 43 | 44 | count = 0 45 | self.train_size = splits["train"] 46 | self.val_size = splits["val"] 47 | self.test_size = splits["test"] 48 | 49 | # load the points and control points 50 | with h5py.File(path, "r") as hf: 51 | points = np.array(hf.get(name="points")).astype(np.float32) 52 | control_points = np.array(hf.get(name="controlpoints")).astype(np.float32) 53 | 54 | np.random.seed(0) 55 | List = np.arange(points.shape[0]) 56 | np.random.shuffle(List) 57 | points = points[List] 58 | control_points = control_points[List] 59 | if closed: 60 | # closed spline has different split 61 | self.train_points = points[0:28000] 62 | self.val_points = points[28000:31000] 63 | self.test_points = points[31000:] 64 | 65 | self.train_control_points = control_points[0:28000] 66 | self.val_control_points = control_points[28000:31000] 67 | self.test_control_points = control_points[31000:] 68 | else: 69 | self.train_points = points[0:50000] 70 | self.val_points = points[50000:60000] 71 | self.test_points = points[60000:] 72 | 73 | self.train_control_points = control_points[0:50000] 74 | self.val_control_points = control_points[50000:60000] 75 | self.test_control_points = control_points[60000:] 76 | 77 | self.draw = DrawSurfs() 78 | 79 | def rotation_matrix_a_to_b(self, A, B): 80 | """ 81 | Finds rotation matrix from vector A in 3d to vector B 82 | in 3d. 83 | B = R @ A 84 | """ 85 | cos = np.dot(A, B) 86 | sin = np.linalg.norm(np.cross(B, A)) 87 | u = A 88 | v = B - np.dot(A, B) * A 89 | v = v / (np.linalg.norm(v) + EPS) 90 | w = np.cross(B, A) 91 | w = w / (np.linalg.norm(w) + EPS) 92 | F = np.stack([u, v, w], 1) 93 | G = np.array([[cos, -sin, 0], 94 | [sin, cos, 0], 95 | [0, 0, 1]]) 96 | 97 | # B = R @ A 98 | try: 99 | R = F @ G @ np.linalg.inv(F) 100 | except: 101 | R = np.eye(3, dtype=np.float32) 102 | return R 103 | 104 | def load_train_data(self, if_regular_points=False, align_canonical=False, anisotropic=False, if_augment=False): 105 | while True: 106 | for batch_id in range(self.train_size // self.batch_size - 1): 107 | Points = [] 108 | Parameters = [] 109 | controlpoints = [] 110 | scales = [] 111 | RS = [] 112 | 113 | for i in range(self.batch_size): 114 | points = self.train_points[batch_id * self.batch_size + i] 115 | mean = np.mean(points, 0) 116 | points = points - mean 117 | 118 | if align_canonical: 119 | S, U = self.pca_numpy(points) 120 | smallest_ev = U[:, np.argmin(S)] 121 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 122 | # rotate input points such that the minor principal 123 | # axis aligns with x axis. 124 | points = R @ points.T 125 | points = points.T 126 | RS.append(R) 127 | 128 | if anisotropic: 129 | std = np.abs(np.max(points, 0) - np.min(points, 0)) 130 | std = std.reshape((1, 3)) 131 | points = points / (std + EPS) 132 | else: 133 | std = np.max(np.max(points, 0) - np.min(points, 0)) 134 | points = points / std 135 | 136 | scales.append(std) 137 | Points.append(points) 138 | cntrl_point = self.train_control_points[batch_id * self.batch_size + i] 139 | cntrl_point = cntrl_point - mean.reshape((1, 1, 3)) 140 | 141 | if align_canonical: 142 | cntrl_point = cntrl_point.reshape((self.size_u * self.size_v, 3)) 143 | cntrl_point = R @ cntrl_point.T 144 | cntrl_point = np.reshape(cntrl_point.T, (self.size_u, self.size_v, 3)) 145 | 146 | if anisotropic: 147 | cntrl_point = cntrl_point / (std.reshape((1, 1, 3)) + EPS) 148 | else: 149 | cntrl_point = cntrl_point / std 150 | controlpoints.append(cntrl_point) 151 | controlpoints = np.stack(controlpoints, 0) 152 | Points = np.stack(Points, 0) 153 | if if_augment: 154 | Points = augment.augment(Points) 155 | Points = Points.astype(np.float32) 156 | yield [Points, None, controlpoints, scales, RS] 157 | 158 | def load_val_data(self, if_regular_points=False, align_canonical=False, anisotropic=False, if_augment=False): 159 | while True: 160 | for batch_id in range(self.val_size // self.batch_size - 1): 161 | Points = [] 162 | Parameters = [] 163 | controlpoints = [] 164 | scales = [] 165 | RS = [] 166 | for i in range(self.batch_size): 167 | points = self.val_points[batch_id * self.batch_size + i] 168 | mean = np.mean(points, 0) 169 | points = points - mean 170 | 171 | if align_canonical: 172 | S, U = self.pca_numpy(points) 173 | smallest_ev = U[:, np.argmin(S)] 174 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 175 | # rotate input points such that the minor principal 176 | # axis aligns with x axis. 177 | points = R @ points.T 178 | points = points.T 179 | RS.append(R) 180 | 181 | if anisotropic: 182 | std = np.abs(np.max(points, 0) - np.min(points, 0)) 183 | std = std.reshape((1, 3)) 184 | points = points / (std + EPS) 185 | else: 186 | std = np.max(np.max(points, 0) - np.min(points, 0)) 187 | points = points / std 188 | 189 | scales.append(std) 190 | Points.append(points) 191 | cntrl_point = self.val_control_points[batch_id * self.batch_size + i] 192 | cntrl_point = cntrl_point - mean.reshape((1, 1, 3)) 193 | 194 | if align_canonical: 195 | cntrl_point = cntrl_point.reshape((self.size_u * self.size_v, 3)) 196 | cntrl_point = R @ cntrl_point.T 197 | cntrl_point = np.reshape(cntrl_point.T, (self.size_u, self.size_v, 3)) 198 | 199 | if anisotropic: 200 | cntrl_point = cntrl_point / (std.reshape((1, 1, 3)) + EPS) 201 | else: 202 | cntrl_point = cntrl_point / std 203 | controlpoints.append(cntrl_point) 204 | controlpoints = np.stack(controlpoints, 0) 205 | Points = np.stack(Points, 0) 206 | if if_augment: 207 | Points = augment.augment(Points) 208 | Points = Points.astype(np.float32) 209 | yield [Points, None, controlpoints, scales, RS] 210 | 211 | def load_test_data(self, if_regular_points=False, align_canonical=False, anisotropic=False, if_augment=False): 212 | for batch_id in range(self.test_size // self.batch_size): 213 | Points = [] 214 | controlpoints = [] 215 | scales = [] 216 | RS = [] 217 | for i in range(self.batch_size): 218 | points = self.test_points[batch_id * self.batch_size + i] 219 | mean = np.mean(points, 0) 220 | points = points - mean 221 | 222 | if align_canonical: 223 | S, U = self.pca_numpy(points) 224 | smallest_ev = U[:, np.argmin(S)] 225 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 226 | # rotate input points such that the minor principal 227 | # axis aligns with x axis. 228 | points = R @ points.T 229 | points = points.T 230 | RS.append(R) 231 | 232 | if anisotropic: 233 | std = np.abs(np.max(points, 0) - np.min(points, 0)) 234 | std = std.reshape((1, 3)) 235 | points = points / (std + EPS) 236 | else: 237 | std = np.max(np.max(points, 0) - np.min(points, 0)) 238 | # points = points / std 239 | 240 | scales.append(std) 241 | Points.append(points) 242 | cntrl_point = self.test_control_points[batch_id * self.batch_size + i] 243 | cntrl_point = cntrl_point - mean.reshape((1, 1, 3)) 244 | 245 | if align_canonical: 246 | cntrl_point = cntrl_point.reshape((self.size_u * self.size_v, 3)) 247 | cntrl_point = R @ cntrl_point.T 248 | cntrl_point = np.reshape(cntrl_point.T, (self.size_u, self.size_v, 3)) 249 | if anisotropic: 250 | cntrl_point = cntrl_point / (std.reshape((1, 1, 3)) + EPS) 251 | else: 252 | cntrl_point = cntrl_point / std 253 | controlpoints.append(cntrl_point) 254 | 255 | controlpoints = np.stack(controlpoints, 0) 256 | Points = np.stack(Points, 0) 257 | if if_augment: 258 | Points = augment.augment(Points) 259 | Points = Points.astype(np.float32) 260 | yield [Points, None, controlpoints, scales, RS] 261 | 262 | def pca_torch(self, X): 263 | covariance = torch.transpose(X, 1, 0) @ X 264 | S, U = torch.eig(covariance, eigenvectors=True) 265 | return S, U 266 | 267 | def pca_numpy(self, X): 268 | S, U = np.linalg.eig(X.T @ X) 269 | return S, U 270 | -------------------------------------------------------------------------------- /src/dataset_segments.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines dataset loading for the segmentation task on ABC dataset. 3 | """ 4 | 5 | import h5py 6 | import numpy as np 7 | 8 | from src.augment_utils import rotate_perturbation_point_cloud, jitter_point_cloud, shift_point_cloud, \ 9 | random_scale_point_cloud, rotate_point_cloud 10 | 11 | EPS = np.finfo(np.float32).eps 12 | 13 | 14 | class Dataset: 15 | def __init__(self, 16 | batch_size, 17 | train_size=None, 18 | val_size=None, 19 | test_size=None, 20 | normals=False, 21 | primitives=False, 22 | if_train_data=True, 23 | prefix=""): 24 | """ 25 | Dataset of point cloud from ABC dataset. 26 | :param root_path: 27 | :param batch_size: 28 | :param if_train_data: since training dataset is large and consumes RAM, 29 | we can optionally choose to not load it. 30 | """ 31 | self.batch_size = batch_size 32 | self.normals = normals 33 | self.primitives = primitives 34 | self.augment_routines = [rotate_perturbation_point_cloud, jitter_point_cloud, shift_point_cloud, 35 | random_scale_point_cloud, rotate_point_cloud] 36 | 37 | if if_train_data: 38 | with h5py.File(prefix + "data/shapes/train_data.h5", "r") as hf: 39 | train_points = np.array(hf.get("points")) 40 | train_labels = np.array(hf.get("labels")) 41 | if normals: 42 | train_normals = np.array(hf.get("normals")) 43 | if primitives: 44 | train_primitives = np.array(hf.get("prim")) 45 | train_points = train_points[0:train_size].astype(np.float32) 46 | train_labels = train_labels[0:train_size] 47 | self.train_normals = train_normals[0:train_size].astype(np.float32) 48 | self.train_primitives = train_primitives[0:train_size] 49 | means = np.mean(train_points, 1) 50 | means = np.expand_dims(means, 1) 51 | 52 | self.train_points = (train_points - means) 53 | self.train_labels = train_labels 54 | 55 | with h5py.File(prefix + "data/shapes/val_data.h5", "r") as hf: 56 | val_points = np.array(hf.get("points")) 57 | val_labels = np.array(hf.get("labels")) 58 | if normals: 59 | val_normals = np.array(hf.get("normals")) 60 | if primitives: 61 | val_primitives = np.array(hf.get("prim")) 62 | 63 | with h5py.File(prefix + "data/shapes/test_data.h5", "r") as hf: 64 | test_points = np.array(hf.get("points")) 65 | test_labels = np.array(hf.get("labels")) 66 | if normals: 67 | test_normals = np.array(hf.get("normals")) 68 | if primitives: 69 | test_primitives = np.array(hf.get("prim")) 70 | 71 | val_points = val_points[0:val_size].astype(np.float32) 72 | val_labels = val_labels[0:val_size] 73 | 74 | test_points = test_points[0:test_size].astype(np.float32) 75 | test_labels = test_labels[0:test_size] 76 | 77 | if normals: 78 | self.val_normals = val_normals[0:val_size].astype(np.float32) 79 | self.test_normals = test_normals[0:test_size].astype(np.float32) 80 | 81 | if primitives: 82 | self.val_primitives = val_primitives[0:val_size] 83 | self.test_primitives = test_primitives[0:test_size] 84 | 85 | means = np.mean(test_points, 1) 86 | means = np.expand_dims(means, 1) 87 | self.test_points = (test_points - means) 88 | self.test_labels = test_labels 89 | 90 | means = np.mean(val_points, 1) 91 | means = np.expand_dims(means, 1) 92 | self.val_points = (val_points - means) 93 | self.val_labels = val_labels 94 | 95 | def get_train(self, randomize=False, augment=False, anisotropic=False, align_canonical=False, 96 | if_normal_noise=False): 97 | train_size = self.train_points.shape[0] 98 | while (True): 99 | l = np.arange(train_size) 100 | if randomize: 101 | np.random.shuffle(l) 102 | train_points = self.train_points[l] 103 | train_labels = self.train_labels[l] 104 | 105 | if self.normals: 106 | train_normals = self.train_normals[l] 107 | if self.primitives: 108 | train_primitives = self.train_primitives[l] 109 | 110 | for i in range(train_size // self.batch_size): 111 | points = train_points[i * self.batch_size:(i + 1) * 112 | self.batch_size] 113 | if self.normals: 114 | normals = train_normals[i * self.batch_size:(i + 1) * self.batch_size] 115 | 116 | if augment: 117 | points = self.augment_routines[np.random.choice(np.arange(5))](points) 118 | 119 | if if_normal_noise: 120 | normals = train_normals[i * self.batch_size:(i + 1) * self.batch_size] 121 | 122 | noise = normals * np.clip(np.random.randn(1, points.shape[1], 1) * 0.01, a_min=-0.01, a_max=0.01) 123 | points = points + noise.astype(np.float32) 124 | 125 | labels = train_labels[i * self.batch_size:(i + 1) * self.batch_size] 126 | 127 | for j in range(self.batch_size): 128 | if align_canonical: 129 | S, U = self.pca_numpy(points[j]) 130 | smallest_ev = U[:, np.argmin(S)] 131 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 132 | # rotate input points such that the minor principal 133 | # axis aligns with x axis. 134 | points[j] = (R @ points[j].T).T 135 | 136 | if self.normals: 137 | normals[j] = (R @ normals[j].T).T 138 | 139 | std = np.max(points[j], 0) - np.min(points[j], 0) 140 | if anisotropic: 141 | points[j] = points[j] / (std.reshape((1, 3)) + EPS) 142 | # TODO make the same changes to normals also. 143 | else: 144 | points[j] = points[j] / (np.max(std) + EPS) 145 | return_items = [points, labels] 146 | if self.normals: 147 | return_items.append(normals) 148 | else: 149 | return_items.append(None) 150 | 151 | if self.primitives: 152 | primitives = train_primitives[i * self.batch_size:(i + 1) * self.batch_size] 153 | return_items.append(primitives) 154 | else: 155 | return_items.append(None) 156 | 157 | yield return_items 158 | 159 | def get_test(self, randomize=False, anisotropic=False, align_canonical=False, if_normal_noise=False): 160 | test_size = self.test_points.shape[0] 161 | batch_size = self.batch_size 162 | 163 | while (True): 164 | for i in range(test_size // batch_size): 165 | points = self.test_points[i * self.batch_size:(i + 1) * 166 | self.batch_size] 167 | labels = self.test_labels[i * self.batch_size:(i + 1) * self.batch_size] 168 | if self.normals: 169 | normals = self.test_normals[i * self.batch_size:(i + 1) * 170 | self.batch_size] 171 | if if_normal_noise and self.normals: 172 | normals = self.test_normals[i * self.batch_size:(i + 1) * 173 | self.batch_size] 174 | noise = normals * np.clip(np.random.randn(1, points.shape[1], 1) * 0.01, a_min=-0.01, a_max=0.01) 175 | points = points + noise.astype(np.float32) 176 | 177 | new_points = [] 178 | for j in range(self.batch_size): 179 | if align_canonical: 180 | S, U = self.pca_numpy(points[j]) 181 | smallest_ev = U[:, np.argmin(S)] 182 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 183 | # rotate input points such that the minor principal 184 | # axis aligns with x axis. 185 | points[j] = (R @ points[j].T).T 186 | if self.normals: 187 | normals[j] = (R @ normals[j].T).T 188 | 189 | std = np.max(points[j], 0) - np.min(points[j], 0) 190 | if anisotropic: 191 | points[j] = points[j] / (std.reshape((1, 3)) + EPS) 192 | else: 193 | points[j] = points[j] / (np.max(std) + EPS) 194 | 195 | return_items = [points, labels] 196 | if self.normals: 197 | return_items.append(normals) 198 | else: 199 | return_items.append(None) 200 | 201 | if self.primitives: 202 | primitives = self.test_primitives[i * self.batch_size:(i + 1) * self.batch_size] 203 | return_items.append(primitives) 204 | else: 205 | return_items.append(None) 206 | yield return_items 207 | 208 | def get_val(self, randomize=False, anisotropic=False, align_canonical=False, if_normal_noise=False): 209 | val_size = self.val_points.shape[0] 210 | batch_size = self.batch_size 211 | 212 | while (True): 213 | for i in range(val_size // batch_size): 214 | points = self.val_points[i * self.batch_size:(i + 1) * 215 | self.batch_size] 216 | labels = self.val_labels[i * self.batch_size:(i + 1) * self.batch_size] 217 | if self.normals: 218 | normals = self.val_normals[i * self.batch_size:(i + 1) * 219 | self.batch_size] 220 | if if_normal_noise and self.normals: 221 | normals = self.val_normals[i * self.batch_size:(i + 1) * 222 | self.batch_size] 223 | noise = normals * np.clip(np.random.randn(1, points.shape[1], 1) * 0.01, a_min=-0.01, a_max=0.01) 224 | points = points + noise.astype(np.float32) 225 | 226 | new_points = [] 227 | for j in range(self.batch_size): 228 | if align_canonical: 229 | S, U = self.pca_numpy(points[j]) 230 | smallest_ev = U[:, np.argmin(S)] 231 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 232 | # rotate input points such that the minor principal 233 | # axis aligns with x axis. 234 | points[j] = (R @ points[j].T).T 235 | if self.normals: 236 | normals[j] = (R @ normals[j].T).T 237 | 238 | std = np.max(points[j], 0) - np.min(points[j], 0) 239 | if anisotropic: 240 | points[j] = points[j] / (std.reshape((1, 3)) + EPS) 241 | else: 242 | points[j] = points[j] / (np.max(std) + EPS) 243 | 244 | return_items = [points, labels] 245 | if self.normals: 246 | return_items.append(normals) 247 | else: 248 | return_items.append(None) 249 | 250 | if self.primitives: 251 | primitives = self.val_primitives[i * self.batch_size:(i + 1) * self.batch_size] 252 | return_items.append(primitives) 253 | else: 254 | return_items.append(None) 255 | yield return_items 256 | 257 | def normalize_points(self, points, normals, anisotropic=False): 258 | points = points - np.mean(points, 0, keepdims=True) 259 | noise = normals * np.clip(np.random.randn(points.shape[0], 1) * 0.01, a_min=-0.01, a_max=0.01) 260 | points = points + noise.astype(np.float32) 261 | 262 | S, U = self.pca_numpy(points) 263 | smallest_ev = U[:, np.argmin(S)] 264 | R = self.rotation_matrix_a_to_b(smallest_ev, np.array([1, 0, 0])) 265 | # rotate input points such that the minor principal 266 | # axis aligns with x axis. 267 | points = (R @ points.T).T 268 | normals = (R @ normals.T).T 269 | std = np.max(points, 0) - np.min(points, 0) 270 | if anisotropic: 271 | points = points / (std.reshape((1, 3)) + EPS) 272 | else: 273 | points = points / (np.max(std) + EPS) 274 | return points.astype(np.float32), normals.astype(np.float32) 275 | 276 | def rotation_matrix_a_to_b(self, A, B): 277 | """ 278 | Finds rotation matrix from vector A in 3d to vector B 279 | in 3d. 280 | B = R @ A 281 | """ 282 | cos = np.dot(A, B) 283 | sin = np.linalg.norm(np.cross(B, A)) 284 | u = A 285 | v = B - np.dot(A, B) * A 286 | v = v / (np.linalg.norm(v) + EPS) 287 | w = np.cross(B, A) 288 | w = w / (np.linalg.norm(w) + EPS) 289 | F = np.stack([u, v, w], 1) 290 | G = np.array([[cos, -sin, 0], 291 | [sin, cos, 0], 292 | [0, 0, 1]]) 293 | # B = R @ A 294 | try: 295 | R = F @ G @ np.linalg.inv(F) 296 | except: 297 | R = np.eye(3, dtype=np.float32) 298 | return R 299 | 300 | def pca_numpy(self, X): 301 | S, U = np.linalg.eig(X.T @ X) 302 | return S, U 303 | -------------------------------------------------------------------------------- /src/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from src.segment_utils import SIOU_matched_segments 5 | from src.segment_utils import to_one_hot 6 | 7 | 8 | def mean_IOU_one_sample(pred, gt, C): 9 | IoU_part = 0.0 10 | for label_idx in range(C): 11 | locations_gt = (gt == label_idx) 12 | locations_pred = (pred == label_idx) 13 | I_locations = np.logical_and(locations_gt, locations_pred) 14 | U_locations = np.logical_or(locations_gt, locations_pred) 15 | I = np.sum(I_locations) + np.finfo(np.float32).eps 16 | U = np.sum(U_locations) + np.finfo(np.float32).eps 17 | IoU_part = IoU_part + I / U 18 | return IoU_part / C 19 | 20 | 21 | def iou_segmentation(pred, gt): 22 | # preprocess the predictions and gt to remove the extras 23 | # swap (0, 6, 7) to closed surfaces which is 9 24 | # swap 8 to 2 25 | gt[gt == 0] = 9 26 | gt[gt == 6] = 9 27 | gt[gt == 7] = 9 28 | gt[gt == 8] = 2 29 | 30 | pred[pred == 0] = 9 31 | pred[pred == 6] = 9 32 | pred[pred == 7] = 9 33 | pred[pred == 8] = 2 34 | 35 | return mean_IOU_one_sample(pred, gt, 6) 36 | 37 | 38 | def to_one_hot(target, maxx=50): 39 | target = torch.from_numpy(target.astype(np.int64)).cuda() 40 | N = target.shape[0] 41 | target_one_hot = torch.zeros((N, maxx)) 42 | 43 | target_one_hot = target_one_hot.cuda() 44 | target_t = target.unsqueeze(1) 45 | target_one_hot = target_one_hot.scatter_(1, target_t.long(), 1) 46 | return target_one_hot 47 | 48 | 49 | def matching_iou(matching, predicted_labels, labels): 50 | batch_size = labels.shape[0] 51 | IOU = [] 52 | new_pred = [] 53 | for b in range(batch_size): 54 | iou_b = [] 55 | len_labels = np.unique(predicted_labels[b]).shape[0] 56 | rows, cols = matching[b] 57 | count = 0 58 | for r, c in zip(rows, cols): 59 | pred_indices = predicted_labels[b] == r 60 | gt_indices = labels[b] == c 61 | 62 | # if both input and predictions are empty, ignore that. 63 | if (np.sum(gt_indices) == 0) and (np.sum(pred_indices) == 0): 64 | continue 65 | iou = np.sum(np.logical_and(pred_indices, gt_indices)) / ( 66 | np.sum(np.logical_or(pred_indices, gt_indices)) + 1e-8) 67 | iou_b.append(iou) 68 | 69 | # find the mean of IOU over this shape 70 | IOU.append(np.mean(iou_b)) 71 | return np.mean(IOU) 72 | 73 | 74 | def relaxed_iou(pred, gt, max_clusters=50): 75 | batch_size, N, K = pred.shape 76 | normalize = torch.nn.functional.normalize 77 | one = torch.ones(1).cuda() 78 | 79 | norms_p = torch.sum(pred, 1) 80 | norms_g = torch.sum(gt, 1) 81 | cost = [] 82 | 83 | for b in range(batch_size): 84 | p = pred[b] 85 | g = gt[b] 86 | c_batch = [] 87 | dots = p.transpose(1, 0) @ g 88 | 89 | for k1 in range(K): 90 | c = [] 91 | for k2 in range(K): 92 | r_iou = dots[k1, k2] 93 | r_iou = r_iou / (norms_p[b, k1] + norms_g[b, k2] - dots[k1, k2] + 1e-10) 94 | if r_iou < 0: 95 | import ipdb; 96 | ipdb.set_trace() 97 | c.append(r_iou) 98 | c_batch.append(c) 99 | cost.append(c_batch) 100 | return cost 101 | 102 | 103 | def p_coverage(points, parameters, ResidualLoss): 104 | """ 105 | Compute the p coverage as described in the SPFN paper. 106 | Basically, for each input point, it finds the closest 107 | primitive and define that as the distance from the predicted 108 | surface. Mean over all these distance is reported. 109 | :param points: input point cloud, numpy array, N x 3 110 | :param parameters: dictionary of parameters predicted by the algorithm. 111 | """ 112 | residual_reduce = ResidualLoss(one_side=True, reduce=False) 113 | points = torch.from_numpy(points).cuda() 114 | gpoints = {k: points for k in parameters.keys()} 115 | reduce_distance = residual_reduce.residual_loss(gpoints, 116 | parameters, 117 | sqrt=True) 118 | 119 | reduce_distance = [v[1] for k, v in reduce_distance.items()] 120 | reduce_distance = torch.stack([r for r in reduce_distance], 0) 121 | print(reduce_distance.shape) 122 | reduce_distance = torch.min(reduce_distance, 0)[0] 123 | 124 | cover = reduce_distance < 0.01 125 | cover = torch.mean(cover.float()) 126 | mean_coverage = torch.mean(reduce_distance) 127 | return mean_coverage, cover 128 | 129 | 130 | def separate_losses(distance, gt_points, lamb=1.0): 131 | """ 132 | The idea is to define losses for geometric primitives and splines separately. 133 | :param distance: dictionary containing residual loss for all the geometric 134 | primitives and splines 135 | :param gt_points: dictionary containing ground truth points for matched 136 | points, used to ignore loss for the surfaces with smaller than threshold points 137 | """ 138 | Loss = [] 139 | geometric_loss = [] 140 | spline_loss = [] 141 | # TODO remove parts that are way off from the ground truth points 142 | for item, v in enumerate(sorted(gt_points.keys())): 143 | # cases where number of points are less than 20 or 144 | # bspline surface patches with less than 100 points 145 | if gt_points[v] is None: 146 | continue 147 | if gt_points[v].shape[0] < 100: 148 | continue 149 | if distance[v][1] > 1: 150 | # most probably a degenerate case 151 | # give a fixed error for this. 152 | distance[v][1] = torch.ones(1).cuda()[0] * 0.1 153 | 154 | if distance[v][0] in ["closed-spline", "open-spline"]: 155 | spline_loss.append(distance[v][1].item()) 156 | Loss.append(distance[v][1] * lamb) 157 | else: 158 | geometric_loss.append(distance[v][1].item()) 159 | Loss.append(distance[v][1]) 160 | 161 | try: 162 | Loss = torch.mean(torch.stack(Loss)) 163 | except: 164 | Loss = torch.zeros(1).cuda() 165 | 166 | if len(geometric_loss) > 0: 167 | geometric_loss = np.mean(geometric_loss) 168 | else: 169 | geometric_loss = None 170 | 171 | if len(spline_loss) > 0: 172 | spline_loss = np.mean(spline_loss) 173 | else: 174 | spline_loss = None 175 | return [Loss, geometric_loss, spline_loss] 176 | 177 | 178 | def IOU(data): 179 | """ 180 | Take the per shape output predictions, and produces segment IOU, and 181 | primitive type IOU. 182 | """ 183 | Mapping = {"torus": 0, 184 | "plane": 1, 185 | "cone": 3, 186 | "cylinder": 4, 187 | "sphere": 5, 188 | "open-spline": 2, 189 | "closed-spline": 9} 190 | 191 | parameters = data["primitive_dict"] 192 | # setting the not assigned 193 | 194 | primitives = data["primitives"] 195 | 196 | label_to_primitive = {} 197 | if (data.get("pred_primitives") is None): 198 | pred_primitives = np.zeros(data["points"].shape[0]) 199 | for k, v in data["primitive_dict"].items(): 200 | pred_primitives[data["seg_id"] == k] = Mapping[v[0]] 201 | else: 202 | pred_primitives = data["pred_primitives"] 203 | pred_primitives[pred_primitives == 0] = 9 204 | pred_primitives[pred_primitives == 6] = 9 205 | pred_primitives[pred_primitives == 7] = 9 206 | pred_primitives[pred_primitives == 8] = 2 207 | 208 | primitives[primitives == 0] = 9 209 | primitives[primitives == 6] = 9 210 | primitives[primitives == 7] = 9 211 | primitives[primitives == 8] = 2 212 | 213 | if (data.get("weights") is None): 214 | weights = to_one_hot(data["seg_id"], 215 | np.unique(data["seg_id"]).shape[0]).data.cpu().numpy() 216 | else: 217 | weights = data["weights"] 218 | 219 | s_iou, p_iou, _, iou_b_prims = SIOU_matched_segments(data["labels"], 220 | data["seg_id"], 221 | pred_primitives, 222 | data["primitives"], 223 | weights) 224 | return s_iou, p_iou, iou_b_prims 225 | 226 | 227 | def IOU_simple(data): 228 | """ 229 | Take the per shape output predictions, and produces segment IOU, and 230 | primitive type IOU. 231 | """ 232 | Mapping = {"torus": 0, 233 | "plane": 1, 234 | "cone": 3, 235 | "cylinder": 4, 236 | "sphere": 5, 237 | "open-spline": 2, 238 | "closed-spline": 9} 239 | 240 | parameters = data["primitive_dict"] 241 | # setting the not assigned 242 | pred_primitives = np.zeros(data["points"].shape[0]) 243 | 244 | label_to_primitive = {} 245 | if (data.get("pred_primitives") is None): 246 | for k, v in data["primitive_dict"].items(): 247 | pred_primitives[data["seg_id"] == k] = Mapping[v[0]] 248 | else: 249 | pred_primitives = data["pred_primitives"] 250 | 251 | if (data.get("weights") is None): 252 | weights = to_one_hot(data["seg_id"], 253 | np.unique(data["seg_id"]).shape[0]).data.cpu().numpy() 254 | else: 255 | weights = data["weights"] 256 | 257 | s_iou, p_iou, _ = SIOU_matched_segments(data["labels"], 258 | data["seg_id"], 259 | pred_primitives, 260 | data["primitives"], 261 | weights) 262 | return s_iou, p_iou 263 | 264 | 265 | def preprocess(data, rem_unassign=False): 266 | N = data["seg_id"].shape[0] 267 | keep_indices = np.logical_not(data["seg_id"] == 100) 268 | print("unassigned no. points ", N - np.sum(keep_indices)) 269 | if rem_unassign: 270 | # assign nearest labels 271 | data = remove_unassigned(data) 272 | else: 273 | # remove the points that are not assigned. 274 | data["points"] = data["points"][keep_indices] 275 | data["normals"] = data["normals"][keep_indices] 276 | data["seg_id"] = data["seg_id"][keep_indices] 277 | data["primitives"] = data["primitives"][keep_indices] 278 | data["labels"] = data["labels"][keep_indices] 279 | return data 280 | 281 | 282 | def remove_unassigned(data): 283 | """ 284 | For un assigned points, assign the nearest neighbors label. 285 | """ 286 | points = torch.from_numpy(data['points'].astype(np.float32)).cuda() 287 | dst_matrix = torch.sum((torch.unsqueeze(points, 1) - torch.unsqueeze(points, 0)) ** 2, 2) 288 | unassigned_index = data['seg_id'] == 100 289 | 290 | dst_matrix = dst_matrix.fill_diagonal_(2e8) 291 | dst_matrix[:, unassigned_index] = 2e8 292 | nearest_index = torch.min(dst_matrix, 1)[1].data.cpu().numpy() 293 | 294 | data['seg_id'][unassigned_index] = data['seg_id'][nearest_index[unassigned_index]] 295 | return data 296 | -------------------------------------------------------------------------------- /src/fitting_optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script describes all fitting modules like bspline fitting, geometric 3 | primitives. The idea is to call each module with required input parameters 4 | and get as an output the parameters of fitting. 5 | """ 6 | import numpy as np 7 | import open3d 8 | import scipy 9 | import torch 10 | from lapsolver import solve_dense 11 | from open3d import * 12 | 13 | from src.VisUtils import tessalate_points 14 | from src.curve_utils import DrawSurfs 15 | from src.fitting_utils import ( 16 | project_to_plane, 17 | ) 18 | from src.loss import ( 19 | uniform_knot_bspline, 20 | ) 21 | from src.primitive_forward import Fit 22 | from src.primitive_forward import forward_pass_open_spline, forward_closed_splines, initialize_open_spline_model, \ 23 | initialize_closed_spline_model 24 | from src.utils import visualize_point_cloud 25 | 26 | Vector3dVector, Vector3iVector = utility.Vector3dVector, utility.Vector3iVector 27 | draw_surf = DrawSurfs() 28 | regular_parameters = draw_surf.regular_parameterization(30, 30) 29 | EPS = np.finfo(np.float32).eps 30 | 31 | 32 | class Arap: 33 | def __init__(self, size_u=31, size_v=30): 34 | """ 35 | As rigid as possible transformation of mesh, 36 | """ 37 | self.size_u = size_u 38 | self.size_v = size_v 39 | l = np.array(self.get_boundary_indices(size_u, size_v)) 40 | 41 | indices = [] 42 | for i in range(l.shape[0]): 43 | indices.append(np.unravel_index(np.ravel_multi_index(l[i], 44 | [size_u, size_v]), 45 | [size_u * size_v])[0]) 46 | self.indices = indices 47 | 48 | def deform(self, recon_points, gt_points, viz=False): 49 | """ 50 | ARAP, given recon_points, that are in grid, we first create a mesh out of 51 | it, then we do max matching to find correspondance between gt and boundary 52 | points. Then we do ARAP over the mesh, making the boundary points go closer 53 | to the matched points. Note that this is taking only the points 54 | TODO: better way to do it is do maximal matching between all points and use 55 | only the boundary points as the pivot points. 56 | """ 57 | new_recon_points = recon_points.reshape((self.size_u, self.size_v, 3)) 58 | mesh = tessalate_points(recon_points, self.size_u, self.size_v) 59 | 60 | new_recon_points = recon_points.reshape((self.size_u, self.size_v, 3)) 61 | 62 | mesh_ = mesh 63 | for i in range(1): 64 | mesh, constraint_ids, constraint_pos = self.generate_handles(mesh_, 65 | self.indices, 66 | gt_points, 67 | np.array(mesh_.vertices)) 68 | constraint_ids = np.array(constraint_ids, dtype=np.int32) 69 | constraint_pos = open3d.utility.Vector3dVector(constraint_pos) 70 | 71 | mesh_prime = mesh.deform_as_rigid_as_possible( 72 | open3d.utility.IntVector(constraint_ids), constraint_pos, max_iter=500) 73 | mesh_ = mesh_prime 74 | 75 | if viz: 76 | pcd = visualize_point_cloud(gt_points) 77 | mesh_prime.compute_vertex_normals() 78 | mesh.paint_uniform_color((1, 0, 0)) 79 | handles = open3d.geometry.PointCloud() 80 | handles.points = constraint_pos 81 | handles.paint_uniform_color((0, 1, 0)) 82 | open3d.visualization.draw_geometries([mesh, mesh_prime, handles, pcd]) 83 | return mesh_prime 84 | 85 | def get_boundary_indices(self, m, n): 86 | l = [] 87 | for i in range(m): 88 | for j in range(n): 89 | if (j == 0): 90 | l.append((i, j)) 91 | elif (j == n - 1): 92 | l.append((i, j)) 93 | return l 94 | 95 | def generate_handles(self, mesh, indices, input_points, recon_points): 96 | matched_points = self.define_matching(input_points, recon_points) 97 | dist = matched_points - recon_points 98 | vertices = np.asarray(mesh.vertices) 99 | 100 | handle_ids = indices 101 | handle_positions = [] 102 | for i in indices: 103 | handle_positions.append(vertices[i] + dist[i]) 104 | return mesh, handle_ids, handle_positions 105 | 106 | def define_matching(self, input, out): 107 | # Input points need to at least 1.2 times more than output points 108 | L = np.random.choice(np.arange(input.shape[0]), int(1.2 * out.shape[0]), replace=False) 109 | input = input[L] 110 | 111 | dist = scipy.spatial.distance.cdist(out, input) 112 | rids, cids = solve_dense(dist) 113 | matched = input[cids] 114 | return matched 115 | 116 | 117 | class FittingModule: 118 | def __init__(self, closed_splinenet_path, open_splinenet_path): 119 | self.fitting = Fit() 120 | self.closed_splinenet_path = closed_splinenet_path 121 | self.open_splinenet_path = open_splinenet_path 122 | 123 | # get routine for the spline prediction 124 | nu, nv = uniform_knot_bspline(20, 20, 3, 3, 30) 125 | self.nu = torch.from_numpy(nu.astype(np.float32)) 126 | self.nv = torch.from_numpy(nv.astype(np.float32)) 127 | self.open_control_decoder = initialize_open_spline_model( 128 | self.open_splinenet_path, 0 129 | ) 130 | self.closed_control_decoder = initialize_closed_spline_model( 131 | self.closed_splinenet_path, 1 132 | ) 133 | 134 | def forward_pass_open_spline(self, points, ids, weights, if_optimize=False): 135 | points = torch.unsqueeze(points, 0) 136 | 137 | # NOTE: this will avoid back ward pass through the encoder of SplineNet. 138 | points.requires_grad = False 139 | reconst_points = forward_pass_open_spline( 140 | input_points_=points, control_decoder=self.open_control_decoder, nu=self.nu, nv=self.nv, 141 | if_optimize=if_optimize, weights=weights)[1] 142 | # reconst_points = np.array(reconst_points).astype(np.float32) 143 | torch.cuda.empty_cache() 144 | self.fitting.parameters[ids] = ["open-spline", reconst_points] 145 | return reconst_points 146 | 147 | def forward_pass_closed_spline(self, points, ids, weights, if_optimize=False): 148 | points = torch.unsqueeze(points, 0) 149 | points.requires_grad = False 150 | reconst_points = forward_closed_splines( 151 | points, self.closed_control_decoder, self.nu, self.nv, if_optimize=if_optimize, weights=weights)[2] 152 | torch.cuda.empty_cache() 153 | self.fitting.parameters[ids] = ["closed-spline", reconst_points] 154 | 155 | return reconst_points 156 | 157 | def forward_pass_plane(self, points, normals, weights, ids, sample_points=False): 158 | axis, distance = self.fitting.fit_plane_torch( 159 | points=points, 160 | normals=normals, 161 | weights=weights, 162 | ids=ids, 163 | ) 164 | self.fitting.parameters[ids] = ["plane", axis.reshape((3, 1)), distance] 165 | if sample_points: 166 | # Project points on the surface 167 | new_points = project_to_plane( 168 | points, axis, distance.item() 169 | ) 170 | 171 | new_points = self.fitting.sample_plane( 172 | distance.item(), 173 | axis.data.cpu().numpy(), 174 | mean=torch.mean(new_points, 0).data.cpu().numpy(), 175 | ) 176 | return new_points 177 | else: 178 | None 179 | 180 | def forward_pass_cone(self, points, normals, weights, ids, sample_points=False): 181 | try: 182 | apex, axis, theta = self.fitting.fit_cone_torch( 183 | points, 184 | normals, 185 | weights=weights, 186 | ids=ids, 187 | ) 188 | except: 189 | import ipdb; 190 | ipdb.set_trace() 191 | 192 | self.fitting.parameters[ids] = ["cone", apex.reshape((1, 3)), axis.reshape((3, 1)), theta] 193 | if sample_points: 194 | new_points, new_normals = self.fitting.sample_cone_trim( 195 | apex.data.cpu().numpy().reshape(3), axis.data.cpu().numpy().reshape(3), theta.item(), 196 | points.data.cpu().numpy() 197 | ) 198 | # new_points = project_to_point_cloud(points, new_points) 199 | return new_points 200 | else: 201 | None 202 | 203 | def forward_pass_cylinder(self, points, normals, weights, ids, sample_points=False): 204 | a, center, radius = self.fitting.fit_cylinder_torch( 205 | points, 206 | normals, 207 | weights, 208 | ids=ids, 209 | ) 210 | self.fitting.parameters[ids] = ["cylinder", a, center, radius] 211 | if sample_points: 212 | new_points, new_normals = self.fitting.sample_cylinder_trim( 213 | radius.item(), 214 | center.data.cpu().numpy().reshape(3), 215 | a.data.cpu().numpy().reshape(3), 216 | points.data.cpu().numpy(), 217 | N=10000, 218 | ) 219 | 220 | # new_points = project_to_point_cloud(points, new_points) 221 | return new_points 222 | else: 223 | return None 224 | 225 | def forward_pass_sphere(self, points, normals, weights, ids, sample_points=False): 226 | center, radius = self.fitting.fit_sphere_torch( 227 | points, 228 | normals, 229 | weights, 230 | ids=ids, 231 | ) 232 | self.fitting.parameters[ids] = ["sphere", center, radius] 233 | if sample_points: 234 | # Project points on the surface 235 | new_points, new_normals = self.fitting.sample_sphere(radius.item(), center.data.cpu().numpy(), N=10000) 236 | center = center.data.cpu().numpy() 237 | # sphere = geometry.TriangleMesh.create_sphere(radius=radius, resolution=50) 238 | # sphere.translate(center.reshape(3).astype(np.float64).tolist()) 239 | # new_points = project_to_point_cloud(points.data.cpu().numpy(), new_points) 240 | return new_points 241 | else: 242 | return None 243 | -------------------------------------------------------------------------------- /src/guard.py: -------------------------------------------------------------------------------- 1 | """ 2 | Guarded routines for torch. 3 | """ 4 | import torch 5 | 6 | 7 | def guard_exp(x, max_value=75, min_value=-75): 8 | x = torch.clamp(x, max=max_value, min=min_value) 9 | return torch.exp(x) 10 | 11 | 12 | def guard_sqrt(x, minimum=1e-5): 13 | x = torch.clamp(x, min=minimum) 14 | return torch.sqrt(x) 15 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd.variable import Variable 5 | from torch.nn import MSELoss 6 | 7 | from src.utils import chamfer_distance, chamfer_distance_one_side 8 | 9 | mse = MSELoss(size_average=True, reduce=True) 10 | 11 | 12 | def regressions_loss_per_shape(output, points): 13 | """ 14 | Both are in square grid 15 | """ 16 | dist = torch.sum((output - points) ** 2, 2) 17 | dist = torch.mean(dist) 18 | return dist 19 | 20 | 21 | def all_permutations(array): 22 | """ 23 | This method is used to generate permutation of control points grid. 24 | This is specifically used for open b-spline surfaces. 25 | """ 26 | permutations = [] 27 | permutations.append(array) 28 | permutations.append(torch.flip(array, (1,))) 29 | permutations.append(torch.flip(array, (2,))) 30 | permutations.append(torch.flip(array, (1, 2))) 31 | 32 | permutations.append(torch.transpose(array, 2, 1)) 33 | permutations.append(torch.transpose(torch.flip(array, (1,)), 2, 1)) 34 | permutations.append(torch.transpose(torch.flip(array, (2,)), 2, 1)) 35 | permutations.append(torch.transpose(torch.flip(array, (1, 2)), 2, 1)) 36 | permutations = torch.stack(permutations, 0) 37 | permutations = permutations.permute(1, 0, 2, 3, 4) 38 | return permutations 39 | 40 | 41 | def all_permutations_half(array): 42 | """ 43 | This method is used to generate permutation of control points grid. 44 | This is specifically used for closed b-spline surfaces. Note that 45 | In the pre-processing step, all closed splines are made to close in u 46 | direction only, thereby reducing the possible permutations to half. This 47 | is done to speedup the training and also to facilitate learning for neural 48 | network. 49 | """ 50 | permutations = [] 51 | permutations.append(array) 52 | permutations.append(torch.flip(array, (1,))) 53 | permutations.append(torch.flip(array, (2,))) 54 | permutations.append(torch.flip(array, (1, 2))) 55 | permutations = torch.stack(permutations, 0) 56 | permutations = permutations.permute(1, 0, 2, 3, 4) 57 | return permutations 58 | 59 | 60 | def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad=None): 61 | """ 62 | Rolls the tensor by certain shifts along certain dimension. 63 | """ 64 | if 0 == shift: 65 | return x 66 | elif shift < 0: 67 | shift = -shift 68 | gap = x.index_select(dim, torch.arange(shift)) 69 | return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim))), gap], dim=dim) 70 | else: 71 | shift = x.size(dim) - shift 72 | gap = x.index_select(dim, torch.arange(shift, x.size(dim)).cuda()) 73 | return torch.cat([gap, x.index_select(dim, torch.arange(shift).cuda())], dim=dim) 74 | 75 | 76 | def control_points_permute_reg_loss(output, control_points, grid_size): 77 | """ 78 | control points prediction with permutation invariant loss 79 | :param output: output of the network 80 | :param control_points: N x grid_size x grid_size x 3 81 | :param grid_size_x: size of the control points in u direction 82 | :param grid_size_y: size of the control points in v direction 83 | """ 84 | batch_size = output.shape[0] 85 | # TODO Check whether this permutation is good or not. 86 | output = output.view(batch_size, grid_size, grid_size, 3) 87 | output = torch.unsqueeze(output, 1) 88 | 89 | # N x 8 x grid_size x grid_size x 3 90 | control_points = all_permutations(control_points) 91 | diff = (output - control_points) ** 2 92 | diff = torch.sum(diff, (2, 3, 4)) 93 | loss, index = torch.min(diff, 1) 94 | loss = torch.mean(loss) / (grid_size * grid_size * 3) 95 | # returns the loss and also the permutation that matches 96 | # best with the input. 97 | return loss, control_points[np.arange(batch_size), index] 98 | 99 | 100 | def control_points_permute_closed_reg_loss(output, control_points, grid_size_x, grid_size_y): 101 | """ 102 | control points prediction with permutation invariant loss 103 | :param output: output of the network 104 | :param control_points: N x grid_size x grid_size x 3 105 | :param grid_size_x: size of the control points in u direction 106 | :param grid_size_y: size of the control points in v direction 107 | """ 108 | batch_size = output.shape[0] 109 | output = output.view(batch_size, grid_size_x, grid_size_y, 3) 110 | output = torch.unsqueeze(output, 1) 111 | 112 | # N x 8 x grid_size x grid_size x 3 113 | rhos = [] 114 | for i in range(grid_size_y): 115 | new_control_points = roll(control_points, i, 1) 116 | rhos.append(all_permutations_half(new_control_points)) 117 | control_points = torch.cat(rhos, 1) 118 | 119 | diff = (output - control_points) ** 2 120 | diff = torch.sum(diff, (2, 3, 4)) 121 | loss, index = torch.min(diff, 1) 122 | loss = torch.mean(loss) / (grid_size_x * grid_size_y * 3) 123 | 124 | return loss, control_points[np.arange(batch_size), index] 125 | 126 | 127 | def control_points_loss(output, control_points, grid_size): 128 | """ 129 | control points prediction with permutation invariant loss 130 | :param output: N x C x 3 131 | :param control_points: N x grid_size x grid_size x 3 132 | """ 133 | batch_size = output.shape[0] 134 | # N x 8 x grid_size x grid_size x 3 135 | output = output.view(batch_size, grid_size, grid_size, 3) 136 | diff = (output - control_points) ** 2 137 | diff = torch.sum(diff, (1, 2, 3)) 138 | loss = torch.mean(diff) / (grid_size * grid_size * 3) 139 | return loss 140 | 141 | 142 | def spline_reconstruction_loss_one_sided(nu, nv, output, points, config, side=1): 143 | """ 144 | Spline reconsutruction loss defined using chamfer distance, but one 145 | sided either gt surface can cover the prediction or otherwise, which 146 | is defined by the network. side=1 means prediction can cover gt. 147 | :param nu: spline basis function in u direction. 148 | :param nv: spline basis function in v direction. 149 | :param points: points sampled over the spline. 150 | :param config: object of configuration class for extra parameters. 151 | """ 152 | reconst_points = [] 153 | batch_size = output.shape[0] 154 | c_size_u = output.shape[1] 155 | c_size_v = output.shape[2] 156 | grid_size_u = nu.shape[0] 157 | grid_size_v = nv.shape[0] 158 | 159 | output = output.view(config.batch_size, config.grid_size, config.grid_size, 3) 160 | points = points.permute(0, 2, 1) 161 | for b in range(config.batch_size): 162 | point = [] 163 | for i in range(3): 164 | point.append(torch.matmul(torch.matmul(nu, output[b, :, :, i]), torch.transpose(nv, 1, 0))) 165 | reconst_points.append(torch.stack(point, 2)) 166 | 167 | reconst_points = torch.stack(reconst_points, 0) 168 | reconst_points = reconst_points.view(config.batch_size, grid_size_u * grid_size_v, 3) 169 | dist = chamfer_distance_one_side(reconst_points, points, side) 170 | return dist, reconst_points 171 | 172 | 173 | def spline_reconstruction_loss(nu, nv, output, points, config, sqrt=False): 174 | reconst_points = [] 175 | batch_size = output.shape[0] 176 | grid_size = nu.shape[0] 177 | output = output.reshape(config.batch_size, nu.shape[1], nv.shape[1], 3) 178 | points = points.permute(0, 2, 1) 179 | for b in range(config.batch_size): 180 | point = [] 181 | for i in range(3): 182 | point.append(torch.matmul(torch.matmul(nu, output[b, :, :, i]), torch.transpose(nv, 1, 0))) 183 | reconst_points.append(torch.stack(point, 2)) 184 | reconst_points = torch.stack(reconst_points, 0) 185 | reconst_points = reconst_points.view(config.batch_size, grid_size ** 2, 3) 186 | dist = chamfer_distance(reconst_points, points, sqrt=sqrt) 187 | return dist, reconst_points 188 | 189 | 190 | def uniform_knot_bspline(control_points_u, control_points_v, degree_u, degree_v, grid_size=30): 191 | """ 192 | Returns uniform knots, given the number of control points in u and v directions and 193 | their degrees. 194 | """ 195 | u = v = np.arange(0., 1, 1 / grid_size) 196 | 197 | knots_u = [0.0] * degree_u + np.arange(0, 1.01, 1 / (control_points_u - degree_u)).tolist() + [1.0] * degree_u 198 | knots_v = [0.0] * degree_v + np.arange(0, 1.01, 1 / (control_points_v - degree_v)).tolist() + [1.0] * degree_v 199 | 200 | nu = [] 201 | nu = np.zeros((u.shape[0], control_points_u)) 202 | for i in range(u.shape[0]): 203 | for j in range(0, control_points_u): 204 | nu[i, j] = basis_function_one(degree_u, knots_u, j, u[i]) 205 | 206 | nv = np.zeros((v.shape[0], control_points_v)) 207 | for i in range(v.shape[0]): 208 | for j in range(0, control_points_v): 209 | nv[i, j] = basis_function_one(degree_v, knots_v, j, v[i]) 210 | return nu, nv 211 | 212 | 213 | def laplacian_loss(output, gt, dist_type="l2"): 214 | """ 215 | Computes the laplacian of the input and output grid and defines 216 | regression loss. 217 | :param output: predicted control points grid. Makes sure the orientation/ 218 | permutation of this output grid matches with the ground truth orientation. 219 | This is done by finding the least cost orientation during training. 220 | :param gt: gt control points grid. 221 | """ 222 | batch_size, grid_size, grid_size, input_channels = gt.shape 223 | filter = ([[[0.0, 0.25, 0.0], [0.25, -1.0, 0.25], [0.0, 0.25, 0.0]], 224 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 225 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]]]) 226 | filter = np.stack([filter, np.roll(filter, 1, 0), np.roll(filter, 2, 0)]) 227 | 228 | filter = -np.array(filter, dtype=np.float32) 229 | filter = Variable(torch.from_numpy(filter)).cuda() 230 | 231 | laplacian_output = F.conv2d(output.permute(0, 3, 1, 2), filter, padding=1) 232 | laplacian_input = F.conv2d(gt.permute(0, 3, 1, 2), filter, padding=1) 233 | if dist_type == "l2": 234 | dist = (laplacian_output - laplacian_input) ** 2 235 | elif dist_type == "l1": 236 | dist = torch.abs(laplacian_output - laplacian_input) 237 | dist = torch.sum(dist, 1) 238 | dist = torch.mean(dist) 239 | return dist 240 | 241 | 242 | def basis_function_one(degree, knot_vector, span, knot): 243 | """ Computes the value of a basis function for a single parameter. 244 | 245 | Implementation of Algorithm 2.4 from The NURBS Book by Piegl & Tiller. 246 | :param degree: degree, :math:`p` 247 | :type degree: int 248 | :param knot_vector: knot vector 249 | :type knot_vector: list, tuple 250 | :param span: knot span, :math:`i` 251 | :type span: int 252 | :param knot: knot or parameter, :math:`u` 253 | :type knot: float 254 | :return: basis function, :math:`N_{i,p}` 255 | :rtype: float 256 | """ 257 | # Special case at boundaries 258 | if ( 259 | (span == 0 and knot == knot_vector[0]) 260 | or (span == len(knot_vector) - degree - 2) 261 | and knot == knot_vector[len(knot_vector) - 1] 262 | ): 263 | return 1.0 264 | 265 | # Knot is outside of span range 266 | if knot < knot_vector[span] or knot >= knot_vector[span + degree + 1]: 267 | return 0.0 268 | 269 | N = [0.0 for _ in range(degree + span + 1)] 270 | 271 | # Initialize the zeroth degree basis functions 272 | for j in range(0, degree + 1): 273 | if knot_vector[span + j] <= knot < knot_vector[span + j + 1]: 274 | N[j] = 1.0 275 | 276 | # Computing triangular table of basis functions 277 | for k in range(1, degree + 1): 278 | # Detecting zeros saves computations 279 | saved = 0.0 280 | if N[0] != 0.0: 281 | saved = ((knot - knot_vector[span]) * N[0]) / ( 282 | knot_vector[span + k] - knot_vector[span] 283 | ) 284 | 285 | for j in range(0, degree - k + 1): 286 | Uleft = knot_vector[span + j + 1] 287 | Uright = knot_vector[span + j + k + 1] 288 | 289 | # Zero detection 290 | if N[j + 1] == 0.0: 291 | N[j] = saved 292 | saved = 0.0 293 | else: 294 | temp = N[j + 1] / (Uright - Uleft) 295 | N[j] = saved + (Uright - knot) * temp 296 | saved = (knot - Uleft) * temp 297 | return N[0] 298 | -------------------------------------------------------------------------------- /src/mean_shift.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implements differentiable mean shift clustering 3 | algorithm for the use in deep learning. 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | from src.guard import guard_exp, guard_sqrt 9 | 10 | 11 | class MeanShift: 12 | def __init__(self, ): 13 | """ 14 | Differentiable mean shift clustering inspired from 15 | https://arxiv.org/pdf/1712.08273.pdf 16 | """ 17 | pass 18 | 19 | def mean_shift(self, X, num_samples, quantile, iterations, kernel_type="gaussian", bw=None, nms=True): 20 | """ 21 | Complete function to do mean shift clutering on the input X 22 | :param num_samples: number of samples to consider for band width 23 | calculation 24 | :param X: input, N x d 25 | :param quantile: to be used for computing number of nearest 26 | neighbors, 0.05 works fine. 27 | :param iterations: 28 | """ 29 | if bw == None: 30 | with torch.no_grad(): 31 | bw = self.compute_bandwidth(X, num_samples, quantile) 32 | 33 | # avoid numerical issues. 34 | bw = torch.clamp(bw, min=0.003) 35 | new_X, _ = self.mean_shift_(X, b=bw, iterations=iterations, kernel_type=kernel_type) 36 | if not nms: 37 | return new_X, bw 38 | 39 | with torch.no_grad(): 40 | _, indices, new_labels = self.nms(new_X, X, b=bw) 41 | center = new_X[indices] 42 | 43 | return new_X, center, bw, new_labels 44 | 45 | def mean_shift_(self, X, b, iterations=10, kernel_type="gaussian"): 46 | """ 47 | Differentiable mean shift clustering. 48 | X are assumed to lie on the hyper shphere, and thus are normalized 49 | to have unit norm. This is done for computational 50 | efficiency and will not work if the assumptions are voilated. 51 | :param X: N x d, points to be clustered 52 | :param b: bandwidth 53 | :param iterations: number of iterations 54 | """ 55 | # initialize all the points as the seed points 56 | new_X = X.clone() 57 | delta = 1 58 | for i in range(iterations): 59 | if kernel_type == "gaussian": 60 | dist = 2.0 - 2.0 * new_X @ torch.transpose(X, 1, 0) 61 | 62 | # TODO Normalization is still remaining. 63 | K = guard_exp(- dist / (b ** 2) / 2) 64 | else: 65 | # epanechnikov 66 | dist = 2.0 - 2.0 * new_X @ torch.transpose(X, 1, 0) 67 | dist = 3 / 4 * (1 - dist / (b ** 2)) 68 | K = torch.nn.functional.relu(dist) 69 | 70 | D = 1 / (torch.sum(K, 1, keepdim=True)) 71 | 72 | # K: N x N, X: N x d, D: N x 1 73 | M = (K @ X) * D - new_X 74 | new_X = new_X + delta * M 75 | 76 | # re-normalize it to lie on unit hyper-sphere. 77 | new_X = new_X / torch.norm(new_X, dim=1, p=2, keepdim=True) 78 | # new_X: center of the clusters 79 | return new_X, X 80 | 81 | def guard_mean_shift(self, embedding, quantile, iterations, kernel_type="gaussian"): 82 | """ 83 | Some times if band width is small, number of cluster can be larger than 50, that 84 | but we would like to keep max clusters 50 as it is the max number in our dataset. 85 | in that case you increase the quantile to increase the band width to decrease 86 | the number of clusters. 87 | """ 88 | while True: 89 | _, center, bandwidth, cluster_ids = self.mean_shift( 90 | embedding, 5000, quantile, iterations, kernel_type=kernel_type 91 | ) 92 | if torch.unique(cluster_ids).shape[0] > 49: 93 | quantile *= 2 94 | else: 95 | break 96 | return center, bandwidth, cluster_ids 97 | 98 | def kernel(self, X, kernel_type, bw): 99 | """ 100 | Assuing that the feature vector in X are normalized. 101 | """ 102 | if kernel_type == "gaussian": 103 | # gaussian 104 | dist = 2.0 - 2.0 * X @ torch.transpose(X, 1, 0) 105 | # TODO not considering the normalization factor 106 | K = guard_exp(- dist / (bw ** 2) / 2) 107 | 108 | elif kernel_type == "epa": 109 | # epanechnikov 110 | dist = 2.0 - 2.0 * X @ torch.transpose(X, 1, 0) 111 | dist = 3 / 4 * (1 - dist / (bw ** 2)) 112 | K = torch.nn.functional.relu(dist) 113 | return K 114 | 115 | def compute_bandwidth(self, X, num_samples, quantile): 116 | """ 117 | Compute the bandwidth for mean shift clustering. 118 | Assuming the X is normalized to lie on hypersphere. 119 | :param X: input data, N x d 120 | :param num_samples: number of samples to be used 121 | for computing distance, <= N 122 | :param quantile: nearest neighbors used for computing 123 | the bandwidth. 124 | """ 125 | N = X.shape[0] 126 | L = np.arange(N) 127 | np.random.shuffle(L) 128 | X = X[L[0:num_samples]] 129 | # dist = (torch.unsqueeze(X, 1) - torch.unsqueeze(X, 0)) ** 2 130 | dist = 2 - 2 * X @ torch.transpose(X, 1, 0) 131 | # dist = torch.sum(dist, 1) 132 | K = int(quantile * num_samples) 133 | top_k = torch.topk(dist, k=K, dim=1, largest=False)[0] 134 | 135 | max_top_k = guard_sqrt(top_k[:, -1], 1e-6) 136 | 137 | return torch.mean(max_top_k) 138 | 139 | def nms(self, centers, X, b): 140 | """ 141 | Non max suprression. 142 | :param centers: center of clusters 143 | :param X: points to be clustered 144 | :param b: band width used to get the centers 145 | """ 146 | membership = 2.0 - 2.0 * centers @ torch.transpose(X, 1, 0) 147 | 148 | # which cluster center is closer to the points 149 | membership = torch.min(membership, 0)[1] 150 | 151 | # Find the unique clusters which is closer to at least one point 152 | uniques, counts_ = np.unique(membership.data.cpu().numpy(), return_counts=True) 153 | 154 | # count of the number of points belonging to unique cluster ids above 155 | counts = torch.from_numpy(counts_.astype(np.float32)).cuda(torch.get_device(centers)) 156 | 157 | num_mem_cluster = torch.zeros((X.shape[0])).cuda(torch.get_device(centers)) 158 | 159 | # Contains the count of number of points belonging to a 160 | # unique cluster 161 | num_mem_cluster[uniques] = counts 162 | 163 | # distance of clusters from each other 164 | dist = 2.0 - 2.0 * centers @ torch.transpose(centers, 1, 0) 165 | 166 | # find the nearest neighbors to each cluster based on some threshold 167 | # TODO this could be b ** 2 168 | cluster_nbrs = dist < b 169 | cluster_nbrs = cluster_nbrs.float() 170 | 171 | cluster_center_ids = torch.unique(torch.max(cluster_nbrs[uniques] * num_mem_cluster.reshape((1, -1)), 1)[1]) 172 | # pruned centers 173 | centers = centers[cluster_center_ids] 174 | 175 | # assign labels to the input points 176 | # It is assumed that the embeddings lie on the hypershphere and are normalized 177 | temp = centers @ torch.transpose(X, 1, 0) 178 | labels = torch.max(temp, 0)[1] 179 | return centers, cluster_center_ids, labels 180 | 181 | def pdist(self, x, y): 182 | x = torch.unsqueeze(x, 1) 183 | y = torch.unsqueeze(y, 0) 184 | dist = torch.sum((x - y) ** 2, 2) 185 | return dist 186 | 187 | # return torch.unique(torch.max(cluster_nbrs[uniques] * num_mem_cluster.reshape((1, -1)), 0)[1]) 188 | 189 | # def nms(new_X, X, b): 190 | # membership = new_X @ torch.transpose(X, 1, 0) 191 | 192 | # # which cluster center is closer to the points 193 | # membership = torch.max(membership, 0)[1] 194 | 195 | # # Find the unique clusters which is closer to at least one point 196 | # uniques, counts_ = np.unique(membership.data.cpu().numpy(), return_counts=True) 197 | 198 | # # count of the number of points belonging to unique cluster ids above 199 | # counts = torch.from_numpy(counts_.astype(np.float32)).cuda() 200 | 201 | # num_mem_cluster = torch.zeros((10000)).cuda().float() 202 | 203 | # # Contains the count of number of points belonging to a 204 | # # unique cluster 205 | # num_mem_cluster[uniques] = counts 206 | 207 | # # distance of clusters from each other 208 | # dist = new_X @ torch.transpose(new_X, 1, 0) 209 | 210 | # # find the nearest neighbors based on some threshold 211 | # correct = dist > b 212 | # correct = correct.float() 213 | # clusters = [] 214 | # for i in uniques: 215 | # # choose the cluster id which has more number of points closer to itself 216 | # clusters.append(torch.max(correct[i] * num_mem_cluster[i], 0)[1]) 217 | 218 | # return torch.unique(torch.stack(clusters)) 219 | 220 | 221 | # def mean_shift_clustering(X, b, iterations = 10): 222 | # """ 223 | # :param X: N x d 224 | # :param b: bandwidth 225 | # Problems: 226 | # """ 227 | # for i in range(iterations): 228 | # K = torch.exp(X @ torch.transpose(X, 1, 0) * b) 229 | # D = 1 / (torch.sum(K, 1, keepdim=True) + 1e-7) 230 | 231 | # # K: N x N, X: N x d, D: N x 1 232 | # M = (K @ X) * D - X 233 | # X = X + M 234 | # X = X / torch.norm(X, dim=1, p=2, keepdim=True) 235 | # print (torch.norm(M)) 236 | # dist = X @ torch.transpose(X, 1, 0) 237 | # return X, dist 238 | 239 | 240 | # def mean_shift_clustering_1(X, b, iterations = 10): 241 | # """ 242 | # :param X: N x d 243 | # :param b: bandwidth 244 | # Problems: 245 | # """ 246 | # new_X = X.clone() 247 | # delta = 1 248 | # for i in range(iterations): 249 | # K = torch.exp(new_X @ torch.transpose(X, 1, 0) / b) 250 | # D = 1 / (torch.sum(K, 1, keepdim=True)) 251 | 252 | # # K: N x N, X: N x d, D: N x 1 253 | # M = (K @ X) * D - new_X 254 | # new_X = new_X + delta * M 255 | # print (torch.mean(torch.norm(M, dim=1, p=2))) 256 | 257 | # new_X = new_X / torch.norm(new_X, dim=1, p=2, keepdim=True) 258 | # return new_X, X 259 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | EPS = np.finfo(np.float32).eps 7 | 8 | 9 | def knn(x, k): 10 | batch_size = x.shape[0] 11 | indices = np.arange(0, k) 12 | with torch.no_grad(): 13 | distances = [] 14 | for b in range(batch_size): 15 | inner = -2 * torch.matmul(x[b:b + 1].transpose(2, 1), x[b:b + 1]) 16 | xx = torch.sum(x[b:b + 1] ** 2, dim=1, keepdim=True) 17 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 18 | distances.append(pairwise_distance) 19 | distances = torch.stack(distances, 0) 20 | distances = distances.squeeze(1) 21 | idx = distances.topk(k=k, dim=-1)[1][:, :, indices] 22 | return idx 23 | 24 | 25 | def get_graph_feature(x, k=20, idx=None): 26 | batch_size = x.size(0) 27 | num_points = x.size(2) 28 | x = x.contiguous() 29 | x = x.view(batch_size, -1, num_points).contiguous() 30 | if idx is None: 31 | idx = knn(x, k=k) # (batch_size, num_points, k) 32 | # device = torch.device('cuda') 33 | 34 | idx_base = torch.arange(0, batch_size).view(-1, 1, 1) * num_points 35 | idx_base = idx_base.cuda(torch.get_device(x)) 36 | idx = idx + idx_base 37 | 38 | idx = idx.view(-1) 39 | 40 | _, num_dims, _ = x.size() 41 | 42 | x = x.transpose(2, 1).contiguous() 43 | try: 44 | feature = x.view(batch_size * num_points, -1)[idx, :] 45 | except: 46 | import ipdb; 47 | ipdb.set_trace() 48 | feature = feature.view(batch_size, num_points, k, num_dims) 49 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 50 | 51 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2) 52 | 53 | return feature 54 | 55 | 56 | class DGCNNControlPoints(nn.Module): 57 | def __init__(self, num_control_points, num_points=40, mode=0): 58 | """ 59 | Control points prediction network. Takes points as input 60 | and outputs control points grid. 61 | :param num_control_points: size of the control points grid. 62 | :param num_points: number of nearest neighbors used in DGCNN. 63 | :param mode: different modes are used that decides different number of layers. 64 | """ 65 | super(DGCNNControlPoints, self).__init__() 66 | self.k = num_points 67 | self.mode = mode 68 | if self.mode == 0: 69 | self.bn1 = nn.BatchNorm2d(64) 70 | self.bn2 = nn.BatchNorm2d(64) 71 | self.bn3 = nn.BatchNorm2d(128) 72 | self.bn4 = nn.BatchNorm2d(256) 73 | self.bn5 = nn.BatchNorm1d(1024) 74 | self.drop = 0.0 75 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 76 | self.bn1, 77 | nn.LeakyReLU(negative_slope=0.2)) 78 | self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False), 79 | self.bn2, 80 | nn.LeakyReLU(negative_slope=0.2)) 81 | self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False), 82 | self.bn3, 83 | nn.LeakyReLU(negative_slope=0.2)) 84 | self.conv4 = nn.Sequential(nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False), 85 | self.bn4, 86 | nn.LeakyReLU(negative_slope=0.2)) 87 | self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 88 | self.bn5, 89 | nn.LeakyReLU(negative_slope=0.2)) 90 | 91 | self.controlpoints = num_control_points 92 | self.conv6 = torch.nn.Conv1d(1024, 1024, 1) 93 | self.conv7 = torch.nn.Conv1d(1024, 1024, 1) 94 | 95 | # Predicts the entire control points grid. 96 | self.conv8 = torch.nn.Conv1d(1024, 3 * (self.controlpoints ** 2), 1) 97 | 98 | self.bn6 = nn.BatchNorm1d(1024) 99 | self.bn7 = nn.BatchNorm1d(1024) 100 | 101 | if self.mode == 1: 102 | self.bn1 = nn.BatchNorm2d(128) 103 | self.bn2 = nn.BatchNorm2d(256) 104 | self.bn3 = nn.BatchNorm2d(256) 105 | self.bn4 = nn.BatchNorm2d(512) 106 | self.bn5 = nn.BatchNorm1d(1024) 107 | self.drop = 0.0 108 | 109 | self.conv1 = nn.Sequential(nn.Conv2d(6, 128, kernel_size=1, bias=False), 110 | self.bn1, 111 | nn.LeakyReLU(negative_slope=0.2)) 112 | 113 | self.conv2 = nn.Sequential(nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False), 114 | self.bn2, 115 | nn.LeakyReLU(negative_slope=0.2)) 116 | 117 | self.conv3 = nn.Sequential(nn.Conv2d(256 * 2, 256, kernel_size=1, bias=False), 118 | self.bn3, 119 | nn.LeakyReLU(negative_slope=0.2)) 120 | 121 | self.conv4 = nn.Sequential(nn.Conv2d(256 * 2, 512, kernel_size=1, bias=False), 122 | self.bn4, 123 | nn.LeakyReLU(negative_slope=0.2)) 124 | 125 | self.conv5 = nn.Sequential(nn.Conv1d(1024 + 128, 1024, kernel_size=1, bias=False), 126 | self.bn5, 127 | nn.LeakyReLU(negative_slope=0.2)) 128 | 129 | self.controlpoints = num_control_points 130 | self.conv6 = torch.nn.Conv1d(1024, 1024, 1) 131 | self.conv7 = torch.nn.Conv1d(1024, 1024, 1) 132 | 133 | # Predicts the entire control points grid. 134 | self.conv8 = torch.nn.Conv1d(1024, 3 * (self.controlpoints ** 2), 1) 135 | self.bn6 = nn.BatchNorm1d(1024) 136 | self.bn7 = nn.BatchNorm1d(1024) 137 | 138 | self.tanh = nn.Tanh() 139 | 140 | def forward(self, x, weights=None): 141 | """ 142 | :param weights: weights of size B x N 143 | """ 144 | batch_size = x.size(0) 145 | x = get_graph_feature(x, k=self.k) 146 | x = self.conv1(x) 147 | x1 = x.max(dim=-1, keepdim=False)[0] 148 | 149 | x = get_graph_feature(x1, k=self.k) 150 | x = self.conv2(x) 151 | x2 = x.max(dim=-1, keepdim=False)[0] 152 | 153 | x = get_graph_feature(x2, k=self.k) 154 | x = self.conv3(x) 155 | x3 = x.max(dim=-1, keepdim=False)[0] 156 | 157 | x = get_graph_feature(x3, k=self.k) 158 | x = self.conv4(x) 159 | x4 = x.max(dim=-1, keepdim=False)[0] 160 | 161 | x = torch.cat((x1, x2, x3, x4), dim=1) 162 | 163 | x = self.conv5(x) 164 | 165 | if isinstance(weights, torch.Tensor): 166 | weights = weights.reshape((1, 1, -1)) 167 | x = x * weights 168 | 169 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 170 | 171 | x1 = torch.unsqueeze(x1, 2) 172 | 173 | x = F.dropout(F.relu(self.bn6(self.conv6(x1))), self.drop) 174 | 175 | x = F.dropout(F.relu(self.bn7(self.conv7(x))), self.drop) 176 | x = self.conv8(x) 177 | x = self.tanh(x[:, :, 0]) 178 | 179 | x = x.view(batch_size, self.controlpoints * self.controlpoints, 3) 180 | return x 181 | -------------------------------------------------------------------------------- /src/segment_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines loss functions for AE based training. 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.nn import ReLU 7 | 8 | from src.mean_shift import MeanShift 9 | 10 | meanshift = MeanShift() 11 | WEIGHT = False 12 | relu = ReLU() 13 | 14 | if WEIGHT: 15 | nllloss = torch.nn.NLLLoss(weight=old_weight) 16 | else: 17 | nllloss = torch.nn.NLLLoss() 18 | 19 | 20 | class EmbeddingLoss: 21 | def __init__(self, margin=1.0, if_mean_shift=False): 22 | """ 23 | Defines loss function to train embedding network. 24 | :param margin: margin to be used in triplet loss. 25 | :param if_mean_shift: bool, whether to use mean shift 26 | iterations. This is only used in end to end training. 27 | """ 28 | self.margin = margin 29 | self.if_mean_shift = if_mean_shift 30 | 31 | def triplet_loss(self, output, labels: np.ndarray, iterations=5): 32 | """ 33 | Triplet loss 34 | :param output: output embedding from the network. size: B x 128 x N 35 | where B is the batch size, 128 is the dim size and N is the number of points. 36 | :param labels: B x N 37 | """ 38 | max_segments = 5 39 | batch_size = output.shape[0] 40 | N = output.shape[2] 41 | loss_diff = torch.tensor([0.], requires_grad=True).cuda() 42 | relu = torch.nn.ReLU() 43 | 44 | output = output.permute(0, 2, 1) 45 | output = torch.nn.functional.normalize(output, p=2, dim=2) 46 | new_output = [] 47 | 48 | if self.if_mean_shift: 49 | for b in range(batch_size): 50 | new_X, bw = meanshift.mean_shift(output[b], 4000, 51 | 0.015, iterations=iterations, 52 | nms=False) 53 | new_output.append(new_X) 54 | output = torch.stack(new_output, 0) 55 | 56 | num_sample_points = {} 57 | sampled_points = {} 58 | for i in range(batch_size): 59 | sampled_points[i] = {} 60 | p = labels[i] 61 | unique_labels = np.unique(p) 62 | 63 | # number of points from each cluster. 64 | num_sample_points[i] = min([N // unique_labels.shape[0] + 1, 30]) 65 | for l in unique_labels: 66 | ix = np.isin(p, l) 67 | sampled_indices = np.where(ix)[0] 68 | # point indices that belong to a certain cluster. 69 | sampled_points[i][l] = np.random.choice( 70 | list(sampled_indices), 71 | num_sample_points[i], 72 | replace=True) 73 | 74 | sampled_predictions = {} 75 | for i in range(batch_size): 76 | sampled_predictions[i] = {} 77 | for k, v in sampled_points[i].items(): 78 | pred = output[i, v, :] 79 | sampled_predictions[i][k] = pred 80 | 81 | all_satisfied = 0 82 | only_one_segments = 0 83 | for i in range(batch_size): 84 | len_keys = len(sampled_predictions[i].keys()) 85 | keys = list(sorted(sampled_predictions[i].keys())) 86 | num_iterations = min([max_segments * max_segments, len_keys * len_keys]) 87 | normalization = 0 88 | if len_keys == 1: 89 | only_one_segments += 1 90 | continue 91 | 92 | loss_shape = torch.tensor([0.], requires_grad=True).cuda() 93 | for _ in range(num_iterations): 94 | k1 = np.random.choice(len_keys, 1)[0] 95 | k2 = np.random.choice(len_keys, 1)[0] 96 | if k1 == k2: 97 | continue 98 | else: 99 | normalization += 1 100 | 101 | pred1 = sampled_predictions[i][keys[k1]] 102 | pred2 = sampled_predictions[i][keys[k2]] 103 | 104 | Anchor = pred1.unsqueeze(1) 105 | Pos = pred1.unsqueeze(0) 106 | Neg = pred2.unsqueeze(0) 107 | 108 | diff_pos = torch.sum(torch.pow((Anchor - Pos), 2), 2) 109 | diff_neg = torch.sum(torch.pow((Anchor - Neg), 2), 2) 110 | constraint = diff_pos - diff_neg + self.margin 111 | constraint = relu(constraint) 112 | 113 | # remove diagonals corresponding to same points in anchors 114 | loss = torch.sum(constraint) - constraint.trace() 115 | 116 | satisfied = torch.sum(constraint > 0) + 1.0 117 | satisfied = satisfied.type(torch.cuda.FloatTensor) 118 | 119 | loss_shape = loss_shape + loss / satisfied.detach() 120 | 121 | loss_shape = loss_shape / (normalization + 1e-8) 122 | loss_diff = loss_diff + loss_shape 123 | loss_diff = loss_diff / (batch_size - only_one_segments + 1e-8) 124 | return loss_diff 125 | 126 | 127 | def evaluate_miou(gt_labels, pred_labels): 128 | N = gt_labels.shape[0] 129 | C = pred_labels.shape[2] 130 | pred_labels = np.argmax(pred_labels, 2) 131 | IoU_category = 0 132 | 133 | for n in range(N): 134 | label_gt = gt_labels[n] 135 | label_pred = pred_labels[n] 136 | IoU_part = 0.0 137 | 138 | for label_idx in range(C): 139 | locations_gt = (label_gt == label_idx) 140 | locations_pred = (label_pred == label_idx) 141 | I_locations = np.logical_and(locations_gt, locations_pred) 142 | U_locations = np.logical_or(locations_gt, locations_pred) 143 | I = np.sum(I_locations) + np.finfo(np.float32).eps 144 | U = np.sum(U_locations) + np.finfo(np.float32).eps 145 | IoU_part = IoU_part + I / U 146 | IoU_sample = IoU_part / C 147 | IoU_category += IoU_sample 148 | return IoU_category / N 149 | 150 | 151 | def primitive_loss(pred, gt): 152 | return nllloss(pred, gt) 153 | -------------------------------------------------------------------------------- /src/segment_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains utility functions to segment the embedding using clustering algorithms. 3 | """ 4 | import numpy as np 5 | 6 | random_state = 170 7 | from sklearn.cluster import SpectralClustering, KMeans, MeanShift, estimate_bandwidth 8 | import torch 9 | from lapsolver import solve_dense 10 | from src.utils import sample_mesh, triangle_area_multi 11 | 12 | 13 | def cluster(X, number_cluster, bandwidth=None, alg="kmeans"): 14 | X = X.astype(np.float32) 15 | if alg == "kmeans": 16 | y_pred = KMeans(n_clusters=number_cluster, random_state=random_state).fit_predict(X) 17 | 18 | elif alg == "spectral": 19 | y_pred = SpectralClustering(n_clusters=number_cluster, random_state=random_state, n_jobs=10).fit_predict(X) 20 | 21 | elif alg == "meanshift": 22 | # There is a little insight here, the number of neighbors are somewhat 23 | # dependent on the number of neighbors used in the dynamic graph network. 24 | if bandwidth: 25 | pass 26 | else: 27 | bandwidth = estimate_bandwidth(X, quantile=0.1, n_samples=1000) 28 | seeds = X[np.random.choice(np.arange(X.shape[0]), 5000)] 29 | # y_pred = MeanShift(bandwidth=bandwidth).fit_predict(X) 30 | clustering = MeanShift(bandwidth=bandwidth, seeds=seeds, n_jobs=32).fit(X) 31 | y_pred = clustering.predict(X) 32 | 33 | if alg == "meanshift": 34 | return y_pred, clustering.cluster_centers_, bandwidth 35 | else: 36 | return y_pred 37 | 38 | 39 | def cluster_prob(embedding, centers): 40 | """ 41 | Returns cluster probabilities. 42 | :param embedding: N x 128, embedding for each point 43 | :param centers: C x 128, embedding for centers 44 | """ 45 | # should of size N x C 46 | dot_p = np.dot(centers, embedding.transpose()).transpose() 47 | 48 | prob = np.exp(dot_p) / np.expand_dims(np.sum(np.exp(dot_p), 1), 1) 49 | return prob 50 | 51 | 52 | def cluster_prob(embedding, centers, band_width): 53 | """ 54 | Returns cluster probabilities. 55 | :param embedding: N x 128, embedding for each point 56 | :param centers: C x 128, embedding for centers 57 | """ 58 | dist = 2 - 2 * centers @ embedding.T 59 | prob = np.exp(-dist / 2 / (band_width)) / np.sqrt(2 * np.pi * band_width) 60 | return prob 61 | 62 | 63 | def cluster_prob_mutual(embedding, centers, bandwidth, if_normalize=False): 64 | """ 65 | Returns cluster probabilities. 66 | :param embedding: N x 128, embedding for each point 67 | :param centers: C x 128, embedding for centers 68 | """ 69 | # dim: C x N 70 | dist = np.exp(centers @ embedding.T / bandwidth) 71 | prob = dist / np.sum(dist, 0, keepdims=True) 72 | 73 | if if_normalize: 74 | prob = prob - np.min(prob, 1, keepdims=True) 75 | prob = prob / np.max(prob, 1, keepdims=True) 76 | return prob 77 | 78 | 79 | def dot_product_from_cluster_centers(embedding, centers): 80 | return centers @ embedding.T 81 | 82 | 83 | def sample_from_collection_of_mesh(Meshes, N=10000): 84 | A = [] 85 | sampled_points = [] 86 | new_meshes = [] 87 | for mesh in Meshes: 88 | new_mesh = mesh.remove_unreferenced_vertices() 89 | if np.array(new_mesh.vertices).shape[0] == 0: 90 | continue 91 | else: 92 | new_meshes.append(new_mesh) 93 | 94 | for mesh in new_meshes: 95 | mesh.remove_unreferenced_vertices() 96 | vertices = np.array(mesh.vertices)[np.array(mesh.triangles)] 97 | v1 = vertices[:, 0] 98 | v2 = vertices[:, 1] 99 | v3 = vertices[:, 2] 100 | 101 | A.append(np.sum(triangle_area_multi(v1, v2, v3))) 102 | 103 | area = np.sum(A) 104 | Points = [] 105 | 106 | for index, mesh in enumerate(new_meshes): 107 | if np.array(mesh.vertices).shape[0] == 0: 108 | continue 109 | mesh.remove_unreferenced_vertices() 110 | vertices = np.array(mesh.vertices)[np.array(mesh.triangles)] 111 | v1 = vertices[:, 0] 112 | v2 = vertices[:, 1] 113 | v3 = vertices[:, 2] 114 | n = int((N * A[index]) // area) 115 | if n > 10: 116 | # , face_normals=np.array(mesh.triangle_normals) 117 | points, normals, _ = sample_mesh(v1, v2, v3, n=n, norms=False) 118 | try: 119 | Points.append(points) 120 | except: 121 | pass 122 | Points = np.concatenate(Points, 0) 123 | return Points.astype(np.float32) 124 | 125 | 126 | def mean_IOU_one_sample(pred, gt, C): 127 | IoU_part = 0.0 128 | for label_idx in range(C): 129 | locations_gt = (gt == label_idx) 130 | locations_pred = (pred == label_idx) 131 | I_locations = np.logical_and(locations_gt, locations_pred) 132 | U_locations = np.logical_or(locations_gt, locations_pred) 133 | I = np.sum(I_locations) + np.finfo(np.float32).eps 134 | U = np.sum(U_locations) + np.finfo(np.float32).eps 135 | IoU_part = IoU_part + I / U 136 | return IoU_part / C 137 | 138 | 139 | def SIOU_matched_segments(target, pred_labels, primitives_pred, primitives, weights): 140 | """ 141 | Computes iou for segmentation performance and primitive type 142 | prediction performance. 143 | First it computes the matching using hungarian matching 144 | between predicted and ground truth labels. 145 | Then it computes the iou score, starting from matching pairs 146 | coming out from hungarian matching solver. Note that 147 | it is assumed that the iou is only computed over matched pairs. 148 | That is to say, if any column in the matched pair has zero 149 | number of points, that pair is not considered. 150 | 151 | It also computes the iou for primitive type prediction. In this case 152 | iou is computed only over the matched segments. 153 | """ 154 | # 2 is open spline and 9 is close spline 155 | primitives[primitives == 0] = 9 156 | primitives[primitives == 6] = 9 157 | primitives[primitives == 7] = 9 158 | primitives[primitives == 8] = 2 159 | 160 | primitives_pred[primitives_pred == 0] = 9 161 | primitives_pred[primitives_pred == 6] = 9 162 | primitives_pred[primitives_pred == 7] = 9 163 | primitives_pred[primitives_pred == 8] = 2 164 | 165 | labels_one_hot = to_one_hot(target) 166 | cluster_ids_one_hot = to_one_hot(pred_labels) 167 | 168 | cost = relaxed_iou_fast(torch.unsqueeze(cluster_ids_one_hot, 0).float(), torch.unsqueeze(labels_one_hot, 0).float()) 169 | cost_ = 1.0 - cost.data.cpu().numpy() 170 | matching = [] 171 | 172 | for b in range(1): 173 | rids, cids = solve_dense(cost_[b]) 174 | matching.append([rids, cids]) 175 | 176 | primitives_pred_hot = to_one_hot(primitives_pred, 10, weights.device.index).float() 177 | 178 | # this gives you what primitive type the predicted segment has. 179 | prim_pred = primitive_type_segment_torch(primitives_pred_hot, weights).data.cpu().numpy() 180 | target = np.expand_dims(target, 0) 181 | pred_labels = np.expand_dims(pred_labels, 0) 182 | prim_pred = np.expand_dims(prim_pred, 0) 183 | primitives = np.expand_dims(primitives, 0) 184 | 185 | segment_iou, primitive_iou, iou_b_prims = mean_IOU_primitive_segment(matching, pred_labels, target, prim_pred, 186 | primitives) 187 | return segment_iou, primitive_iou, matching, iou_b_prims 188 | 189 | 190 | def mean_IOU_primitive_segment(matching, predicted_labels, labels, pred_prim, gt_prim): 191 | """ 192 | Primitive type IOU, this is calculated over the segment level. 193 | First the predicted segments are matched with ground truth segments, 194 | then IOU is calculated over these segments. 195 | :param matching 196 | :param pred_labels: N x 1, pred label id for segments 197 | :param gt_labels: N x 1, gt label id for segments 198 | :param pred_prim: K x 1, pred primitive type for each of the predicted segments 199 | :param gt_prim: N x 1, gt primitive type for each point 200 | """ 201 | batch_size = labels.shape[0] 202 | IOU = [] 203 | IOU_prim = [] 204 | 205 | for b in range(batch_size): 206 | iou_b = [] 207 | iou_b_prim = [] 208 | iou_b_prims = [] 209 | len_labels = np.unique(predicted_labels[b]).shape[0] 210 | rows, cols = matching[b] 211 | count = 0 212 | for r, c in zip(rows, cols): 213 | pred_indices = predicted_labels[b] == r 214 | gt_indices = labels[b] == c 215 | 216 | # use only matched segments for evaluation 217 | if (np.sum(gt_indices) == 0) or (np.sum(pred_indices) == 0): 218 | continue 219 | 220 | # also remove the gt labels that are very small in number 221 | if np.sum(gt_indices) < 100: 222 | continue 223 | 224 | iou = np.sum(np.logical_and(pred_indices, gt_indices)) / ( 225 | np.sum(np.logical_or(pred_indices, gt_indices)) + 1e-8) 226 | iou_b.append(iou) 227 | 228 | # evaluation of primitive type prediction performance 229 | gt_prim_type_k = gt_prim[b][gt_indices][0] 230 | try: 231 | predicted_prim_type_k = pred_prim[b][r] 232 | except: 233 | import ipdb; 234 | ipdb.set_trace() 235 | 236 | iou_b_prim.append(gt_prim_type_k == predicted_prim_type_k) 237 | iou_b_prims.append([gt_prim_type_k, predicted_prim_type_k]) 238 | 239 | # find the mean of IOU over this shape 240 | IOU.append(np.mean(iou_b)) 241 | IOU_prim.append(np.mean(iou_b_prim)) 242 | return np.mean(IOU), np.mean(IOU_prim), iou_b_prims 243 | 244 | 245 | def primitive_type_segment(pred, weights): 246 | """ 247 | Returns the primitive type for every segment in the predicted shape. 248 | :param pred: N x L 249 | :param weights: N x k 250 | """ 251 | d = np.expand_dims(pred, 2) * np.expand_dims(weights, 1) 252 | d = np.sum(d, 0) 253 | return np.argmax(d, 0) 254 | 255 | 256 | def primitive_type_segment_torch(pred, weights): 257 | """ 258 | Returns the primitive type for every segment in the predicted shape. 259 | :param pred: N x L 260 | :param weights: N x k 261 | """ 262 | d = torch.unsqueeze(pred, 2) * torch.unsqueeze(weights, 1) 263 | d = torch.sum(d, 0) 264 | return torch.max(d, 0)[1] 265 | 266 | 267 | def iou_segmentation(pred, gt): 268 | # preprocess the predictions and gt to remove the extras 269 | # swap (0, 6, 7) to closed surfaces which is 9 270 | # swap 8 to 2 271 | gt[gt == 0] = 9 272 | gt[gt == 6] = 9 273 | gt[gt == 7] = 9 274 | gt[gt == 8] = 2 275 | 276 | pred[pred == 0] = 9 277 | pred[pred == 6] = 9 278 | pred[pred == 7] = 9 279 | pred[pred == 8] = 2 280 | return mean_IOU_one_sample(pred, gt, 6) 281 | 282 | 283 | def to_one_hot(target, maxx=50, device_id=0): 284 | if isinstance(target, np.ndarray): 285 | target = torch.from_numpy(target.astype(np.int64)).cuda(device_id) 286 | N = target.shape[0] 287 | target_one_hot = torch.zeros((N, maxx)) 288 | 289 | target_one_hot = target_one_hot.cuda(device_id) 290 | target_t = target.unsqueeze(1) 291 | target_one_hot = target_one_hot.scatter_(1, target_t.long(), 1) 292 | return target_one_hot 293 | 294 | 295 | def matching_iou(matching, predicted_labels, labels): 296 | """ 297 | Computes the iou score, starting from matching pairs 298 | coming out from hungarian matching solver. Note that 299 | it is assumed that iou is only computed over matched pairs. 300 | That is to say, if any column in the matched pair has zero 301 | number of points, that pair is not considered. 302 | """ 303 | batch_size = labels.shape[0] 304 | IOU = [] 305 | new_pred = [] 306 | for b in range(batch_size): 307 | iou_b = [] 308 | len_labels = np.unique(predicted_labels[b]).shape[0] 309 | rows, cols = matching[b] 310 | count = 0 311 | for r, c in zip(rows, cols): 312 | pred_indices = predicted_labels[b] == r 313 | gt_indices = labels[b] == c 314 | 315 | # if both input and predictions are empty, ignore that. 316 | if (np.sum(gt_indices) == 0) and (np.sum(pred_indices) == 0): 317 | continue 318 | iou = np.sum(np.logical_and(pred_indices, gt_indices)) / ( 319 | np.sum(np.logical_or(pred_indices, gt_indices)) + 1e-8) 320 | iou_b.append(iou) 321 | 322 | # find the mean of IOU over this shape 323 | IOU.append(np.mean(iou_b)) 324 | return np.mean(IOU) 325 | 326 | 327 | def relaxed_iou(pred, gt, max_clusters=50): 328 | batch_size, N, K = pred.shape 329 | normalize = torch.nn.functional.normalize 330 | one = torch.ones(1).cuda() 331 | 332 | norms_p = torch.sum(pred, 1) 333 | norms_g = torch.sum(gt, 1) 334 | cost = [] 335 | 336 | for b in range(batch_size): 337 | p = pred[b] 338 | g = gt[b] 339 | c_batch = [] 340 | dots = p.transpose(1, 0) @ g 341 | 342 | for k1 in range(K): 343 | c = [] 344 | for k2 in range(K): 345 | r_iou = dots[k1, k2] 346 | r_iou = r_iou / (norms_p[b, k1] + norms_g[b, k2] - dots[k1, k2] + 1e-7) 347 | if (r_iou < 0) or (r_iou > 1): 348 | import ipdb; 349 | ipdb.set_trace() 350 | c.append(r_iou) 351 | c_batch.append(c) 352 | cost.append(c_batch) 353 | return cost 354 | 355 | 356 | def relaxed_iou_fast(pred, gt, max_clusters=50): 357 | batch_size, N, K = pred.shape 358 | normalize = torch.nn.functional.normalize 359 | one = torch.ones(1).cuda() 360 | 361 | norms_p = torch.unsqueeze(torch.sum(pred, 1), 2) 362 | norms_g = torch.unsqueeze(torch.sum(gt, 1), 1) 363 | cost = [] 364 | 365 | for b in range(batch_size): 366 | p = pred[b] 367 | g = gt[b] 368 | c_batch = [] 369 | dots = p.transpose(1, 0) @ g 370 | r_iou = dots 371 | r_iou = r_iou / (norms_p[b] + norms_g[b] - dots + 1e-7) 372 | cost.append(r_iou) 373 | cost = torch.stack(cost, 0) 374 | return cost 375 | -------------------------------------------------------------------------------- /src/test_fitting_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def test_cone(): 6 | points, normals = fitting.sample_cone(np.array([0.0, 0.0, 0]), 7 | np.array([1, 1, 0]), np.pi / 3) 8 | 9 | apex, axis, theta = fitting.fit_cone_torch(torch.from_numpy(points), 10 | torch.from_numpy(normals), 11 | torch.from_numpy(np.ones((1000, 1)))) 12 | 13 | visualize_point_cloud(points, normals=normals, viz=True) 14 | 15 | new_points, new_normals = fitting.sample_cone(apex.data.numpy().reshape(3), 16 | axis.data.numpy().reshape(3), theta.item()) 17 | visualize_point_cloud(np.concatenate([points, new_points], 0), normals=np.concatenate([normals, new_normals], 0), 18 | viz=True) 19 | 20 | 21 | def test_cylinder(): 22 | points, normals = fitting.sample_cylinder(1, np.array([0, 0, 0]), np.array([1, 2, 0]) / np.sqrt(5)) 23 | 24 | points = points.astype(np.float32) 25 | normals = normals.astype(np.float32) 26 | weights = np.ones((100, 1), dtype=np.float32) 27 | axis, center, radius = fitting.fit_cylinder_torch(torch.from_numpy(points), 28 | torch.from_numpy(normals), torch.from_numpy(weights)) 29 | 30 | new_points, new_normals = fitting.sample_cylinder(1, center.data.numpy().reshape(3), axis.data.numpy().reshape(3)) 31 | 32 | visualize_point_cloud(points, normals=normals, viz=True) 33 | colors = np.zeros((200, 3)) 34 | colors[0:100, 0] = 1 35 | print(center, radius, axis) 36 | visualize_point_cloud(np.concatenate([points, new_points], 0), normals=np.concatenate([normals, new_normals], 0), 37 | colors=colors, viz=True) 38 | 39 | 40 | def test_sphere(): 41 | points, normals = fitting.sample_sphere(1, np.array([0, 0, 0])) 42 | print(np.mean(points, 0)) 43 | points = points.astype(np.float32) 44 | # normals = normals.astype(np.float32) 45 | weights = np.ones((1000, 1), dtype=np.float32) 46 | center, radius = fitting.fit_sphere_torch(torch.from_numpy(points), None, torch.from_numpy(weights)) 47 | # center, radius = fitting.fit_sphere_numpy(points, weights) 48 | print(center, radius) 49 | 50 | new_points, new_normals = fitting.sample_sphere(radius.item(), center.data.numpy().reshape(3)) 51 | 52 | visualize_point_cloud(points, normals=normals, viz=True) 53 | colors = np.zeros((200, 3)) 54 | colors[0:100, 0] = 1 55 | 56 | visualize_point_cloud(np.concatenate([points, new_points], 0), normals=np.concatenate([normals, new_normals], 0), 57 | colors=colors, viz=True) 58 | 59 | 60 | # normals = normals.astype(np.float32) 61 | def grad_check_sphere(): 62 | points, normals = fitting.sample_sphere(1, np.array([0, 0, 0])) 63 | points = points.astype(np.float64) 64 | weights = torch.from_numpy(np.ones((100, 1), dtype=np.float64)) 65 | weights.requires_grad = True 66 | 67 | def func(weights): 68 | center, radius = fitting.fit_sphere_torch(torch.from_numpy(points), weights) 69 | return torch.mean(center) 70 | 71 | print(gradcheck(func, weights)) 72 | 73 | 74 | def grad_check_cone(): 75 | points, normals = fitting.sample_cone(np.array([0.10, 1.0, 2.0]), 76 | np.array([1, 1, 0]), np.pi / 3) 77 | 78 | weights = torch.from_numpy(np.ones((1000, 1), dtype=np.float64)) 79 | weights.requires_grad = True 80 | 81 | def func_apex(weights): 82 | apex, axis, theta = fitting.fit_cone_torch(torch.from_numpy(points), 83 | torch.from_numpy(normals), 84 | weights) 85 | return torch.mean(theta) 86 | 87 | print(gradcheck(func_apex, weights)) 88 | 89 | 90 | def grad_check_cone(): 91 | points, normals = fitting.sample_cone(np.array([0.10, 1.0, 2.0]), 92 | np.array([1, 1, 0]), np.pi / 3) 93 | 94 | weights = torch.from_numpy(np.ones((1000, 1), dtype=np.float64)) 95 | weights.requires_grad = True 96 | 97 | def func_apex(weights): 98 | apex, axis, theta = fitting.fit_cone_torch(torch.from_numpy(points), 99 | torch.from_numpy(normals), 100 | weights) 101 | return torch.mean(theta) 102 | 103 | print(gradcheck(func_apex, weights)) 104 | 105 | 106 | def grad_check_cylinder(): 107 | points, normals = fitting.sample_cylinder(1, np.array([1, 2, 0.3]), np.array([1, 0, 1]) / np.sqrt(2)) 108 | 109 | points = points.astype(np.float64) 110 | normals = normals.astype(np.float64) 111 | weights = torch.from_numpy(np.ones((100, 1), dtype=np.float64)) 112 | weights.requires_grad = True 113 | 114 | def func_axis(weights): 115 | axis, center, radius = fitting.fit_cylinder_torch(torch.from_numpy(points), 116 | torch.from_numpy(normals), weights) 117 | # print(axis, center, radius) 118 | return torch.mean(axis) 119 | 120 | print(gradcheck(func_axis, weights)) 121 | -------------------------------------------------------------------------------- /src/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.fitting_utils import to_one_hot 4 | from src.mean_shift import MeanShift 5 | from src.segment_utils import SIOU_matched_segments 6 | from src.utils import chamfer_distance 7 | from src.utils import fit_surface_sample_points 8 | 9 | ms = MeanShift() 10 | 11 | 12 | def convert_to_one_hot(data): 13 | """ 14 | Given a tensor of N x D, converts it into one_hot 15 | by filling zeros to non max along every row. 16 | """ 17 | N, C = data.shape 18 | max_rows = torch.max(data, 1)[1] 19 | 20 | data = to_one_hot(max_rows, C) 21 | return data.float() 22 | 23 | 24 | def test(output, points, num_points=900): 25 | predicted_points, fitted_surfaces = fit_surface_sample_points( 26 | output.permute(0, 2, 1).data.cpu().numpy()[:, 0:num_points], 27 | points.permute(0, 2, 1).data.cpu().numpy()[:, 0:num_points], 28 | 30, 29 | ) 30 | distance = chamfer_distance( 31 | points.permute(0, 2, 1), predicted_points 32 | ) 33 | return distance, predicted_points, fitted_surfaces 34 | 35 | 36 | def IOU_from_embeddings(embedding, labels, primitives_log_prob, primitives, quantile, iterations=20): 37 | """ 38 | Starting from embedding, it first cluster the shape and 39 | then calculate the IOU scores 40 | """ 41 | # import ipdb; ipdb.set_trace() 42 | B = embedding.shape[0] 43 | embedding = embedding.permute(0, 2, 1) 44 | primitives_log_prob = primitives_log_prob.permute(0, 2, 1) 45 | 46 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=2) 47 | seg_IOUs = [] 48 | prim_IOUs = [] 49 | primitives_log_prob = torch.max(primitives_log_prob, 2)[1] 50 | primitives_log_prob = primitives_log_prob.data.cpu().numpy() 51 | 52 | for b in range(B): 53 | center, bandwidth, cluster_ids = ms.guard_mean_shift(embedding[b], quantile, iterations) 54 | weight = center @ torch.transpose(embedding[b], 1, 0) 55 | weight = convert_to_one_hot(weight) 56 | s_iou, p_iou, _ = SIOU_matched_segments(labels[b], cluster_ids.data.cpu().numpy(), primitives_log_prob[b], 57 | primitives[b].data.cpu().numpy(), weight.T.data.cpu().numpy()) 58 | seg_IOUs.append([s_iou]) 59 | prim_IOUs.append([p_iou]) 60 | return [seg_IOUs, prim_IOUs] 61 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import lap 4 | import numpy as np 5 | import open3d 6 | import open3d as o3d 7 | import torch 8 | from matplotlib import pyplot as plt 9 | from open3d import * 10 | from torch.autograd.variable import Variable 11 | 12 | from src.curve_utils import fit_surface 13 | from src.guard import guard_sqrt 14 | 15 | Vector3dVector, Vector3iVector = utility.Vector3dVector, utility.Vector3iVector 16 | draw_geometries = o3d.visualization.draw_geometries 17 | 18 | 19 | def get_rotation_matrix(theta): 20 | R = np.array([[np.cos(theta), np.sin(theta), 0], 21 | [-np.sin(theta), np.cos(theta), 0], 22 | [0, 0, 1]]) 23 | return R 24 | 25 | 26 | def rotation_matrix_a_to_b(A, B): 27 | """ 28 | Finds rotation matrix from vector A in 3d to vector B 29 | in 3d. 30 | B = R @ A 31 | """ 32 | EPS = 1e-8 33 | cos = np.dot(A, B) 34 | sin = np.linalg.norm(np.cross(B, A)) 35 | u = A 36 | v = B - np.dot(A, B) * A 37 | v = v / (np.linalg.norm(v) + EPS) 38 | w = np.cross(B, A) 39 | w = w / (np.linalg.norm(w) + EPS) 40 | F = np.stack([u, v, w], 1) 41 | G = np.array([[cos, -sin, 0], 42 | [sin, cos, 0], 43 | [0, 0, 1]]) 44 | # B = R @ A 45 | try: 46 | R = F @ G @ np.linalg.inv(F) 47 | except: 48 | R = np.eye(3, dtype=np.float32) 49 | return R 50 | 51 | 52 | def save_point_cloud(filename, data): 53 | np.savetxt(filename, data, delimiter=" ") 54 | 55 | 56 | def visualize_point_cloud(points, normals=[], colors=[], file="", viz=False): 57 | # pcd = PointCloud() 58 | pcd = geometry.PointCloud() 59 | pcd.points = Vector3dVector(points) 60 | 61 | # estimate_normals(pcd, search_param = KDTreeSearchParamHybrid( 62 | # radius = 0.1, max_nn = 30)) 63 | if isinstance(normals, np.ndarray): 64 | pcd.normals = Vector3dVector(normals) 65 | if isinstance(colors, np.ndarray): 66 | pcd.colors = Vector3dVector(colors) 67 | 68 | if file: 69 | write_point_cloud(file, pcd, write_ascii=True) 70 | 71 | if viz: 72 | draw_geometries([pcd]) 73 | return pcd 74 | 75 | 76 | def visualize_point_cloud_from_labels(points, labels, COLORS=None, normals=None, viz=False): 77 | if not isinstance(COLORS, np.ndarray): 78 | COLORS = np.random.rand(500, 3) 79 | 80 | colors = COLORS[labels] 81 | pcd = visualize_point_cloud(points, colors=colors, normals=normals, viz=viz) 82 | return pcd 83 | 84 | 85 | def sample_mesh_torch( 86 | v1, v2, v3, n, face_normals=[], rgb1=[], rgb2=[], rgb3=[], norms=False, rgb=False 87 | ): 88 | """ 89 | Samples mesh given its vertices 90 | :param rgb: 91 | :param v1: first vertex of the face, N x 3 92 | :param v2: second vertex of the face, N x 3 93 | :param v3: third vertex of the face, N x 3 94 | :param n: number of points to be sampled 95 | :return: 96 | """ 97 | areas = 0.5 * torch.norm(torch.cross(v2 - v1, v3 - v1), dim=1) 98 | # To avoid zero areas 99 | areas = areas + torch.min(areas) + 1e-8 100 | probabilities = areas / torch.sum(areas) 101 | face_ids = np.random.choice(np.arange(len(areas)), size=n, p=probabilities.data.cpu.numpy()) 102 | # import ipdb; ipdb.set_trace() 103 | 104 | v1 = v1[face_ids] 105 | v2 = v2[face_ids] 106 | v3 = v3[face_ids] 107 | 108 | # (n, 1) the 1 is for broadcasting 109 | u = np.random.rand(n, 1) 110 | v = np.random.rand(n, 1) 111 | is_a_problem = u + v > 1 112 | 113 | u[is_a_problem] = 1 - u[is_a_problem] 114 | v[is_a_problem] = 1 - v[is_a_problem] 115 | sample_points = (v1 * u) + (v2 * v) + ((1 - (u + v)) * v3) 116 | sample_points = sample_points.data.cpu().numpy() 117 | 118 | sample_point_normals = face_normals[face_ids].data.cpu().numpy() 119 | 120 | return sample_points, sample_point_normals 121 | 122 | 123 | def sample_mesh( 124 | v1, v2, v3, n, face_normals=[], rgb1=[], rgb2=[], rgb3=[], norms=False, rgb=False 125 | ): 126 | """ 127 | Samples mesh given its vertices 128 | :param rgb: 129 | :param v1: first vertex of the face, N x 3 130 | :param v2: second vertex of the face, N x 3 131 | :param v3: third vertex of the face, N x 3 132 | :param n: number of points to be sampled 133 | :return: 134 | """ 135 | areas = triangle_area_multi(v1, v2, v3) 136 | # To avoid zero areas 137 | areas = areas + np.min(areas) + 1e-10 138 | probabilities = areas / np.sum(areas) 139 | 140 | face_ids = np.random.choice(np.arange(len(areas)), size=n, p=probabilities) 141 | 142 | v1 = v1[face_ids] 143 | v2 = v2[face_ids] 144 | v3 = v3[face_ids] 145 | 146 | # (n, 1) the 1 is for broadcasting 147 | u = np.random.rand(n, 1) 148 | v = np.random.rand(n, 1) 149 | is_a_problem = u + v > 1 150 | 151 | u[is_a_problem] = 1 - u[is_a_problem] 152 | v[is_a_problem] = 1 - v[is_a_problem] 153 | sample_points = (v1 * u) + (v2 * v) + ((1 - (u + v)) * v3) 154 | sample_points = sample_points.astype(np.float32) 155 | 156 | sample_rgb = [] 157 | sample_normals = [] 158 | 159 | if rgb: 160 | v1_rgb = rgb1[face_ids, :] 161 | v2_rgb = rgb2[face_ids, :] 162 | v3_rgb = rgb3[face_ids, :] 163 | 164 | sample_rgb = (v1_rgb * u) + (v2_rgb * v) + ((1 - (u + v)) * v3_rgb) 165 | 166 | if norms: 167 | sample_point_normals = face_normals[face_ids] 168 | sample_point_normals = sample_point_normals.astype(np.float32) 169 | return sample_points, sample_point_normals, sample_rgb, face_ids 170 | else: 171 | return sample_points, sample_rgb, face_ids 172 | 173 | 174 | def triangle_area_multi(v1, v2, v3): 175 | """ v1, v2, v3 are (N,3) arrays. Each one represents the vertices 176 | such as v1[i], v2[i], v3[i] represent the ith triangle 177 | """ 178 | return 0.5 * np.linalg.norm(np.cross(v2 - v1, v3 - v1), axis=1) 179 | 180 | 181 | def visualize_uv_maps( 182 | output, root_path="data/uvmaps/", iter=0, grid_size=20, viz=False 183 | ): 184 | """ 185 | visualizes uv map using the output of the network 186 | :param output: 187 | :param root_path: 188 | :param iter: 189 | :return: 190 | """ 191 | os.makedirs(root_path, exist_ok=True) 192 | B = output.shape[0] 193 | for index in range(B): 194 | figure, a = plt.subplots(1, 3) 195 | uvmap = output[index, :].reshape((grid_size, grid_size, 2)) 196 | a[0].imshow(np.sum(uvmap, 2)) 197 | uvmap = output[index, :].reshape((grid_size, grid_size, 2)) 198 | 199 | for ind in range(grid_size): 200 | a[1].plot(uvmap[ind, :, 1]) 201 | for ind in range(grid_size): 202 | a[1].plot(uvmap[:, ind, 0]) 203 | temp = output[index, :].reshape((grid_size ** 2, 2)) 204 | a[2].scatter(temp[:, 0], temp[:, 1]) 205 | if viz: 206 | plt.show() 207 | plt.savefig("{}/plots_iter_{}.png".format(root_path, iter * B + index)) 208 | plt.close("all") 209 | np.save("{}/plots_iter_{}.npy".format(root_path, iter * B + index), uvmap) 210 | 211 | 212 | def visualize_fitted_surface(output, points, grid_size, viz=True, path="data/uvmaps/"): 213 | os.makedirs(path, exist_ok=True) 214 | nx, ny = (grid_size, grid_size) 215 | x = np.linspace(0, 1, nx) 216 | y = np.linspace(0, 1, ny) 217 | xv, yv = np.meshgrid(x, y) 218 | xv = np.expand_dims(xv.transpose().flatten(), 1) 219 | yv = np.expand_dims(yv.transpose().flatten(), 1) 220 | par = np.concatenate([xv, yv], 1) 221 | 222 | B = output.shape[0] 223 | predicted_points = [] 224 | surfaces = [] 225 | for index in range(B): 226 | uv = output[index] 227 | C = np.sum(np.square(np.expand_dims(uv, 1) - np.expand_dims(par, 0)), 2) 228 | cost, x, y = lap.lapjv(C) 229 | p = points[index] 230 | p = p[y] 231 | fitted_surface, fitted_points = fit_surface(p, grid_size, grid_size, 2, 2) 232 | fitted_points = fitted_points - np.expand_dims(np.mean(fitted_points, 0), 0) 233 | colors_gt = np.ones((np.array(points[index]).shape[0], 3)) 234 | colors_pred = np.ones((np.array(fitted_points).shape[0], 3)) 235 | colors_gt[:, 2] = 0 236 | colors_pred[:, 1] = 0 237 | color = np.concatenate([colors_gt, colors_pred]) 238 | p = np.concatenate([np.array(points[index]), np.array(fitted_points)]) 239 | pcd = visualize_point_cloud(p, colors=color, viz=viz) 240 | open3d.io.write_point_cloud("{}pcd_{}.pcd".format(path, index), pcd) 241 | predicted_points.append(fitted_points) 242 | surfaces.append(fitted_surface) 243 | predicted_points = np.stack(predicted_points, 0) 244 | return predicted_points, surfaces 245 | 246 | 247 | def fit_surface_sample_points(output, points, grid_size, regular_grids=False): 248 | nx, ny = (grid_size, grid_size) 249 | x = np.linspace(0, 1, nx) 250 | y = np.linspace(0, 1, ny) 251 | xv, yv = np.meshgrid(x, y) 252 | xv = np.expand_dims(xv.transpose().flatten(), 1) 253 | yv = np.expand_dims(yv.transpose().flatten(), 1) 254 | par = np.concatenate([xv, yv], 1) 255 | B = output.shape[0] 256 | predicted_points = [] 257 | fitted_surfaces = [] 258 | for index in range(B): 259 | uv = output[index] 260 | # TODO include the optimal rotation matrix 261 | C = np.sum(np.square(np.expand_dims(uv, 1) - np.expand_dims(par, 0)), 2) 262 | cost, x, y = lap.lapjv(C) 263 | p = points[index] 264 | p = p[y] 265 | fitted_surface, fitted_points = fit_surface(p, grid_size, grid_size, 2, 2, regular_grids) 266 | fitted_points = fitted_points - np.expand_dims(np.mean(fitted_points, 0), 0) 267 | predicted_points.append(fitted_points) 268 | fitted_surfaces.append(fitted_surface) 269 | predicted_points = np.stack(predicted_points, 0) 270 | return predicted_points, fitted_surfaces 271 | 272 | 273 | def chamfer_distance(pred, gt, sqrt=False): 274 | """ 275 | Computes average chamfer distance prediction and groundtruth 276 | :param pred: Prediction: B x N x 3 277 | :param gt: ground truth: B x M x 3 278 | :return: 279 | """ 280 | if isinstance(pred, np.ndarray): 281 | pred = Variable(torch.from_numpy(pred.astype(np.float32))).cuda() 282 | 283 | if isinstance(gt, np.ndarray): 284 | gt = Variable(torch.from_numpy(gt.astype(np.float32))).cuda() 285 | 286 | pred = torch.unsqueeze(pred, 1) 287 | gt = torch.unsqueeze(gt, 2) 288 | 289 | diff = pred - gt 290 | diff = torch.sum(diff ** 2, 3) 291 | if sqrt: 292 | diff = guard_sqrt(diff) 293 | 294 | cd = torch.mean(torch.min(diff, 1)[0], 1) + torch.mean(torch.min(diff, 2)[0], 1) 295 | cd = torch.mean(cd) / 2.0 296 | return cd 297 | 298 | 299 | def chamfer_distance_one_side(pred, gt, side=1): 300 | """ 301 | Computes average chamfer distance prediction and groundtruth 302 | but is one sided 303 | :param pred: Prediction: B x N x 3 304 | :param gt: ground truth: B x M x 3 305 | :return: 306 | """ 307 | if isinstance(pred, np.ndarray): 308 | pred = Variable(torch.from_numpy(pred.astype(np.float32))).cuda() 309 | 310 | if isinstance(gt, np.ndarray): 311 | gt = Variable(torch.from_numpy(gt.astype(np.float32))).cuda() 312 | 313 | pred = torch.unsqueeze(pred, 1) 314 | gt = torch.unsqueeze(gt, 2) 315 | 316 | diff = pred - gt 317 | diff = torch.sum(diff ** 2, 3) 318 | if side == 0: 319 | cd = torch.mean(torch.min(diff, 1)[0], 1) 320 | elif side == 1: 321 | cd = torch.mean(torch.min(diff, 2)[0], 1) 322 | cd = torch.mean(cd) 323 | return cd 324 | 325 | 326 | def chamfer_distance_single_shape(pred, gt, one_side=False, sqrt=False, reduce=True): 327 | """ 328 | Computes average chamfer distance prediction and groundtruth 329 | :param pred: Prediction: B x N x 3 330 | :param gt: ground truth: B x M x 3 331 | :return: 332 | """ 333 | if isinstance(pred, np.ndarray): 334 | pred = Variable(torch.from_numpy(pred.astype(np.float32))).cuda() 335 | 336 | if isinstance(gt, np.ndarray): 337 | gt = Variable(torch.from_numpy(gt.astype(np.float32))).cuda() 338 | pred = torch.unsqueeze(pred, 0) 339 | gt = torch.unsqueeze(gt, 1) 340 | 341 | diff = pred - gt 342 | diff = torch.sum(diff ** 2, 2) 343 | 344 | if sqrt: 345 | diff = guard_sqrt(diff) 346 | 347 | if one_side: 348 | cd = torch.min(diff, 1)[0] 349 | if reduce: 350 | cd = torch.mean(cd, 0) 351 | else: 352 | cd1 = torch.min(diff, 0)[0] 353 | cd2 = torch.min(diff, 1)[0] 354 | if reduce: 355 | cd1 = torch.mean(cd1) 356 | cd2 = torch.mean(cd2) 357 | cd = (cd1 + cd2) / 2.0 358 | return cd 359 | 360 | 361 | def rescale_input_outputs(scales, output, points, control_points, batch_size): 362 | """ 363 | In the case of anisotropic scaling, we need to rescale the tensors 364 | to original dimensions to compute the loss and eval metric. 365 | """ 366 | scales = np.stack(scales, 0).astype(np.float32) 367 | scales = torch.from_numpy(scales).cuda() 368 | scales = scales.reshape((batch_size, 1, 3)) 369 | output = ( 370 | output 371 | * scales 372 | / torch.max(scales.reshape((batch_size, 3)), 1)[0].reshape( 373 | (batch_size, 1, 1) 374 | ) 375 | ) 376 | points = ( 377 | points 378 | * scales.reshape((batch_size, 3, 1)) 379 | / torch.max(scales.reshape((batch_size, 3)), 1)[0].reshape( 380 | (batch_size, 1, 1) 381 | ) 382 | ) 383 | control_points = ( 384 | control_points 385 | * scales.reshape((batch_size, 1, 1, 3)) 386 | / torch.max(scales.reshape((batch_size, 3)), 1)[0].reshape( 387 | (batch_size, 1, 1, 1) 388 | ) 389 | ) 390 | return scales, output, points, control_points 391 | 392 | 393 | def grad_norm(model): 394 | total_norm = 0 395 | for p in model.parameters(): 396 | param_norm = p.grad.data.norm(2) 397 | total_norm += param_norm 398 | total_norm = total_norm.item() 399 | return np.isnan(total_norm) or np.isinf(total_norm) 400 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from open3d import * 4 | 5 | sys.path.append("../") 6 | 7 | import h5py 8 | 9 | import numpy as np 10 | import torch 11 | from src.fitting_utils import ( 12 | to_one_hot, 13 | ) 14 | import os 15 | from src.segment_utils import SIOU_matched_segments 16 | from src.utils import chamfer_distance_single_shape 17 | from src.segment_utils import sample_from_collection_of_mesh 18 | from src.primitives import SaveParameters 19 | from src.dataset_segments import Dataset 20 | from src.residual_utils import Evaluation 21 | import sys 22 | 23 | start = int(sys.argv[1]) 24 | end = int(sys.argv[2]) 25 | prefix = "" 26 | 27 | dataset = Dataset( 28 | 1, 29 | 24000, 30 | 4000, 31 | 4000, 32 | normals=True, 33 | primitives=True, 34 | if_train_data=False, 35 | prefix=prefix 36 | ) 37 | 38 | 39 | def continuous_labels(labels_): 40 | new_labels = np.zeros_like(labels_) 41 | for index, value in enumerate(np.sort(np.unique(labels_))): 42 | new_labels[labels_ == value] = index 43 | return new_labels 44 | 45 | 46 | # root_path = "data/shapes/test_data.h5" 47 | root_path = prefix + "data/shapes/test_data.h5" 48 | 49 | with h5py.File(root_path, "r") as hf: 50 | # N x 3 51 | test_points = np.array(hf.get("points")) 52 | 53 | # N x 1 54 | test_labels = np.array(hf.get("labels")) 55 | 56 | # N x 3 57 | test_normals = np.array(hf.get("normals")) 58 | 59 | # N x 1 60 | test_primitives = np.array(hf.get("prim")) 61 | 62 | method_name = "parsenet_with_normals.pth" 63 | 64 | root_path = prefix + "logs/results/{}/results/predictions.h5".format(method_name) 65 | print(root_path) 66 | with h5py.File(root_path, "r") as hf: 67 | print(list(hf.keys())) 68 | test_cluster_ids = np.array(hf.get("seg_id")).astype(np.int32) 69 | test_pred_primitives = np.array(hf.get("pred_primitives")) 70 | 71 | prim_ids = {} 72 | prim_ids[11] = "torus" 73 | prim_ids[1] = "plane" 74 | prim_ids[2] = "open-bspline" 75 | prim_ids[3] = "cone" 76 | prim_ids[4] = "cylinder" 77 | prim_ids[5] = "sphere" 78 | prim_ids[6] = "other" 79 | prim_ids[7] = "revolution" 80 | prim_ids[8] = "extrusion" 81 | prim_ids[9] = "closed-bspline" 82 | 83 | saveparameters = SaveParameters() 84 | 85 | root_path = "/mnt/nfs/work1/kalo/gopalsharma/Projects/surfacefitting/logs_curve_fitting/outputs/{}/" 86 | 87 | all_pred_meshes = [] 88 | all_input_points = [] 89 | all_input_labels = [] 90 | all_input_normals = [] 91 | all_cluster_ids = [] 92 | evaluation = Evaluation() 93 | all_segments = [] 94 | 95 | os.makedirs("../logs_curve_fitting/results/{}/results/".format(method_name), exist_ok=True) 96 | 97 | test_res = [] 98 | test_s_iou = [] 99 | test_p_iou = [] 100 | s_k_1s = [] 101 | s_k_2s = [] 102 | p_k_1s = [] 103 | p_k_2s = [] 104 | s_ks = [] 105 | p_ks = [] 106 | test_cds = [] 107 | 108 | for i in range(start, end): 109 | bw = 0.01 110 | points = test_points[i].astype(np.float32) 111 | normals = test_normals[i].astype(np.float32) 112 | 113 | labels = test_labels[i].astype(np.int32) 114 | labels = continuous_labels(labels) 115 | 116 | cluster_ids = test_cluster_ids[i].astype(np.int32) 117 | cluster_ids = continuous_labels(cluster_ids) 118 | weights = to_one_hot(cluster_ids, np.unique(cluster_ids).shape[0]) 119 | 120 | points, normals = dataset.normalize_points(points, normals) 121 | torch.cuda.empty_cache() 122 | with torch.no_grad(): 123 | # if_visualize=True, will give you all segments 124 | # if_sample=True will return segments as trimmed meshes 125 | # if_optimize=True will optimize the spline surface patches 126 | _, parameters, newer_pred_mesh = evaluation.residual_eval_mode( 127 | torch.from_numpy(points).cuda(), 128 | torch.from_numpy(normals).cuda(), 129 | labels, 130 | cluster_ids, 131 | test_primitives[i], 132 | test_pred_primitives[i], 133 | weights.T, 134 | bw, 135 | sample_points=True, 136 | if_optimize=False, 137 | if_visualize=True, 138 | epsilon=0.1) 139 | 140 | torch.cuda.empty_cache() 141 | s_iou, p_iou, _, _ = SIOU_matched_segments( 142 | labels, 143 | cluster_ids, 144 | test_pred_primitives[i], 145 | test_primitives[i], 146 | weights, 147 | ) 148 | 149 | test_s_iou.append(s_iou) 150 | test_p_iou.append(p_iou) 151 | 152 | try: 153 | Points = sample_from_collection_of_mesh(newer_pred_mesh) 154 | except Exception as e: 155 | print("error in sample_from_collection_of_mesh method", e) 156 | continue 157 | cd1 = chamfer_distance_single_shape(torch.from_numpy(Points).cuda(), torch.from_numpy(points).cuda(), sqrt=True, 158 | one_side=True, reduce=False) 159 | cd2 = chamfer_distance_single_shape(torch.from_numpy(points).cuda(), torch.from_numpy(Points).cuda(), sqrt=True, 160 | one_side=True, reduce=False) 161 | 162 | s_k_1s.append(torch.mean((cd1 < 0.01).float()).item()) 163 | s_k_2s.append(torch.mean((cd1 < 0.02).float()).item()) 164 | s_ks.append(torch.mean(cd1).item()) 165 | p_k_1s.append(torch.mean((cd2 < 0.01).float()).item()) 166 | p_k_2s.append(torch.mean((cd2 < 0.02).float()).item()) 167 | p_ks.append(torch.mean(cd2).item()) 168 | test_cds.append((s_ks[-1] + p_ks[-1]) / 2.0) 169 | 170 | results = {"sk_1": s_k_1s[-1], 171 | "sk_2": s_k_2s[-1], 172 | "sk": s_ks[-1], 173 | "pk_1": p_k_1s[-1], 174 | "pk_2": p_k_2s[-1], 175 | "pk": p_ks[-1], 176 | "cd": test_cds[-1], 177 | "p_iou": p_iou, 178 | "s_iou": s_iou} 179 | 180 | print(i, s_iou, p_iou, test_cds[-1]) 181 | 182 | print("Test CD: {}, Test p cover: {}, Test s cover: {}".format(np.mean(test_cds), np.mean(s_ks), np.mean(p_ks))) 183 | print("iou seg: {}, iou prim type: {}".format(np.mean(test_s_iou), np.mean(test_p_iou))) 184 | -------------------------------------------------------------------------------- /test_closed_control_points.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import open3d 6 | import torch.utils.data 7 | from open3d import * 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | 11 | from read_config import Config 12 | from src.VisUtils import tessalate_points 13 | from src.dataset import DataSetControlPointsPoisson 14 | from src.dataset import generator_iter 15 | from src.fitting_utils import sample_points_from_control_points_ 16 | from src.fitting_utils import up_sample_points_torch_in_range 17 | from src.loss import control_points_permute_reg_loss 18 | from src.loss import laplacian_loss 19 | from src.loss import ( 20 | uniform_knot_bspline, 21 | spline_reconstruction_loss, 22 | ) 23 | from src.model import DGCNNControlPoints 24 | from src.primitive_forward import optimize_close_spline 25 | from src.utils import chamfer_distance_single_shape 26 | 27 | config = Config(sys.argv[1]) 28 | 29 | userspace = ".." 30 | print(config.mode) 31 | control_decoder = DGCNNControlPoints(20, num_points=10, mode=config.mode) 32 | control_decoder = torch.nn.DataParallel(control_decoder) 33 | control_decoder.cuda() 34 | config.batch_size = 1 35 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 36 | 37 | dataset = DataSetControlPointsPoisson( 38 | path=config.dataset_path, 39 | batch_size=config.batch_size, 40 | splits=split_dict, 41 | size_v=config.grid_size, 42 | size_u=config.grid_size, 43 | closed=True 44 | ) 45 | 46 | nu, nv = uniform_knot_bspline(20, 20, 3, 3, 30) 47 | nu = torch.from_numpy(nu.astype(np.float32)).cuda() 48 | nv = torch.from_numpy(nv.astype(np.float32)).cuda() 49 | 50 | nu_3, nv_3 = uniform_knot_bspline(31, 30, 3, 3, 50) 51 | nu_3 = torch.from_numpy(nu_3.astype(np.float32)).cuda() 52 | nv_3 = torch.from_numpy(nv_3.astype(np.float32)).cuda() 53 | 54 | # We want to gather the regular grid points for tesellation 55 | align_canonical = True 56 | anisotropic = True 57 | if_augmentation = False 58 | if_rand_num_points = False 59 | if_upsample = False 60 | visualize = True 61 | if_optimize = True 62 | 63 | os.makedirs( 64 | "logs/results/{}/".format(config.pretrain_model_path), 65 | exist_ok=True, 66 | ) 67 | 68 | config.num_points = 700 69 | 70 | get_test_data = dataset.load_test_data( 71 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic, if_augment=if_augmentation 72 | ) 73 | loader = generator_iter(get_test_data, int(1e10)) 74 | get_test_data = iter( 75 | DataLoader( 76 | loader, 77 | batch_size=1, 78 | shuffle=False, 79 | collate_fn=lambda x: x, 80 | num_workers=0, 81 | pin_memory=False, 82 | ) 83 | ) 84 | 85 | control_decoder.load_state_dict( 86 | torch.load("logs/pretrained_models/" + config.pretrain_model_path) 87 | ) 88 | 89 | distances = [] 90 | test_reg = [] 91 | test_cd = [] 92 | test_str = [] 93 | 94 | count = 0 95 | test_lap = [] 96 | 97 | control_decoder.eval() 98 | 99 | for val_b_id in range(config.num_test // config.batch_size - 1): 100 | points_, parameters, control_points, scales, RS = next(get_test_data)[0] 101 | 102 | control_points = Variable( 103 | torch.from_numpy(control_points.astype(np.float32)) 104 | ).cuda() 105 | 106 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 107 | points = points.permute(0, 2, 1) 108 | 109 | if if_rand_num_points: 110 | rand_num_points = config.num_points + np.random.choice(np.arange(-200, 200), 1)[0] 111 | else: 112 | rand_num_points = config.num_points 113 | 114 | with torch.no_grad(): 115 | L = np.arange(points.shape[2]) 116 | np.random.shuffle(L) 117 | new_points = points[:, :, L[0:rand_num_points]] 118 | 119 | if if_upsample: 120 | new_points = up_sample_points_torch_in_range(new_points[0].permute(1, 0), 1200, 1800).permute(1, 0) 121 | new_points = torch.unsqueeze(new_points, 0) 122 | 123 | output = control_decoder(new_points) 124 | 125 | for b in range(config.batch_size): 126 | # re-alinging back to original orientation for better comparison 127 | if anisotropic: 128 | s = torch.from_numpy(scales[b].astype(np.float32)).cuda() 129 | output[b] = output[b] * s.reshape(1, 3) / torch.max(s) 130 | points[b] = points[b] * s.reshape(3, 1) / torch.max(s) 131 | control_points[b] = ( 132 | control_points[b] * s.reshape(1, 1, 3) / torch.max(s) 133 | ) 134 | 135 | # Chamfer Distance loss, between predicted and GT surfaces 136 | cd, reconstructed_points = spline_reconstruction_loss( 137 | nu, nv, output, points, config, sqrt=True 138 | ) 139 | 140 | temp = reconstructed_points[b].reshape((30, 30, 3)) 141 | temp = torch.cat([temp, temp[0:1]], 0) 142 | temp = torch.unsqueeze(temp, 0) 143 | 144 | if if_optimize: 145 | new_points = optimize_close_spline(temp, points.permute(0, 2, 1)) 146 | optimized_points = new_points.clone() 147 | cd = chamfer_distance_single_shape(new_points[0], points[0].permute(1, 0), sqrt=True) 148 | 149 | l_reg, permute_cp = control_points_permute_reg_loss( 150 | output, control_points, config.grid_size 151 | ) 152 | laplac_loss = laplacian_loss( 153 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 154 | permute_cp, 155 | dist_type="l2", 156 | ) 157 | 158 | test_reg.append(l_reg.data.cpu().numpy()) 159 | test_cd.append(cd.data.cpu().numpy()) 160 | test_lap.append(laplac_loss.data.cpu().numpy()) 161 | 162 | print(val_b_id, cd.item()) 163 | if visualize: 164 | pred_meshes = [] 165 | gt_meshes = [] 166 | reconstructed_points = reconstructed_points.data.cpu().numpy() 167 | control_points = control_points.reshape((config.batch_size, 400, 3)) 168 | for b in range(config.batch_size): 169 | temp = reconstructed_points[b].reshape((30, 30, 3)) 170 | temp = np.concatenate([temp, temp[0:1]], 0) 171 | pred_mesh = tessalate_points(temp, 31, 30) 172 | pred_mesh.paint_uniform_color([1, 0.0, 0]) 173 | 174 | gt_points = sample_points_from_control_points_(nu, nv, control_points[b:b + 1], 1).data.cpu().numpy() 175 | temp = gt_points[b].reshape((30, 30, 3)) 176 | gt_points = np.concatenate([temp, temp[0:1]], 0) 177 | gt_mesh = tessalate_points(gt_points, 31, 30) 178 | 179 | temp = optimized_points[0].reshape((31, 30, 3)) 180 | optimized_points = torch.cat([temp, temp[0:1]], 0) 181 | optim_mesh = tessalate_points(optimized_points.data.cpu().numpy(), 32, 30) 182 | 183 | open3d.io.write_triangle_mesh( 184 | "logs/results/{}/gt_{}.ply".format( 185 | config.pretrain_model_path, val_b_id * config.batch_size + b 186 | ), 187 | gt_mesh, 188 | ) 189 | open3d.io.write_triangle_mesh( 190 | "logs/results/{}/pred_{}.ply".format( 191 | config.pretrain_model_path, val_b_id * config.batch_size + b 192 | ), 193 | pred_mesh, 194 | ) 195 | open3d.io.write_triangle_mesh( 196 | "logs/results/{}/optim_{}.ply".format( 197 | config.pretrain_model_path, val_b_id * config.batch_size + b 198 | ), 199 | optim_mesh, 200 | ) 201 | 202 | results = {} 203 | results["test_reg"] = str(np.mean(test_reg)) 204 | results["test_cd"] = str(np.mean(test_cd)) 205 | results["test_lap"] = str(np.mean(test_lap)) 206 | 207 | print( 208 | "Test Reg Loss: {}, Test CD Loss: {}, Test Lap: {}".format( 209 | np.mean(test_reg), np.mean(test_cd), np.mean(test_lap) 210 | ) 211 | ) 212 | -------------------------------------------------------------------------------- /test_open_splines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import open3d 6 | import torch.utils.data 7 | from open3d import * 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | 11 | from read_config import Config 12 | from src.VisUtils import tessalate_points 13 | from src.dataset import DataSetControlPointsPoisson 14 | from src.dataset import generator_iter 15 | from src.fitting_utils import sample_points_from_control_points_ 16 | from src.fitting_utils import up_sample_points_torch_in_range 17 | from src.loss import control_points_permute_reg_loss 18 | from src.loss import laplacian_loss 19 | from src.loss import ( 20 | uniform_knot_bspline, 21 | spline_reconstruction_loss, 22 | ) 23 | from src.model import DGCNNControlPoints 24 | from src.primitive_forward import optimize_open_spline 25 | 26 | config = Config(sys.argv[1]) 27 | 28 | control_decoder = DGCNNControlPoints(20, num_points=10, mode=config.mode) 29 | control_decoder = torch.nn.DataParallel(control_decoder) 30 | control_decoder.cuda() 31 | 32 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 33 | 34 | dataset = DataSetControlPointsPoisson( 35 | config.dataset_path, 36 | config.batch_size, 37 | splits=split_dict, 38 | size_v=config.grid_size, 39 | size_u=config.grid_size) 40 | 41 | nu, nv = uniform_knot_bspline(20, 20, 3, 3, 30) 42 | nu = torch.from_numpy(nu.astype(np.float32)).cuda() 43 | nv = torch.from_numpy(nv.astype(np.float32)).cuda() 44 | 45 | nu_3, nv_3 = uniform_knot_bspline(30, 30, 3, 3, 50) 46 | nu_3 = torch.from_numpy(nu_3.astype(np.float32)).cuda() 47 | nv_3 = torch.from_numpy(nv_3.astype(np.float32)).cuda() 48 | 49 | align_canonical = True 50 | anisotropic = True 51 | if_augment = False 52 | if_rand_points = False 53 | if_optimize = False 54 | if_save_meshes = True 55 | if_upsample = False 56 | 57 | get_test_data = dataset.load_test_data( 58 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic, 59 | if_augment=if_augment) 60 | loader = generator_iter(get_test_data, int(1e10)) 61 | get_test_data = iter( 62 | DataLoader( 63 | loader, 64 | batch_size=1, 65 | shuffle=False, 66 | collate_fn=lambda x: x, 67 | num_workers=0, 68 | pin_memory=False, 69 | ) 70 | ) 71 | 72 | control_decoder.load_state_dict( 73 | torch.load("logs/pretrained_models/" + config.pretrain_model_path) 74 | ) 75 | os.makedirs( 76 | "logs/results/{}/".format(config.pretrain_model_path), 77 | exist_ok=True, 78 | ) 79 | 80 | distances = [] 81 | test_reg = [] 82 | test_cd = [] 83 | test_str = [] 84 | test_lap = [] 85 | config.num_points = 700 86 | control_decoder.eval() 87 | for val_b_id in range(config.num_test // config.batch_size - 2): 88 | points_, parameters, control_points, scales, RS = next(get_test_data)[0] 89 | control_points = Variable( 90 | torch.from_numpy(control_points.astype(np.float32)) 91 | ).cuda() 92 | 93 | points_ = points_ 94 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 95 | points = points.permute(0, 2, 1) 96 | 97 | with torch.no_grad(): 98 | if if_rand_points: 99 | num_points = config.num_points + np.random.choice(np.arange(-200, 200), 1)[0] 100 | else: 101 | num_points = config.num_points 102 | L = np.arange(points.shape[2]) 103 | np.random.shuffle(L) 104 | new_points = points[:, :, L[0:num_points]] 105 | 106 | if if_upsample: 107 | new_points = up_sample_points_torch_in_range(new_points[0].permute(1, 0), 800, 1200).permute(1, 0) 108 | new_points = torch.unsqueeze(new_points, 0) 109 | output = control_decoder(new_points) 110 | 111 | for b in range(config.batch_size): 112 | # re-alinging back to original orientation for better comparison 113 | if anisotropic: 114 | s = torch.from_numpy(scales[b].astype(np.float32)).cuda() 115 | output[b] = output[b] * s.reshape(1, 3) / torch.max(s) 116 | points[b] = points[b] * s.reshape(3, 1) / torch.max(s) 117 | control_points[b] = ( 118 | control_points[b] * s.reshape(1, 1, 3) / torch.max(s) 119 | ) 120 | 121 | # Chamfer Distance loss, between predicted and GT surfaces 122 | cd, reconstructed_points = spline_reconstruction_loss( 123 | nu, nv, output, points, config, sqrt=True 124 | ) 125 | 126 | if if_optimize: 127 | new_points = optimize_open_spline(reconstructed_points, points.permute(0, 2, 1)) 128 | 129 | cd, optimized_points = spline_reconstruction_loss(nu_3, nv_3, new_points, points, config, sqrt=True) 130 | optimized_points = optimized_points.data.cpu().numpy() 131 | 132 | l_reg, permute_cp = control_points_permute_reg_loss( 133 | output, control_points, config.grid_size 134 | ) 135 | 136 | laplac_loss = laplacian_loss( 137 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 138 | permute_cp, 139 | dist_type="l2", 140 | ) 141 | 142 | test_reg.append(l_reg.data.cpu().numpy()) 143 | test_cd.append(cd.data.cpu().numpy()) 144 | test_lap.append(laplac_loss.data.cpu().numpy()) 145 | print(val_b_id) 146 | if if_save_meshes: 147 | reconstructed_points = reconstructed_points.data.cpu().numpy() 148 | reg_points = sample_points_from_control_points_(nu, nv, control_points, config.batch_size, 149 | input_size_u=20, input_size_v=20).data.cpu().numpy() 150 | 151 | # Save the predictions. 152 | for b in range(config.batch_size): 153 | if align_canonical: 154 | # to bring back into cannonical orientation. 155 | new_points = np.linalg.inv(RS[b]) @ reconstructed_points[b].T 156 | reconstructed_points[b] = new_points.T 157 | 158 | new_points = np.linalg.inv(RS[b]) @ reg_points[b].T 159 | reg_points[b] = new_points.T 160 | 161 | if if_optimize: 162 | new_points = np.linalg.inv(RS[b]) @ optimized_points[b].T 163 | optimized_points[b] = new_points.T 164 | 165 | pred_mesh = tessalate_points(reconstructed_points[b], 30, 30) 166 | pred_mesh.paint_uniform_color([1, 0, 0]) 167 | 168 | gt_mesh = tessalate_points(reg_points[b], 30, 30) 169 | 170 | open3d.io.write_triangle_mesh( 171 | "logs/results/{}/gt_{}.ply".format( 172 | config.pretrain_model_path, val_b_id 173 | ), 174 | gt_mesh, 175 | ) 176 | open3d.io.write_triangle_mesh( 177 | "logs/results/{}/pred_{}.ply".format( 178 | config.pretrain_model_path, val_b_id 179 | ), 180 | pred_mesh, 181 | ) 182 | 183 | if if_optimize: 184 | optim_mesh = tessalate_points(optimized_points[b], 50, 50) 185 | open3d.io.write_triangle_mesh( 186 | "logs/results/{}/optim_{}.ply".format( 187 | config.pretrain_model_path, val_b_id 188 | ), 189 | optim_mesh, 190 | ) 191 | 192 | results = {} 193 | results["test_reg"] = str(np.mean(test_reg)) 194 | results["test_cd"] = str(np.mean(test_cd)) 195 | results["test_lap"] = str(np.mean(test_lap)) 196 | print(results) 197 | print( 198 | "Test Reg Loss: {}, Test CD Loss: {}, Test Lap: {}".format( 199 | np.mean(test_reg), np.mean(test_cd), np.mean(test_lap) 200 | ) 201 | ) 202 | -------------------------------------------------------------------------------- /train_closed_control_points.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | from shutil import copyfile 5 | 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from tensorboard_logger import configure, log_value 10 | from torch.autograd import Variable 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | from torch.utils.data import DataLoader 13 | 14 | from read_config import Config 15 | from src.dataset import DataSetControlPointsPoisson 16 | from src.dataset import generator_iter 17 | from src.loss import control_points_permute_closed_reg_loss 18 | from src.loss import laplacian_loss 19 | from src.loss import ( 20 | uniform_knot_bspline, 21 | spline_reconstruction_loss_one_sided, 22 | ) 23 | from src.model import DGCNNControlPoints 24 | from src.utils import rescale_input_outputs 25 | 26 | np.set_printoptions(precision=4) 27 | 28 | config = Config(sys.argv[1]) 29 | model_name = config.model_path.format( 30 | config.mode, 31 | config.num_points, 32 | config.loss_weight, 33 | config.batch_size, 34 | config.lr, 35 | config.num_train, 36 | config.num_test, 37 | config.loss_weight, 38 | ) 39 | 40 | print(model_name) 41 | userspace = ".." 42 | 43 | configure("logs/tensorboard/{}".format(model_name), flush_secs=5) 44 | 45 | logger = logging.getLogger(__name__) 46 | logger.setLevel(logging.INFO) 47 | handler = logging.StreamHandler(sys.stdout) 48 | formatter = logging.Formatter("%(asctime)s:%(name)s:%(message)s") 49 | file_handler = logging.FileHandler( 50 | "logs/logs/{}.log".format(model_name), mode="w" 51 | ) 52 | file_handler.setFormatter(formatter) 53 | logger.addHandler(file_handler) 54 | logger.addHandler(handler) 55 | 56 | with open( 57 | "logs/configs/{}_config.json".format(model_name), "w" 58 | ) as file: 59 | json.dump(vars(config), file) 60 | source_file = __file__ 61 | destination_file = "logs/scripts/{}_{}".format( 62 | model_name, __file__.split("/")[-1] 63 | ) 64 | copyfile(source_file, destination_file) 65 | 66 | control_decoder = DGCNNControlPoints(20, num_points=10, mode=config.mode) 67 | if torch.cuda.device_count() > 1: 68 | control_decoder = torch.nn.DataParallel(control_decoder) 69 | 70 | control_decoder.cuda() 71 | 72 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 73 | 74 | dataset = DataSetControlPointsPoisson( 75 | path=config.dataset_path, 76 | batch_size=config.batch_size, 77 | splits=split_dict, 78 | size_v=config.grid_size, 79 | size_u=config.grid_size, 80 | closed=True 81 | ) 82 | 83 | align_canonical = True 84 | anisotropic = True 85 | if_augmentation = True 86 | if_rand_num_points = True 87 | 88 | get_train_data = dataset.load_train_data( 89 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic, if_augment=if_augmentation 90 | ) 91 | get_val_data = dataset.load_val_data( 92 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic 93 | ) 94 | 95 | loader = generator_iter(get_train_data, int(1e10)) 96 | get_train_data = iter( 97 | DataLoader( 98 | loader, 99 | batch_size=1, 100 | shuffle=False, 101 | collate_fn=lambda x: x, 102 | num_workers=0, 103 | pin_memory=False, 104 | ) 105 | ) 106 | 107 | loader = generator_iter(get_val_data, int(1e10)) 108 | get_val_data = iter( 109 | DataLoader( 110 | loader, 111 | batch_size=1, 112 | shuffle=False, 113 | collate_fn=lambda x: x, 114 | num_workers=0, 115 | pin_memory=False, 116 | ) 117 | ) 118 | 119 | optimizer = optim.Adam(control_decoder.parameters(), lr=config.lr) 120 | scheduler = ReduceLROnPlateau( 121 | optimizer, mode="min", factor=0.5, patience=10, verbose=True, min_lr=3e-5 122 | ) 123 | 124 | nu, nv = uniform_knot_bspline(20, 20, 3, 3, 30) 125 | nu = torch.from_numpy(nu.astype(np.float32)).cuda() 126 | nv = torch.from_numpy(nv.astype(np.float32)).cuda() 127 | 128 | prev_test_cd = 1e8 129 | for e in range(config.epochs): 130 | train_reg = [] 131 | train_str = [] 132 | train_cd = [] 133 | train_lap = [] 134 | control_decoder.train() 135 | for train_b_id in range(config.num_train // config.batch_size): 136 | torch.cuda.empty_cache() 137 | optimizer.zero_grad() 138 | points_, parameters, control_points, scales, _ = next(get_train_data)[0] 139 | control_points = Variable( 140 | torch.from_numpy(control_points.astype(np.float32)) 141 | ).cuda() 142 | 143 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 144 | points = points.permute(0, 2, 1) 145 | 146 | if if_rand_num_points: 147 | rand_num_points = config.num_points + np.random.choice(np.arange(-300, 1300), 1)[0] 148 | else: 149 | rand_num_points = config.num_points 150 | 151 | output = control_decoder(points[:, :, 0:rand_num_points]) 152 | if anisotropic: 153 | # rescale all tensors to original dimensions for evaluation 154 | scales, output, points, control_points = rescale_input_outputs(scales, output, points, control_points, 155 | config.batch_size) 156 | 157 | # Chamfer Distance loss, between predicted and GT surfaces 158 | cd, reconstructed_points = spline_reconstruction_loss_one_sided( 159 | nu, nv, output, points, config 160 | ) 161 | 162 | # permute_cp has the best permutation of gt control points grid 163 | l_reg, permute_cp = control_points_permute_closed_reg_loss( 164 | output, control_points, config.grid_size, 20 165 | ) 166 | 167 | laplac_loss = laplacian_loss( 168 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 169 | permute_cp, 170 | dist_type="l2", 171 | ) 172 | 173 | loss = l_reg * config.loss_weight + (cd) * (1 - config.loss_weight) # laplac_loss 174 | loss.backward() 175 | train_cd.append(cd.data.cpu().numpy()) 176 | train_reg.append(l_reg.data.cpu().numpy()) 177 | train_lap.append(laplac_loss.data.cpu().numpy()) 178 | optimizer.step() 179 | log_value( 180 | "cd", 181 | cd.data.cpu().numpy(), 182 | train_b_id + e * (config.num_train // config.batch_size), 183 | ) 184 | log_value( 185 | "l_reg", 186 | l_reg.data.cpu().numpy(), 187 | train_b_id + e * (config.num_train // config.batch_size), 188 | ) 189 | log_value( 190 | "l_lap", 191 | laplac_loss.data.cpu().numpy(), 192 | train_b_id + e * (config.num_train // config.batch_size), 193 | ) 194 | print( 195 | "\rEpoch: {} iter: {}, loss: {}".format( 196 | e, train_b_id, loss.item() 197 | ), 198 | end="", 199 | ) 200 | 201 | distances = [] 202 | test_reg = [] 203 | test_cd = [] 204 | test_str = [] 205 | test_lap = [] 206 | control_decoder.eval() 207 | 208 | for val_b_id in range(config.num_test // config.batch_size - 1): 209 | torch.cuda.empty_cache() 210 | points_, parameters, control_points, scales, _ = next(get_val_data)[0] 211 | 212 | control_points = Variable( 213 | torch.from_numpy(control_points.astype(np.float32)) 214 | ).cuda() 215 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 216 | 217 | points = points.permute(0, 2, 1) 218 | with torch.no_grad(): 219 | output = control_decoder(points[:, :, 0:config.num_points]) 220 | if anisotropic: 221 | # rescale all tensors to original dimensions for evaluation 222 | scales, output, points, control_points = rescale_input_outputs(scales, output, points, control_points, 223 | config.batch_size) 224 | 225 | # Chamfer Distance loss, between predicted and GT surfaces 226 | cd, reconstructed_points = spline_reconstruction_loss_one_sided( 227 | nu, nv, output, points, config 228 | ) 229 | l_reg, permute_cp = control_points_permute_closed_reg_loss( 230 | output, control_points, config.grid_size, 20 231 | ) 232 | laplac_loss = laplacian_loss( 233 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 234 | permute_cp, 235 | dist_type="l2", 236 | ) 237 | 238 | loss = l_reg * config.loss_weight + (cd + laplac_loss) * ( 239 | 1 - config.loss_weight 240 | ) 241 | test_reg.append(l_reg.data.cpu().numpy()) 242 | test_cd.append(cd.data.cpu().numpy()) 243 | test_lap.append(laplac_loss.data.cpu().numpy()) 244 | 245 | print("\n") 246 | logger.info( 247 | "Epoch: {}/{} => Tr lreg: {}, Ts loss: {}, Tr CD: {}, Ts CD: {}, Tr lap: {}, Ts lap: {}".format( 248 | e, 249 | config.epochs, 250 | np.mean(train_reg), 251 | np.mean(test_reg), 252 | np.mean(train_cd), 253 | np.mean(test_cd), 254 | np.mean(train_lap), 255 | np.mean(test_lap), 256 | ) 257 | ) 258 | 259 | log_value("train_cd", np.mean(train_cd), e) 260 | log_value("test_cd", np.mean(test_cd), e) 261 | log_value("train_reg", np.mean(train_reg), e) 262 | log_value("test_reg", np.mean(test_reg), e) 263 | 264 | scheduler.step(np.mean(test_cd)) 265 | if prev_test_cd > np.mean(test_cd): 266 | logger.info("CD improvement, saving model at epoch: {}".format(e)) 267 | prev_test_cd = np.mean(test_cd) 268 | torch.save( 269 | control_decoder.state_dict(), 270 | "logs/trained_models/{}.pth".format(model_name), 271 | ) 272 | -------------------------------------------------------------------------------- /train_open_splines.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | from shutil import copyfile 6 | 7 | import numpy as np 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from tensorboard_logger import configure, log_value 11 | from torch.autograd import Variable 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.utils.data import DataLoader 14 | 15 | from read_config import Config 16 | from src.dataset import DataSetControlPointsPoisson 17 | from src.dataset import generator_iter 18 | from src.loss import ( 19 | control_points_permute_reg_loss, 20 | ) 21 | from src.loss import laplacian_loss 22 | from src.loss import ( 23 | uniform_knot_bspline, 24 | spline_reconstruction_loss_one_sided, 25 | ) 26 | from src.model import DGCNNControlPoints 27 | from src.utils import rescale_input_outputs 28 | 29 | np.set_printoptions(precision=4) 30 | 31 | config = Config(sys.argv[1]) 32 | 33 | model_name = config.model_path.format( 34 | config.mode, 35 | config.num_points, 36 | config.loss_weight, 37 | config.batch_size, 38 | config.lr, 39 | config.num_train, 40 | config.num_test, 41 | config.loss_weight, 42 | ) 43 | 44 | print("Model name: ", model_name) 45 | print(config.config) 46 | 47 | userspace = os.path.dirname(os.path.abspath(__file__)) 48 | configure("logs/tensorboard/{}".format(model_name), flush_secs=5) 49 | 50 | logger = logging.getLogger(__name__) 51 | logger.setLevel(logging.INFO) 52 | handler = logging.StreamHandler(sys.stdout) 53 | formatter = logging.Formatter("%(asctime)s:%(name)s:%(message)s") 54 | file_handler = logging.FileHandler( 55 | "logs/logs/{}.log".format(model_name), mode="w" 56 | ) 57 | file_handler.setFormatter(formatter) 58 | logger.addHandler(file_handler) 59 | logger.addHandler(handler) 60 | 61 | with open( 62 | "logs/configs/{}_config.json".format(model_name), "w" 63 | ) as file: 64 | json.dump(vars(config), file) 65 | 66 | source_file = __file__ 67 | destination_file = "../logs_curve_fitting/scripts/{}_{}".format( 68 | model_name, __file__.split("/")[-1] 69 | ) 70 | copyfile(source_file, destination_file) 71 | 72 | control_decoder = DGCNNControlPoints(20, num_points=10, mode=config.mode) 73 | 74 | if torch.cuda.device_count() > 1: 75 | control_decoder = torch.nn.DataParallel(control_decoder) 76 | control_decoder.cuda() 77 | 78 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 79 | 80 | align_canonical = True 81 | anisotropic = True 82 | if_augment = True 83 | 84 | dataset = DataSetControlPointsPoisson( 85 | config.dataset_path, 86 | config.batch_size, 87 | splits=split_dict, 88 | size_v=config.grid_size, 89 | size_u=config.grid_size) 90 | 91 | get_train_data = dataset.load_train_data( 92 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic, if_augment=if_augment 93 | ) 94 | 95 | get_val_data = dataset.load_val_data( 96 | if_regular_points=True, align_canonical=align_canonical, anisotropic=anisotropic 97 | ) 98 | 99 | loader = generator_iter(get_train_data, int(1e10)) 100 | get_train_data = iter( 101 | DataLoader( 102 | loader, 103 | batch_size=1, 104 | shuffle=False, 105 | collate_fn=lambda x: x, 106 | num_workers=0, 107 | pin_memory=False, 108 | ) 109 | ) 110 | 111 | loader = generator_iter(get_val_data, int(1e10)) 112 | get_val_data = iter( 113 | DataLoader( 114 | loader, 115 | batch_size=1, 116 | shuffle=False, 117 | collate_fn=lambda x: x, 118 | num_workers=0, 119 | pin_memory=False, 120 | ) 121 | ) 122 | 123 | optimizer = optim.Adam(control_decoder.parameters(), lr=config.lr) 124 | 125 | scheduler = ReduceLROnPlateau( 126 | optimizer, mode="min", factor=0.5, patience=10, verbose=True, min_lr=3e-5 127 | ) 128 | 129 | nu, nv = uniform_knot_bspline(20, 20, 3, 3, 40) 130 | nu = torch.from_numpy(nu.astype(np.float32)).cuda() 131 | nv = torch.from_numpy(nv.astype(np.float32)).cuda() 132 | 133 | prev_test_cd = 1e8 134 | for e in range(config.epochs): 135 | train_reg = [] 136 | train_str = [] 137 | train_cd = [] 138 | train_lap = [] 139 | control_decoder.train() 140 | for train_b_id in range(config.num_train // config.batch_size): 141 | optimizer.zero_grad() 142 | torch.cuda.empty_cache() 143 | points_, parameters, control_points, scales, _ = next(get_train_data)[0] 144 | control_points = Variable( 145 | torch.from_numpy(control_points.astype(np.float32)) 146 | ).cuda() 147 | 148 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 149 | points = points.permute(0, 2, 1) 150 | 151 | # Sample random number of points to make network robust to density. 152 | rand_num_points = config.num_points + np.random.choice(np.arange(-300, 1300), 1)[0] 153 | 154 | output = control_decoder(points[:, :, 0:rand_num_points]) 155 | 156 | if anisotropic: 157 | # rescale all tensors to original dimensions for evaluation 158 | scales, output, points, control_points = rescale_input_outputs(scales, output, points, control_points, 159 | config.batch_size) 160 | 161 | # Chamfer Distance loss, between predicted and GT surfaces 162 | cd, reconstructed_points = spline_reconstruction_loss_one_sided( 163 | nu, nv, output, points, config 164 | ) 165 | 166 | # Permutation Regression Loss 167 | # permute_cp has the best permutation of gt control points grid 168 | l_reg, permute_cp = control_points_permute_reg_loss( 169 | output, control_points, config.grid_size 170 | ) 171 | 172 | laplac_loss = laplacian_loss( 173 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 174 | permute_cp, 175 | dist_type="l2", 176 | ) 177 | 178 | loss = l_reg * config.loss_weight + (cd + laplac_loss) * ( 179 | 1 - config.loss_weight 180 | ) 181 | 182 | loss.backward() 183 | train_cd.append(cd.data.cpu().numpy()) 184 | train_reg.append(l_reg.data.cpu().numpy()) 185 | train_lap.append(laplac_loss.data.cpu().numpy()) 186 | optimizer.step() 187 | log_value( 188 | "cd", 189 | cd.data.cpu().numpy(), 190 | train_b_id + e * (config.num_train // config.batch_size), 191 | ) 192 | log_value( 193 | "l_reg", 194 | l_reg.data.cpu().numpy(), 195 | train_b_id + e * (config.num_train // config.batch_size), 196 | ) 197 | 198 | log_value( 199 | "l_lap", 200 | laplac_loss.data.cpu().numpy(), 201 | train_b_id + e * (config.num_train // config.batch_size), 202 | ) 203 | print( 204 | "\rEpoch: {} iter: {}, loss: {}".format( 205 | e, train_b_id, loss.item() 206 | ), 207 | end="", 208 | ) 209 | 210 | distances = [] 211 | test_reg = [] 212 | test_cd = [] 213 | test_str = [] 214 | test_lap = [] 215 | control_decoder.eval() 216 | 217 | for val_b_id in range(config.num_test // config.batch_size - 1): 218 | torch.cuda.empty_cache() 219 | points_, parameters, control_points, scales, _ = next(get_val_data)[0] 220 | 221 | control_points = Variable( 222 | torch.from_numpy(control_points.astype(np.float32)) 223 | ).cuda() 224 | points = Variable(torch.from_numpy(points_.astype(np.float32))).cuda() 225 | points = points.permute(0, 2, 1) 226 | with torch.no_grad(): 227 | output = control_decoder(points[:, :, 0:config.num_points]) 228 | if anisotropic: 229 | scales, output, points, control_points = rescale_input_outputs(scales, output, points, control_points, 230 | config.batch_size) 231 | 232 | # Chamfer Distance loss, between predicted and GT surfaces 233 | cd, reconstructed_points = spline_reconstruction_loss_one_sided( 234 | nu, nv, output, points, config 235 | ) 236 | 237 | l_reg, permute_cp = control_points_permute_reg_loss( 238 | output, control_points, config.grid_size 239 | ) 240 | laplac_loss = laplacian_loss( 241 | output.reshape((config.batch_size, config.grid_size, config.grid_size, 3)), 242 | permute_cp, 243 | dist_type="l2", 244 | ) 245 | 246 | loss = l_reg * config.loss_weight + (cd + laplac_loss) * ( 247 | 1 - config.loss_weight 248 | ) 249 | test_reg.append(l_reg.data.cpu().numpy()) 250 | test_cd.append(cd.data.cpu().numpy()) 251 | test_lap.append(laplac_loss.data.cpu().numpy()) 252 | 253 | print("\n") 254 | logger.info( 255 | "Epoch: {}/{} => Tr lreg: {}, Ts loss: {}, Tr CD: {}, Ts CD: {}, Tr lap: {}, Ts lap: {}".format( 256 | e, 257 | config.epochs, 258 | np.mean(train_reg), 259 | np.mean(test_reg), 260 | np.mean(train_cd), 261 | np.mean(test_cd), 262 | np.mean(train_lap), 263 | np.mean(test_lap), 264 | ) 265 | ) 266 | 267 | log_value("train_cd", np.mean(train_cd), e) 268 | log_value("test_cd", np.mean(test_cd), e) 269 | log_value("train_reg", np.mean(train_reg), e) 270 | log_value("test_reg", np.mean(test_reg), e) 271 | 272 | scheduler.step(np.mean(test_cd)) 273 | if prev_test_cd > np.mean(test_cd): 274 | logger.info("CD improvement, saving model at epoch: {}".format(e)) 275 | prev_test_cd = np.mean(test_cd) 276 | torch.save( 277 | control_decoder.state_dict(), 278 | "logs/trained_models/{}.pth".format(model_name), 279 | ) 280 | -------------------------------------------------------------------------------- /train_parsenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This scrip trains model to predict per point primitive type. 3 | """ 4 | import json 5 | import logging 6 | import os 7 | import sys 8 | from shutil import copyfile 9 | 10 | import numpy as np 11 | import torch.optim as optim 12 | import torch.utils.data 13 | from tensorboard_logger import configure, log_value 14 | from torch.optim.lr_scheduler import ReduceLROnPlateau 15 | from torch.utils.data import DataLoader 16 | 17 | from read_config import Config 18 | from src.PointNet import PrimitivesEmbeddingDGCNGn 19 | from src.dataset import generator_iter 20 | from src.dataset_segments import Dataset 21 | from src.segment_loss import ( 22 | EmbeddingLoss, 23 | evaluate_miou, 24 | primitive_loss 25 | ) 26 | 27 | config = Config(sys.argv[1]) 28 | model_name = config.model_path.format( 29 | config.batch_size, 30 | config.lr, 31 | config.num_train, 32 | config.num_test, 33 | config.loss_weight, 34 | config.mode, 35 | ) 36 | print(model_name) 37 | configure("logs/tensorboard/{}".format(model_name), flush_secs=5) 38 | 39 | userspace = os.path.dirname(os.path.abspath(__file__)) 40 | 41 | logger = logging.getLogger(__name__) 42 | logger.setLevel(logging.INFO) 43 | handler = logging.StreamHandler(sys.stdout) 44 | formatter = logging.Formatter("%(asctime)s:%(name)s:%(message)s") 45 | file_handler = logging.FileHandler( 46 | "logs/logs/{}.log".format(model_name), mode="w" 47 | ) 48 | file_handler.setFormatter(formatter) 49 | logger.addHandler(file_handler) 50 | logger.addHandler(handler) 51 | 52 | with open( 53 | "logs/configs/{}_config.json".format(model_name), "w" 54 | ) as file: 55 | json.dump(vars(config), file) 56 | source_file = __file__ 57 | destination_file = "logs/scripts/{}_{}".format( 58 | model_name, __file__.split("/")[-1] 59 | ) 60 | copyfile(source_file, destination_file) 61 | if_normals = config.normals 62 | if_normal_noise = True 63 | 64 | Loss = EmbeddingLoss(margin=1.0, if_mean_shift=False) 65 | if config.mode == 0: 66 | # Just using points for training 67 | model = PrimitivesEmbeddingDGCNGn( 68 | embedding=True, 69 | emb_size=128, 70 | primitives=True, 71 | num_primitives=10, 72 | loss_function=Loss.triplet_loss, 73 | mode=config.mode, 74 | num_channels=3, 75 | ) 76 | elif config.mode == 5: 77 | # Using points and normals for training 78 | model = PrimitivesEmbeddingDGCNGn( 79 | embedding=True, 80 | emb_size=128, 81 | primitives=True, 82 | num_primitives=10, 83 | loss_function=Loss.triplet_loss, 84 | mode=config.mode, 85 | num_channels=6, 86 | ) 87 | 88 | model_bkp = model 89 | model_bkp.l_permute = np.arange(7000) 90 | if torch.cuda.device_count() > 1: 91 | model = torch.nn.DataParallel(model) 92 | model.cuda() 93 | 94 | split_dict = {"train": config.num_train, "val": config.num_val, "test": config.num_test} 95 | 96 | dataset = Dataset( 97 | config.batch_size, 98 | config.num_train, 99 | config.num_val, 100 | config.num_test, 101 | primitives=True, 102 | normals=True, 103 | ) 104 | 105 | get_train_data = dataset.get_train( 106 | randomize=True, augment=True, align_canonical=True, anisotropic=False, if_normal_noise=if_normal_noise 107 | ) 108 | get_val_data = dataset.get_val(align_canonical=True, anisotropic=False, if_normal_noise=if_normal_noise) 109 | optimizer = optim.Adam(model.parameters(), lr=config.lr) 110 | 111 | loader = generator_iter(get_train_data, int(1e10)) 112 | get_train_data = iter( 113 | DataLoader( 114 | loader, 115 | batch_size=1, 116 | shuffle=False, 117 | collate_fn=lambda x: x, 118 | num_workers=2, 119 | pin_memory=False, 120 | ) 121 | ) 122 | 123 | loader = generator_iter(get_val_data, int(1e10)) 124 | get_val_data = iter( 125 | DataLoader( 126 | loader, 127 | batch_size=1, 128 | shuffle=False, 129 | collate_fn=lambda x: x, 130 | num_workers=2, 131 | pin_memory=False, 132 | ) 133 | ) 134 | 135 | scheduler = ReduceLROnPlateau( 136 | optimizer, mode="min", factor=0.5, patience=4, verbose=True, min_lr=1e-4 137 | ) 138 | 139 | model_bkp.triplet_loss = Loss.triplet_loss 140 | prev_test_loss = 1e4 141 | 142 | for e in range(config.epochs): 143 | train_emb_losses = [] 144 | train_prim_losses = [] 145 | train_iou = [] 146 | train_losses = [] 147 | model.train() 148 | 149 | # this is used for gradient accumulation because of small gpu memory. 150 | num_iter = 3 151 | for train_b_id in range(config.num_train // config.batch_size): 152 | optimizer.zero_grad() 153 | losses = 0 154 | ious = 0 155 | p_losses = 0 156 | embed_losses = 0 157 | torch.cuda.empty_cache() 158 | for _ in range(num_iter): 159 | points, labels, normals, primitives = next(get_train_data)[0] 160 | l = np.arange(10000) 161 | np.random.shuffle(l) 162 | # randomly sub-sampling points to increase robustness to density and 163 | # saving gpu memory 164 | rand_num_points = 7000 165 | l = l[0:rand_num_points] 166 | points = points[:, l] 167 | labels = labels[:, l] 168 | normals = normals[:, l] 169 | primitives = primitives[:, l] 170 | points = torch.from_numpy(points).cuda() 171 | normals = torch.from_numpy(normals).cuda() 172 | 173 | primitives = torch.from_numpy(primitives.astype(np.int64)).cuda() 174 | if if_normals: 175 | input = torch.cat([points, normals], 2) 176 | embedding, primitives_log_prob, embed_loss = model( 177 | input.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 178 | ) 179 | else: 180 | embedding, primitives_log_prob, embed_loss = model( 181 | points.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 182 | ) 183 | embed_loss = torch.mean(embed_loss) 184 | 185 | p_loss = primitive_loss(primitives_log_prob, primitives) 186 | iou = evaluate_miou( 187 | primitives.data.cpu().numpy(), 188 | primitives_log_prob.permute(0, 2, 1).data.cpu().numpy(), 189 | ) 190 | loss = embed_loss + p_loss 191 | loss.backward() 192 | 193 | losses += loss.data.cpu().numpy() / num_iter 194 | p_losses += p_loss.data.cpu().numpy() / num_iter 195 | ious += iou / num_iter 196 | embed_losses += embed_loss.data.cpu().numpy() / num_iter 197 | 198 | optimizer.step() 199 | train_iou.append(ious) 200 | train_losses.append(losses) 201 | train_prim_losses.append(p_losses) 202 | train_emb_losses.append(embed_losses) 203 | print( 204 | "\rEpoch: {} iter: {}, prim loss: {}, emb loss: {}, iou: {}".format( 205 | e, train_b_id, p_loss, embed_losses, iou 206 | ), 207 | end="", 208 | ) 209 | log_value("iou", iou, train_b_id + e * (config.num_train // config.batch_size)) 210 | log_value( 211 | "embed_loss", 212 | embed_losses, 213 | train_b_id + e * (config.num_train // config.batch_size), 214 | ) 215 | 216 | test_emb_losses = [] 217 | test_prim_losses = [] 218 | test_losses = [] 219 | test_iou = [] 220 | model.eval() 221 | 222 | for val_b_id in range(config.num_test // config.batch_size - 1): 223 | points, labels, normals, primitives = next(get_val_data)[0] 224 | l = np.arange(10000) 225 | np.random.shuffle(l) 226 | l = l[0:7000] 227 | points = points[:, l] 228 | labels = labels[:, l] 229 | normals = normals[:, l] 230 | primitives = primitives[:, l] 231 | points = torch.from_numpy(points).cuda() 232 | primitives = torch.from_numpy(primitives.astype(np.int64)).cuda() 233 | normals = torch.from_numpy(normals).cuda() 234 | with torch.no_grad(): 235 | if if_normals: 236 | input = torch.cat([points, normals], 2) 237 | embedding, primitives_log_prob, embed_loss = model( 238 | input.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 239 | ) 240 | else: 241 | embedding, primitives_log_prob, embed_loss = model( 242 | points.permute(0, 2, 1), torch.from_numpy(labels).cuda(), True 243 | ) 244 | 245 | embed_loss = torch.mean(embed_loss) 246 | p_loss = primitive_loss(primitives_log_prob, primitives) 247 | loss = embed_loss + p_loss 248 | iou = evaluate_miou( 249 | primitives.data.cpu().numpy(), 250 | primitives_log_prob.permute(0, 2, 1).data.cpu().numpy(), 251 | ) 252 | test_iou.append(iou) 253 | test_prim_losses.append(p_loss.data.cpu().numpy()) 254 | test_emb_losses.append(embed_loss.data.cpu().numpy()) 255 | test_losses.append(loss.data.cpu().numpy()) 256 | torch.cuda.empty_cache() 257 | print("\n") 258 | logger.info( 259 | "Epoch: {}/{} => TrL:{}, TsL:{}, TrP:{}, TsP:{}, TrE:{}, TsE:{}, TrI:{}, TsI:{}".format( 260 | e, 261 | config.epochs, 262 | np.mean(train_losses), 263 | np.mean(test_losses), 264 | np.mean(train_prim_losses), 265 | np.mean(test_prim_losses), 266 | np.mean(train_emb_losses), 267 | np.mean(test_emb_losses), 268 | np.mean(train_iou), 269 | np.mean(test_iou), 270 | ) 271 | ) 272 | log_value("train iou", np.mean(train_iou), e) 273 | log_value("test iou", np.mean(test_iou), e) 274 | 275 | log_value("train emb loss", np.mean(train_emb_losses), e) 276 | log_value("test emb loss", np.mean(test_emb_losses), e) 277 | 278 | scheduler.step(np.mean(test_emb_losses)) 279 | if prev_test_loss > np.mean(test_emb_losses): 280 | logger.info("improvement, saving model at epoch: {}".format(e)) 281 | prev_test_loss = np.mean(test_emb_losses) 282 | torch.save( 283 | model.state_dict(), 284 | "logs/trained_models/{}.pth".format(model_name), 285 | ) 286 | torch.save( 287 | optimizer.state_dict(), 288 | "logs/trained_models/{}_optimizer.pth".format(model_name), 289 | ) 290 | --------------------------------------------------------------------------------