├── .gitignore ├── README.md ├── build_losses.sh ├── core ├── arg_parser.py ├── epoch_loops.py ├── experiments.py ├── main.py └── setup.py ├── datasets ├── __init__.py ├── base_dataset.py ├── real_data.py ├── shapenet.py ├── shapenet_3depn.py ├── shapenet_completion3d.py └── utils │ ├── dataset_generator.py │ └── shapenet_category_mapping.py ├── images └── hyperpocket_arch.png ├── losses └── champfer_loss.py ├── model ├── encoder.py ├── full_model.py ├── hyper_network.py └── target_network.py ├── requirements.txt ├── settings ├── config.json.sample ├── config_3depn_airplane.json.sample ├── config_3depn_chair.json.sample ├── config_3depn_table.json.sample ├── config_completion.json.sample └── config_missing_shapenet.json.sample ├── splits ├── 3depn │ └── shapenet-official-split.csv └── shapenet │ ├── test.list │ ├── train.list │ └── val.list ├── util_scripts ├── download_shapenet_2048.py ├── generate_eval_gen_test_set.py └── generate_partial_dataset.py └── utils ├── __init__.py ├── evaluation ├── chamfer.py ├── completeness.py ├── mmd.py └── total_mutual_diff.py ├── metrics.py ├── pcutil.py ├── plyfile.py ├── points.py ├── pytorch_structural_losses ├── approxmatch.cu ├── match_cost.py ├── nn_distance.py ├── nndistance.cu ├── setup.py └── structural_loss.cpp ├── sphere_triangles.py ├── telegram_logging.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | settings/*json 3 | start.sh 4 | *__pycache__ 5 | init.sh 6 | data 7 | 8 | utils/pytorch_structural_losses/StructuralLossesBackend* 9 | utils/pytorch_structural_losses/build/ 10 | losses/emd/build/ 11 | losses/emd/dist/ 12 | losses/emd/emd.egg-info/ 13 | /losses/chamfer/dist/ 14 | /losses/chamfer/chamfer.egg-info/ 15 | /losses/chamfer/build/ 16 | /losses/chamfer_dist/dist/ 17 | /losses/chamfer_dist/chamfer.egg-info/ 18 | /losses/chamfer_dist/build/ 19 | 20 | submission.zip 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperPocket: Generative Point Cloud Completion 2 | This repository contains the source code for the paper: 3 | 4 | [HyperPocket: Generative Point Cloud Completion](https://arxiv.org/abs/2102.05973) 5 | 6 | ![Overview](images/hyperpocket_arch.png) 7 | 8 | #### Abstract 9 | Scanning real-life scenes with modern registration devices typically give incomplete point cloud representations, 10 | mostly due to the limitations of the scanning process and 3D occlusions. Therefore, completing such partial 11 | representations remains a fundamental challenge of many computer vision applications. 12 | Most of the existing approaches aim to solve this problem by learning to reconstruct individual 3D objects in a 13 | synthetic setup of an uncluttered environment, which is far from a real-life scenario. In this work, we reformulate 14 | the problem of point cloud completion into an object hallucination task. 15 | Thus, we introduce a novel autoencoder-based architecture called HyperPocket that disentangles latent representations 16 | and, as a result, enables the generation of multiple variants of the completed 3D point clouds. We split point cloud 17 | processing into two disjoint data streams and leverage a hypernetwork paradigm to fill the spaces, dubbed pockets, 18 | that are left by the missing object parts. As a result, the generated point clouds are not only smooth but also 19 | plausible and geometrically consistent with the scene. Our method offers competitive performances to the other 20 | state-of-the-art models, and it enables a plethora of novel applications. 21 | 22 | 23 | ## Requirements 24 | - Python 3.7+ 25 | - dependencies stored in `requirements.txt`. 26 | - NVIDIA GPU + CUDA 27 | 28 | ## Installation 29 | We highly recommend using [Conda](https://docs.conda.io/en/latest) or 30 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html). 31 | 32 | Create and activate your conda env: 33 | - run `conda create --name python=3.7` 34 | - run `conda activate ` 35 | - go to the project dir 36 | 37 | Install requirements: 38 | - run `conda install pytorch torchvision torchaudio cudatoolkit= -c pytorch` 39 | - run `pip install -r requirements.txt` 40 | - set your CUDA_HOME by the command: `export CUDA_HOME=... # e.g., /var/lib/cuda-10.2/` 41 | - install CUDA extension by running `./build_losses.sh` 42 | 43 | 44 | ## Usage 45 | **Add project root directory to PYTHONPATH** 46 | 47 | ```export PYTHONPATH=$(project_path):$PYTHONPATH``` 48 | 49 | **Download dataset** 50 | 51 | We use four datasets in our paper. 52 | 53 | 1. 3D-EPN 54 | 55 | Download it from the [link](https://ujchmura-my.sharepoint.com/:u:/g/personal/przemyslaw_spurek_uj_edu_pl/ESrI4SBeef5MrpxNz3PhUa4BdSw-CQazfPHAPvHDJUzVQw?e=r7w4dc) or generate by yourself: 56 | 1) Please download the partial scan point cloud data from [the website](http://kaldir.vc.in.tum.de/adai/CNNComplete/shapenet_dim32_sdf_pc.zip) 57 | and extract it into the folder for storing the dataset (e.g., `${project_path}/data/dataset/3depn`). 58 | 2) For the complete point clouds data, please download it from [PKU disk](https://disk.pku.edu.cn:443/link/9A3E1AC9FBA4DEBD705F028650CBE8C7) 59 | (provided by [MSC](https://github.com/ChrisWu1997/Multimodal-Shape-Completion)) and extract it into the same folder. 60 | 3) copy `splits/3depn/shapenet-official-split.csv` file to that folder 61 | 4) (if you haven't done it earlier) make a copy of the sample configs by executing 62 | 63 | `cp setting/config.json.sample setting/config.json` 64 | 5) specify your dataset preferences in `setting/config.json` file: 65 | ``` 66 | ["dataset"]["name"] = "3depn" 67 | ["dataset"]["path"] = "" 68 | ["dataset"]["num_samples"] = 69 | ``` 70 | 6) run `python3 util_scripts/generate_partial_dataset.py --config setting/config.json` 71 | 72 | 2. PartNet 73 | 74 | 1) Please download it from [the official website](https://www.shapenet.org/download/parts) 75 | 76 | 3. Completion3D 77 | 1) Please download it from [the official website](http://download.cs.stanford.edu/downloads/completion3d/dataset2019.zip) 78 | 2) Extract it into your folder for datasets (e.g., `${project_path}/data/dataset/completion`) 79 | 3) (if you haven't done it earlier) make a copy of the sample configs by executing 80 | 81 | `cp setting/config.json.sample setting/config.json` 82 | 4) specify your dataset preferences in `setting/config.json` file: 83 | ``` 84 | ["dataset"]["name"] = "completion" 85 | ["dataset"]["path"] = "" 86 | ``` 87 | 88 | 4. MissingShapeNet 89 | 90 | Download it from the [link](https://ujchmura-my.sharepoint.com/:u:/g/personal/przemyslaw_spurek_uj_edu_pl/EfNG1CNZwDhDnCJlblwf7r0BvbIRcbhSw5XqR98wXmiWPg?e=fpao42) or generate by yourself: 91 | 1) (if you haven't done it earlier) make a copy of the sample configs by executing 92 | 93 | `cp setting/config.json.sample setting/config.json` 94 | 2) specify your dataset preferences in `setting/config.json` file: 95 | ``` 96 | ["dataset"]["name"] = "shapenet" 97 | ["dataset"]["path"] = "" 98 | ["dataset"]["num_samples"] = 99 | ["dataset"]["is_rotated"] = 100 | ["dataset"]["gen_test_set"] = 101 | ``` 102 | 3) run `python3 util_scripts/download_shapenet_2048.py --config setting/config.json` 103 | 4) run `python3 util_scripts/generate_partial_dataset.py --config setting/config.json` 104 | 5) copy `splits/shapenet/*.list` to the specified folder 105 | 106 | **Training** 107 | 108 | We have prepared several settings for working with different datasets: 109 | ``` 110 | #train single class of 3depn dataset 111 | config_3depn_airplane.json.sample 112 | config_3depn_chair.json.sample 113 | config_3depn_table.json.sample 114 | 115 | #train model for the Completion3D benchmark 116 | config_completion.json.sample 117 | 118 | #train MissingShapeNet 119 | config_missing_shapenet.json.sample 120 | ``` 121 | 122 | 1) (if you haven't done it earlier) make a copy of the preferred config by executing 123 | `cp setting/config_.json.sample setting/config_.json` 124 | 125 | 2) specify your personal configs in `setting/config_.json`: 126 | - change `["dataset"]["path"]` and `["results_root"]` fields 127 | - select your GPU in the field `["setup"]["gpu_id"]` 128 | - select the batch_size for your device in `["training"]["dataloader"]` 129 | - also you may change Optimizer and LRScheduler in the appropriate fields 130 | 131 | 3) exec script 132 | - run `python3 core/main.py --config settings/config.json` 133 | 134 | **Pre-trained Models** 135 | Pre-trained models can be downloaded from [our Release page](https://github.com/gmum/3d-point-clouds-autocomplete/releases). 136 | To use them: 137 | 138 | 1) Download the model weights zip file (naming convention is the same as for the configs above). 139 | 2) Extract zip file to your results directory 140 | 3) If you have not train models with sample configs you may set `["experiments]["epoch"]` to `"latest"` 141 | else you need to specify the exac epoch (listed on the release page). 142 | 143 | 144 | **Experiments** 145 | 146 | 1) In case you train the model by yourself, just change `["mode"]` in the config file to `"experiments"` 147 | otherwise need also to specify fields mentioned above. 148 | 2) Indicate which experiments you want to run by changing bool fields 149 | `["experiments"][]["execute"]` 150 | 151 | Experiments list: 152 | - fixed 153 | - evaluate_generativity 154 | - compute_mmd_tmd_uhd (requires fixed experiment before) 155 | - merge_different_categories 156 | - same_model_different_slices 157 | - completion3d_submission (generates submission.zip file in your $(project_path) folder) 158 | 159 | 160 | ## Extending 161 | In case you want create your own experiments: 162 | 1) write you experiment function in core/experiments 163 | 2) add it to `experiment_functions_dict` in core/experiments 164 | 3) include your special parameters into the config file `["experiments][""]` (be sure to add a bool field "execute" there) 165 | -------------------------------------------------------------------------------- /build_losses.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | readonly PYTHON_VERSION=$(python -c 'import sys; print("%d.%d" % (sys.version_info[0], sys.version_info[1]))') 4 | 5 | cd utils/pytorch_structural_losses || exit 6 | rm StructuralLossesBackend.cpython* || true 7 | rm -rf build || true 8 | 9 | python setup.py build 10 | 11 | cp `find build/lib* -name "StructuralLossesBackend*"` . 12 | -------------------------------------------------------------------------------- /core/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def parse_config(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-c', '--config', default=None, type=str, 8 | help='config file path') 9 | args = parser.parse_args() 10 | 11 | config = None 12 | if args.config is not None and args.config.endswith('.json'): 13 | with open(args.config) as f: 14 | config = json.load(f) 15 | assert config is not None 16 | 17 | return config 18 | -------------------------------------------------------------------------------- /core/epoch_loops.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import numpy as np 4 | 5 | from model.full_model import FullModel 6 | 7 | 8 | def train_epoch(epoch, full_model: FullModel, optimizer, loader, device, rec_loss_function, loss_coef=0.05): 9 | full_model.train() 10 | loss_all = 0.0 11 | loss_r = 0.0 12 | loss_kld = 0.0 13 | 14 | for i, point_data in tqdm(enumerate(loader, 1), total=len(loader)): 15 | optimizer.zero_grad() 16 | 17 | existing, missing, gt, _ = point_data 18 | 19 | existing = existing.to(device) 20 | missing = missing.to(device) 21 | gt = gt.to(device) 22 | 23 | reconstruction, logvar, mu = full_model(existing, missing, list(gt.shape), epoch, device) 24 | 25 | loss_r = torch.mean( 26 | loss_coef * rec_loss_function(gt, reconstruction.permute(0, 2, 1))) 27 | 28 | if full_model.mode.has_generativity(): 29 | loss_kld = 0.5 * (torch.exp(logvar) + torch.square(mu) - 1 - logvar).sum() 30 | loss_kld = torch.div(loss_kld, existing.shape[0]) 31 | loss_all = loss_r + loss_kld 32 | loss_kld += loss_kld.item() 33 | else: 34 | loss_all = loss_r 35 | loss_r += loss_r.item() 36 | loss_all += loss_all.item() 37 | 38 | loss_all.backward() 39 | optimizer.step() 40 | 41 | loss_all = loss_all / i 42 | loss_kld = loss_kld / i 43 | loss_r = loss_r / i 44 | 45 | return full_model, optimizer, loss_all, loss_kld, loss_r, \ 46 | existing.detach().cpu().numpy(), gt.detach().cpu().numpy(), reconstruction.detach().cpu().numpy() 47 | 48 | 49 | def val_epoch(epoch, full_model, device, loaders_dict, val_classes_names, loss_function, loss_coef=0.05): 50 | full_model.eval() 51 | 52 | val_losses = dict.fromkeys(val_classes_names) 53 | val_samples = dict.fromkeys(val_classes_names) 54 | 55 | with torch.no_grad(): 56 | for cat_name, dl in loaders_dict.items(): 57 | loss = 0.0 58 | for i, point_data in enumerate(dl, 1): 59 | existing, missing, gt, _ = point_data 60 | existing = existing.to(device) 61 | missing = missing.to(device) 62 | gt = gt.to(device) 63 | 64 | reconstruction = full_model(existing, missing, list(gt.shape), epoch, device) 65 | 66 | loss_our_cd = torch.mean( 67 | loss_coef * loss_function(gt, reconstruction.permute(0, 2, 1))) 68 | 69 | loss += loss_our_cd.item() 70 | 71 | existing = existing.detach().cpu().numpy() 72 | gt = gt.detach().cpu().numpy() 73 | reconstruction = reconstruction.detach().cpu().numpy() 74 | 75 | val_samples[cat_name] = (existing[0], gt[0], reconstruction[0]) 76 | val_losses[cat_name] = np.array([loss / i]) 77 | 78 | total = np.zeros(1) 79 | for v in val_losses.values(): 80 | total = np.add(total, v) 81 | val_losses['total'] = total / len(val_losses.keys()) 82 | 83 | return val_losses, val_samples 84 | -------------------------------------------------------------------------------- /core/experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from datetime import datetime 4 | import json 5 | from os.path import join, basename 6 | from zipfile import ZipFile 7 | 8 | import h5py 9 | from sklearn import manifold 10 | from tqdm import tqdm 11 | import torch 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from torch.utils.data import DataLoader 15 | 16 | from datasets.utils.dataset_generator import SlicedDatasetGenerator 17 | from losses.champfer_loss import ChamferLoss 18 | from model.full_model import FullModel 19 | from utils.metrics import compute_all_metrics, jsd_between_point_cloud_sets 20 | from utils.pcutil import plot_3d_point_cloud 21 | 22 | 23 | def fixed(full_model: FullModel, device, datasets_dict, results_dir: str, epoch, amount=30, mean=0.0, std=0.015, 24 | noises_per_item=10, batch_size=8, save_plots=False, 25 | triangulation_config={'execute': False, 'method': 'edge', 'depth': 2}): 26 | # clean dir 27 | shutil.rmtree(join(results_dir, 'fixed'), ignore_errors=True) 28 | os.makedirs(join(results_dir, 'fixed'), exist_ok=True) 29 | 30 | dataloaders_dict = {cat_name: DataLoader(cat_ds, pin_memory=True, batch_size=batch_size) 31 | for cat_name, cat_ds in datasets_dict.items()} 32 | for cat_name, dl in dataloaders_dict.items(): 33 | 34 | for i, data in tqdm(enumerate(dl), total=len(dl)): 35 | 36 | existing, _, _, idx = data 37 | existing = existing.to(device) 38 | 39 | for j in range(noises_per_item): 40 | fixed_noise = torch.zeros(existing.shape[0], full_model.get_noise_size()).normal_(mean=mean, std=std).to( 41 | device) 42 | reconstruction = full_model(existing, None, [existing.shape[0], 2048, 3], epoch, device, 43 | noise=fixed_noise).cpu() 44 | for k in range(reconstruction.shape[0]): 45 | np.save(join(results_dir, 'fixed', f'{cat_name}_{i * batch_size + k}_{j}_reconstruction'), 46 | reconstruction[k].numpy()) 47 | if save_plots: 48 | fig = plot_3d_point_cloud(reconstruction[k][0], reconstruction[k][1], reconstruction[k][2], 49 | in_u_sphere=True, show=False) 50 | fig.savefig(join(results_dir, 'fixed', f'{cat_name}_{i * batch_size + k}_{j}_fixed_reconstructed.png')) 51 | plt.close(fig) 52 | # np.save(join(results_dir, 'fixed', f'{i*batch_size+k}_{j}_fixed_noise'), np.array(fixed_noise[k].cpu().numpy())) 53 | 54 | existing = existing.cpu() 55 | for k in range(existing.shape[0]): 56 | np.save(join(results_dir, 'fixed', f'{cat_name}_{i * batch_size + k}_existing'), np.array(existing[k].cpu().numpy())) 57 | if save_plots: 58 | fig = plot_3d_point_cloud(existing[k][0], existing[k][1], existing[k][2], in_u_sphere=True, show=False) 59 | fig.savefig(join(results_dir, 'fixed', f'{cat_name}_{i * batch_size + k}_existing.png')) 60 | plt.close(fig) 61 | 62 | 63 | def evaluate_generativity(full_model: FullModel, device, datasets_dict, results_dir, epoch, batch_size, num_workers, 64 | mean=0.0, std=0.005): 65 | dataloaders_dict = {cat_name: DataLoader(cat_ds, pin_memory=True, batch_size=1, num_workers=num_workers) 66 | for cat_name, cat_ds in datasets_dict.items()} 67 | chamfer_loss = ChamferLoss().to(device) 68 | with torch.no_grad(): 69 | results = {} 70 | 71 | for cat_name, dl in dataloaders_dict.items(): 72 | cat_gt = [] 73 | for data in dl: 74 | _, missing, _, _ = data 75 | missing = missing.to(device) 76 | cat_gt.append(missing) 77 | cat_gt = torch.cat(cat_gt).contiguous() 78 | 79 | cat_results = {} 80 | 81 | for data in tqdm(dl, total=len(dl)): 82 | existing, _, _, _ = data 83 | existing = existing.to(device) 84 | 85 | obj_recs = [] 86 | 87 | for j in range(len(cat_gt)): 88 | fixed_noise = torch.zeros(1, full_model.get_noise_size()).normal_(mean=mean, std=std).to(device) 89 | reconstruction = full_model(existing, None, [1, 2048, 3], epoch, device, noise=fixed_noise) 90 | 91 | pc = reconstruction.cpu().detach().numpy()[0] 92 | obj_recs.append(torch.from_numpy(pc.T[pc[1].argsort()[:1024]]).unsqueeze(0).to(device)) 93 | 94 | obj_recs = torch.cat(obj_recs) 95 | 96 | for k, v in compute_all_metrics(obj_recs, cat_gt, batch_size, chamfer_loss).items(): 97 | cat_results[k] = cat_results.get(k, 0.0) + v.item() 98 | cat_results['jsd'] = cat_results.get('jsd', 0.0) + jsd_between_point_cloud_sets( 99 | obj_recs.cpu().detach().numpy(), cat_gt.cpu().numpy()) 100 | results[cat_name] = cat_results 101 | print(cat_name, cat_results) 102 | 103 | with open(join(results_dir, 'evaluate_generativity', str(epoch) + 'eval_gen_by_cat.json'), mode='w') as f: 104 | json.dump(results, f) 105 | 106 | 107 | def compute_mmd_tmd_uhd(full_model: FullModel, device, dataset, results_dir, epoch, batch_size=64): 108 | from utils.evaluation.total_mutual_diff import process as tmd 109 | from utils.evaluation.completeness import process as uhd 110 | from utils.evaluation.mmd import process as mmd 111 | 112 | res = {} 113 | shape_dir_path = join(results_dir, 'fixed') 114 | 115 | mmd_v = mmd(shape_dir_path, dataset, device, batch_size) 116 | print('MMD * 1000', mmd_v * 1000) 117 | res['MMD * 1000'] = mmd_v * 1000 118 | 119 | uhd_v = uhd(shape_dir_path) 120 | print('UHD * 100', uhd_v * 100) 121 | res['UHD * 100'] = uhd_v * 100 122 | 123 | tmd_v = tmd(shape_dir_path) 124 | print('TMD * 100', tmd_v * 100) 125 | res['TMD * 100'] = tmd_v * 100 126 | 127 | with open(join(results_dir, 'compute_mmd_tmd_uhd', str(epoch) + 'res.json'), mode='w') as f: 128 | json.dump(res, f) 129 | 130 | 131 | def merge_different_categories(full_model, device, dataset, results_dir, epoch, amount=10, first_cat='car', 132 | second_cat='airplane'): 133 | first_cat_dataset = dataset[first_cat] 134 | second_cat_dataset = dataset[second_cat] 135 | 136 | if len(first_cat_dataset) < amount or len(second_cat_dataset) < amount: 137 | raise ValueError(f'with current dataset config the max amount value is ' 138 | f'{np.min([len(first_cat_dataset), len(second_cat_dataset)])}') 139 | 140 | first_cat_ids = np.random.choice(len(first_cat_dataset), amount, replace=False) 141 | second_cat_ids = np.random.choice(len(first_cat_dataset), amount, replace=False) 142 | 143 | with torch.no_grad(): 144 | for i in range(amount): 145 | f_existing, f_missing, f_gt, _ = first_cat_dataset[first_cat_ids[i]] 146 | 147 | s_existing, s_missing, s_gt, _ = second_cat_dataset[second_cat_ids[i]] 148 | 149 | f_existing = f_gt[f_gt.T[0].argsort()[1024:]] 150 | f_missing = f_gt[f_gt.T[0].argsort()[:1024]] 151 | s_existing = s_gt[s_gt.T[0].argsort()[1024:]] 152 | s_missing = s_gt[s_gt.T[0].argsort()[:1024]] 153 | 154 | np.save(join(results_dir, 'merge_different_categories', f'{first_cat}_{i}_existing'), f_existing) 155 | np.save(join(results_dir, 'merge_different_categories', f'{first_cat}_{i}_missing'), f_missing) 156 | np.save(join(results_dir, 'merge_different_categories', f'{first_cat}_{i}_gt'), f_gt) 157 | 158 | np.save(join(results_dir, 'merge_different_categories', f'{second_cat}_{i}_existing'), s_existing) 159 | np.save(join(results_dir, 'merge_different_categories', f'{second_cat}_{i}_missing'), s_missing) 160 | np.save(join(results_dir, 'merge_different_categories', f'{second_cat}_{i}_gt'), s_gt) 161 | 162 | f_existing = torch.from_numpy(f_existing).unsqueeze(0).to(device) 163 | s_existing = torch.from_numpy(s_existing).unsqueeze(0).to(device) 164 | 165 | gt_shape = list(torch.from_numpy(f_gt).unsqueeze(0).shape) 166 | 167 | for j in range(amount): 168 | _, temp_f_missing, temp_f_gt, _ = first_cat_dataset[first_cat_ids[j]] 169 | _, temp_s_missing, temp_s_gt, _ = second_cat_dataset[second_cat_ids[j]] 170 | 171 | temp_f_missing = temp_f_gt[temp_f_gt.T[0].argsort()[:1024]] 172 | temp_s_missing = temp_s_gt[temp_s_gt.T[0].argsort()[:1024]] 173 | 174 | temp_f_missing = torch.from_numpy(temp_f_missing).unsqueeze(0).to(device) 175 | temp_s_missing = torch.from_numpy(temp_s_missing).unsqueeze(0).to(device) 176 | 177 | rec_ff = full_model(f_existing, temp_f_missing, gt_shape, epoch, device) 178 | np.save(join(results_dir, 'merge_different_categories', f'{first_cat}_{i}~{first_cat}_{j}_rec'), 179 | rec_ff.cpu().numpy()[0].T) 180 | 181 | rec_fs = full_model(f_existing, temp_s_missing, gt_shape, epoch, device) 182 | np.save(join(results_dir, 'merge_different_categories', f'{first_cat}_{i}~{second_cat}_{j}_rec'), 183 | rec_fs.cpu().numpy()[0].T) 184 | 185 | rec_sf = full_model(s_existing, temp_f_missing, gt_shape, epoch, device) 186 | np.save(join(results_dir, 'merge_different_categories', f'{second_cat}_{i}~{first_cat}_{j}_rec'), 187 | rec_sf.cpu().numpy()[0].T) 188 | 189 | rec_ss = full_model(s_existing, temp_f_missing, gt_shape, epoch, device) 190 | np.save(join(results_dir, 'merge_different_categories', f'{second_cat}_{i}~{second_cat}_{j}_rec'), 191 | rec_ss.cpu().numpy()[0].T) 192 | 193 | 194 | def same_model_different_slices(full_model, device, datasets_dict, results_dir, epoch, amount=10, slices_number=10, 195 | mean=0.0, std=0.015): 196 | def process_existing(pcd, cat_name, name, i, j): 197 | np.save(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_{j}_{name}_pcd'), pcd) 198 | noise = torch.zeros(1, full_model.get_noise_size()).normal_(mean=mean, std=std) 199 | np.save(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_{j}_{name}_noise'), noise.numpy()) 200 | 201 | pcd = torch.from_numpy(pcd).unsqueeze(0).to(device) 202 | noise = noise.to(device) 203 | rec = full_model(pcd, None, [1, 2048, 3], epoch, device, noise=noise)[0].cpu().numpy() 204 | 205 | np.save(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_{j}_{name}_rec'), rec) 206 | 207 | fig = plot_3d_point_cloud(rec[0], rec[1], rec[2], in_u_sphere=True, show=False) 208 | fig.savefig(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_{j}_{name}_rec.png')) 209 | plt.close(fig) 210 | 211 | with torch.no_grad(): 212 | for cat_name, ds in datasets_dict.items(): 213 | ids = np.random.choice(len(ds), amount, replace=False) 214 | for i, idx in tqdm(enumerate(ids), total=len(ids)): 215 | _, _, points, _ = ds[idx] 216 | points = points.T 217 | fig = plot_3d_point_cloud(points[0], points[1], points[2], in_u_sphere=True, show=False) 218 | fig.savefig(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_gt.png')) 219 | plt.close(fig) 220 | points = points.T 221 | np.save(join(results_dir, 'same_model_different_slices', f'{cat_name}_{i}_gt'), points) 222 | for j in range(slices_number): 223 | f_pcd, s_pcd = SlicedDatasetGenerator.generate_item(points, 1024) 224 | process_existing(f_pcd, cat_name, 'f', i, j) 225 | process_existing(s_pcd, cat_name, 's', i, j) 226 | 227 | 228 | def completion3d_submission(full_model, device, datasets_dict, results_dir, epoch, batch_size=1): 229 | benchmark_dir = join(results_dir, 'completion3d_submission') 230 | 231 | shutil.rmtree(benchmark_dir, ignore_errors=True) 232 | os.makedirs(benchmark_dir, exist_ok=True) 233 | 234 | submission_zip = ZipFile('submission.zip', 'w') 235 | 236 | test_dataloader = DataLoader(datasets_dict['all'], batch_size=batch_size) 237 | 238 | for i, point_data in tqdm(enumerate(test_dataloader, 1), total=len(test_dataloader)): 239 | existing, _, _, model_id = point_data 240 | existing = existing.to(device) 241 | reconstruction = full_model(existing, None, [batch_size, 2048, 3], epoch, device).cpu() 242 | for idx, x in enumerate(reconstruction.permute(0, 2, 1)): 243 | ofile = join(benchmark_dir, model_id[idx].split('/')[-1] + '.h5') 244 | with h5py.File(ofile, "w") as f: 245 | f.create_dataset("data", data=x.numpy()) 246 | f.close() 247 | submission_zip.write(ofile, 'all/' + basename(ofile)) 248 | 249 | 250 | def make_tsne_reduction(full_model, device, dataset_dict, results_dir, epoch): 251 | ''' 252 | this experiment requires changing model.full_model.py:FullModel#forward method return to 253 | reconstruction, latent, target_networks_weights 254 | ''' 255 | 256 | from datasets.shapenet import ShapeNetDataset 257 | 258 | cat_name = 'car' 259 | amount = 100 260 | 261 | train_dataset_dict = ShapeNetDataset._get_datasets_for_classes( 262 | 'D:\\UJ\\bachelors\\3d-point-clouds-autocomplete\\data\\shapenet', 263 | 'train', 264 | use_pcn_model_list=True, 265 | is_random_rotated=False, 266 | num_samples=1, 267 | # classes=['04530566', '02933112'] 268 | ) 269 | 270 | is_compute = False 271 | 272 | with torch.no_grad(): 273 | latents = {} 274 | tnws = {} 275 | if is_compute: 276 | dataloaders_dict = {cat_name: DataLoader(cat_ds, pin_memory=True, batch_size=1, num_workers=0) 277 | for cat_name, cat_ds in train_dataset_dict.items()} 278 | for cat_name, dl in dataloaders_dict.items(): 279 | if cat_name != 'car': 280 | continue 281 | cat_latent = [] 282 | cat_tnw = [] 283 | for data in tqdm(dl, total=len(dl)): 284 | existing, missing, gt, _ = data 285 | existing = existing.to(device) 286 | missing = missing.to(device) 287 | rec, latent, tnw = full_model(existing, missing, list(gt.shape), epoch, device) 288 | cat_latent.append(latent.detach().cpu()) 289 | cat_tnw.append(tnw.detach().cpu()) 290 | latents[cat_name] = torch.cat(cat_latent).numpy() 291 | tnws[cat_name] = torch.cat(cat_tnw).numpy() 292 | latents['all'] = np.concatenate([v for v in latents.values()]) 293 | tnws['all'] = np.concatenate([v for v in tnws.values()]) 294 | 295 | for cat_name in latents.keys(): 296 | np.save(join(results_dir, 'temp_exp', f'{cat_name}_latent1'), latents[cat_name]) 297 | np.save(join(results_dir, 'temp_exp', f'{cat_name}_tnw1'), tnws[cat_name]) 298 | else: 299 | for cat_name in train_dataset_dict.keys(): 300 | if cat_name != 'car': 301 | continue 302 | latents[cat_name] = np.load(join(results_dir, 'temp_exp', f'{cat_name}_latent1.npy')) 303 | tnws[cat_name] = np.load(join(results_dir, 'temp_exp', f'{cat_name}_tnw1.npy')) 304 | 305 | for cat_name, ds in dataset_dict.items(): 306 | 307 | if cat_name != 'car': 308 | continue 309 | 310 | cat_ids = np.random.choice(len(ds), amount, replace=False) 311 | 312 | cat_latent = [] 313 | cat_tnw = [] 314 | 315 | for i in range(amount): 316 | _, _, gt, _ = ds[cat_ids[i]] 317 | 318 | np.save(join(results_dir, 'temp_exp', 'gts', f'{cat_name}_{i}'), gt) 319 | 320 | existing_x = gt[gt.T[0].argsort()[1024:]] 321 | missing_x = gt[gt.T[0].argsort()[:1024]] 322 | 323 | existing_y = gt[gt.T[1].argsort()[1024:]] 324 | missing_y = gt[gt.T[1].argsort()[:1024]] 325 | 326 | gt_shape = list(torch.from_numpy(gt).unsqueeze(0).shape) 327 | 328 | existing = torch.from_numpy(existing_x).unsqueeze(0).to(device) 329 | missing = torch.from_numpy(missing_x).unsqueeze(0).to(device) 330 | _, latent, tnw = full_model(existing, missing, gt_shape, epoch, device) 331 | cat_latent.append(latent.cpu()) 332 | cat_tnw.append(tnw.cpu()) 333 | 334 | existing = torch.from_numpy(existing_y).unsqueeze(0).to(device) 335 | missing = torch.from_numpy(missing_y).unsqueeze(0).to(device) 336 | _, latent, tnw = full_model(existing, missing, gt_shape, epoch, device) 337 | cat_latent.append(latent.cpu()) 338 | cat_tnw.append(tnw.cpu()) 339 | 340 | cat_latent = torch.cat(cat_latent).numpy() 341 | cat_tnw = torch.cat(cat_tnw).numpy() 342 | 343 | cc_latent = np.concatenate([latents[cat_name], cat_latent]) 344 | cc_tnw = np.concatenate([tnws[cat_name], cat_tnw]) 345 | 346 | start_time = datetime.now() 347 | print(start_time) 348 | latent_tsne = manifold.TSNE(n_components=2, init='pca').fit_transform(cc_latent) 349 | print(datetime.now() - start_time) 350 | cat_test_tsne = latent_tsne[-(2 * amount):] 351 | plt.plot(latent_tsne.T[0], latent_tsne.T[1], 'o', cat_test_tsne.T[0], cat_test_tsne.T[1], 'o') 352 | plt.title('latent') 353 | plt.show() 354 | 355 | np.save(join(results_dir, 'temp_exp', f'{cat_name}_latent_tsne'), latent_tsne) 356 | 357 | start_time = datetime.now() 358 | print(start_time) 359 | tnw_tsne = manifold.TSNE(n_components=2, init='pca').fit_transform(cc_tnw) 360 | print(datetime.now() - start_time) 361 | cat_test_tnw = tnw_tsne[-(2 * amount):] 362 | plt.plot(tnw_tsne.T[0], tnw_tsne.T[1], 'o', cat_test_tnw.T[0], cat_test_tnw.T[1], 'o') 363 | plt.title('tnw') 364 | plt.show() 365 | 366 | np.save(join(results_dir, 'temp_exp', f'{cat_name}_tnw_tsne'), tnw_tsne) 367 | 368 | latent_tsne = np.load(join(results_dir, 'temp_exp', f'{cat_name}_latent_tsne.npy')) 369 | tnw_tsne = np.load(join(results_dir, 'temp_exp', f'{cat_name}_tnw_tsne.npy')) 370 | 371 | cat_test_tsne = latent_tsne[-(2 * amount):] 372 | cat_test_tnw = tnw_tsne[-(2 * amount):] 373 | 374 | latent_dist = np.zeros(amount) 375 | tnw_dist = np.zeros(amount) 376 | 377 | for i in range(amount): 378 | latent_dist[i] = np.linalg.norm(cat_test_tsne[2 * i] - cat_test_tsne[2 * i + 1]) 379 | tnw_dist[i] = np.linalg.norm(cat_test_tnw[2 * i] - cat_test_tnw[2 * i + 1]) 380 | 381 | plt.plot(latent_tsne.T[0], latent_tsne.T[1], 'o', cat_test_tsne.T[0], cat_test_tsne.T[1], 'o') 382 | plt.title('latent') 383 | plt.show() 384 | 385 | plt.plot(tnw_tsne.T[0], tnw_tsne.T[1], 'o', cat_test_tnw.T[0][0], cat_test_tnw.T[1][0], 'o', 386 | cat_test_tnw.T[0][1], 387 | cat_test_tnw.T[1][1], 'o') 388 | plt.title('tnw') 389 | plt.show() 390 | 391 | 392 | def temp_exp(full_model, device, dataset_dict, results_dir, epoch): 393 | # you may write your experiment here 394 | pass 395 | 396 | 397 | experiment_functions_dict = { 398 | 'fixed': fixed, 399 | 'evaluate_generativity': evaluate_generativity, 400 | 'compute_mmd_tmd_uhd': compute_mmd_tmd_uhd, 401 | 'merge_different_categories': merge_different_categories, 402 | 'same_model_different_slices': same_model_different_slices, 403 | "completion3d_submission": completion3d_submission, 404 | "temp_exp": temp_exp 405 | } 406 | -------------------------------------------------------------------------------- /core/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from os.path import join 4 | from datetime import datetime 5 | 6 | import torch 7 | import numpy as np 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from core.arg_parser import parse_config 12 | from core.epoch_loops import train_epoch, val_epoch 13 | from core.experiments import experiment_functions_dict 14 | from datasets import get_datasets 15 | from model.full_model import FullModel 16 | from losses.champfer_loss import ChamferLoss 17 | 18 | from core.setup import seed_setup, logging_setup, cuda_setup, results_dir_setup, restore_model_state, \ 19 | get_results_dir_path, restore_metrics, weights_init 20 | from utils.telegram_logging import TelegramLogger 21 | from utils.util import find_latest_epoch, save_plot, get_model_name 22 | 23 | 24 | def main(config: dict): 25 | # region Setup 26 | seed_setup(config['setup']['seed']) 27 | 28 | run_mode: str = config['mode'] 29 | result_dir_path: str = get_results_dir_path(config, run_mode) 30 | 31 | if run_mode == 'training': 32 | dirs_to_create = ('weights', 'samples', 'metrics') 33 | weights_path = join(result_dir_path, 'weights') 34 | metrics_path = join(result_dir_path, 'metrics') 35 | elif run_mode == 'experiments': 36 | dirs_to_create = tuple(experiment_functions_dict.keys()) 37 | weights_path = join(get_results_dir_path(config, 'training'), 'weights') 38 | metrics_path = join(get_results_dir_path(config, 'training'), 'metrics') 39 | else: 40 | raise ValueError("mode should be `training` or `experiments`") 41 | 42 | results_dir_setup(result_dir_path, dirs_to_create) 43 | 44 | with open(join(result_dir_path, 'last_config.json'), mode='w') as f: 45 | json.dump(config, f) 46 | 47 | logging_setup(result_dir_path) 48 | log = logging.getLogger() 49 | 50 | log.info(f'Current mode {run_mode}') 51 | 52 | if config['telegram_logger']['enable']: 53 | tg_log = TelegramLogger.get_logger(config['telegram_logger']) 54 | 55 | device = cuda_setup(config['setup']['gpu_id']) 56 | log.info(f'Device variable: {device}') 57 | 58 | reconstruction_loss = ChamferLoss().to(device) 59 | full_model = FullModel(config['full_model']).to(device) 60 | full_model.apply(weights_init) 61 | 62 | optimizer = getattr(optim, config['training']['optimizer']['type']) # class 63 | optimizer = optimizer(full_model.parameters(), **config['training']['optimizer']['hyperparams']) 64 | 65 | scheduler = getattr(optim.lr_scheduler, config['training']['lr_scheduler']['type']) # class 66 | scheduler = scheduler(optimizer, **config['training']['lr_scheduler']['hyperparams']) 67 | log.info(f'Model {get_model_name(config)} created') 68 | 69 | latest_epoch = find_latest_epoch(result_dir_path if run_mode == "training" else weights_path) 70 | 71 | log.info(f'Latest epoch found: {latest_epoch}') 72 | 73 | if latest_epoch > 0: 74 | if run_mode == "training": 75 | latest_epoch = restore_model_state(weights_path, metrics_path, config['setup']['gpu_id'], latest_epoch, 76 | "latest", full_model, optimizer, scheduler) 77 | elif run_mode == "experiments": 78 | latest_epoch = restore_model_state(weights_path, metrics_path, config['setup']['gpu_id'], latest_epoch, 79 | config['experiments']['epoch'], full_model) 80 | log.info(f'Restored epoch : {latest_epoch}') 81 | elif run_mode == "experiments": 82 | raise FileNotFoundError("no weights found at ", weights_path) 83 | # endregion Setup 84 | 85 | train_dataset, val_dataset_dict, test_dataset_dict = get_datasets(config['dataset']) 86 | 87 | log.info(f'Dataset loaded for classes: {[cat_name for cat_name in val_dataset_dict.keys()]}') 88 | 89 | if run_mode == 'training': 90 | samples_path = join(result_dir_path, 'samples') 91 | train_dataloader = DataLoader(train_dataset, pin_memory=True, **config['training']['dataloader']['train']) 92 | val_dataloaders_dict = {cat_name: DataLoader(cat_ds, pin_memory=True, **config['training']['dataloader']['val']) 93 | for cat_name, cat_ds in val_dataset_dict.items()} 94 | if latest_epoch == 0: 95 | best_epoch_loss = np.Infinity 96 | train_losses = [] 97 | val_losses = [] 98 | else: 99 | train_losses, val_losses, best_epoch_loss = restore_metrics(metrics_path, latest_epoch) 100 | 101 | for epoch in range(latest_epoch + 1, config['training']['max_epoch'] + 1): 102 | start_epoch_time = datetime.now() 103 | log.debug("Epoch: %s" % epoch) 104 | 105 | full_model, optimizer, epoch_loss_all, epoch_loss_kld, epoch_loss_r, latest_existing, latest_gt, latest_rec \ 106 | = train_epoch(epoch, full_model, optimizer, train_dataloader, device, reconstruction_loss, 107 | config['training']['loss_coef']) 108 | scheduler.step() 109 | 110 | train_losses.append(np.array([epoch_loss_all, epoch_loss_r, epoch_loss_kld])) 111 | 112 | log_string = f'[{epoch}/{config["training"]["max_epoch"]}] ' \ 113 | f'Loss_ALL: {epoch_loss_all:.4f} ' \ 114 | f'Loss_R: {epoch_loss_r:.4f} ' \ 115 | f'Loss_E: {epoch_loss_kld:.4f} ' \ 116 | f'Time: {datetime.now() - start_epoch_time}' 117 | log.info(log_string) 118 | 119 | train_plots = [] 120 | for k in range(min(5, latest_rec.shape[0])): 121 | train_plots.append(save_plot(latest_existing[k], epoch, k, samples_path, 'existing')) 122 | train_plots.append(save_plot(latest_rec[k], epoch, k, samples_path, 'reconstructed')) 123 | train_plots.append(save_plot(latest_gt[k].T, epoch, k, samples_path, 'gt')) 124 | 125 | if config['telegram_logger']['enable']: 126 | tg_log.log_images(train_plots[:9], log_string) 127 | 128 | epoch_val_losses, epoch_val_samples = val_epoch(epoch, full_model, device, val_dataloaders_dict, 129 | val_dataset_dict.keys(), reconstruction_loss, 130 | config['training']['loss_coef']) 131 | 132 | is_new_best = epoch_val_losses['total'][0] < best_epoch_loss 133 | 134 | if is_new_best: 135 | best_epoch_loss = epoch_val_losses['total'][0] 136 | 137 | val_losses.append(epoch_val_losses['total']) 138 | 139 | log_string = f'val results[{config["training"]["loss_coef"]}*our_cd]:\n' 140 | for k, v in epoch_val_losses.items(): 141 | log_string += k + ': ' + str(v) + '\n' 142 | 143 | if is_new_best: 144 | log_string += "new best epoch" 145 | 146 | log.info(log_string) 147 | 148 | val_plots = [] 149 | for cat_name, sample in epoch_val_samples.items(): 150 | val_plots.append(save_plot(sample[0], epoch, cat_name, samples_path, 'val_existing')) 151 | val_plots.append(save_plot(sample[2], epoch, cat_name, samples_path, 'val_rec')) 152 | val_plots.append(save_plot(sample[1].T, epoch, cat_name, samples_path, 'val_gt')) 153 | 154 | if config['telegram_logger']['enable']: 155 | chosen_plot_idx = np.random.choice(np.arange(len(val_plots) / 3, dtype=np.int), 156 | int(np.min([3, len(val_plots) / 3])), replace=False) 157 | plots_to_log = [] 158 | for idx in chosen_plot_idx: 159 | plots_to_log.extend(val_plots[3 * idx:3 * idx + 3]) 160 | tg_log.log_images(plots_to_log, log_string) 161 | 162 | if (epoch % config['training']['state_save_frequency'] == 0 or is_new_best) \ 163 | and epoch > config['training'].get('min_save_epoch', 0): 164 | torch.save(full_model.state_dict(), join(weights_path, f'{epoch:05}_model.pth')) 165 | torch.save(optimizer.state_dict(), join(weights_path, f'{epoch:05}_O.pth')) 166 | torch.save(scheduler.state_dict(), join(weights_path, f'{epoch:05}_S.pth')) 167 | 168 | np.save(join(metrics_path, f'{epoch:05}_train'), np.array(train_losses)) 169 | np.save(join(metrics_path, f'{epoch:05}_val'), np.array(val_losses)) 170 | 171 | log_string = "Epoch: %s saved" % epoch 172 | log.debug(log_string) 173 | if config['telegram_logger']['enable']: 174 | tg_log.log(log_string) 175 | 176 | elif run_mode == 'experiments': 177 | 178 | # from datasets.real_data import RealDataNPYDataset 179 | # test_dataset_dict = RealDataNPYDataset(root_dir="D:\\UJ\\bachelors\\3d-point-clouds-autocomplete\\data\\real_car_data") 180 | 181 | full_model.eval() 182 | 183 | with torch.no_grad(): 184 | for experiment_name, experiment_dict in config['experiments']['settings'].items(): 185 | if experiment_dict.pop('execute', False): 186 | log.info(experiment_name) 187 | experiment_functions_dict[experiment_name](full_model, device, test_dataset_dict, result_dir_path, 188 | latest_epoch, **experiment_dict) 189 | 190 | exit(0) 191 | 192 | 193 | if __name__ == '__main__': 194 | main(parse_config()) 195 | -------------------------------------------------------------------------------- /core/setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os import makedirs 3 | from os.path import join, exists 4 | import random 5 | 6 | import torch 7 | import numpy as np 8 | 9 | from utils.util import get_classes_dir, get_distribution_dir, get_model_name 10 | 11 | 12 | def seed_setup(seed: int = 0): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | def get_results_dir_path(config, mode): 23 | return join(config['results_root'], mode, get_distribution_dir(config['full_model']), 24 | config['dataset']['name'], get_classes_dir(config['dataset']), get_model_name(config)) 25 | 26 | 27 | def results_dir_setup(dir_path, dirs_to_create=('weights', 'samples', 'metrics')): 28 | makedirs(dir_path, exist_ok=True) 29 | for dir_to_create in dirs_to_create: 30 | makedirs(join(dir_path, dir_to_create), exist_ok=True) 31 | return dir_path 32 | 33 | 34 | def logging_setup(log_dir): 35 | makedirs(log_dir, exist_ok=True) 36 | 37 | logpath = join(log_dir, 'log.txt') 38 | filemode = 'a' if exists(logpath) else 'w' 39 | 40 | # set up logging to file - see previous section for more details 41 | logging.basicConfig(level=logging.DEBUG, 42 | format='%(asctime)s %(message)s', 43 | datefmt='%m-%d %H:%M:%S', 44 | filename=logpath, 45 | filemode=filemode) 46 | # define a Handler which writes INFO messages or higher to the sys.stderr 47 | console = logging.StreamHandler() 48 | console.setLevel(logging.INFO) 49 | # set a format which is simpler for console use 50 | formatter = logging.Formatter('%(asctime)s: %(levelname)-8s %(message)s') 51 | # tell the handler to use this format 52 | console.setFormatter(formatter) 53 | # add the handler to the root logger 54 | logging.getLogger('').addHandler(console) 55 | 56 | 57 | def cuda_setup(gpu_idx=0): 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | torch.cuda.set_device(gpu_idx) 60 | return device 61 | 62 | 63 | def weights_init(m): 64 | classname = m.__class__.__name__ 65 | if classname.find('Conv') != -1: 66 | gain = torch.nn.init.calculate_gain('relu') 67 | torch.nn.init.xavier_uniform_(m.weight, gain) 68 | if m.bias is not None: 69 | torch.nn.init.constant_(m.bias, 0) 70 | elif classname.find('BatchNorm') != -1: 71 | torch.nn.init.constant_(m.weight, 1) 72 | torch.nn.init.constant_(m.bias, 0) 73 | elif classname.find('Linear') != -1: 74 | gain = torch.nn.init.calculate_gain('relu') 75 | torch.nn.init.xavier_uniform_(m.weight, gain) 76 | if m.bias is not None: 77 | torch.nn.init.constant_(m.bias, 0) 78 | 79 | 80 | def restore_model_state(weights_path, metrics_path, gpu_id, epoch, restore_policy, full_model, optimizer=None, 81 | scheduler=None): 82 | if restore_policy == "latest": 83 | pass 84 | elif restore_policy == "best_val": 85 | val_losses = np.load(join(metrics_path, f'{epoch:05}_val.npy'), allow_pickle=True) 86 | epoch = np.argmin(val_losses) + 1 87 | else: 88 | try: 89 | epoch = int(restore_policy) 90 | except ValueError: 91 | raise ValueError('`[epoch]` value can take only values: `latest`, `best_val` or positive integer') 92 | 93 | full_model.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_model.pth'), 94 | map_location='cuda:' + str(gpu_id))) 95 | 96 | if optimizer is not None: 97 | optimizer.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_O.pth'))) 98 | 99 | if scheduler is not None: 100 | scheduler.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_S.pth'))) 101 | return epoch 102 | 103 | 104 | def restore_metrics(metrics_path, epoch): 105 | train_losses = np.load(join(metrics_path, f'{epoch:05}_train.npy'), allow_pickle=True) 106 | val_losses = np.load(join(metrics_path, f'{epoch:05}_val.npy'), allow_pickle=True) 107 | return train_losses.tolist(), val_losses.tolist(), np.min(val_losses) 108 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.shapenet import ShapeNetDataset 2 | from datasets.shapenet_completion3d import ShapeNetCompletion3DDataset 3 | from datasets.shapenet_3depn import ShapeNet3DEPNDataset 4 | 5 | 6 | def get_datasets(config): 7 | dataset_name = config['name'] 8 | if dataset_name == 'shapenet': 9 | train_dataset = ShapeNetDataset(root_dir=config['path'], classes=config['classes'], split='train', 10 | is_random_rotated=config['is_rotated'], num_samples=config['num_samples'], 11 | use_pcn_model_list=True) 12 | val_dataset_dict = ShapeNetDataset.get_validation_datasets(root_dir=config['path'], 13 | classes=config['classes'], 14 | is_random_rotated=config['is_rotated'], 15 | num_samples=config['num_samples'], 16 | use_pcn_model_list=True) 17 | test_dataset_dict = ShapeNetDataset.get_test_datasets(root_dir=config['path'], 18 | classes=config['classes'], 19 | is_random_rotated=config['is_rotated'], 20 | num_samples=config['num_samples'], 21 | use_pcn_model_list=True, 22 | is_gen=config['gen_test_set']) 23 | elif dataset_name == 'completion': 24 | train_dataset = ShapeNetCompletion3DDataset(root_dir=config['path'], split='train', classes=config['classes']) 25 | val_dataset_dict = ShapeNetCompletion3DDataset.get_validation_datasets(config['path'], classes=config['classes']) 26 | test_dataset_dict = ShapeNetCompletion3DDataset.get_test_datasets(config['path']) 27 | elif dataset_name == '3depn': 28 | train_dataset = ShapeNet3DEPNDataset(root_dir=config['path'], split='train', classes=config['classes']) 29 | val_dataset_dict = ShapeNet3DEPNDataset.get_validation_datasets(config['path'], classes=config['classes']) 30 | test_dataset_dict = ShapeNet3DEPNDataset.get_test_datasets(config['path'], classes=config['classes']) 31 | else: 32 | raise ValueError(f'Invalid dataset name. Expected `shapenet`, `completion` or `3depn`. Got: `{dataset_name}`') 33 | 34 | return train_dataset, val_dataset_dict, test_dataset_dict 35 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class BaseDataset(Dataset): 5 | 6 | def __init__(self, root_dir, split='train', classes=[]): 7 | self.root_dir = root_dir 8 | self.split = split 9 | 10 | @classmethod 11 | def get_validation_datasets(cls, root_dir, classes=[], **kwargs): 12 | raise NotImplementedError 13 | 14 | @classmethod 15 | def get_test_datasets(cls, root_dir, classes=[], **kwargs): 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /datasets/real_data.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import join 3 | 4 | import numpy as np 5 | from datasets.base_dataset import BaseDataset 6 | from utils.util import resample_pcd 7 | 8 | 9 | class RealDataNPYDataset(BaseDataset): 10 | 11 | def __init__(self, root_dir): 12 | super().__init__(root_dir) 13 | 14 | self.scenes = [] 15 | self.objs = [] 16 | self.boxes = [] 17 | 18 | for f in listdir(self.root_dir): 19 | if f.startswith('object_box'): 20 | self.boxes.append(f) 21 | elif f.startswith('object'): 22 | self.objs.append(f) 23 | elif f.startswith('scen'): 24 | self.scenes.append(f) 25 | 26 | def _get_scales(self, pcd): 27 | axis_mins = np.min(pcd.T, axis=1) 28 | axis_maxs = np.max(pcd.T, axis=1) 29 | 30 | scale = np.max(axis_maxs - axis_mins) 31 | pcd_center = (axis_maxs + axis_mins) / 2 32 | 33 | return pcd_center, scale / 0.9 34 | 35 | def __getitem__(self, idx): 36 | pcd = np.load(join(self.root_dir, self.objs[idx])).astype(np.float32) 37 | pcd_center, scale = self._get_scales(pcd) 38 | pcd = (pcd - pcd_center) / scale 39 | return resample_pcd(pcd, 1024), 0, 0, idx 40 | 41 | def get_full_object(self, idx): 42 | return np.load(join(self.root_dir, self.objs[idx])).astype(np.float32) 43 | 44 | def get_scene(self, idx): 45 | if self.scenes: 46 | return np.load(join(self.root_dir, self.scenes[idx])).astype(np.float32) 47 | else: 48 | raise ValueError("Dataset does not include scenes") 49 | 50 | def get_obj_box(self, idx): 51 | if self.boxes: 52 | return np.load(join(self.root_dir, self.boxes[idx])).astype(np.float32) 53 | else: 54 | raise ValueError("Dataset does not include object boxes") 55 | 56 | def inverse_scale_to_scene(self, idx, scaled_pcd): 57 | scene = np.load(join(self.root_dir, self.scenes[idx])).astype(np.float32) 58 | pcd = np.load(join(self.root_dir, self.objs[idx])).astype(np.float32) 59 | pcd_center, scale = self._get_scales(pcd) 60 | scaled_pcd_center, scaled_pcd_scale = self._get_scales(scaled_pcd) 61 | return np.concatenate([scene, (scaled_pcd / scaled_pcd_scale * scale) + pcd_center]) 62 | 63 | def inverse_scale(self, idx, scaled_pcd): 64 | pcd = np.load(join(self.root_dir, self.objs[idx])).astype(np.float32) 65 | pcd_center, scale = self._get_scales(pcd) 66 | scaled_pcd_center, scaled_pcd_scale = self._get_scales(scaled_pcd) 67 | return (scaled_pcd / scaled_pcd_scale * scale) + pcd_center 68 | 69 | def __len__(self): 70 | return len(self.objs) 71 | 72 | @classmethod 73 | def get_validation_datasets(cls, root_dir, classes=[], **kwargs): 74 | raise NotImplementedError 75 | -------------------------------------------------------------------------------- /datasets/shapenet.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from datasets.base_dataset import BaseDataset 7 | from datasets.utils.shapenet_category_mapping import synth_id_to_category, category_to_synth_id, synth_id_to_number 8 | from utils.plyfile import load_ply 9 | from utils.util import resample_pcd, get_filenames_by_cat 10 | 11 | 12 | class ShapeNetDataset(BaseDataset): 13 | 14 | def __init__(self, root_dir='/home/datasets/shapenet', split='train', classes=[], 15 | is_random_rotated=False, num_samples=4, use_pcn_model_list=True, is_gen=False): 16 | """ 17 | Args: 18 | root_dir (string): Directory with all the point clouds. 19 | """ 20 | super().__init__(root_dir, split, classes) 21 | 22 | self.is_random_rotated = is_random_rotated 23 | self.use_pcn_model_list = use_pcn_model_list 24 | self.is_gen = is_gen 25 | self.num_samples = num_samples 26 | 27 | if is_gen: 28 | self.num_samples = 1 29 | 30 | if self.use_pcn_model_list: 31 | list_path = join(root_dir, self.split +'.list') 32 | with open(list_path) as file: 33 | if classes: 34 | self.point_clouds_names = [line.strip() for line in file if line.strip().split('/')[0] in classes] 35 | else: 36 | self.point_clouds_names = [line.strip() for line in file] 37 | else: 38 | pc_df = get_filenames_by_cat(self.root_dir) 39 | if classes: 40 | if classes[0] not in synth_id_to_category.keys(): 41 | classes = [category_to_synth_id[c] for c in classes] 42 | pc_df = pc_df[pc_df.category.isin(classes)].reset_index(drop=True) 43 | else: 44 | classes = synth_id_to_category.keys() 45 | 46 | if self.split == 'train': 47 | # first 85% 48 | self.point_clouds_names = pd.concat( 49 | [pc_df[pc_df['category'] == c][:int(0.85 * len(pc_df[pc_df['category'] == c]))] 50 | .reset_index(drop=True) for c in classes]) 51 | elif self.split == 'val': 52 | # missing 5% 53 | self.point_clouds_names = pd.concat([pc_df[pc_df['category'] == c][ 54 | int(0.85 * len(pc_df[pc_df['category'] == c])):int( 55 | 0.9 * len(pc_df[pc_df['category'] == c]))] 56 | .reset_index(drop=True) for c in classes]) 57 | else: 58 | # last 10% 59 | self.point_clouds_names = pd.concat( 60 | [pc_df[pc_df['category'] == c][int(0.9 * len(pc_df[pc_df['category'] == c])):] 61 | .reset_index(drop=True) for c in classes]) 62 | 63 | def __len__(self): 64 | return len(self.point_clouds_names) * self.num_samples 65 | 66 | def __getitem__(self, idx): 67 | if self.use_pcn_model_list: 68 | pc_category, pc_filename = self.point_clouds_names[idx // self.num_samples].split('/') 69 | pc_filename += '.ply' 70 | else: 71 | pc_category, pc_filename = self.point_clouds_names.iloc[idx // self.num_samples].values 72 | 73 | if self.is_random_rotated: 74 | from scipy.spatial.transform import Rotation 75 | random_rotation_matrix = Rotation.from_euler('z', np.random.randint(360), degrees=True).as_matrix().astype( 76 | np.float32) 77 | 78 | scan_idx = str(idx % self.num_samples) 79 | 80 | if self.is_gen and self.split == 'test': 81 | existing = resample_pcd(load_ply(join(self.root_dir, 'test_gen', 'right', pc_category, pc_filename)), 1024) 82 | missing = resample_pcd(load_ply(join(self.root_dir, 'test_gen', 'left', pc_category, pc_filename)), 1024) 83 | gt = load_ply(join(self.root_dir, 'test_gen', 'gt', pc_category, pc_filename)) 84 | else: 85 | existing = load_ply(join(self.root_dir, 'slices', 'existing', pc_category, scan_idx + '~' + pc_filename)) 86 | missing = load_ply(join(self.root_dir, 'slices', 'missing', pc_category, scan_idx + '~' + pc_filename)) 87 | gt = load_ply(join(self.root_dir, pc_category, pc_filename)) 88 | 89 | if self.is_random_rotated: 90 | existing = existing @ random_rotation_matrix 91 | missing = missing @ random_rotation_matrix 92 | gt = gt @ random_rotation_matrix 93 | 94 | return existing, missing, gt, synth_id_to_number[pc_category] 95 | 96 | @classmethod 97 | def _get_datasets_for_classes(cls, root_dir, split, classes=[], **kwargs): 98 | if not classes: 99 | if kwargs.get('use_pcn_model_list'): 100 | classes = ['02691156', '02933112', '02958343', '03001627', '03636649', '04256520', '04379243', 101 | '04530566'] 102 | else: 103 | classes = list(synth_id_to_category.keys()) 104 | return {synth_id_to_category[category_id]: ShapeNetDataset(root_dir=root_dir, 105 | split=split, 106 | classes=[category_id], **kwargs) 107 | for category_id in classes} 108 | 109 | @classmethod 110 | def get_validation_datasets(cls, root_dir, classes=[], **kwargs): 111 | return cls._get_datasets_for_classes(root_dir, 'val', classes, **kwargs) 112 | 113 | @classmethod 114 | def get_test_datasets(cls, root_dir, classes=[], **kwargs): 115 | return cls._get_datasets_for_classes(root_dir, 'test', classes, **kwargs) 116 | -------------------------------------------------------------------------------- /datasets/shapenet_3depn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import torch 4 | 5 | import csv 6 | import os 7 | import random 8 | from os.path import join, exists 9 | 10 | from datasets.base_dataset import BaseDataset 11 | from datasets.utils.shapenet_category_mapping import synth_id_to_category 12 | from utils.plyfile import load_ply 13 | 14 | # code is based on 15 | # https://github.com/ChrisWu1997/Multimodal-Shape-Completion/blob/master/dataset/dataset_3depn.py 16 | 17 | 18 | def downsample_point_cloud(points, n_pts): 19 | """downsample points by random choice 20 | 21 | :param points: (n, 3) 22 | :param n_pts: int 23 | :return: 24 | """ 25 | p_idx = random.choices(list(range(points.shape[0])), k=n_pts) 26 | return points[p_idx] 27 | 28 | 29 | def upsample_point_cloud(points, n_pts): 30 | """upsample points by random choice 31 | 32 | :param points: (n, 3) 33 | :param n_pts: int, > n 34 | :return: 35 | """ 36 | p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0]) 37 | dup_points = points[p_idx] 38 | points = np.concatenate([points, dup_points], axis=0) 39 | return points 40 | 41 | 42 | def sample_point_cloud_by_n(points, n_pts): 43 | """resample point cloud to given number of points""" 44 | if n_pts > points.shape[0]: 45 | return upsample_point_cloud(points, n_pts) 46 | elif n_pts < points.shape[0]: 47 | return downsample_point_cloud(points, n_pts) 48 | else: 49 | return points 50 | 51 | 52 | def collect_train_split_by_id(path, cat_id): 53 | split_info = {"train":[], 'validation':[], 'test':[]} 54 | with open(path, 'r') as csv_file: 55 | csv_reader = csv.reader(csv_file, delimiter=',') 56 | line_cnt = 0 57 | for row in csv_reader: 58 | if line_cnt == 0 or row[1] != cat_id: 59 | pass 60 | else: 61 | if row[-1] == "train": 62 | split_info["train"].append(row[-2]) 63 | elif row[-1] == "val": 64 | split_info["validation"].append(row[-2]) 65 | else: 66 | split_info["test"].append(row[-2]) 67 | line_cnt += 1 68 | return split_info 69 | 70 | 71 | class ShapeNet3DEPNDataset(BaseDataset): 72 | 73 | def __init__(self, root_dir='/home/datasets/completion', split='train', classes=[], num_samples=4): 74 | super(ShapeNet3DEPNDataset, self).__init__(root_dir, split, classes) 75 | 76 | if self.split == 'test': 77 | self.cat_pc_root = join(root_dir, 'ShapeNetPointCloud', classes[0]) 78 | self.cat_pc_raw_root = join(root_dir, 'shapenet_dim32_sdf_pc', classes[0]) 79 | shape_names = [] 80 | with open(join(self.root_dir, 'shapenet-official-split.csv'), 'r') as csv_file: 81 | csv_reader = csv.reader(csv_file, delimiter=',') 82 | line_cnt = 0 83 | for row in csv_reader: 84 | if line_cnt == 0 or (row[1] != classes[0]): 85 | pass 86 | else: 87 | if row[-1] == self.split: 88 | shape_names.append(row[-2]) 89 | line_cnt += 1 90 | 91 | self.shape_names = [] 92 | for name in shape_names: 93 | ply_path = join(self.cat_pc_root, name + '.ply') 94 | path = join(self.cat_pc_raw_root, "{}__0__.ply".format(name)) 95 | if exists(ply_path) and exists(path): 96 | self.shape_names.append(name) 97 | 98 | self.raw_ply_names = sorted(os.listdir(self.cat_pc_raw_root)) 99 | 100 | self.rng = random.Random(1234) # from original publication 101 | else: 102 | self.shape_names = os.listdir(join(self.root_dir, 'slices', 'gt', classes[0])) 103 | self.num_samples = num_samples 104 | self.cat = classes[0] 105 | 106 | def __getitem__(self, index): 107 | if self.split == 'test': 108 | raw_n = self.rng.randint(0, 7) 109 | raw_pc_name = self.shape_names[index] + "__{}__.ply".format(raw_n) 110 | raw_ply_path = os.path.join(self.cat_pc_raw_root, raw_pc_name) 111 | raw_pc = np.array(trimesh.load(raw_ply_path).vertices) 112 | raw_pc = self._rotate_point_cloud_by_axis_angle(raw_pc) 113 | raw_pc = sample_point_cloud_by_n(raw_pc, 1024) 114 | raw_pc = torch.tensor(raw_pc, dtype=torch.float32) 115 | 116 | # process existing complete shapes 117 | real_shape_name = self.shape_names[index] 118 | real_ply_path = os.path.join(self.cat_pc_root, real_shape_name + '.ply') 119 | real_pc = np.array(trimesh.load(real_ply_path).vertices) 120 | real_pc = sample_point_cloud_by_n(real_pc, 2048) 121 | real_pc = torch.tensor(real_pc, dtype=torch.float32) 122 | 123 | return raw_pc, 0, real_pc, real_shape_name 124 | else: 125 | pc_filename = self.shape_names[index // self.num_samples] 126 | existing = load_ply(join(self.root_dir, 'slices', 'existing', self.cat, 127 | str(index % self.num_samples) + '~' + pc_filename)) 128 | missing = load_ply(join(self.root_dir, 'slices', 'missing', self.cat, 129 | str(index % self.num_samples) + '~' + pc_filename)) 130 | gt = load_ply(join(self.root_dir, 'slices', 'gt', self.cat, pc_filename)) 131 | return existing, missing, gt, pc_filename[:-4] 132 | 133 | def __len__(self): 134 | if self.split == 'test': 135 | return len(self.shape_names) 136 | else: 137 | return len(self.shape_names) * self.num_samples 138 | 139 | def _rotate_point_cloud_by_axis_angle(self, points): 140 | rot_m = np.array([[2.22044605e-16, 0.00000000e+00, 1.00000000e+00], 141 | [0.00000000e+00, 1.00000000e+00, 0.00000000e+00], 142 | [-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]]) 143 | return np.dot(rot_m, points.T).T 144 | 145 | @classmethod 146 | def get_validation_datasets(cls, root_dir, classes=[], **kwargs): 147 | if not classes: 148 | classes = ['02691156', '03001627', '04379243'] 149 | 150 | return {synth_id_to_category[category_id]: ShapeNet3DEPNDataset(root_dir=root_dir, split='val', 151 | classes=[category_id]) 152 | for category_id in classes} 153 | 154 | @classmethod 155 | def get_test_datasets(cls, root_dir, classes=[], **kwargs): 156 | return {synth_id_to_category[category_id]: ShapeNet3DEPNDataset(root_dir=root_dir, split='test', 157 | classes=[category_id]) 158 | for category_id in classes} 159 | -------------------------------------------------------------------------------- /datasets/shapenet_completion3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | from datasets.base_dataset import BaseDataset 7 | from datasets.utils.shapenet_category_mapping import synth_id_to_category 8 | 9 | 10 | class ShapeNetCompletion3DDataset(BaseDataset): 11 | 12 | def __init__(self, root_dir='/home/datasets/completion', split='train', classes=[], model_list=None): 13 | super(ShapeNetCompletion3DDataset, self).__init__(root_dir, split, classes) 14 | 15 | if self.split == 'train': 16 | self.list_path = os.path.join(root_dir, 'train.list') 17 | elif self.split == 'val': 18 | self.list_path = os.path.join(root_dir, 'val.list') 19 | else: 20 | self.list_path = os.path.join(root_dir, 'test.list') 21 | 22 | if model_list is None: 23 | with open(self.list_path) as file: 24 | if classes: 25 | self.model_list = [line.strip() for line in file if line.strip().split('/')[0] in classes] 26 | else: 27 | self.model_list = [line.strip() for line in file] 28 | else: 29 | self.model_list = model_list 30 | self.len = len(self.model_list) 31 | 32 | def __len__(self): 33 | return self.len 34 | 35 | def _load_h5(self, path): 36 | f = h5py.File(path, 'r') 37 | cloud_data = np.array(f['data']) 38 | f.close() 39 | return cloud_data.astype(np.float32) 40 | 41 | def __getitem__(self, index): 42 | model_name = self.model_list[index] 43 | existing = self._load_h5(os.path.join(self.root_dir, self.split, 'partial', model_name + '.h5')) 44 | if self.split != 'test': 45 | gt = self._load_h5(os.path.join(self.root_dir, self.split, 'gt', model_name + '.h5')) 46 | else: 47 | gt = existing 48 | return existing, 0, gt, model_name 49 | 50 | @classmethod 51 | def get_validation_datasets(cls, root_dir, classes=[], **kwargs): 52 | if not classes: 53 | classes = ['02691156', '02933112', '02958343', '03001627', '03636649', '04256520', '04379243', '04530566'] 54 | 55 | list_path = os.path.join(root_dir, 'val.list') 56 | 57 | model_lists = dict.fromkeys(classes) 58 | for k in model_lists.keys(): 59 | model_lists[k] = list() 60 | 61 | with open(list_path) as file: 62 | for line in file: 63 | model_lists[line.strip().split('/')[0]].append(line.strip()) 64 | return {synth_id_to_category[category_id]: ShapeNetCompletion3DDataset(root_dir=root_dir, split='val', 65 | model_list=model_list) 66 | for category_id, model_list in model_lists.items()} 67 | 68 | @classmethod 69 | def get_test_datasets(cls, root_dir, classes=[], **kwargs): 70 | return {'all': ShapeNetCompletion3DDataset(root_dir=root_dir, split='test')} 71 | -------------------------------------------------------------------------------- /datasets/utils/dataset_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class HyperPlane(object): 5 | 6 | def __init__(self, params, bias): 7 | self.params = params 8 | self.bias = bias 9 | 10 | def check_point(self, point): 11 | return np.sign(np.dot(point, self.params) + self.bias) 12 | 13 | @staticmethod 14 | def get_plane_from_3_points(points): 15 | cp = np.cross(points[1] - points[0], points[2] - points[0]) 16 | return HyperPlane(cp, np.dot(cp, points[0])) 17 | 18 | @staticmethod 19 | def get_random_plane(): 20 | return HyperPlane.get_plane_from_3_points(np.random.rand(3, 3)) 21 | 22 | def __str__(self): 23 | return "Plane A={}, B={}, C={}, D={}".format(*self.params, self.bias) 24 | 25 | 26 | class SlicedDatasetGenerator(object): 27 | 28 | @staticmethod 29 | def generate_item(points, target_partition_points=1024): 30 | 31 | while True: 32 | under = HyperPlane.get_random_plane().check_point(points) > 0 33 | points_under_plane = points[under] 34 | points_above_plane = points[~under] 35 | 36 | if target_partition_points == len(points_under_plane): 37 | return points_under_plane, points_above_plane 38 | if target_partition_points == len(points_above_plane): 39 | return points_above_plane, points_under_plane 40 | -------------------------------------------------------------------------------- /datasets/utils/shapenet_category_mapping.py: -------------------------------------------------------------------------------- 1 | synth_id_to_category = { 2 | '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', 3 | '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', 4 | '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', 5 | '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', 6 | '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', 7 | '02954340': 'cap', '02958343': 'car', '03001627': 'chair', 8 | '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', 9 | '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', 10 | '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', 11 | '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', 12 | '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', 13 | '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', 14 | '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', 15 | '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', 16 | '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', 17 | '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', 18 | '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', 19 | '04256520': 'sofa', '04330267': 'stove', '04530566': 'watercraft', 20 | '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' 21 | } 22 | 23 | category_to_synth_id = {v: k for k, v in synth_id_to_category.items()} 24 | synth_id_to_number = {k: i for i, k in enumerate(synth_id_to_category.keys())} 25 | -------------------------------------------------------------------------------- /images/hyperpocket_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/3d-point-clouds-autocomplete/13ec26e10da7b0ad2f71d4c97016fbb8499b0cff/images/hyperpocket_arch.png -------------------------------------------------------------------------------- /losses/champfer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChamferLoss(nn.Module): 6 | 7 | def __init__(self): 8 | super(ChamferLoss, self).__init__() 9 | self.use_cuda = torch.cuda.is_available() 10 | 11 | def forward(self, preds, gts): 12 | P = self.batch_pairwise_dist(gts, preds) 13 | mins, _ = torch.min(P, 1) 14 | loss_1 = torch.sum(mins) 15 | mins, _ = torch.min(P, 2) 16 | loss_2 = torch.sum(mins) 17 | return loss_1 + loss_2 18 | 19 | def batch_pairwise_dist(self, x, y): 20 | bs, num_points_x, points_dim = x.size() 21 | _, num_points_y, _ = y.size() 22 | xx = torch.bmm(x, x.transpose(2, 1)) 23 | yy = torch.bmm(y, y.transpose(2, 1)) 24 | zz = torch.bmm(x, y.transpose(2, 1)) 25 | if self.use_cuda: 26 | dtype = torch.cuda.LongTensor 27 | else: 28 | dtype = torch.LongTensor 29 | diag_ind_x = torch.arange(0, num_points_x).type(dtype) 30 | diag_ind_y = torch.arange(0, num_points_y).type(dtype) 31 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as( 32 | zz.transpose(2, 1)) 33 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 34 | P = rx.transpose(2, 1) + ry - 2 * zz 35 | return P 36 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, config, is_vae=False): 7 | super().__init__() 8 | 9 | self.output_size = config['output_size'] 10 | self.use_bias = config['use_bias'] 11 | self.relu_slope = config['relu_slope'] 12 | self.is_vae = is_vae 13 | 14 | self.conv = nn.Sequential( 15 | nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, bias=self.use_bias), 16 | nn.ReLU(inplace=True), 17 | 18 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, bias=self.use_bias), 19 | nn.ReLU(inplace=True), 20 | 21 | nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1, bias=self.use_bias), 22 | nn.ReLU(inplace=True), 23 | 24 | nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1, bias=self.use_bias), 25 | nn.ReLU(inplace=True), 26 | 27 | nn.Conv1d(in_channels=512, out_channels=512, kernel_size=1, bias=self.use_bias), 28 | ) 29 | 30 | self.fc = nn.Sequential( 31 | nn.Linear(512, 512, bias=True), 32 | nn.ReLU(inplace=True) 33 | ) 34 | 35 | self.mu_layer = nn.Linear(512, self.output_size, bias=True) 36 | self.std_layer = nn.Linear(512, self.output_size, bias=True) 37 | 38 | def reparameterize(self, mu, logvar): 39 | std = torch.exp(logvar) 40 | eps = torch.randn_like(std) 41 | return eps.mul(std).add_(mu) 42 | 43 | def forward(self, x): 44 | output = self.conv(x) 45 | output2 = output.max(dim=2)[0] 46 | logit = self.fc(output2) 47 | mu = self.mu_layer(logit) 48 | if self.is_vae: 49 | logvar = self.std_layer(logit) 50 | z = self.reparameterize(mu, logvar) 51 | return z, mu, torch.exp(logvar) 52 | else: 53 | return mu 54 | -------------------------------------------------------------------------------- /model/full_model.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Iterator 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import Parameter 7 | 8 | from model.encoder import Encoder 9 | from model.hyper_network import HyperNetwork 10 | from model.target_network import TargetNetwork 11 | from utils.points import generate_points 12 | 13 | 14 | class FullModel(nn.Module): 15 | 16 | @staticmethod 17 | def _complete_config(config): 18 | config['hyper_network']['target_network_layer_out_channels'] = config['target_network']['layer_out_channels'] 19 | config['hyper_network']['target_network_use_bias'] = config['target_network']['use_bias'] 20 | config['hyper_network']['input_size'] = config['random_encoder']['output_size'] + \ 21 | config['real_encoder']['output_size'] 22 | 23 | config['hyper_network']['target_network_freeze_layers_learning'] = config['target_network'][ 24 | 'freeze_layers_learning'] 25 | 26 | def get_noise_size(self): 27 | return self.random_encoder_output_size 28 | 29 | def _resolve_mode(self, config): 30 | self.random_encoder_output_size = config['random_encoder']['output_size'] 31 | if config['random_encoder']['output_size'] > 0 and config['real_encoder']['output_size'] > 0: 32 | self.mode = HyperPocket() 33 | self.random_encoder = Encoder(config['random_encoder'], is_vae=True) 34 | self.real_encoder = Encoder(config['real_encoder'], is_vae=False) 35 | elif config['random_encoder']['output_size'] > 0: 36 | self.mode = HyperCloud() 37 | self.random_encoder = Encoder(config['random_encoder'], is_vae=True) 38 | elif config['real_encoder']['output_size'] > 0: 39 | self.mode = HyperRec() 40 | self.real_encoder = Encoder(config['real_encoder'], is_vae=False) 41 | else: 42 | raise ValueError("at least one encoder should have non zero output") 43 | 44 | def __init__(self, config): 45 | super().__init__() 46 | self._complete_config(config) 47 | self._resolve_mode(config) 48 | 49 | self.hyper_network = HyperNetwork(config['hyper_network']) 50 | self.target_network_config = config['target_network'] 51 | 52 | self.point_generator_config = {'target_network_input': config['target_network_input']} 53 | 54 | def forward(self, existing, missing, gt_shape, epoch, device, noise=None): 55 | 56 | if existing.size(-1) == 3: 57 | existing.transpose_(existing.dim() - 2, existing.dim() - 1) 58 | 59 | if noise is None and missing is not None and missing.size(-1) == 3: 60 | missing.transpose_(missing.dim() - 2, missing.dim() - 1) 61 | 62 | if gt_shape[-1] == 3: 63 | gt_shape[1], gt_shape[2] = gt_shape[2], gt_shape[1] 64 | 65 | latent, mu, logvar = self.mode.get_latent(self, existing, missing, noise) 66 | 67 | target_networks_weights = self.hyper_network(latent) 68 | reconstruction = torch.zeros(gt_shape).to(device) 69 | 70 | for j, target_network_weights in enumerate(target_networks_weights): 71 | target_network = TargetNetwork(self.target_network_config, target_network_weights).to(device) 72 | target_network_input = generate_points(config=self.point_generator_config, epoch=epoch, 73 | size=(gt_shape[2], gt_shape[1])) 74 | reconstruction[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1) 75 | 76 | # reconstruction shape [BATCH_SIZE, 3, N] 77 | if self.training: 78 | return reconstruction, logvar, mu 79 | else: 80 | return reconstruction # , latent, target_networks_weights 81 | 82 | def parameters(self, recurse: bool = True) -> Iterator[Parameter]: 83 | return self.mode.get_parameters(self) 84 | 85 | 86 | class ModelMode(object): 87 | 88 | def get_latent(self, model: FullModel, existing, missing, noise=None): 89 | raise NotImplementedError 90 | 91 | def get_parameters(self, model: FullModel) -> Iterator[Parameter]: 92 | raise NotImplementedError 93 | 94 | def has_generativity(self) -> bool: 95 | raise NotImplementedError 96 | 97 | 98 | class HyperPocket(ModelMode): 99 | 100 | def get_latent(self, model: FullModel, existing, missing, noise=None): 101 | if model.training: 102 | codes, mu, logvar = model.random_encoder(missing) 103 | real_mu = model.real_encoder(existing) 104 | latent = torch.cat([codes, real_mu], 1) 105 | return latent, mu, logvar 106 | else: 107 | if noise is None: 108 | _, random_mu, _ = model.random_encoder(missing) 109 | else: 110 | random_mu = noise 111 | real_mu = model.real_encoder(existing) 112 | latent = torch.cat([random_mu, real_mu], 1) 113 | return latent, None, None 114 | 115 | def get_parameters(self, model: FullModel): 116 | return chain(model.random_encoder.parameters(), 117 | model.real_encoder.parameters(), 118 | model.hyper_network.parameters()) 119 | 120 | def has_generativity(self) -> bool: 121 | return True 122 | 123 | 124 | class HyperRec(ModelMode): 125 | 126 | def get_latent(self, model: FullModel, existing, missing, noise=None): 127 | return model.real_encoder(existing), None, None 128 | 129 | def get_parameters(self, model: FullModel): 130 | return chain(model.real_encoder.parameters(), model.hyper_network.parameters()) 131 | 132 | def has_generativity(self) -> bool: 133 | return False 134 | 135 | 136 | class HyperCloud(ModelMode): 137 | 138 | def get_latent(self, model: FullModel, existing, missing, noise=None): 139 | if model.training: 140 | return model.random_encoder(existing) 141 | else: 142 | if noise is None: 143 | _, random_mu, _ = model.random_encoder(existing) 144 | else: 145 | random_mu = noise 146 | return random_mu, None, None 147 | 148 | def get_parameters(self, model: FullModel): 149 | return chain(model.random_encoder.parameters(), model.hyper_network.parameters()) 150 | 151 | def has_generativity(self) -> bool: 152 | return False 153 | -------------------------------------------------------------------------------- /model/hyper_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class HyperNetwork(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | 9 | self.input_size = config['input_size'] 10 | self.use_bias = config['use_bias'] 11 | self.relu_slope = config['relu_slope'] 12 | # target network layers out channels 13 | target_network_out_ch = [3] + config['target_network_layer_out_channels'] + [3] 14 | target_network_use_bias = int(config['target_network_use_bias']) 15 | 16 | self.model = nn.Sequential( 17 | nn.Linear(in_features=self.input_size, out_features=64, bias=self.use_bias), 18 | nn.ReLU(inplace=True), 19 | 20 | nn.Linear(in_features=64, out_features=128, bias=self.use_bias), 21 | nn.ReLU(inplace=True), 22 | 23 | nn.Linear(in_features=128, out_features=512, bias=self.use_bias), 24 | nn.ReLU(inplace=True), 25 | 26 | nn.Linear(in_features=512, out_features=1024, bias=self.use_bias), 27 | nn.ReLU(inplace=True), 28 | 29 | nn.Linear(in_features=1024, out_features=2048, bias=self.use_bias), 30 | ) 31 | 32 | self.output = [ 33 | nn.Linear(2048, (target_network_out_ch[x - 1] + target_network_use_bias) * target_network_out_ch[x], 34 | bias=True) 35 | for x in range(1, len(target_network_out_ch)) 36 | ] 37 | 38 | if not config['target_network_freeze_layers_learning']: 39 | self.output = nn.ModuleList(self.output) 40 | 41 | def forward(self, x): 42 | output = self.model(x) 43 | return torch.cat([target_network_layer(output) for target_network_layer in self.output], 1) 44 | -------------------------------------------------------------------------------- /model/target_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TargetNetwork(nn.Module): 6 | def __init__(self, config, weights): 7 | super().__init__() 8 | 9 | # self.z_size = config['missing_size'] 10 | self.use_bias = config['use_bias'] 11 | # target network layers out channels 12 | out_ch = config['layer_out_channels'] 13 | 14 | layer_data, split_index = self._get_layer_data(start_index=0, end_index=out_ch[0] * 3, 15 | shape=(out_ch[0], 3), weights=weights) 16 | self.layers = {"1": layer_data} 17 | 18 | for x in range(1, len(out_ch)): 19 | layer_data, split_index = self._get_layer_data(start_index=split_index, 20 | end_index=split_index + (out_ch[x - 1] * out_ch[x]), 21 | shape=(out_ch[x], out_ch[x - 1]), weights=weights) 22 | self.layers[str(x + 1)] = layer_data 23 | 24 | layer_data, split_index = self._get_layer_data(start_index=split_index, 25 | end_index=split_index + (out_ch[-1] * 3), 26 | shape=(3, out_ch[-1]), weights=weights) 27 | self.output = layer_data 28 | self.activation = torch.nn.ReLU() 29 | assert split_index == len(weights) 30 | 31 | def forward(self, x): 32 | for layer_index in self.layers: 33 | x = torch.mm(x, torch.transpose(self.layers[layer_index]["weight"], 0, 1)) 34 | if self.use_bias: 35 | assert "bias" in self.layers[layer_index] 36 | x = x + self.layers[layer_index]["bias"] 37 | x = self.activation(x) 38 | return torch.mm(x, torch.transpose(self.output["weight"], 0, 1)) + self.output.get("bias", 0) 39 | 40 | def _get_layer_data(self, start_index, end_index, shape, weights): 41 | layer_data = {"weight": weights[start_index:end_index].view(shape[0], shape[1])} 42 | if self.use_bias: 43 | layer_data["bias"] = weights[end_index:end_index + shape[0]] 44 | end_index = end_index + shape[0] 45 | return layer_data, end_index 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | h5py==3.1.0 3 | sklearn 4 | matplotlib 5 | pandas 6 | numpy 7 | trimesh 8 | requests 9 | ray -------------------------------------------------------------------------------- /settings/config.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "shapenet", 5 | "path": "data/dataset/shapenet", 6 | "classes": [], 7 | "is_rotated": false, 8 | "num_samples": 1, 9 | "gen_test_set": false 10 | }, 11 | "training": { 12 | "optimizer": { 13 | "type": "Adam", 14 | "hyperparams": { 15 | "lr": 0.0001, 16 | "weight_decay": 0, 17 | "betas": [ 18 | 0.9, 19 | 0.999 20 | ], 21 | "amsgrad": false 22 | } 23 | }, 24 | "lr_scheduler": { 25 | "type": "StepLR", 26 | "hyperparams": { 27 | "step_size": 3000, 28 | "gamma": 0.01 29 | } 30 | }, 31 | "dataloader": { 32 | "train": { 33 | "batch_size": 5, 34 | "shuffle": true, 35 | "num_workers" : 8, 36 | "drop_last" : true 37 | }, 38 | "val": { 39 | "batch_size": 5, 40 | "shuffle": true, 41 | "num_workers" : 8, 42 | "drop_last" : false 43 | } 44 | }, 45 | "state_save_frequency" : 1, 46 | "loss_coef": 0.05, 47 | "max_epoch": 2000 48 | }, 49 | "experiments": { 50 | "epoch": "best_val", 51 | "settings": { 52 | "fixed": { 53 | "execute": false, 54 | "mean": 0.0, 55 | "std": 0.05, 56 | "amount": 64, 57 | "triangulation_config": { 58 | "execute": true, 59 | "method": "edge", 60 | "depth": 2 61 | } 62 | }, 63 | "evaluate_generativity": { 64 | "execute": true, 65 | "batch_size": 25, 66 | "num_workers" : 8 67 | } 68 | } 69 | }, 70 | 71 | "full_model": { 72 | "random_encoder": { 73 | "output_size": 1024, 74 | "use_bias": true, 75 | "relu_slope": 0.2 76 | }, 77 | "real_encoder": { 78 | "output_size": 1024, 79 | "use_bias": true, 80 | "relu_slope": 0.2 81 | }, 82 | "hyper_network": { 83 | "use_bias": true, 84 | "relu_slope": 0.2 85 | }, 86 | "target_network": { 87 | "use_bias": true, 88 | "relu_slope": 0.2, 89 | "freeze_layers_learning": false, 90 | "layer_out_channels": [ 91 | 32, 92 | 64, 93 | 128, 94 | 64 95 | ] 96 | }, 97 | "target_network_input": { 98 | "constant": false, 99 | "normalization": { 100 | "enable": true, 101 | "type": "progressive", 102 | "epoch": 100 103 | } 104 | } 105 | }, 106 | "setup": { 107 | "seed": 2020, 108 | "gpu_id": 0 109 | }, 110 | "telegram_logger": { 111 | "enable": false, 112 | "bot_token": "", 113 | "chat_id": "" 114 | }, 115 | "results_root": "data/results" 116 | } -------------------------------------------------------------------------------- /settings/config_3depn_airplane.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "3depn", 5 | "path": "data/dataset/3depn", 6 | "classes": ["02691156"], 7 | "num_samples": 4 8 | }, 9 | "training": { 10 | "optimizer": { 11 | "type": "Adam", 12 | "hyperparams": { 13 | "lr": 0.0001, 14 | "weight_decay": 0, 15 | "betas": [ 16 | 0.9, 17 | 0.999 18 | ], 19 | "amsgrad": false 20 | } 21 | }, 22 | "lr_scheduler": { 23 | "type": "StepLR", 24 | "hyperparams": { 25 | "step_size": 30001, 26 | "gamma": 0.01 27 | } 28 | }, 29 | "dataloader": { 30 | "train": { 31 | "batch_size": 64, 32 | "shuffle": true, 33 | "num_workers" : 8, 34 | "drop_last" : true 35 | }, 36 | "val": { 37 | "batch_size": 60, 38 | "shuffle": true, 39 | "num_workers" : 8, 40 | "drop_last" : false 41 | } 42 | }, 43 | "state_save_frequency" : 100, 44 | "loss_coef": 0.05, 45 | "min_save_epoch": 10, 46 | "max_epoch": 350 47 | }, 48 | "experiments": { 49 | "epoch": "best_val", 50 | "settings": { 51 | "fixed": { 52 | "execute": true, 53 | "mean": 0.0, 54 | "std": 0.1, 55 | "amount": 64, 56 | "triangulation_config": { 57 | "execute": false, 58 | "method": "edge", 59 | "depth": 3 60 | } 61 | }, 62 | "evaluate_generativity": { 63 | "execute": false, 64 | "batch_size": 150, 65 | "num_workers" : 8 66 | }, 67 | "compute_mmd_tmd_uhd": { 68 | "execute": true 69 | } 70 | } 71 | }, 72 | "full_model": { 73 | "random_encoder": { 74 | "output_size": 128, 75 | "use_bias": true, 76 | "relu_slope": 0.2 77 | }, 78 | "real_encoder": { 79 | "output_size": 128, 80 | "use_bias": true, 81 | "relu_slope": 0.2 82 | }, 83 | "hyper_network": { 84 | "use_bias": true, 85 | "relu_slope": 0.2 86 | }, 87 | "target_network": { 88 | "use_bias": true, 89 | "relu_slope": 0.2, 90 | "freeze_layers_learning": false, 91 | "layer_out_channels": [ 92 | 32, 64, 128, 64 93 | ] 94 | }, 95 | "target_network_input": { 96 | "constant": false, 97 | "normalization": { 98 | "enable": true, 99 | "type": "progressive", 100 | "epoch": 100 101 | } 102 | } 103 | }, 104 | "setup": { 105 | "seed": 1856, 106 | "gpu_id": 0 107 | }, 108 | "telegram_logger": { 109 | "enable": false, 110 | "bot_token": "", 111 | "chat_id": "" 112 | }, 113 | "results_root": "data/results" 114 | } 115 | -------------------------------------------------------------------------------- /settings/config_3depn_chair.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "3depn", 5 | "path": "data/dataset/3depn", 6 | "classes": ["03001627"], 7 | "num_samples": 4 8 | }, 9 | "training": { 10 | "optimizer": { 11 | "type": "Adam", 12 | "hyperparams": { 13 | "lr": 0.0001, 14 | "weight_decay": 0, 15 | "betas": [ 16 | 0.9, 17 | 0.999 18 | ], 19 | "amsgrad": false 20 | } 21 | }, 22 | "lr_scheduler": { 23 | "type": "StepLR", 24 | "hyperparams": { 25 | "step_size": 30001, 26 | "gamma": 0.01 27 | } 28 | }, 29 | "dataloader": { 30 | "train": { 31 | "batch_size": 64, 32 | "shuffle": true, 33 | "num_workers" : 8, 34 | "drop_last" : true 35 | }, 36 | "val": { 37 | "batch_size": 60, 38 | "shuffle": true, 39 | "num_workers" : 8, 40 | "drop_last" : false 41 | } 42 | }, 43 | "state_save_frequency" : 10, 44 | "loss_coef": 0.05, 45 | "min_save_epoch": 10, 46 | "max_epoch": 400 47 | }, 48 | "experiments": { 49 | "epoch": "best_val", 50 | "settings": { 51 | "fixed": { 52 | "execute": true, 53 | "mean": 0.0, 54 | "std": 0.13, 55 | "amount": 64, 56 | "triangulation_config": { 57 | "execute": false, 58 | "method": "edge", 59 | "depth": 2 60 | } 61 | }, 62 | "evaluate_generativity": { 63 | "execute": false, 64 | "batch_size": 150, 65 | "num_workers" : 8 66 | }, 67 | "compute_mmd_tmd_uhd": { 68 | "execute": true 69 | } 70 | } 71 | }, 72 | "full_model": { 73 | "random_encoder": { 74 | "output_size": 128, 75 | "use_bias": true, 76 | "relu_slope": 0.2 77 | }, 78 | "real_encoder": { 79 | "output_size": 128, 80 | "use_bias": true, 81 | "relu_slope": 0.2 82 | }, 83 | "hyper_network": { 84 | "use_bias": true, 85 | "relu_slope": 0.2 86 | }, 87 | "target_network": { 88 | "use_bias": true, 89 | "relu_slope": 0.2, 90 | "freeze_layers_learning": false, 91 | "layer_out_channels": [ 92 | 32, 64, 128, 64 93 | ] 94 | }, 95 | "target_network_input": { 96 | "constant": false, 97 | "normalization": { 98 | "enable": true, 99 | "type": "progressive", 100 | "epoch": 100 101 | } 102 | } 103 | }, 104 | "setup": { 105 | "seed": 1856, 106 | "gpu_id": 0 107 | }, 108 | "telegram_logger": { 109 | "enable": false, 110 | "bot_token": "", 111 | "chat_id": "" 112 | }, 113 | "results_root": "data/results" 114 | } 115 | -------------------------------------------------------------------------------- /settings/config_3depn_table.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "3depn", 5 | "path": "data/dataset/3depn", 6 | "classes": ["04379243"], 7 | "num_samples": 4 8 | }, 9 | "training": { 10 | "optimizer": { 11 | "type": "Adam", 12 | "hyperparams": { 13 | "lr": 0.0001, 14 | "weight_decay": 0, 15 | "betas": [ 16 | 0.9, 17 | 0.999 18 | ], 19 | "amsgrad": false 20 | } 21 | }, 22 | "lr_scheduler": { 23 | "type": "StepLR", 24 | "hyperparams": { 25 | "step_size": 30001, 26 | "gamma": 0.01 27 | } 28 | }, 29 | "dataloader": { 30 | "train": { 31 | "batch_size": 64, 32 | "shuffle": true, 33 | "num_workers" : 8, 34 | "drop_last" : true 35 | }, 36 | "val": { 37 | "batch_size": 60, 38 | "shuffle": true, 39 | "num_workers" : 8, 40 | "drop_last" : false 41 | } 42 | }, 43 | "state_save_frequency" : 100, 44 | "loss_coef": 0.05, 45 | "min_save_epoch": 10, 46 | "max_epoch": 140 47 | }, 48 | "experiments": { 49 | "epoch": "best_val", 50 | "settings": { 51 | "fixed": { 52 | "execute": true, 53 | "mean": 0.0, 54 | "std": 0.065, 55 | "amount": 64, 56 | "triangulation_config": { 57 | "execute": false, 58 | "method": "edge", 59 | "depth": 2 60 | } 61 | }, 62 | "evaluate_generativity": { 63 | "execute": false, 64 | "batch_size": 150, 65 | "num_workers" : 8 66 | }, 67 | "compute_mmd_tmd_uhd": { 68 | "execute": true 69 | } 70 | } 71 | }, 72 | "full_model": { 73 | "random_encoder": { 74 | "output_size": 128, 75 | "use_bias": true, 76 | "relu_slope": 0.2 77 | }, 78 | "real_encoder": { 79 | "output_size": 128, 80 | "use_bias": true, 81 | "relu_slope": 0.2 82 | }, 83 | "hyper_network": { 84 | "use_bias": true, 85 | "relu_slope": 0.2 86 | }, 87 | "target_network": { 88 | "use_bias": true, 89 | "relu_slope": 0.2, 90 | "freeze_layers_learning": false, 91 | "layer_out_channels": [ 92 | 32, 64, 128, 64 93 | ] 94 | }, 95 | "target_network_input": { 96 | "constant": false, 97 | "normalization": { 98 | "enable": true, 99 | "type": "progressive", 100 | "epoch": 100 101 | } 102 | } 103 | }, 104 | "setup": { 105 | "seed": 1856, 106 | "gpu_id": 0 107 | }, 108 | "telegram_logger": { 109 | "enable": false, 110 | "bot_token": "", 111 | "chat_id": "" 112 | }, 113 | "results_root": "data/results" 114 | } 115 | -------------------------------------------------------------------------------- /settings/config_completion.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "completion", 5 | "path": "data/dataset/completion", 6 | "classes": [] 7 | }, 8 | "training": { 9 | "optimizer": { 10 | "type": "Adam", 11 | "hyperparams": { 12 | "lr": 0.0001, 13 | "weight_decay": 0, 14 | "betas": [ 15 | 0.9, 16 | 0.999 17 | ], 18 | "amsgrad": false 19 | } 20 | }, 21 | "lr_scheduler": { 22 | "type": "StepLR", 23 | "hyperparams": { 24 | "step_size": 41, 25 | "gamma": 0.01 26 | } 27 | }, 28 | "dataloader": { 29 | "train": { 30 | "batch_size": 66, 31 | "shuffle": true, 32 | "num_workers" : 8, 33 | "drop_last" : true 34 | }, 35 | "val": { 36 | "batch_size": 10, 37 | "shuffle": true, 38 | "num_workers" : 8, 39 | "drop_last" : false 40 | } 41 | }, 42 | "state_save_frequency" : 10, 43 | "loss_coef": 0.05, 44 | "min_save_epoch": 20, 45 | "max_epoch": 100 46 | }, 47 | "experiments": { 48 | "epoch": "best_val", 49 | "settings": { 50 | "fixed": { 51 | "execute": false, 52 | "mean": 0.0, 53 | "std": 0.05, 54 | "amount": 64, 55 | "triangulation_config": { 56 | "execute": false, 57 | "method": "edge", 58 | "depth": 2 59 | } 60 | }, 61 | "evaluate_generativity": { 62 | "execute": false, 63 | "batch_size": 150, 64 | "num_workers" : 8 65 | }, 66 | "compute_mmd_tmd_uhd": { 67 | "execute": false 68 | }, 69 | "completion3d_submission":{ 70 | "execute": true, 71 | "batch_size": 16 72 | } 73 | } 74 | }, 75 | "full_model": { 76 | "random_encoder": { 77 | "output_size": 0, 78 | "use_bias": true, 79 | "relu_slope": 0.2 80 | }, 81 | "real_encoder": { 82 | "output_size": 128, 83 | "use_bias": true, 84 | "relu_slope": 0.2 85 | }, 86 | "hyper_network": { 87 | "use_bias": true, 88 | "relu_slope": 0.2 89 | }, 90 | "target_network": { 91 | "use_bias": true, 92 | "relu_slope": 0.2, 93 | "freeze_layers_learning": false, 94 | "layer_out_channels": [ 95 | 32, 64, 128, 64 96 | ] 97 | }, 98 | "target_network_input": { 99 | "constant": false, 100 | "normalization": { 101 | "enable": true, 102 | "type": "progressive", 103 | "epoch": 100 104 | } 105 | } 106 | }, 107 | "setup": { 108 | "seed": 2020, 109 | "gpu_id": 0 110 | }, 111 | "telegram_logger": { 112 | "enable": false, 113 | "bot_token": "", 114 | "chat_id": "" 115 | }, 116 | "results_root": "data/results" 117 | } 118 | -------------------------------------------------------------------------------- /settings/config_missing_shapenet.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "training", 3 | "dataset": { 4 | "name": "shapenet", 5 | "path": "data/dataset/shapenet", 6 | "classes": [], 7 | "is_rotated": false, 8 | "num_samples": 4, 9 | "gen_test_set": true 10 | }, 11 | "training": { 12 | "optimizer": { 13 | "type": "Adam", 14 | "hyperparams": { 15 | "lr": 0.0001, 16 | "weight_decay": 0, 17 | "betas": [ 18 | 0.9, 19 | 0.999 20 | ], 21 | "amsgrad": false 22 | } 23 | }, 24 | "lr_scheduler": { 25 | "type": "StepLR", 26 | "hyperparams": { 27 | "step_size": 30001, 28 | "gamma": 0.01 29 | } 30 | }, 31 | "dataloader": { 32 | "train": { 33 | "batch_size": 5, 34 | "shuffle": true, 35 | "num_workers" : 8, 36 | "drop_last" : true 37 | }, 38 | "val": { 39 | "batch_size": 5, 40 | "shuffle": true, 41 | "num_workers" : 8, 42 | "drop_last" : false 43 | } 44 | }, 45 | "state_save_frequency" : 1, 46 | "loss_coef": 0.05, 47 | "max_epoch": 230 48 | }, 49 | "experiments": { 50 | "epoch": "best_val", 51 | "settings": { 52 | "fixed": { 53 | "execute": false, 54 | "mean": 0.0, 55 | "std": 0.05, 56 | "amount": 64, 57 | "triangulation_config": { 58 | "execute": true, 59 | "method": "edge", 60 | "depth": 2 61 | } 62 | }, 63 | "evaluate_generativity": { 64 | "execute": true, 65 | "batch_size": 25, 66 | "num_workers" : 8 67 | } 68 | } 69 | }, 70 | 71 | "full_model": { 72 | "random_encoder": { 73 | "output_size": 128, 74 | "use_bias": true, 75 | "relu_slope": 0.2 76 | }, 77 | "real_encoder": { 78 | "output_size": 128, 79 | "use_bias": true, 80 | "relu_slope": 0.2 81 | }, 82 | "hyper_network": { 83 | "use_bias": true, 84 | "relu_slope": 0.2 85 | }, 86 | "target_network": { 87 | "use_bias": true, 88 | "relu_slope": 0.2, 89 | "freeze_layers_learning": false, 90 | "layer_out_channels": [ 91 | 32, 92 | 64, 93 | 128, 94 | 64 95 | ] 96 | }, 97 | "target_network_input": { 98 | "constant": false, 99 | "normalization": { 100 | "enable": true, 101 | "type": "progressive", 102 | "epoch": 100 103 | } 104 | } 105 | }, 106 | "setup": { 107 | "seed": 2020, 108 | "gpu_id": 0 109 | }, 110 | "telegram_logger": { 111 | "enable": false, 112 | "bot_token": "", 113 | "chat_id": "" 114 | }, 115 | "results_root": "data/results" 116 | } -------------------------------------------------------------------------------- /util_scripts/download_shapenet_2048.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import urllib 3 | from os import makedirs, remove, listdir 4 | from os.path import exists, join 5 | from zipfile import ZipFile 6 | 7 | from core.arg_parser import parse_config 8 | 9 | 10 | def main(config): 11 | dataset_config: dict = config['dataset'] 12 | dataset_path: str = dataset_config['path'] 13 | 14 | if exists(dataset_path): 15 | raise Exception(f'directory {dataset_path} already exists') 16 | 17 | makedirs(dataset_path) 18 | 19 | url = 'https://www.dropbox.com/s/vmsdrae6x5xws1v/shape_net_core_uniform_samples_2048.zip?dl=1' 20 | 21 | data = urllib.request.urlopen(url) 22 | filename = url.rpartition('/')[2][:-5] 23 | file_path = join(dataset_path, filename) 24 | with open(file_path, mode='wb') as f: 25 | d = data.read() 26 | f.write(d) 27 | 28 | print('Extracting...') 29 | with ZipFile(file_path, mode='r') as zip_f: 30 | zip_f.extractall(dataset_path) 31 | 32 | remove(file_path) 33 | 34 | extracted_dir = join(dataset_path, 35 | 'shape_net_core_uniform_samples_2048') 36 | for d in listdir(extracted_dir): 37 | shutil.move(src=join(extracted_dir, d), 38 | dst=dataset_path) 39 | 40 | shutil.rmtree(extracted_dir) 41 | 42 | 43 | if __name__ == '__main__': 44 | main(parse_config()) 45 | -------------------------------------------------------------------------------- /util_scripts/generate_eval_gen_test_set.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join, exists 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from core.arg_parser import parse_config 7 | from datasets.utils.dataset_generator import HyperPlane 8 | from utils.plyfile import load_ply, quick_save_ply_file 9 | 10 | 11 | def div_left_right_bin_search(dataset_dir, init_plane_points, pc_paths): 12 | for i, pc_path in tqdm(enumerate(pc_paths), total=len(pc_paths)): 13 | 14 | pc = load_ply(join(dataset_dir, pc_path)) 15 | 16 | points = init_plane_points.copy() 17 | 18 | l, r = pc.T[1].min(), pc.T[1].max() 19 | 20 | counter = 0 21 | 22 | while True: 23 | 24 | m = np.divide(l + r, 2) 25 | 26 | points[0][1] = m 27 | points[1][1] = m 28 | points[2][1] = m 29 | 30 | right = HyperPlane.get_plane_from_3_points(points).check_point(pc) > 0 31 | right_points = pc[right] 32 | left_points = pc[~right] 33 | 34 | counter += 1 35 | if counter == 100000000: 36 | quick_save_ply_file(right_points, join(dataset_dir, 'test_gen', 'right', pc_path)) 37 | quick_save_ply_file(left_points, join(dataset_dir, 'test_gen', 'left', pc_path)) 38 | quick_save_ply_file(pc, join(dataset_dir, 'test_gen', 'gt', pc_path)) 39 | break 40 | 41 | if len(left_points) > len(right_points): 42 | l = m 43 | elif len(left_points) < len(right_points): 44 | r = m 45 | else: 46 | quick_save_ply_file(left_points, join(dataset_dir, 'test_gen', 'left', pc_path)) 47 | quick_save_ply_file(right_points, join(dataset_dir, 'test_gen', 'right', pc_path)) 48 | quick_save_ply_file(pc, join(dataset_dir, 'test_gen', 'gt', pc_path)) 49 | break 50 | 51 | 52 | def div_left_right_min_y(dataset_dir, pc_paths): 53 | for i, pc_path in tqdm(enumerate(pc_paths), total=len(pc_paths)): 54 | pc = load_ply(join(dataset_dir, pc_path)) 55 | 56 | right_points = pc[pc.T[1].argsort()[1024:]] 57 | left_points = pc[pc.T[1].argsort()[:1024]] 58 | 59 | quick_save_ply_file(left_points, join(dataset_dir, 'test_gen', 'left', pc_path)) 60 | quick_save_ply_file(right_points, join(dataset_dir, 'test_gen', 'right', pc_path)) 61 | quick_save_ply_file(pc, join(dataset_dir, 'test_gen', 'gt', pc_path)) 62 | 63 | 64 | def main(config): 65 | dataset_dir = config['dataset']['path'] 66 | 67 | with open(join(dataset_dir, 'test.list')) as file: 68 | pc_paths = [line.strip() + '.ply' for line in file] 69 | 70 | plane_points = np.zeros((3, 3)) 71 | plane_points[1][2] = 1 72 | plane_points[2][0] = 1 73 | 74 | for cat in ['02691156', '02933112', '02958343', '03001627', '03636649', '04256520', '04379243', '04530566']: 75 | makedirs(join(dataset_dir, 'test_gen', 'left', cat), exist_ok=True) 76 | makedirs(join(dataset_dir, 'test_gen', 'right', cat), exist_ok=True) 77 | makedirs(join(dataset_dir, 'test_gen', 'gt', cat), exist_ok=True) 78 | 79 | div_left_right_min_y(dataset_dir, pc_paths) 80 | 81 | not_existed_pc = [] 82 | 83 | for pc_path in pc_paths: 84 | if not (exists(join(dataset_dir, 'test_gen', 'left', pc_path)) 85 | and exists(join(dataset_dir, 'test_gen', 'left', pc_path))): 86 | not_existed_pc.append(pc_path) 87 | 88 | # div_left_right_bin_search(dataset_dir, plane_points, not_existed_pc) 89 | 90 | not_1024 = [] 91 | for pc_path in pc_paths: 92 | if load_ply(join(dataset_dir, 'test_gen', 'left', pc_path)).shape[0] != 1024: 93 | not_1024.append(pc_path) 94 | 95 | 96 | if __name__ == '__main__': 97 | main(parse_config()) 98 | -------------------------------------------------------------------------------- /util_scripts/generate_partial_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join, exists 4 | 5 | import ray 6 | import trimesh 7 | import numpy as np 8 | 9 | from core.arg_parser import parse_config 10 | from datasets.shapenet_3depn import sample_point_cloud_by_n 11 | from datasets.utils.dataset_generator import SlicedDatasetGenerator 12 | from datasets.utils.shapenet_category_mapping import synth_id_to_category 13 | from utils.plyfile import quick_save_ply_file, load_ply 14 | from utils.util import get_filenames_by_cat 15 | 16 | 17 | @ray.remote 18 | def generate_one_shapenet(category: str, filename: str, dataset_path: str, num_samples: int = 4): 19 | pc_filepath = join(dataset_path, category, filename) 20 | points = load_ply(pc_filepath) 21 | 22 | for i in range(num_samples): 23 | existing, missing = SlicedDatasetGenerator.generate_item(points) 24 | quick_save_ply_file(existing, join(dataset_path, 'slices', 'existing', category, str(i) + '~' + filename)) 25 | quick_save_ply_file(missing, join(dataset_path, 'slices', 'missing', category, str(i) + '~' + filename)) 26 | 27 | 28 | @ray.remote 29 | def generate_one_3depn(cat: str, name: str, dataset_path: str, pc_root: str, num_samples: int = 4): 30 | ply_path = join(pc_root, name + '.ply') 31 | 32 | pc = np.array(trimesh.load(ply_path).vertices) 33 | pc = sample_point_cloud_by_n(pc, 2048) 34 | 35 | quick_save_ply_file(pc, join(dataset_path, 'slices', 'gt', cat, name + '.ply')) 36 | 37 | for i in range(num_samples): 38 | existing, missing = SlicedDatasetGenerator.generate_item(pc) 39 | quick_save_ply_file(existing, join(dataset_path, 'slices', 'existing', cat, str(i) + '~' + name + '.ply')) 40 | quick_save_ply_file(missing, join(dataset_path, 'slices', 'missing', cat, str(i) + '~' + name + '.ply')) 41 | 42 | 43 | def main(config: dict): 44 | dataset_config: dict = config['dataset'] 45 | 46 | dataset_path: str = dataset_config['path'] 47 | dataset_name: str = dataset_config['name'] 48 | num_samples: int = dataset_config['num_samples'] 49 | 50 | if dataset_name == 'shapenet': 51 | if not exists(join(dataset_path)): 52 | raise Exception(f'no ShapeNet dataset found at {dataset_path}, ' 53 | f'please run `util_scripts/download_shapenet_2048.py` first') 54 | 55 | for category in synth_id_to_category.keys(): 56 | os.makedirs(join(dataset_path, 'slices', 'existing', category), exist_ok=True) 57 | os.makedirs(join(dataset_path, 'slices', 'missing', category), exist_ok=True) 58 | 59 | ray.init(num_cpus=os.cpu_count()) 60 | ray.get([generate_one_shapenet.remote(row['category'], row['filename'], dataset_path, num_samples) for _, row in 61 | get_filenames_by_cat(dataset_path).iterrows()]) 62 | ray.shutdown() 63 | 64 | elif dataset_name == '3depn': 65 | classes: list = ['02691156', '03001627', '04379243'] 66 | 67 | cat_pc_root: dict = {cat: join(dataset_path, 'ShapeNetPointCloud', cat) for cat in classes} 68 | cat_pc_raw_root: dict = {cat: join(dataset_path, 'shapenet_dim32_sdf_pc', cat) for cat in classes} 69 | cat_shape_names: dict = {cat: [] for cat in classes} 70 | 71 | with open(join(dataset_path, 'shapenet-official-split.csv'), 'r') as csv_file: 72 | csv_reader = csv.reader(csv_file, delimiter=',') 73 | line_cnt = 0 74 | for row in csv_reader: 75 | if line_cnt == 0 or (row[1] not in classes): 76 | pass 77 | else: 78 | if row[-1] in ['train', 'val']: 79 | cat_shape_names[row[1]].append(row[-2]) 80 | line_cnt += 1 81 | 82 | refined_shape_names: dict = {cat: [] for cat in classes} 83 | for cat, shapes in cat_shape_names.items(): 84 | for name in shapes: 85 | ply_path = join(cat_pc_root[cat], name + '.ply') 86 | path = join(cat_pc_raw_root[cat], f'{name}__0__.ply') 87 | if exists(ply_path) and exists(path): 88 | refined_shape_names[cat].append(name) 89 | 90 | for cat in classes: 91 | os.makedirs(join(dataset_path, 'slices', 'existing', cat), exist_ok=True) 92 | os.makedirs(join(dataset_path, 'slices', 'missing', cat), exist_ok=True) 93 | os.makedirs(join(dataset_path, 'slices', 'gt', cat), exist_ok=True) 94 | 95 | print('pc to process: ', np.sum([len(v) for v in refined_shape_names.values()])) 96 | print('pc to process: ', {k: len(v) for k, v in refined_shape_names.items()}) 97 | 98 | ray.init(num_cpus=os.cpu_count()) 99 | 100 | ray_tasks = [] 101 | for cat, shapes in refined_shape_names.items(): 102 | for name in shapes: 103 | ray_tasks.append(generate_one_3depn.remote(cat, name, dataset_path, cat_pc_root[cat], num_samples)) 104 | # single thread version 105 | # ply_path = join(cat_pc_root[cat], name + '.ply') 106 | # 107 | # pc = np.array(trimesh.load(ply_path).vertices) 108 | # pc = sample_point_cloud_by_n(pc, 2048) 109 | # 110 | # quick_save_ply_file(pc, join(dataset_path, 'slices', 'existing', cat, str(i) + '~' + name + '.ply')) 111 | # 112 | # for i in range(4): 113 | # existing, missing = SlicedDatasetGenerator.generate_item(pc) 114 | # quick_save_ply_file(existing, join(dataset_path, 'slices', 'existing', cat, 115 | # str(i) + '~' + name + '.ply')) 116 | # quick_save_ply_file(missing, join(dataset_path, 'slices', 'missing', cat, 117 | # str(i) + '~' + name + '.ply')) 118 | ray.get(ray_tasks) 119 | ray.shutdown() 120 | 121 | 122 | if __name__ == '__main__': 123 | main(parse_config()) 124 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/3d-point-clouds-autocomplete/13ec26e10da7b0ad2f71d4c97016fbb8499b0cff/utils/__init__.py -------------------------------------------------------------------------------- /utils/evaluation/chamfer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import cKDTree as KDTree 3 | 4 | # code from 5 | # https://github.com/ChrisWu1997/Multimodal-Shape-Completion/blob/master/evaluation/chamfer.py 6 | 7 | 8 | def compute_trimesh_chamfer(gt_points, gen_points, offset=0, scale=1): 9 | """ 10 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers. 11 | gt_points: numpy array. trimesh.points.PointCloud of just poins, sampled from the surface (see 12 | compute_metrics.ply for more documentation) 13 | gen_mesh: numpy array. trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction 14 | method (see compute_metrics.py for more) 15 | """ 16 | 17 | # gen_points_sampled = trimesh.sample.sample_surface(gen_mesh, num_mesh_samples)[0] 18 | 19 | gen_points = gen_points / scale - offset 20 | 21 | # one direction 22 | gen_points_kd_tree = KDTree(gen_points) 23 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) 24 | gt_to_gen_chamfer = np.mean(np.square(one_distances)) 25 | 26 | # other direction 27 | gt_points_kd_tree = KDTree(gt_points) 28 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points) 29 | gen_to_gt_chamfer = np.mean(np.square(two_distances)) 30 | 31 | return gt_to_gen_chamfer + gen_to_gt_chamfer 32 | 33 | 34 | def scale_to_unit_sphere(points): 35 | """ 36 | scale point clouds into a unit sphere 37 | :param points: (n, 3) numpy array 38 | :return: 39 | """ 40 | midpoints = (np.max(points, axis=0) + np.min(points, axis=0)) / 2 41 | points = points - midpoints 42 | scale = np.max(np.sqrt(np.sum(points ** 2, axis=1))) 43 | points = points / scale 44 | return points 45 | -------------------------------------------------------------------------------- /utils/evaluation/completeness.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import warnings 4 | 5 | import ray 6 | import torch 7 | import numpy as np 8 | from scipy.spatial import cKDTree as KDTree 9 | 10 | # code is based on 11 | # https://github.com/ChrisWu1997/Multimodal-Shape-Completion/blob/master/evaluation/completeness.py 12 | 13 | 14 | def directed_hausdorff(point_cloud1:torch.Tensor, point_cloud2:torch.Tensor, reduce_mean=True): 15 | """ 16 | 17 | :param point_cloud1: (B, 3, N) 18 | :param point_cloud2: (B, 3, M) 19 | :return: directed hausdorff distance, A -> B 20 | """ 21 | n_pts1 = point_cloud1.shape[2] 22 | n_pts2 = point_cloud2.shape[2] 23 | 24 | pc1 = point_cloud1.unsqueeze(3) 25 | pc1 = pc1.repeat((1, 1, 1, n_pts2)) # (B, 3, N, M) 26 | pc2 = point_cloud2.unsqueeze(2) 27 | pc2 = pc2.repeat((1, 1, n_pts1, 1)) # (B, 3, N, M) 28 | 29 | l2_dist = torch.sqrt(torch.sum((pc1 - pc2) ** 2, dim=1)) # (B, N, M) 30 | 31 | shortest_dist, _ = torch.min(l2_dist, dim=2) 32 | 33 | hausdorff_dist, _ = torch.max(shortest_dist, dim=1) # (B, ) 34 | 35 | if reduce_mean: 36 | hausdorff_dist = torch.mean(hausdorff_dist) 37 | 38 | return hausdorff_dist 39 | 40 | 41 | def nn_distance(query_points, ref_points): 42 | ref_points_kd_tree = KDTree(ref_points) 43 | one_distances, one_vertex_ids = ref_points_kd_tree.query(query_points) 44 | return one_distances 45 | 46 | 47 | def completeness(query_points, ref_points, thres=0.03): 48 | a2b_nn_distance = nn_distance(query_points, ref_points) 49 | percentage = np.sum(a2b_nn_distance < thres) / len(a2b_nn_distance) 50 | return percentage 51 | 52 | 53 | @ray.remote 54 | def process_one_uhd(existing, gen_pcs): 55 | with warnings.catch_warnings(): 56 | warnings.simplefilter('ignore') 57 | gen_pcs_tensors = [torch.tensor(pc) for pc in gen_pcs] 58 | gen_pcs_tensors = torch.stack(gen_pcs_tensors, dim=0) 59 | existing_pc_tensor = torch.tensor(existing).unsqueeze(0).repeat((gen_pcs_tensors.size(0), 1, 1)) 60 | return directed_hausdorff(existing_pc_tensor, gen_pcs_tensors, reduce_mean=True).item() 61 | 62 | 63 | def process(shape_dir): 64 | # load generated shape 65 | pc_paths = glob.glob(os.path.join(shape_dir, "*reconstruction.npy")) 66 | pc_paths = sorted(pc_paths) 67 | 68 | # load existing input 69 | existing_paths = glob.glob(os.path.join(shape_dir, "*existing.npy")) 70 | existing_paths = sorted(existing_paths) 71 | 72 | gen_pcs = [] 73 | for i in range(int(len(pc_paths) / 10)): 74 | pcs = [] 75 | for j in range(10): 76 | pcs.append(np.load(pc_paths[i * 10 + j])) 77 | gen_pcs.append(pcs) 78 | gen_pcs = np.array(gen_pcs) 79 | 80 | existing_pcs = [] 81 | for i in range(len(existing_paths)): 82 | existing_pcs.append(np.load(existing_paths[i])) 83 | existing_pcs = np.array(existing_pcs) 84 | 85 | ray.init(num_cpus=4) 86 | ray_uhd_tasks = [process_one_uhd.remote(existing_pcs[i], gen_pcs[i]) for i in range(int(len(pc_paths) / 10))] 87 | uhd = np.mean(ray.get(ray_uhd_tasks)) 88 | ray.shutdown() 89 | return uhd 90 | 91 | # single thread version 92 | # 93 | # # completeness percentage 94 | # gen_comp_res = [] 95 | # 96 | # for i in range(len(gen_pcs)): 97 | # gen_comp = 0 98 | # for sample_pts in gen_pcs[i]: 99 | # comp = completeness(existing_pcs[i], sample_pts) 100 | # gen_comp += comp 101 | # gen_comp_res.append(gen_comp / len(gen_pcs)) 102 | # 103 | # 104 | # # unidirectional hausdorff 105 | # hausdorff_res = [] 106 | # 107 | # for i in tqdm(range(len(gen_pcs))): 108 | # 109 | # gen_pcs_tensors = [torch.tensor(pc).transpose(1, 0) for pc in gen_pcs[i]] 110 | # gen_pcs_tensors = torch.stack(gen_pcs_tensors, dim=0) 111 | # 112 | # existing_pc_tensor = torch.tensor(existing_pcs[i]).transpose(1, 0) 113 | # 114 | # existing_pc_tensor = existing_pc_tensor.unsqueeze(0).repeat((gen_pcs_tensors.size(0), 1, 1)) 115 | # 116 | # hausdorff = directed_hausdorff(existing_pc_tensor, gen_pcs_tensors, reduce_mean=True).item() 117 | # hausdorff_res.append(hausdorff) 118 | # 119 | # return np.mean(hausdorff_res) # np.mean(gen_comp_res), np.mean(hausdorff_res) 120 | # 121 | -------------------------------------------------------------------------------- /utils/evaluation/mmd.py: -------------------------------------------------------------------------------- 1 | import random 2 | import glob 3 | from os.path import join 4 | 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from utils.pytorch_structural_losses.nn_distance import nn_distance 10 | 11 | # code is based on 12 | # https://github.com/ChrisWu1997/Multimodal-Shape-Completion/blob/master/evaluation/mmd.py 13 | 14 | 15 | def iterate_in_chunks(l, n): 16 | '''Yield successive 'n'-sized chunks from iterable 'l'. 17 | Note: last chunk will be smaller than l if n doesn't divide l perfectly. 18 | ''' 19 | for i in range(0, len(l), n): 20 | yield l[i:i + n] 21 | 22 | 23 | def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, device=None): 24 | 25 | n_ref, n_pc_points, pc_dim = ref_pcs.shape 26 | _, n_pc_points_s, pc_dim_s = sample_pcs.shape 27 | 28 | if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: 29 | raise ValueError('Incompatible size of point-clouds.') 30 | 31 | matched_dists = [] 32 | pbar = tqdm(range(n_ref)) 33 | for i in pbar: 34 | best_in_all_batches = [] 35 | ref = torch.from_numpy(ref_pcs[i]).unsqueeze(0).to(device).contiguous() 36 | for sample_chunk in iterate_in_chunks(sample_pcs, batch_size): 37 | chunk = torch.from_numpy(sample_chunk).to(device).contiguous() 38 | ref_to_s, s_to_ref = nn_distance(ref, chunk) 39 | all_dist_in_batch = ref_to_s.mean(dim=1) + s_to_ref.mean(dim=1) 40 | best_in_batch = torch.min(all_dist_in_batch).item() 41 | best_in_all_batches.append(best_in_batch) 42 | 43 | matched_dists.append(np.min(best_in_all_batches)) 44 | pbar.set_postfix({"mmd": np.mean(matched_dists)}) 45 | 46 | mmd = np.mean(matched_dists) 47 | return mmd, matched_dists 48 | 49 | 50 | def process(shape_dir, dataset, device, batch_size=64): 51 | random.seed(1234) 52 | ref_pcs = [] 53 | for data in dataset: 54 | _, _, gt, _ = data 55 | ref_pcs.append(gt) 56 | ref_pcs = np.stack(ref_pcs, axis=0) 57 | 58 | pc_paths = glob.glob(join(shape_dir, "*reconstruction.npy")) 59 | pc_paths = sorted(pc_paths) 60 | 61 | sample_pcs = [] 62 | for path in pc_paths: 63 | sample_pcs.append(np.load(path).T) 64 | sample_pcs = np.stack(sample_pcs, axis=0) 65 | 66 | mmd, matched_dists = minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, device) 67 | del sample_pcs 68 | del ref_pcs 69 | return mmd 70 | -------------------------------------------------------------------------------- /utils/evaluation/total_mutual_diff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import ray 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from utils.evaluation.chamfer import compute_trimesh_chamfer 9 | 10 | # code is based on 11 | # https://github.com/ChrisWu1997/Multimodal-Shape-Completion/blob/master/evaluation/total_mutual_diff.py 12 | 13 | 14 | @ray.remote 15 | def process_one_tmd(gen_pcs): 16 | sum_dist = 0 17 | for j in range(len(gen_pcs)): 18 | for k in range(j + 1, len(gen_pcs), 1): 19 | pc1 = gen_pcs[j] 20 | pc2 = gen_pcs[k] 21 | chamfer_dist = compute_trimesh_chamfer(pc1, pc2) 22 | sum_dist += chamfer_dist 23 | mean_dist = sum_dist * 2 / (len(gen_pcs) - 1) 24 | return mean_dist 25 | 26 | 27 | def process(shape_dir): 28 | pc_paths = glob.glob(os.path.join(shape_dir, "*reconstruction.npy")) 29 | 30 | pc_paths = sorted(pc_paths) 31 | gen_pcs = [] 32 | 33 | for i in range(int(len(pc_paths)/10)): 34 | pcs = [] 35 | for j in range(10): 36 | pcs.append(np.load(pc_paths[i*10+j]).T) 37 | gen_pcs.append(pcs) 38 | gen_pcs = np.array(gen_pcs) 39 | 40 | # parallel version 41 | # ray.init(num_cpus=os.cpu_count()) 42 | # ray_tmd_tasks = [process_one_tmd.remote(gen_pcs[i]) for i in range(len(gen_pcs))] 43 | # tmd = ray.get(ray_tmd_tasks) 44 | # ray.shutdown() 45 | # return np.mean(tmd) 46 | 47 | results = [] 48 | pbar = tqdm(range(len(gen_pcs))) 49 | for i in pbar: 50 | sum_dist = 0 51 | for j in range(len(gen_pcs[i])): 52 | for k in range(j + 1, len(gen_pcs[i]), 1): 53 | pc1 = gen_pcs[i][j] 54 | pc2 = gen_pcs[i][k] 55 | chamfer_dist = compute_trimesh_chamfer(pc1, pc2) 56 | sum_dist += chamfer_dist 57 | mean_dist = sum_dist * 2 / (len(gen_pcs[i]) - 1) 58 | results.append(mean_dist) 59 | pbar.set_postfix({"mmd": np.mean(results)}) 60 | 61 | return np.mean(results) 62 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from numpy.linalg import norm 7 | from scipy.stats import entropy 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | # code is based on 11 | # https://github.com/stevenygd/PointFlow/blob/master/metrics/evaluation_metrics.py 12 | 13 | # Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/: 14 | from utils.pytorch_structural_losses.match_cost import match_cost 15 | from utils.pytorch_structural_losses.nn_distance import nn_distance 16 | 17 | 18 | def _average_precision(query: torch.Tensor, retrieved: torch.Tensor) -> torch.Tensor: 19 | corrects = (retrieved == query.view(-1, 1)).half() 20 | denominators = torch.arange(1, retrieved.shape[1] + 1).type_as(corrects) 21 | return (corrects * corrects.cumsum(dim=1, dtype=corrects.dtype) / denominators).sum(dim=1) / corrects.sum(dim=1) 22 | 23 | 24 | def average_precision(query: torch.Tensor, retrieved: torch.Tensor) -> torch.Tensor: 25 | corrects = (query.view(-1, 1) == retrieved).half() 26 | denominators = torch.arange(1, retrieved.size(1) + 1).type_as(corrects) 27 | return (corrects * corrects.cumsum(dim=1, dtype=corrects.dtype) / denominators).sum(dim=1) / corrects.sum(dim=1) 28 | 29 | 30 | def mean_average_precision(query: torch.Tensor, retrieved: torch.Tensor): 31 | return average_precision(query, retrieved).mean() 32 | 33 | 34 | def average_precision_numpy(query, retrieved): 35 | corrects = (retrieved == query.reshape(-1, 1)).astype(np.int) 36 | denominators = np.arange(1, retrieved.shape[1] + 1) 37 | return ((corrects * corrects.cumsum(axis=1)) / denominators).sum(axis=1) / corrects.sum(axis=1) 38 | 39 | 40 | def mean_average_precision_numpy(query, retrieved): 41 | return np.mean(average_precision(query, retrieved)) 42 | 43 | 44 | def earth_mover_distance(sample_pcs, ref_pcs, batch_size=None): 45 | """Use this function to calculate EMD in our experiments.""" 46 | sample_pcs = sample_pcs.contiguous() 47 | ref_pcs = ref_pcs.contiguous() 48 | if sample_pcs.dim() == 2: 49 | sample_pcs = sample_pcs.unsqueeze(0) 50 | if ref_pcs.dim() == 2: 51 | ref_pcs = ref_pcs.unsqueeze(0) 52 | 53 | N_sample = sample_pcs.shape[0] 54 | N_ref = ref_pcs.shape[0] 55 | assert N_sample == N_ref, f'REF:{N_ref} SMP:{N_sample}' 56 | 57 | batch_size = min(batch_size or N_sample, 300) 58 | 59 | emd_lst = [] 60 | for b_start in range(0, N_sample, batch_size): 61 | b_end = min(N_sample, b_start + batch_size) 62 | sample_batch = sample_pcs[b_start:b_end] 63 | ref_batch = ref_pcs[b_start:b_end] 64 | 65 | emd_batch = emd_approx(sample_batch, ref_batch) 66 | emd_lst.append(emd_batch) 67 | 68 | return torch.cat(emd_lst) 69 | 70 | 71 | def emd_approx(sample, ref): 72 | B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) 73 | assert N == N_ref, "Not sure what would EMD do in this case" 74 | emd = match_cost(sample, ref) # (B,) 75 | emd_norm = emd / float(N) # (B,) 76 | return emd_norm 77 | 78 | def dist_chamfer(x, y, chamfer_loss): 79 | # from losses.champfer_loss import ChamferLoss 80 | # from utils.util import cuda_setup 81 | # chamfer_loss = ChamferLoss().to(cuda_setup()) 82 | P = chamfer_loss.batch_pairwise_dist(x, y) 83 | return P.min(1)[0], P.min(2)[0] 84 | 85 | 86 | def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): 87 | N_sample = sample_pcs.shape[0] 88 | N_ref = ref_pcs.shape[0] 89 | assert N_sample == N_ref, f'REF:{N_ref} SMP:{N_sample}' 90 | 91 | cd_lst = [] 92 | emd_lst = [] 93 | iterator = range(0, N_sample, batch_size) 94 | 95 | for b_start in iterator: 96 | b_end = min(N_sample, b_start + batch_size) 97 | sample_batch = sample_pcs[b_start:b_end] 98 | ref_batch = ref_pcs[b_start:b_end] 99 | 100 | dl, dr = dist_chamfer(sample_batch, ref_batch) 101 | cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) 102 | 103 | emd_batch = emd_approx(sample_batch, ref_batch) 104 | emd_lst.append(emd_batch) 105 | 106 | if reduced: 107 | cd = torch.cat(cd_lst).mean() 108 | emd = torch.cat(emd_lst).mean() 109 | else: 110 | cd = torch.cat(cd_lst) 111 | emd = torch.cat(emd_lst) 112 | 113 | results = { 114 | 'MMD-CD': cd, 115 | 'MMD-EMD': emd 116 | } 117 | 118 | return results 119 | 120 | 121 | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, chamfer_loss): 122 | N_sample = sample_pcs.shape[0] 123 | N_ref = ref_pcs.shape[0] 124 | all_cd = [] 125 | all_emd = [] 126 | iterator = range(N_sample) 127 | for sample_b_start in tqdm(iterator): 128 | sample_batch = sample_pcs[sample_b_start] 129 | 130 | cd_lst = [] 131 | emd_lst = [] 132 | for ref_b_start in range(0, N_ref, batch_size): 133 | ref_b_end = min(N_ref, ref_b_start + batch_size) 134 | ref_batch = ref_pcs[ref_b_start:ref_b_end] 135 | 136 | batch_size_ref = ref_batch.size(0) 137 | sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) 138 | sample_batch_exp = sample_batch_exp.contiguous() 139 | 140 | # dl, dr = nn_distance(sample_batch_exp, ref_batch) 141 | 142 | dl, dr = dist_chamfer(sample_batch_exp, ref_batch, chamfer_loss) 143 | 144 | 145 | cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) 146 | 147 | emd_batch = emd_approx(sample_batch_exp, ref_batch) 148 | emd_lst.append(emd_batch.view(1, -1)) 149 | 150 | cd_lst = torch.cat(cd_lst, dim=1) 151 | emd_lst = torch.cat(emd_lst, dim=1) 152 | all_cd.append(cd_lst) 153 | all_emd.append(emd_lst) 154 | 155 | all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref 156 | all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref 157 | 158 | return all_cd, all_emd 159 | 160 | 161 | # Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py 162 | def knn(Mxx, Mxy, Myy, k, sqrt=False): 163 | n0 = Mxx.size(0) 164 | n1 = Myy.size(0) 165 | label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) 166 | M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) 167 | if sqrt: 168 | M = M.abs().sqrt() 169 | INFINITY = float('inf') 170 | val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) 171 | 172 | count = torch.zeros(n0 + n1).to(Mxx) 173 | for i in range(0, k): 174 | count = count + label.index_select(0, idx[i]) 175 | pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() 176 | 177 | s = { 178 | 'tp': (pred * label).sum(), 179 | 'fp': (pred * (1 - label)).sum(), 180 | 'fn': ((1 - pred) * label).sum(), 181 | 'tn': ((1 - pred) * (1 - label)).sum(), 182 | } 183 | 184 | s.update({ 185 | 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), 186 | 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 187 | 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 188 | 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), 189 | 'acc': torch.eq(label, pred).float().mean(), 190 | }) 191 | return s 192 | 193 | 194 | def mmd_cov(all_dist): 195 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 196 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 197 | min_val, _ = torch.min(all_dist, dim=0) 198 | mmd = min_val.mean() 199 | mmd_smp = min_val_fromsmp.mean() 200 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 201 | cov = torch.tensor(cov).to(all_dist) 202 | return { 203 | 'mmd(Fidelity)': mmd, 204 | 'cov(Coverage)': cov, 205 | 'mmd_smp': mmd_smp, 206 | } 207 | 208 | 209 | def compute_all_metrics(sample_pcs, ref_pcs, batch_size, chamfer_loss): 210 | results = {} 211 | 212 | M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, chamfer_loss) 213 | 214 | res_cd = mmd_cov(M_rs_cd.t()) 215 | results.update({ 216 | "%s-CD" % k: v for k, v in res_cd.items() 217 | }) 218 | 219 | 220 | res_emd = mmd_cov(M_rs_emd.t()) 221 | results.update({ 222 | "%s-EMD" % k: v for k, v in res_emd.items() 223 | }) 224 | ''' 225 | M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) 226 | M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) 227 | 228 | # 1-NN results 229 | one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) 230 | results.update({ 231 | "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k 232 | }) 233 | one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) 234 | results.update({ 235 | "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k 236 | }) 237 | ''' 238 | return results 239 | 240 | 241 | ####################################################### 242 | # JSD : from https://github.com/optas/latent_3d_points 243 | ####################################################### 244 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 245 | """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, 246 | that is placed in the unit-cube. 247 | If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. 248 | """ 249 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 250 | spacing = 1.0 / float(resolution - 1) 251 | for i in range(resolution): 252 | for j in range(resolution): 253 | for k in range(resolution): 254 | grid[i, j, k, 0] = i * spacing - 0.5 255 | grid[i, j, k, 1] = j * spacing - 0.5 256 | grid[i, j, k, 2] = k * spacing - 0.5 257 | 258 | if clip_sphere: 259 | grid = grid.reshape(-1, 3) 260 | grid = grid[norm(grid, axis=1) <= 0.5] 261 | 262 | return grid, spacing 263 | 264 | 265 | def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): 266 | """Computes the JSD between two sets of point-clouds, as introduced in the paper 267 | ```Learning Representations And Generative Models For 3D Point Clouds```. 268 | Args: 269 | sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. 270 | ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. 271 | resolution: (int) grid-resolution. Affects granularity of measurements. 272 | """ 273 | in_unit_sphere = True 274 | sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] 275 | ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] 276 | return jensen_shannon_divergence(sample_grid_var, ref_grid_var) 277 | 278 | 279 | def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False): 280 | """Given a collection of point-clouds, estimate the entropy of the random variables 281 | corresponding to occupancy-grid activation patterns. 282 | Inputs: 283 | pclouds: (numpy array) #point-clouds x points per point-cloud x 3 284 | grid_resolution (int) size of occupancy grid that will be used. 285 | """ 286 | epsilon = 10e-4 287 | bound = 0.5 + epsilon 288 | if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: 289 | if verbose: 290 | warnings.warn('Point-clouds are not in unit cube.') 291 | 292 | if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: 293 | if verbose: 294 | warnings.warn('Point-clouds are not in unit sphere.') 295 | 296 | grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) 297 | grid_coordinates = grid_coordinates.reshape(-1, 3) 298 | grid_counters = np.zeros(len(grid_coordinates)) 299 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) 300 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) 301 | 302 | for pc in pclouds: 303 | _, indices = nn.kneighbors(pc) 304 | indices = np.squeeze(indices) 305 | for i in indices: 306 | grid_counters[i] += 1 307 | indices = np.unique(indices) 308 | for i in indices: 309 | grid_bernoulli_rvars[i] += 1 310 | 311 | acc_entropy = 0.0 312 | n = float(len(pclouds)) 313 | for g in grid_bernoulli_rvars: 314 | if g > 0: 315 | p = float(g) / n 316 | acc_entropy += entropy([p, 1.0 - p]) 317 | 318 | return acc_entropy / len(grid_counters), grid_counters 319 | 320 | 321 | def jensen_shannon_divergence(P, Q): 322 | if np.any(P < 0) or np.any(Q < 0): 323 | raise ValueError('Negative values.') 324 | if len(P) != len(Q): 325 | raise ValueError('Non equal size.') 326 | 327 | P_ = P / np.sum(P) # Ensure probabilities. 328 | Q_ = Q / np.sum(Q) 329 | 330 | e1 = entropy(P_, base=2) 331 | e2 = entropy(Q_, base=2) 332 | e_sum = entropy((P_ + Q_) / 2.0, base=2) 333 | res = e_sum - ((e1 + e2) / 2.0) 334 | 335 | res2 = _jsdiv(P_, Q_) 336 | 337 | if not np.allclose(res, res2, atol=10e-5, rtol=0): 338 | warnings.warn('Numerical values of two JSD methods don\'t agree.') 339 | 340 | return res 341 | 342 | 343 | def _jsdiv(P, Q): 344 | """another way of computing JSD""" 345 | 346 | def _kldiv(A, B): 347 | a = A.copy() 348 | b = B.copy() 349 | idx = np.logical_and(a > 0, b > 0) 350 | a = a[idx] 351 | b = b[idx] 352 | return np.sum([v for v in a * np.log2(a / b)]) 353 | 354 | P_ = P / np.sum(P) 355 | Q_ = Q / np.sum(Q) 356 | 357 | M = 0.5 * (P_ + Q_) 358 | 359 | return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) 360 | -------------------------------------------------------------------------------- /utils/pcutil.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from numpy.linalg import norm 4 | 5 | # Don't delete this line, even if PyCharm says it's an unused import. 6 | # It is required for projection='3d' in add_subplot() 7 | from mpl_toolkits.mplot3d import Axes3D 8 | 9 | 10 | def rand_rotation_matrix(deflection=1.0, seed=None): 11 | """Creates a random rotation matrix. 12 | 13 | Args: 14 | deflection: the magnitude of the rotation. For 0, no rotation; for 1, 15 | completely random rotation. Small deflection => small 16 | perturbation. 17 | 18 | DOI: http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c 19 | http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html 20 | """ 21 | if seed is not None: 22 | np.random.seed(seed) 23 | 24 | theta, phi, z = np.random.uniform(size=(3,)) 25 | 26 | theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). 27 | phi = phi * 2.0 * np.pi # For direction of pole deflection. 28 | z = z * 2.0 * deflection # For magnitude of pole deflection. 29 | 30 | # Compute a vector V used for distributing points over the sphere 31 | # via the reflection I - V Transpose(V). This formulation of V 32 | # will guarantee that if x[1] and x[2] are uniformly distributed, 33 | # the reflected points will be uniform on the sphere. Note that V 34 | # has length sqrt(2) to eliminate the 2 in the Householder matrix. 35 | 36 | r = np.sqrt(z) 37 | V = (np.sin(phi) * r, 38 | np.cos(phi) * r, 39 | np.sqrt(2.0 - z)) 40 | 41 | st = np.sin(theta) 42 | ct = np.cos(theta) 43 | 44 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 45 | 46 | # Construct the rotation matrix ( V Transpose(V) - I ) R. 47 | M = (np.outer(V, V) - np.eye(3)).dot(R) 48 | return M 49 | 50 | 51 | def add_gaussian_noise_to_pcloud(pcloud, mu=0, sigma=1): 52 | gnoise = np.random.normal(mu, sigma, pcloud.shape[0]) 53 | gnoise = np.tile(gnoise, (3, 1)).T 54 | pcloud += gnoise 55 | return pcloud 56 | 57 | 58 | def add_rotation_to_pcloud(pcloud): 59 | r_rotation = rand_rotation_matrix() 60 | 61 | if len(pcloud.shape) == 2: 62 | return pcloud.dot(r_rotation) 63 | else: 64 | return np.asarray([e.dot(r_rotation) for e in pcloud]) 65 | 66 | 67 | def apply_augmentations(batch, conf): 68 | if conf.gauss_augment is not None or conf.z_rotate: 69 | batch = batch.copy() 70 | 71 | if conf.gauss_augment is not None: 72 | mu = conf.gauss_augment['mu'] 73 | sigma = conf.gauss_augment['sigma'] 74 | batch += np.random.normal(mu, sigma, batch.shape) 75 | 76 | if conf.z_rotate: 77 | r_rotation = rand_rotation_matrix() 78 | r_rotation[0, 2] = 0 79 | r_rotation[2, 0] = 0 80 | r_rotation[1, 2] = 0 81 | r_rotation[2, 1] = 0 82 | r_rotation[2, 2] = 1 83 | batch = batch.dot(r_rotation) 84 | return batch 85 | 86 | 87 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 88 | """Returns the center coordinates of each cell of a 3D grid with 89 | resolution^3 cells, that is placed in the unit-cube. 90 | If clip_sphere it True it drops the "corner" cells that lie outside 91 | the unit-sphere. 92 | """ 93 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 94 | spacing = 1.0 / float(resolution - 1) 95 | for i in range(resolution): 96 | for j in range(resolution): 97 | for k in range(resolution): 98 | grid[i, j, k, 0] = i * spacing - 0.5 99 | grid[i, j, k, 1] = j * spacing - 0.5 100 | grid[i, j, k, 2] = k * spacing - 0.5 101 | 102 | if clip_sphere: 103 | grid = grid.reshape(-1, 3) 104 | grid = grid[norm(grid, axis=1) <= 0.5] 105 | 106 | return grid, spacing 107 | 108 | 109 | def plot_3d_point_cloud(x, y, z, show=True, show_axis=True, in_u_sphere=False, 110 | marker='.', s=8, alpha=.8, figsize=(5, 5), elev=10, 111 | azim=240, axis=None, title=None, x1=None, y1=None, z1=None, *args, **kwargs): 112 | plt.switch_backend('agg') 113 | if axis is None: 114 | fig = plt.figure(figsize=figsize) 115 | ax = fig.add_subplot(111, projection='3d') 116 | else: 117 | ax = axis 118 | fig = axis 119 | 120 | if title is not None: 121 | plt.title(title) 122 | 123 | if x1 is not None and y1 is not None and z1 is not None: 124 | ax.scatter(x1, y1, z1, color='r', marker=marker, s=s*3, alpha=1, zorder=2, *args, **kwargs) 125 | alpha = 0.3 126 | 127 | sc = ax.scatter(x, y, z, marker=marker, s=s, alpha=alpha, zorder=1, *args, **kwargs) 128 | ax.view_init(elev=elev, azim=azim) 129 | 130 | if in_u_sphere: 131 | ax.set_xlim3d(-0.5, 0.5) 132 | ax.set_ylim3d(-0.5, 0.5) 133 | ax.set_zlim3d(-0.5, 0.5) 134 | else: 135 | # Multiply with 0.7 to squeeze free-space. 136 | miv = 0.7 * np.min([np.min(x), np.min(y), np.min(z)]) 137 | mav = 0.7 * np.max([np.max(x), np.max(y), np.max(z)]) 138 | ax.set_xlim(miv, mav) 139 | ax.set_ylim(miv, mav) 140 | ax.set_zlim(miv, mav) 141 | plt.tight_layout() 142 | 143 | if not show_axis: 144 | plt.axis('off') 145 | 146 | if 'c' in kwargs: 147 | plt.colorbar(sc) 148 | 149 | if show: 150 | plt.show() 151 | 152 | return fig 153 | 154 | 155 | def transform_point_clouds(X, only_z_rotation=False, deflection=1.0): 156 | r_rotation = rand_rotation_matrix(deflection) 157 | if only_z_rotation: 158 | r_rotation[0, 2] = 0 159 | r_rotation[2, 0] = 0 160 | r_rotation[1, 2] = 0 161 | r_rotation[2, 1] = 0 162 | r_rotation[2, 2] = 1 163 | X = X.dot(r_rotation).astype(np.float32) 164 | return X 165 | -------------------------------------------------------------------------------- /utils/plyfile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2014 Darsh Ranjan 2 | # 3 | # This file is part of python-plyfile. 4 | # 5 | # python-plyfile is free software: you can redistribute it and/or 6 | # modify it under the terms of the GNU General Public License as 7 | # published by the Free Software Foundation, either version 3 of the 8 | # License, or (at your option) any later version. 9 | # 10 | # python-plyfile is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 13 | # General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with python-plyfile. If not, see 17 | # . 18 | 19 | from itertools import islice as _islice 20 | 21 | import numpy as _np 22 | from sys import byteorder as _byteorder 23 | 24 | 25 | try: 26 | _range = xrange 27 | except NameError: 28 | _range = range 29 | 30 | 31 | # Many-many relation 32 | _data_type_relation = [ 33 | ('int8', 'i1'), 34 | ('char', 'i1'), 35 | ('uint8', 'u1'), 36 | ('uchar', 'b1'), 37 | ('uchar', 'u1'), 38 | ('int16', 'i2'), 39 | ('short', 'i2'), 40 | ('uint16', 'u2'), 41 | ('ushort', 'u2'), 42 | ('int32', 'i4'), 43 | ('int', 'i4'), 44 | ('uint32', 'u4'), 45 | ('uint', 'u4'), 46 | ('float32', 'f4'), 47 | ('float', 'f4'), 48 | ('float64', 'f8'), 49 | ('double', 'f8') 50 | ] 51 | 52 | _data_types = dict(_data_type_relation) 53 | _data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) 54 | 55 | _types_list = [] 56 | _types_set = set() 57 | for (_a, _b) in _data_type_relation: 58 | if _a not in _types_set: 59 | _types_list.append(_a) 60 | _types_set.add(_a) 61 | if _b not in _types_set: 62 | _types_list.append(_b) 63 | _types_set.add(_b) 64 | 65 | 66 | _byte_order_map = { 67 | 'ascii': '=', 68 | 'binary_little_endian': '<', 69 | 'binary_big_endian': '>' 70 | } 71 | 72 | _byte_order_reverse = { 73 | '<': 'binary_little_endian', 74 | '>': 'binary_big_endian' 75 | } 76 | 77 | _native_byte_order = {'little': '<', 'big': '>'}[_byteorder] 78 | 79 | 80 | def _lookup_type(type_str): 81 | if type_str not in _data_type_reverse: 82 | try: 83 | type_str = _data_types[type_str] 84 | except KeyError: 85 | raise ValueError("field type %r not in %r" % 86 | (type_str, _types_list)) 87 | 88 | return _data_type_reverse[type_str] 89 | 90 | 91 | def _split_line(line, n): 92 | fields = line.split(None, n) 93 | if len(fields) == n: 94 | fields.append('') 95 | 96 | assert len(fields) == n + 1 97 | 98 | return fields 99 | 100 | 101 | def make2d(array, cols=None, dtype=None): 102 | """ 103 | Make a 2D array from an array of arrays. The `cols' and `dtype' 104 | arguments can be omitted if the array is not empty. 105 | 106 | """ 107 | if (cols is None or dtype is None) and not len(array): 108 | raise RuntimeError("cols and dtype must be specified for empty " 109 | "array") 110 | 111 | if cols is None: 112 | cols = len(array[0]) 113 | 114 | if dtype is None: 115 | dtype = array[0].dtype 116 | 117 | return _np.fromiter(array, [('_', dtype, (cols,))], 118 | count=len(array))['_'] 119 | 120 | 121 | class PlyParseError(Exception): 122 | 123 | """ 124 | Raised when a PLY file cannot be parsed. 125 | 126 | The attributes `element', `row', `property', and `message' give 127 | additional information. 128 | 129 | """ 130 | 131 | def __init__(self, message, element=None, row=None, prop=None): 132 | self.message = message 133 | self.element = element 134 | self.row = row 135 | self.prop = prop 136 | 137 | s = '' 138 | if self.element: 139 | s += 'element %r: ' % self.element.name 140 | if self.row is not None: 141 | s += 'row %d: ' % self.row 142 | if self.prop: 143 | s += 'property %r: ' % self.prop.name 144 | s += self.message 145 | 146 | Exception.__init__(self, s) 147 | 148 | def __repr__(self): 149 | return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % 150 | self.message, self.element, self.row, self.prop) 151 | 152 | 153 | class PlyData(object): 154 | 155 | """ 156 | PLY file header and data. 157 | 158 | A PlyData instance is created in one of two ways: by the static 159 | method PlyData.read (to read a PLY file), or directly from __init__ 160 | given a sequence of elements (which can then be written to a PLY 161 | file). 162 | 163 | """ 164 | 165 | def __init__(self, elements=[], text=False, byte_order='=', 166 | comments=[], obj_info=[]): 167 | """ 168 | elements: sequence of PlyElement instances. 169 | 170 | text: whether the resulting PLY file will be text (True) or 171 | binary (False). 172 | 173 | byte_order: '<' for little-endian, '>' for big-endian, or '=' 174 | for native. This is only relevant if `text' is False. 175 | 176 | comments: sequence of strings that will be placed in the header 177 | between the 'ply' and 'format ...' lines. 178 | 179 | obj_info: like comments, but will be placed in the header with 180 | "obj_info ..." instead of "comment ...". 181 | 182 | """ 183 | if byte_order == '=' and not text: 184 | byte_order = _native_byte_order 185 | 186 | self.byte_order = byte_order 187 | self.text = text 188 | 189 | self.comments = list(comments) 190 | self.obj_info = list(obj_info) 191 | self.elements = elements 192 | 193 | def _get_elements(self): 194 | return self._elements 195 | 196 | def _set_elements(self, elements): 197 | self._elements = tuple(elements) 198 | self._index() 199 | 200 | elements = property(_get_elements, _set_elements) 201 | 202 | def _get_byte_order(self): 203 | return self._byte_order 204 | 205 | def _set_byte_order(self, byte_order): 206 | if byte_order not in ['<', '>', '=']: 207 | raise ValueError("byte order must be '<', '>', or '='") 208 | 209 | self._byte_order = byte_order 210 | 211 | byte_order = property(_get_byte_order, _set_byte_order) 212 | 213 | def _index(self): 214 | self._element_lookup = dict((elt.name, elt) for elt in 215 | self._elements) 216 | if len(self._element_lookup) != len(self._elements): 217 | raise ValueError("two elements with same name") 218 | 219 | @staticmethod 220 | def _parse_header(stream): 221 | """ 222 | Parse a PLY header from a readable file-like stream. 223 | 224 | """ 225 | lines = [] 226 | comments = {'comment': [], 'obj_info': []} 227 | while True: 228 | line = stream.readline().decode('ascii').strip() 229 | fields = _split_line(line, 1) 230 | 231 | if fields[0] == 'end_header': 232 | break 233 | 234 | elif fields[0] in comments.keys(): 235 | lines.append(fields) 236 | else: 237 | lines.append(line.split()) 238 | 239 | a = 0 240 | if lines[a] != ['ply']: 241 | raise PlyParseError("expected 'ply'") 242 | 243 | a += 1 244 | while lines[a][0] in comments.keys(): 245 | comments[lines[a][0]].append(lines[a][1]) 246 | a += 1 247 | 248 | if lines[a][0] != 'format': 249 | raise PlyParseError("expected 'format'") 250 | 251 | if lines[a][2] != '1.0': 252 | raise PlyParseError("expected version '1.0'") 253 | 254 | if len(lines[a]) != 3: 255 | raise PlyParseError("too many fields after 'format'") 256 | 257 | fmt = lines[a][1] 258 | 259 | if fmt not in _byte_order_map: 260 | raise PlyParseError("don't understand format %r" % fmt) 261 | 262 | byte_order = _byte_order_map[fmt] 263 | text = fmt == 'ascii' 264 | 265 | a += 1 266 | while a < len(lines) and lines[a][0] in comments.keys(): 267 | comments[lines[a][0]].append(lines[a][1]) 268 | a += 1 269 | 270 | return PlyData(PlyElement._parse_multi(lines[a:]), 271 | text, byte_order, 272 | comments['comment'], comments['obj_info']) 273 | 274 | @staticmethod 275 | def read(stream): 276 | """ 277 | Read PLY data from a readable file-like object or filename. 278 | 279 | """ 280 | (must_close, stream) = _open_stream(stream, 'read') 281 | try: 282 | data = PlyData._parse_header(stream) 283 | for elt in data: 284 | elt._read(stream, data.text, data.byte_order) 285 | finally: 286 | if must_close: 287 | stream.close() 288 | 289 | return data 290 | 291 | def write(self, stream): 292 | """ 293 | Write PLY data to a writeable file-like object or filename. 294 | 295 | """ 296 | (must_close, stream) = _open_stream(stream, 'write') 297 | try: 298 | stream.write(self.header.encode('ascii')) 299 | stream.write(b'\r\n') 300 | for elt in self: 301 | elt._write(stream, self.text, self.byte_order) 302 | finally: 303 | if must_close: 304 | stream.close() 305 | 306 | @property 307 | def header(self): 308 | """ 309 | Provide PLY-formatted metadata for the instance. 310 | 311 | """ 312 | lines = ['ply'] 313 | 314 | if self.text: 315 | lines.append('format ascii 1.0') 316 | else: 317 | lines.append('format ' + 318 | _byte_order_reverse[self.byte_order] + 319 | ' 1.0') 320 | 321 | # Some information is lost here, since all comments are placed 322 | # between the 'format' line and the first element. 323 | for c in self.comments: 324 | lines.append('comment ' + c) 325 | 326 | for c in self.obj_info: 327 | lines.append('obj_info ' + c) 328 | 329 | lines.extend(elt.header for elt in self.elements) 330 | lines.append('end_header') 331 | return '\r\n'.join(lines) 332 | 333 | def __iter__(self): 334 | return iter(self.elements) 335 | 336 | def __len__(self): 337 | return len(self.elements) 338 | 339 | def __contains__(self, name): 340 | return name in self._element_lookup 341 | 342 | def __getitem__(self, name): 343 | return self._element_lookup[name] 344 | 345 | def __str__(self): 346 | return self.header 347 | 348 | def __repr__(self): 349 | return ('PlyData(%r, text=%r, byte_order=%r, ' 350 | 'comments=%r, obj_info=%r)' % 351 | (self.elements, self.text, self.byte_order, 352 | self.comments, self.obj_info)) 353 | 354 | 355 | def _open_stream(stream, read_or_write): 356 | if hasattr(stream, read_or_write): 357 | return (False, stream) 358 | try: 359 | return (True, open(stream, read_or_write[0] + 'b')) 360 | except TypeError: 361 | raise RuntimeError("expected open file or filename") 362 | 363 | 364 | class PlyElement(object): 365 | 366 | """ 367 | PLY file element. 368 | 369 | A client of this library doesn't normally need to instantiate this 370 | directly, so the following is only for the sake of documenting the 371 | internals. 372 | 373 | Creating a PlyElement instance is generally done in one of two ways: 374 | as a byproduct of PlyData.read (when reading a PLY file) and by 375 | PlyElement.describe (before writing a PLY file). 376 | 377 | """ 378 | 379 | def __init__(self, name, properties, count, comments=[]): 380 | """ 381 | This is not part of the public interface. The preferred methods 382 | of obtaining PlyElement instances are PlyData.read (to read from 383 | a file) and PlyElement.describe (to construct from a numpy 384 | array). 385 | 386 | """ 387 | self._name = str(name) 388 | self._check_name() 389 | self._count = count 390 | 391 | self._properties = tuple(properties) 392 | self._index() 393 | 394 | self.comments = list(comments) 395 | 396 | self._have_list = any(isinstance(p, PlyListProperty) 397 | for p in self.properties) 398 | 399 | @property 400 | def count(self): 401 | return self._count 402 | 403 | def _get_data(self): 404 | return self._data 405 | 406 | def _set_data(self, data): 407 | self._data = data 408 | self._count = len(data) 409 | self._check_sanity() 410 | 411 | data = property(_get_data, _set_data) 412 | 413 | def _check_sanity(self): 414 | for prop in self.properties: 415 | if prop.name not in self._data.dtype.fields: 416 | raise ValueError("dangling property %r" % prop.name) 417 | 418 | def _get_properties(self): 419 | return self._properties 420 | 421 | def _set_properties(self, properties): 422 | self._properties = tuple(properties) 423 | self._check_sanity() 424 | self._index() 425 | 426 | properties = property(_get_properties, _set_properties) 427 | 428 | def _index(self): 429 | self._property_lookup = dict((prop.name, prop) 430 | for prop in self._properties) 431 | if len(self._property_lookup) != len(self._properties): 432 | raise ValueError("two properties with same name") 433 | 434 | def ply_property(self, name): 435 | return self._property_lookup[name] 436 | 437 | @property 438 | def name(self): 439 | return self._name 440 | 441 | def _check_name(self): 442 | if any(c.isspace() for c in self._name): 443 | msg = "element name %r contains spaces" % self._name 444 | raise ValueError(msg) 445 | 446 | def dtype(self, byte_order='='): 447 | """ 448 | Return the numpy dtype of the in-memory representation of the 449 | data. (If there are no list properties, and the PLY format is 450 | binary, then this also accurately describes the on-disk 451 | representation of the element.) 452 | 453 | """ 454 | return [(prop.name, prop.dtype(byte_order)) 455 | for prop in self.properties] 456 | 457 | @staticmethod 458 | def _parse_multi(header_lines): 459 | """ 460 | Parse a list of PLY element definitions. 461 | 462 | """ 463 | elements = [] 464 | while header_lines: 465 | (elt, header_lines) = PlyElement._parse_one(header_lines) 466 | elements.append(elt) 467 | 468 | return elements 469 | 470 | @staticmethod 471 | def _parse_one(lines): 472 | """ 473 | Consume one element definition. The unconsumed input is 474 | returned along with a PlyElement instance. 475 | 476 | """ 477 | a = 0 478 | line = lines[a] 479 | 480 | if line[0] != 'element': 481 | raise PlyParseError("expected 'element'") 482 | if len(line) > 3: 483 | raise PlyParseError("too many fields after 'element'") 484 | if len(line) < 3: 485 | raise PlyParseError("too few fields after 'element'") 486 | 487 | (name, count) = (line[1], int(line[2])) 488 | 489 | comments = [] 490 | properties = [] 491 | while True: 492 | a += 1 493 | if a >= len(lines): 494 | break 495 | 496 | if lines[a][0] == 'comment': 497 | comments.append(lines[a][1]) 498 | elif lines[a][0] == 'property': 499 | properties.append(PlyProperty._parse_one(lines[a])) 500 | else: 501 | break 502 | 503 | return (PlyElement(name, properties, count, comments), 504 | lines[a:]) 505 | 506 | @staticmethod 507 | def describe(data, name, len_types={}, val_types={}, 508 | comments=[]): 509 | """ 510 | Construct a PlyElement from an array's metadata. 511 | 512 | len_types and val_types can be given as mappings from list 513 | property names to type strings (like 'u1', 'f4', etc., or 514 | 'int8', 'float32', etc.). These can be used to define the length 515 | and value types of list properties. List property lengths 516 | always default to type 'u1' (8-bit unsigned integer), and value 517 | types default to 'i4' (32-bit integer). 518 | 519 | """ 520 | if not isinstance(data, _np.ndarray): 521 | raise TypeError("only numpy arrays are supported") 522 | 523 | if len(data.shape) != 1: 524 | raise ValueError("only one-dimensional arrays are " 525 | "supported") 526 | 527 | count = len(data) 528 | 529 | properties = [] 530 | descr = data.dtype.descr 531 | 532 | for t in descr: 533 | if not isinstance(t[1], str): 534 | raise ValueError("nested records not supported") 535 | 536 | if not t[0]: 537 | raise ValueError("field with empty name") 538 | 539 | if len(t) != 2 or t[1][1] == 'O': 540 | # non-scalar field, which corresponds to a list 541 | # property in PLY. 542 | 543 | if t[1][1] == 'O': 544 | if len(t) != 2: 545 | raise ValueError("non-scalar object fields not " 546 | "supported") 547 | 548 | len_str = _data_type_reverse[len_types.get(t[0], 'u1')] 549 | if t[1][1] == 'O': 550 | val_type = val_types.get(t[0], 'i4') 551 | val_str = _lookup_type(val_type) 552 | else: 553 | val_str = _lookup_type(t[1][1:]) 554 | 555 | prop = PlyListProperty(t[0], len_str, val_str) 556 | else: 557 | val_str = _lookup_type(t[1][1:]) 558 | prop = PlyProperty(t[0], val_str) 559 | 560 | properties.append(prop) 561 | 562 | elt = PlyElement(name, properties, count, comments) 563 | elt.data = data 564 | 565 | return elt 566 | 567 | def _read(self, stream, text, byte_order): 568 | """ 569 | Read the actual data from a PLY file. 570 | 571 | """ 572 | if text: 573 | self._read_txt(stream) 574 | else: 575 | if self._have_list: 576 | # There are list properties, so a simple load is 577 | # impossible. 578 | self._read_bin(stream, byte_order) 579 | else: 580 | # There are no list properties, so loading the data is 581 | # much more straightforward. 582 | self._data = _np.fromfile(stream, 583 | self.dtype(byte_order), 584 | self.count) 585 | 586 | if len(self._data) < self.count: 587 | k = len(self._data) 588 | del self._data 589 | raise PlyParseError("early end-of-file", self, k) 590 | 591 | self._check_sanity() 592 | 593 | def _write(self, stream, text, byte_order): 594 | """ 595 | Write the data to a PLY file. 596 | 597 | """ 598 | if text: 599 | self._write_txt(stream) 600 | else: 601 | if self._have_list: 602 | # There are list properties, so serialization is 603 | # slightly complicated. 604 | self._write_bin(stream, byte_order) 605 | else: 606 | # no list properties, so serialization is 607 | # straightforward. 608 | self.data.astype(self.dtype(byte_order), 609 | copy=False).tofile(stream) 610 | 611 | def _read_txt(self, stream): 612 | """ 613 | Load a PLY element from an ASCII-format PLY file. The element 614 | may contain list properties. 615 | 616 | """ 617 | self._data = _np.empty(self.count, dtype=self.dtype()) 618 | 619 | k = 0 620 | for line in _islice(iter(stream.readline, b''), self.count): 621 | fields = iter(line.strip().split()) 622 | for prop in self.properties: 623 | try: 624 | self._data[prop.name][k] = prop._from_fields(fields) 625 | except StopIteration: 626 | raise PlyParseError("early end-of-line", 627 | self, k, prop) 628 | except ValueError: 629 | raise PlyParseError("malformed input", 630 | self, k, prop) 631 | try: 632 | next(fields) 633 | except StopIteration: 634 | pass 635 | else: 636 | raise PlyParseError("expected end-of-line", self, k) 637 | k += 1 638 | 639 | if k < self.count: 640 | del self._data 641 | raise PlyParseError("early end-of-file", self, k) 642 | 643 | def _write_txt(self, stream): 644 | """ 645 | Save a PLY element to an ASCII-format PLY file. The element may 646 | contain list properties. 647 | 648 | """ 649 | for rec in self.data: 650 | fields = [] 651 | for prop in self.properties: 652 | fields.extend(prop._to_fields(rec[prop.name])) 653 | 654 | _np.savetxt(stream, [fields], '%.18g', newline='\r\n') 655 | 656 | def _read_bin(self, stream, byte_order): 657 | """ 658 | Load a PLY element from a binary PLY file. The element may 659 | contain list properties. 660 | 661 | """ 662 | self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) 663 | 664 | for k in _range(self.count): 665 | for prop in self.properties: 666 | try: 667 | self._data[prop.name][k] = \ 668 | prop._read_bin(stream, byte_order) 669 | except StopIteration: 670 | raise PlyParseError("early end-of-file", 671 | self, k, prop) 672 | 673 | def _write_bin(self, stream, byte_order): 674 | """ 675 | Save a PLY element to a binary PLY file. The element may 676 | contain list properties. 677 | 678 | """ 679 | for rec in self.data: 680 | for prop in self.properties: 681 | prop._write_bin(rec[prop.name], stream, byte_order) 682 | 683 | @property 684 | def header(self): 685 | """ 686 | Format this element's metadata as it would appear in a PLY 687 | header. 688 | 689 | """ 690 | lines = ['element %s %d' % (self.name, self.count)] 691 | 692 | # Some information is lost here, since all comments are placed 693 | # between the 'element' line and the first property definition. 694 | for c in self.comments: 695 | lines.append('comment ' + c) 696 | 697 | lines.extend(list(map(str, self.properties))) 698 | 699 | return '\r\n'.join(lines) 700 | 701 | def __getitem__(self, key): 702 | return self.data[key] 703 | 704 | def __setitem__(self, key, value): 705 | self.data[key] = value 706 | 707 | def __str__(self): 708 | return self.header 709 | 710 | def __repr__(self): 711 | return ('PlyElement(%r, %r, count=%d, comments=%r)' % 712 | (self.name, self.properties, self.count, 713 | self.comments)) 714 | 715 | 716 | class PlyProperty(object): 717 | 718 | """ 719 | PLY property description. This class is pure metadata; the data 720 | itself is contained in PlyElement instances. 721 | 722 | """ 723 | 724 | def __init__(self, name, val_dtype): 725 | self._name = str(name) 726 | self._check_name() 727 | self.val_dtype = val_dtype 728 | 729 | def _get_val_dtype(self): 730 | return self._val_dtype 731 | 732 | def _set_val_dtype(self, val_dtype): 733 | self._val_dtype = _data_types[_lookup_type(val_dtype)] 734 | 735 | val_dtype = property(_get_val_dtype, _set_val_dtype) 736 | 737 | @property 738 | def name(self): 739 | return self._name 740 | 741 | def _check_name(self): 742 | if any(c.isspace() for c in self._name): 743 | msg = "Error: property name %r contains spaces" % self._name 744 | raise RuntimeError(msg) 745 | 746 | @staticmethod 747 | def _parse_one(line): 748 | assert line[0] == 'property' 749 | 750 | if line[1] == 'list': 751 | if len(line) > 5: 752 | raise PlyParseError("too many fields after " 753 | "'property list'") 754 | if len(line) < 5: 755 | raise PlyParseError("too few fields after " 756 | "'property list'") 757 | 758 | return PlyListProperty(line[4], line[2], line[3]) 759 | 760 | else: 761 | if len(line) > 3: 762 | raise PlyParseError("too many fields after " 763 | "'property'") 764 | if len(line) < 3: 765 | raise PlyParseError("too few fields after " 766 | "'property'") 767 | 768 | return PlyProperty(line[2], line[1]) 769 | 770 | def dtype(self, byte_order='='): 771 | """ 772 | Return the numpy dtype description for this property (as a tuple 773 | of strings). 774 | 775 | """ 776 | return byte_order + self.val_dtype 777 | 778 | def _from_fields(self, fields): 779 | """ 780 | Parse from generator. Raise StopIteration if the property could 781 | not be read. 782 | 783 | """ 784 | return _np.dtype(self.dtype()).type(next(fields)) 785 | 786 | def _to_fields(self, data): 787 | """ 788 | Return generator over one item. 789 | 790 | """ 791 | yield _np.dtype(self.dtype()).type(data) 792 | 793 | def _read_bin(self, stream, byte_order): 794 | """ 795 | Read data from a binary stream. Raise StopIteration if the 796 | property could not be read. 797 | 798 | """ 799 | try: 800 | return _np.fromfile(stream, self.dtype(byte_order), 1)[0] 801 | except IndexError: 802 | raise StopIteration 803 | 804 | def _write_bin(self, data, stream, byte_order): 805 | """ 806 | Write data to a binary stream. 807 | 808 | """ 809 | _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) 810 | 811 | def __str__(self): 812 | val_str = _data_type_reverse[self.val_dtype] 813 | return 'property %s %s' % (val_str, self.name) 814 | 815 | def __repr__(self): 816 | return 'PlyProperty(%r, %r)' % (self.name, 817 | _lookup_type(self.val_dtype)) 818 | 819 | 820 | class PlyListProperty(PlyProperty): 821 | 822 | """ 823 | PLY list property description. 824 | 825 | """ 826 | 827 | def __init__(self, name, len_dtype, val_dtype): 828 | PlyProperty.__init__(self, name, val_dtype) 829 | 830 | self.len_dtype = len_dtype 831 | 832 | def _get_len_dtype(self): 833 | return self._len_dtype 834 | 835 | def _set_len_dtype(self, len_dtype): 836 | self._len_dtype = _data_types[_lookup_type(len_dtype)] 837 | 838 | len_dtype = property(_get_len_dtype, _set_len_dtype) 839 | 840 | def dtype(self, byte_order='='): 841 | """ 842 | List properties always have a numpy dtype of "object". 843 | 844 | """ 845 | return '|O' 846 | 847 | def list_dtype(self, byte_order='='): 848 | """ 849 | Return the pair (len_dtype, val_dtype) (both numpy-friendly 850 | strings). 851 | 852 | """ 853 | return (byte_order + self.len_dtype, 854 | byte_order + self.val_dtype) 855 | 856 | def _from_fields(self, fields): 857 | (len_t, val_t) = self.list_dtype() 858 | 859 | n = int(_np.dtype(len_t).type(next(fields))) 860 | 861 | data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) 862 | if len(data) < n: 863 | raise StopIteration 864 | 865 | return data 866 | 867 | def _to_fields(self, data): 868 | """ 869 | Return generator over the (numerical) PLY representation of the 870 | list data (length followed by actual data). 871 | 872 | """ 873 | (len_t, val_t) = self.list_dtype() 874 | 875 | data = _np.asarray(data, dtype=val_t).ravel() 876 | 877 | yield _np.dtype(len_t).type(data.size) 878 | for x in data: 879 | yield x 880 | 881 | def _read_bin(self, stream, byte_order): 882 | (len_t, val_t) = self.list_dtype(byte_order) 883 | 884 | try: 885 | n = _np.fromfile(stream, len_t, 1)[0] 886 | except IndexError: 887 | raise StopIteration 888 | 889 | data = _np.fromfile(stream, val_t, n) 890 | if len(data) < n: 891 | raise StopIteration 892 | 893 | return data 894 | 895 | def _write_bin(self, data, stream, byte_order): 896 | """ 897 | Write data to a binary stream. 898 | 899 | """ 900 | (len_t, val_t) = self.list_dtype(byte_order) 901 | 902 | data = _np.asarray(data, dtype=val_t).ravel() 903 | 904 | _np.array(data.size, dtype=len_t).tofile(stream) 905 | data.tofile(stream) 906 | 907 | def __str__(self): 908 | len_str = _data_type_reverse[self.len_dtype] 909 | val_str = _data_type_reverse[self.val_dtype] 910 | return 'property list %s %s %s' % (len_str, val_str, self.name) 911 | 912 | def __repr__(self): 913 | return ('PlyListProperty(%r, %r, %r)' % 914 | (self.name, 915 | _lookup_type(self.len_dtype), 916 | _lookup_type(self.val_dtype))) 917 | 918 | 919 | def load_ply(file_name: str, 920 | with_faces: bool = False, 921 | with_color: bool = False) -> _np.ndarray: 922 | ply_data = PlyData.read(file_name) 923 | points = ply_data['vertex'] 924 | points = _np.vstack([points['x'], points['y'], points['z']]).T 925 | ret_val = [points] 926 | 927 | if with_faces: 928 | faces = _np.vstack(ply_data['face']['vertex_indices']) 929 | ret_val.append(faces) 930 | 931 | if with_color: 932 | r = _np.vstack(ply_data['vertex']['red']) 933 | g = _np.vstack(ply_data['vertex']['green']) 934 | b = _np.vstack(ply_data['vertex']['blue']) 935 | color = _np.hstack((r, g, b)) 936 | ret_val.append(color) 937 | 938 | if len(ret_val) == 1: # Unwrap the list 939 | ret_val = ret_val[0] 940 | 941 | return ret_val 942 | 943 | 944 | def quick_save_ply_file(points, filename: str): 945 | pl = len(points) 946 | header = \ 947 | "ply\n" \ 948 | "format binary_little_endian 1.0\n" \ 949 | "element vertex " + str(pl) + "\n" \ 950 | "property float x\n" \ 951 | "property float y\n" \ 952 | "property float z\n" \ 953 | "end_header\n" 954 | 955 | dtype_vertex = [('vertex', '= size[0]: 13 | return points[:size[0]] 14 | 15 | 16 | def generate_points(config, epoch, size, normalize_points=None): 17 | if normalize_points is None: 18 | normalize_points = config['target_network_input']['normalization']['enable'] 19 | 20 | if normalize_points and config['target_network_input']['normalization']['type'] == 'progressive': 21 | normalization_max_epoch = config['target_network_input']['normalization']['epoch'] 22 | 23 | normalization_coef = np.linspace(0, 1, normalization_max_epoch)[epoch - 1] \ 24 | if epoch <= normalization_max_epoch else 1 25 | points = generate_points_from_uniform_distribution(size=size) 26 | points[np.linalg.norm(points, axis=1) < normalization_coef] = \ 27 | normalization_coef * ( 28 | points[ 29 | np.linalg.norm(points, axis=1) < normalization_coef].T / 30 | torch.from_numpy( 31 | np.linalg.norm(points[np.linalg.norm(points, axis=1) < normalization_coef], axis=1)).float() 32 | ).T 33 | else: 34 | points = generate_points_from_uniform_distribution(size=size) 35 | 36 | return points 37 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/approxmatch.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | class Formatter { 11 | public: 12 | Formatter() {} 13 | ~Formatter() {} 14 | 15 | template Formatter &operator<<(const Type &value) { 16 | stream_ << value; 17 | return *this; 18 | } 19 | 20 | std::string str() const { return stream_.str(); } 21 | operator std::string() const { return stream_.str(); } 22 | 23 | enum ConvertToString { to_str }; 24 | 25 | std::string operator>>(ConvertToString) { return stream_.str(); } 26 | 27 | private: 28 | std::stringstream stream_; 29 | Formatter(const Formatter &); 30 | Formatter &operator=(Formatter &); 31 | }; 32 | 33 | 34 | __global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ 35 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 36 | float multiL,multiR; 37 | if (n>=m){ 38 | multiL=1; 39 | multiR=n/m; 40 | }else{ 41 | multiL=m/n; 42 | multiR=1; 43 | } 44 | const int Block=1024; 45 | __shared__ float buf[Block*4]; 46 | for (int i=blockIdx.x;i=-2;j--){ 55 | for (int j=7;j>-2;j--){ 56 | float level=-powf(4.0f,j); 57 | if (j==-2){ 58 | level=0; 59 | } 60 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out); 258 | //} 259 | 260 | __global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ 261 | __shared__ float sum_grad[256*3]; 262 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); 325 | //} 326 | 327 | /*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 328 | cudaStream_t stream)*/ 329 | // temp: TensorShape{b,(n+m)*2} 330 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){ 331 | approxmatchkernel 332 | <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp); 333 | 334 | cudaError_t err = cudaGetLastError(); 335 | if (cudaSuccess != err) 336 | throw std::runtime_error(Formatter() 337 | << "CUDA kernel failed : " << std::to_string(err)); 338 | } 339 | 340 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){ 341 | matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out); 342 | 343 | cudaError_t err = cudaGetLastError(); 344 | if (cudaSuccess != err) 345 | throw std::runtime_error(Formatter() 346 | << "CUDA kernel failed : " << std::to_string(err)); 347 | } 348 | 349 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){ 350 | matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1); 351 | matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2); 352 | 353 | cudaError_t err = cudaGetLastError(); 354 | if (cudaSuccess != err) 355 | throw std::runtime_error(Formatter() 356 | << "CUDA kernel failed : " << std::to_string(err)); 357 | } 358 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/match_cost.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | from utils.pytorch_structural_losses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 3 | 4 | # Inherit from Function 5 | class MatchCostFunction(Function): 6 | # Note that both forward and backward are @staticmethods 7 | @staticmethod 8 | # bias is an optional argument 9 | def forward(ctx, seta, setb): 10 | #print("Match Cost Forward") 11 | ctx.save_for_backward(seta, setb) 12 | ''' 13 | input: 14 | set1 : batch_size * #dataset_points * 3 15 | set2 : batch_size * #query_points * 3 16 | returns: 17 | match : batch_size * #query_points * #dataset_points 18 | ''' 19 | match, temp = ApproxMatch(seta, setb) 20 | ctx.match = match 21 | cost = MatchCost(seta, setb, match) 22 | # If you want to return matching matrix too, swap the return lines 23 | # Remember to use method `match_cost()` directly, because method 24 | # `utils.metrics.earth_mover_distance()` will try to divide the tuple by 25 | # the batch size. We omit the if statement for performance purposes. 26 | # return match, temp, cost 27 | return cost 28 | 29 | """ 30 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 31 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 32 | """ 33 | # This function has only a single output, so it gets only one gradient 34 | @staticmethod 35 | def backward(ctx, grad_output): 36 | #print("Match Cost Backward") 37 | # This is a pattern that is very convenient - at the top of backward 38 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 39 | # None. Thanks to the fact that additional trailing Nones are 40 | # ignored, the return statement is simple even when the function has 41 | # optional inputs. 42 | seta, setb = ctx.saved_tensors 43 | #grad_input = grad_weight = grad_bias = None 44 | grada, gradb = MatchCostGrad(seta, setb, ctx.match) 45 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) 46 | return grada*grad_output_expand, gradb*grad_output_expand 47 | 48 | match_cost = MatchCostFunction.apply 49 | 50 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/nn_distance.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | from utils.pytorch_structural_losses.StructuralLossesBackend import NNDistance, NNDistanceGrad 4 | 5 | # Inherit from Function 6 | class NNDistanceFunction(Function): 7 | # Note that both forward and backward are @staticmethods 8 | @staticmethod 9 | # bias is an optional argument 10 | def forward(ctx, seta, setb): 11 | #print("Match Cost Forward") 12 | ctx.save_for_backward(seta, setb) 13 | ''' 14 | input: 15 | set1 : batch_size * #dataset_points * 3 16 | set2 : batch_size * #query_points * 3 17 | returns: 18 | dist1, idx1, dist2, idx2 19 | ''' 20 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb) 21 | # print(dist1, idx1, dist2, idx2, flush=True) 22 | ctx.idx1 = idx1 23 | ctx.idx2 = idx2 24 | return dist1, dist2 25 | 26 | # This function has only a single output, so it gets only one gradient 27 | @staticmethod 28 | def backward(ctx, grad_dist1, grad_dist2): 29 | #print("Match Cost Backward") 30 | # This is a pattern that is very convenient - at the top of backward 31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 32 | # None. Thanks to the fact that additional trailing Nones are 33 | # ignored, the return statement is simple even when the function has 34 | # optional inputs. 35 | seta, setb = ctx.saved_tensors 36 | idx1 = ctx.idx1 37 | idx2 = ctx.idx2 38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) 39 | return grada, gradb 40 | 41 | nn_distance = NNDistanceFunction.apply 42 | 43 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/nndistance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 9 | const int batch=512; 10 | __shared__ float buf[batch*3]; 11 | for (int i=blockIdx.x;ibest){ 123 | result[(i*n+j)]=best; 124 | result_i[(i*n+j)]=best_i; 125 | } 126 | } 127 | __syncthreads(); 128 | } 129 | } 130 | } 131 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 132 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 133 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 134 | } 135 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 136 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 159 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 160 | } 161 | 162 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | # Python interface 5 | setup( 6 | name='PyTorchStructuralLosses', 7 | ext_modules=[ 8 | CUDAExtension( 9 | name='StructuralLossesBackend', 10 | sources=[ 11 | 'structural_loss.cpp', 12 | 'approxmatch.cu', 13 | 'nndistance.cu' 14 | ], 15 | ) 16 | ], 17 | cmdclass={'build_ext': BuildExtension}, 18 | ) 19 | -------------------------------------------------------------------------------- /utils/pytorch_structural_losses/structural_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); 12 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); 13 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); 14 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 15 | void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 16 | 17 | 18 | /* 19 | input: 20 | set1 : batch_size * #dataset_points * 3 21 | set2 : batch_size * #query_points * 3 22 | returns: 23 | match : batch_size * #query_points * #dataset_points 24 | */ 25 | // temp: TensorShape{b,(n+m)*2} 26 | std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { 27 | //std::cout << "[ApproxMatch] Called." << std::endl; 28 | int64_t batch_size = set_d.size(0); 29 | int64_t n_dataset_points = set_d.size(1); // n 30 | int64_t n_query_points = set_q.size(1); // m 31 | //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; 32 | at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 33 | at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 34 | CHECK_INPUT(set_d); 35 | CHECK_INPUT(set_q); 36 | CHECK_INPUT(match); 37 | CHECK_INPUT(temp); 38 | 39 | approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); 40 | return {match, temp}; 41 | } 42 | 43 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 44 | //std::cout << "[MatchCost] Called." << std::endl; 45 | int64_t batch_size = set_d.size(0); 46 | int64_t n_dataset_points = set_d.size(1); // n 47 | int64_t n_query_points = set_q.size(1); // m 48 | //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; 49 | at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 50 | CHECK_INPUT(set_d); 51 | CHECK_INPUT(set_q); 52 | CHECK_INPUT(match); 53 | CHECK_INPUT(out); 54 | matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); 55 | return out; 56 | } 57 | 58 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 59 | //std::cout << "[MatchCostGrad] Called." << std::endl; 60 | int64_t batch_size = set_d.size(0); 61 | int64_t n_dataset_points = set_d.size(1); // n 62 | int64_t n_query_points = set_q.size(1); // m 63 | //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; 64 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 65 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 66 | CHECK_INPUT(set_d); 67 | CHECK_INPUT(set_q); 68 | CHECK_INPUT(match); 69 | CHECK_INPUT(grad1); 70 | CHECK_INPUT(grad2); 71 | matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); 72 | return {grad1, grad2}; 73 | } 74 | 75 | 76 | /* 77 | input: 78 | set_d : batch_size * #dataset_points * 3 79 | set_q : batch_size * #query_points * 3 80 | returns: 81 | dist1, idx1 : batch_size * #dataset_points 82 | dist2, idx2 : batch_size * #query_points 83 | */ 84 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { 85 | //std::cout << "[NNDistance] Called." << std::endl; 86 | int64_t batch_size = set_d.size(0); 87 | int64_t n_dataset_points = set_d.size(1); // n 88 | int64_t n_query_points = set_q.size(1); // m 89 | //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; 90 | at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 91 | at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 92 | at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 93 | at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 94 | CHECK_INPUT(set_d); 95 | CHECK_INPUT(set_q); 96 | CHECK_INPUT(dist1); 97 | CHECK_INPUT(idx1); 98 | CHECK_INPUT(dist2); 99 | CHECK_INPUT(idx2); 100 | // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 101 | nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); 102 | return {dist1, idx1, dist2, idx2}; 103 | } 104 | 105 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { 106 | //std::cout << "[NNDistanceGrad] Called." << std::endl; 107 | int64_t batch_size = set_d.size(0); 108 | int64_t n_dataset_points = set_d.size(1); // n 109 | int64_t n_query_points = set_q.size(1); // m 110 | //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; 111 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 112 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 113 | CHECK_INPUT(set_d); 114 | CHECK_INPUT(set_q); 115 | CHECK_INPUT(idx1); 116 | CHECK_INPUT(idx2); 117 | CHECK_INPUT(grad_dist1); 118 | CHECK_INPUT(grad_dist2); 119 | CHECK_INPUT(grad1); 120 | CHECK_INPUT(grad2); 121 | //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 122 | nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), 123 | grad_dist1.data(),idx1.data(), 124 | grad_dist2.data(),idx2.data(), 125 | grad1.data(),grad2.data(), 126 | at::cuda::getCurrentCUDAStream()); 127 | return {grad1, grad2}; 128 | } 129 | 130 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 131 | m.def("ApproxMatch", &ApproxMatch); 132 | m.def("MatchCost", &MatchCost); 133 | m.def("MatchCostGrad", &MatchCostGrad); 134 | m.def("NNDistance", &NNDistance); 135 | m.def("NNDistanceGrad", &NNDistanceGrad); 136 | } -------------------------------------------------------------------------------- /utils/sphere_triangles.py: -------------------------------------------------------------------------------- 1 | import matplotlib.tri as mtri 2 | import numpy as np 3 | from collections import namedtuple 4 | import torch 5 | 6 | # code from 7 | # https://github.com/gmum/3d-point-clouds-HyperCloud/blob/master/utils/sphere_triangles.py 8 | 9 | Triangle = namedtuple("Triangle", "a,b,c") 10 | Point = namedtuple("Point", "x,y,z") 11 | 12 | 13 | def normalize(p): 14 | s = sum(u*u for u in p) ** 0.5 15 | return Point(*(u/s for u in p)) 16 | 17 | 18 | def midpoint(u, v): 19 | return Point(*((a+b)/2 for a, b in zip(u, v))) 20 | 21 | 22 | def subdivide_hybrid3(tri, depth): 23 | def triangle(tri, depth): 24 | if depth == 0: 25 | yield tri 26 | return 27 | for t in subdivide_centroid(tri, 1): 28 | yield from edge(t, depth - 1) 29 | 30 | def centroid(tri, depth): 31 | if depth == 0: 32 | yield tri 33 | return 34 | for t in subdivide_midpoint(tri, 2): 35 | yield from triangle(t, depth - 1) 36 | 37 | def edge(tri, depth): 38 | if depth == 0: 39 | yield tri 40 | return 41 | for t in subdivide_edge(tri, 1): 42 | yield from centroid(t, depth - 1) 43 | 44 | return centroid(tri, depth) 45 | 46 | 47 | def subdivide_hybrid2(tri, depth): 48 | def centroid(tri, depth): 49 | if depth == 0: 50 | yield tri 51 | return 52 | for t in subdivide_centroid(tri, 1): 53 | yield from edge(t, depth - 1) 54 | 55 | def edge(tri, depth): 56 | if depth == 0: 57 | yield tri 58 | return 59 | for t in subdivide_edge(tri, 1): 60 | yield from centroid(t, depth - 1) 61 | 62 | return centroid(tri, depth) 63 | 64 | 65 | def subdivide_hybrid(tri, depth): 66 | def centroid(tri, depth): 67 | if depth == 0: 68 | yield tri 69 | return 70 | for t in subdivide_centroid(tri, 1): 71 | yield from edge(t, depth - 1) 72 | 73 | def edge(tri, depth): 74 | if depth == 0: 75 | yield tri 76 | return 77 | for t in subdivide_edge(tri, 1): 78 | yield from centroid(t, depth - 1) 79 | 80 | return edge(tri, depth) 81 | 82 | 83 | def subdivide_midpoint2(tri, depth): 84 | if depth == 0: 85 | yield tri 86 | return 87 | # p0 88 | # /|\ 89 | # / | \ 90 | # / | \ 91 | # /___|___\ 92 | # p1 m12 p2 93 | p0, p1, p2 = tri 94 | m12 = normalize(midpoint(p1, p2)) 95 | # WRONG TRIANGULATION! 96 | yield from subdivide_midpoint2(Triangle(p0, m12, p1), depth-1) 97 | yield from subdivide_midpoint2(Triangle(p0, p2, m12), depth-1) 98 | 99 | 100 | def subdivide_midpoint(tri, depth): 101 | if depth == 0: 102 | yield tri 103 | return 104 | # p0 105 | # /|\ 106 | # / | \ 107 | # / | \ 108 | # /___|___\ 109 | # p1 m12 p2 110 | p0, p1, p2 = tri 111 | m12 = normalize(midpoint(p1, p2)) 112 | yield from subdivide_midpoint(Triangle(m12, p0, p1), depth-1) 113 | yield from subdivide_midpoint(Triangle(m12, p2, p0), depth-1) 114 | 115 | 116 | def subdivide_edge(tri, depth): 117 | if depth == 0: 118 | yield tri 119 | return 120 | # p0 121 | # / \ 122 | # m01 /....\ m02 123 | # / \ / \ 124 | # /___\/___\ 125 | # p1 m12 p2 126 | p0, p1, p2 = tri 127 | m01 = normalize(midpoint(p0, p1)) 128 | m02 = normalize(midpoint(p0, p2)) 129 | m12 = normalize(midpoint(p1, p2)) 130 | triangles = [ 131 | Triangle(p0, m01, m02), 132 | Triangle(m01, p1, m12), 133 | Triangle(m02, m12, p2), 134 | Triangle(m01, m02, m12), 135 | ] 136 | for t in triangles: 137 | yield from subdivide_edge(t, depth-1) 138 | 139 | 140 | def subdivide_centroid(tri, depth): 141 | if depth == 0: 142 | yield tri 143 | return 144 | # p0 145 | # / \ 146 | # / \ 147 | # / c \ 148 | # /_______\ 149 | # p1 p2 150 | p0, p1, p2 = tri 151 | centroid = normalize(Point( 152 | (p0.x + p1.x + p2.x) / 3, 153 | (p0.y + p1.y + p2.y) / 3, 154 | (p0.z + p1.z + p2.z) / 3, 155 | )) 156 | t1 = Triangle(p0, p1, centroid) 157 | t2 = Triangle(p2, centroid, p0) 158 | t3 = Triangle(centroid, p1, p2) 159 | 160 | yield from subdivide_centroid(t1, depth - 1) 161 | yield from subdivide_centroid(t2, depth - 1) 162 | yield from subdivide_centroid(t3, depth - 1) 163 | 164 | 165 | def subdivide(faces, depth, method): 166 | for tri in faces: 167 | yield from method(tri, depth) 168 | 169 | 170 | def generate(_method, _depth): 171 | method = { 172 | "hybrid": subdivide_hybrid, 173 | "hybrid2": subdivide_hybrid2, 174 | "hybrid3": subdivide_hybrid3, 175 | "midpoint": subdivide_midpoint, 176 | "midpoint2": subdivide_midpoint2, 177 | "centroid": subdivide_centroid, 178 | "edge": subdivide_edge, 179 | }[_method] 180 | depth = int(_depth) 181 | 182 | # octahedron 183 | p = 2**0.5 / 2 184 | faces = [ 185 | # top half 186 | Triangle(Point(0, 1, 0), Point(-p, 0, p), Point( p, 0, p)), 187 | Triangle(Point(0, 1, 0), Point( p, 0, p), Point( p, 0,-p)), 188 | Triangle(Point(0, 1, 0), Point( p, 0,-p), Point(-p, 0,-p)), 189 | Triangle(Point(0, 1, 0), Point(-p, 0,-p), Point(-p, 0, p)), 190 | 191 | # bottom half 192 | Triangle(Point(0,-1, 0), Point( p, 0, p), Point(-p, 0, p)), 193 | Triangle(Point(0,-1, 0), Point( p, 0,-p), Point( p, 0, p)), 194 | Triangle(Point(0,-1, 0), Point(-p, 0,-p), Point( p, 0,-p)), 195 | Triangle(Point(0,-1, 0), Point(-p, 0, p), Point(-p, 0,-p)), 196 | ] 197 | 198 | X = [] 199 | Y = [] 200 | Z = [] 201 | T = [] 202 | 203 | for i, tri in enumerate(subdivide(faces, depth, method)): 204 | X.extend([p.x for p in tri]) 205 | Y.extend([p.y for p in tri]) 206 | Z.extend([p.z for p in tri]) 207 | T.append([3*i, 3*i+1, 3*i+2]) 208 | 209 | X = np.array(X) 210 | Y = np.array(Y) 211 | Z = np.array(Z) 212 | T = mtri.Triangulation(X, Y, np.array(T)) 213 | points = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)), axis=1) 214 | 215 | return torch.from_numpy(points).float(), T 216 | -------------------------------------------------------------------------------- /utils/telegram_logging.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | import requests 5 | 6 | 7 | class TelegramLogger(object): 8 | 9 | @staticmethod 10 | def get_logger(config): 11 | return TelegramLogger(config['bot_token'], config['chat_id']) 12 | 13 | def __init__(self, bot_token: str, chat_id: str, disable_req_log: bool = True): 14 | self._api_url = f'https://api.telegram.org/bot{bot_token}/' 15 | self._message_url = self._api_url + 'sendMessage' 16 | self._image_url = self._api_url + 'sendMediaGroup' 17 | self._chat_id = chat_id 18 | 19 | if disable_req_log: 20 | import logging 21 | logging.getLogger("requests").setLevel(logging.CRITICAL) 22 | logging.getLogger("urllib3").setLevel(logging.CRITICAL) 23 | 24 | def log(self, message: str): 25 | try: 26 | send_data = { 27 | 'chat_id': self._chat_id, 28 | 'text': message, 29 | } 30 | requests.post(self._message_url, json=send_data) 31 | except Exception: 32 | pass 33 | 34 | def log_images(self, image_paths: List[str], message: str = ''): 35 | try: 36 | send_data = { 37 | 'chat_id': self._chat_id, 38 | 'media': json.dumps([ 39 | { 40 | 'type': 'photo', 41 | 'media': f'attach://image_{i}.png', 42 | 'caption': message if i == 0 else '', 43 | } for i in range(len(image_paths)) 44 | ]) 45 | } 46 | files = {f'image_{i}.png': open(image_path, 'rb') for i, image_path in enumerate(image_paths)} 47 | requests.post(self._image_url, params=send_data, files=files) 48 | except Exception: 49 | pass 50 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from os import listdir 3 | from os.path import join, exists 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from datasets.utils.shapenet_category_mapping import synth_id_to_category 10 | from utils.pcutil import plot_3d_point_cloud 11 | 12 | 13 | def find_latest_epoch(dirpath): 14 | # Files with weights are in format ddddd_{D,E,G}.pth 15 | epoch_regex = re.compile(r'^(?P\d+)_([DEG]|model)\.pth$') 16 | epochs_completed = [] 17 | if exists(join(dirpath, 'weights')): 18 | dirpath = join(dirpath, 'weights') 19 | for f in listdir(dirpath): 20 | m = epoch_regex.match(f) 21 | if m: 22 | epochs_completed.append(int(m.group('n_epoch'))) 23 | return max(epochs_completed) if epochs_completed else 0 24 | 25 | 26 | def get_classes_dir(config): 27 | return 'all' if not config.get('classes') else '_'.join(config['classes']) 28 | 29 | 30 | def get_distribution_dir(config): 31 | normed_str = '' 32 | if config['target_network_input']['normalization']['enable']: 33 | if config['target_network_input']['normalization']['type'] == 'progressive': 34 | norm_max_epoch = config['target_network_input']['normalization']['epoch'] 35 | normed_str = 'normed_progressive_to_epoch_%d' % norm_max_epoch 36 | 37 | return '%s%s' % ('uniform', '_' + normed_str if normed_str else '') 38 | 39 | 40 | def get_model_name(config): 41 | model_name = '' 42 | encoders_num = 0 43 | real_size = config['full_model']['real_encoder']['output_size'] 44 | random_size = config['full_model']['random_encoder']['output_size'] 45 | 46 | if real_size > 0: 47 | encoders_num += 1 48 | model_name += str(real_size) 49 | 50 | if random_size > 0: 51 | encoders_num += 1 52 | model_name += 'x' + str(random_size) if real_size >0 else str(random_size) 53 | 54 | model_name = str(encoders_num) + 'e' + model_name 55 | 56 | model_name += config['training']['lr_scheduler']['type'] 57 | 58 | for k, v in config['training']['lr_scheduler']['hyperparams'].items(): 59 | model_name += '_' + k + str(v).replace(' ', '') 60 | 61 | return model_name 62 | 63 | 64 | def show_3d_cloud(points_cloud): 65 | import pptk 66 | pptk.viewer(points_cloud).set() 67 | 68 | 69 | def replace_and_rename_pcd_file(source, dest): 70 | from shutil import copyfile 71 | model_ids = listdir(source) 72 | for model_id in model_ids: 73 | for sample in listdir(join(source, model_id)): 74 | for filename in listdir(join(source, model_id, sample)): 75 | copyfile(join(source, model_id, sample, filename), join(dest, f'{model_id}_{sample}_{filename}')) 76 | 77 | 78 | def get_filenames_by_cat(path) -> pd.DataFrame: 79 | filenames = [] 80 | for category_id in synth_id_to_category.keys(): 81 | for f in listdir(join(path, category_id)): 82 | if f not in ['.DS_Store']: 83 | filenames.append((category_id, f)) 84 | return pd.DataFrame(filenames, columns=['category', 'filename']) 85 | 86 | 87 | def save_plot(X, epoch, k, results_dir, t): 88 | fig = plot_3d_point_cloud(X[0], X[1], X[2], in_u_sphere=True, show=False, title=f'{t}_{k} epoch: {epoch}') 89 | fig_path = join(results_dir, f'{epoch}_{k}_{t}.png') 90 | fig.savefig(fig_path) 91 | plt.close(fig) 92 | return fig_path 93 | 94 | 95 | def resample_pcd(pcd, n): 96 | """Drop or duplicate points so that pcd has exactly n points""" 97 | idx = np.random.permutation(pcd.shape[0]) 98 | if idx.shape[0] < n: 99 | idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n - pcd.shape[0])]) 100 | return pcd[idx[:n]] 101 | --------------------------------------------------------------------------------