├── utils ├── __init__.py ├── libkdtree │ ├── README │ ├── .gitignore │ ├── pykdtree │ │ ├── __init__.py │ │ ├── render_template.py │ │ └── kdtree.pyx │ ├── setup.cfg │ ├── MANIFEST.in │ ├── __init__.py │ ├── README.rst │ └── LICENSE.txt ├── libmcubes │ ├── .gitignore │ ├── pyarray_symbol.h │ ├── __init__.py │ ├── pywrapper.h │ ├── LICENSE │ ├── mcubes.pyx │ ├── exporter.py │ ├── README.rst │ ├── pyarraymodule.h │ └── pywrapper.cpp ├── libmise │ ├── .gitignore │ ├── __init__.py │ └── test.py ├── libsimplify │ ├── test.py │ ├── __init__.py │ └── simplify_mesh.pyx ├── visualize.py └── binvox_rw.py ├── .gitignore ├── .vscode └── settings.json ├── run_train.sh ├── configs ├── train │ ├── reg.yaml │ └── nreg.yaml ├── test │ ├── reg_subsamp.yaml │ ├── reg_ideal.yaml │ └── reg_noisy.yaml └── default.yaml ├── run_test.sh ├── models ├── __init__.py ├── encoder_latent.py ├── legacy.py └── occnet.py ├── encoders ├── __init__.py ├── pix2mesh_cond.py ├── voxels.py ├── psgn_cond.py ├── r2n2.py ├── conv.py └── pointnet.py ├── README.md ├── register_utils.py ├── callbacks.py ├── metricrecord.py ├── train.py ├── misc.py ├── field_tfs.py ├── generation.py ├── se3.py ├── checkpoints.py ├── so3.py ├── sinc.py ├── fmr_transforms.py ├── test.py ├── fields.py ├── transforms.py ├── common.py ├── layers.py └── dataset.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.so 3 | out/ -------------------------------------------------------------------------------- /utils/libkdtree/README: -------------------------------------------------------------------------------- 1 | README.rst -------------------------------------------------------------------------------- /utils/libkdtree/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | -------------------------------------------------------------------------------- /utils/libkdtree/pykdtree/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/libmcubes/.gitignore: -------------------------------------------------------------------------------- 1 | PyMCubes.egg-info 2 | build 3 | -------------------------------------------------------------------------------- /utils/libmise/.gitignore: -------------------------------------------------------------------------------- 1 | mise.c 2 | mise.cpp 3 | mise.html 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": ["./models"] 3 | } -------------------------------------------------------------------------------- /utils/libkdtree/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_rpm] 2 | requires=numpy 3 | release=1 4 | 5 | 6 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python train.py configs/train/reg.yaml --exit-after 100000 -------------------------------------------------------------------------------- /utils/libkdtree/MANIFEST.in: -------------------------------------------------------------------------------- 1 | exclude pykdtree/render_template.py 2 | include LICENSE.txt 3 | -------------------------------------------------------------------------------- /utils/libmcubes/pyarray_symbol.h: -------------------------------------------------------------------------------- 1 | 2 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 3 | -------------------------------------------------------------------------------- /utils/libmise/__init__.py: -------------------------------------------------------------------------------- 1 | from .mise import MISE 2 | 3 | 4 | __all__ = [ 5 | MISE 6 | ] 7 | -------------------------------------------------------------------------------- /utils/libkdtree/__init__.py: -------------------------------------------------------------------------------- 1 | from .pykdtree.kdtree import KDTree 2 | 3 | 4 | __all__ = [ 5 | KDTree 6 | ] 7 | -------------------------------------------------------------------------------- /configs/train/reg.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | out_dir: out/reg 3 | 4 | checkpoint: 5 | model_selection_metric: angle 6 | model_selection_mode: minimize -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | python test.py configs/test/reg_ideal.yaml 4 | # python test.py configs/test/reg_noisy.yaml 5 | # python test.py configs/test/reg_subsamp.yaml -------------------------------------------------------------------------------- /utils/libsimplify/test.py: -------------------------------------------------------------------------------- 1 | from simplify_mesh import mesh_simplify 2 | import numpy as np 3 | 4 | v = np.random.rand(100, 3) 5 | f = np.random.choice(range(100), (50, 3)) 6 | 7 | mesh_simplify(v, f, 50) -------------------------------------------------------------------------------- /configs/train/nreg.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | out_dir: out/nreg 3 | 4 | data: 5 | train: 6 | reg: false 7 | presamp_n: 1200 8 | val: 9 | reg: true 10 | presamp_n: 2048 11 | 12 | 13 | -------------------------------------------------------------------------------- /utils/libkdtree/pykdtree/render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from mako.template import Template 4 | 5 | mytemplate = Template(filename='_kdtree_core.c.mako') 6 | with open('_kdtree_core.c', 'w') as fp: 7 | fp.write(mytemplate.render()) 8 | -------------------------------------------------------------------------------- /utils/libmcubes/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcubes import ( 2 | marching_cubes, marching_cubes_func 3 | ) 4 | from .exporter import ( 5 | export_mesh, export_obj, export_off 6 | ) 7 | 8 | 9 | __all__ = [ 10 | marching_cubes, marching_cubes_func, 11 | export_mesh, export_obj, export_off 12 | ] 13 | -------------------------------------------------------------------------------- /configs/test/reg_subsamp.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: out/reg/cfg.yaml 2 | 3 | testing: 4 | out_dir: test_subsamp 5 | model_file: model_best.pt 6 | 7 | data: 8 | test: 9 | reg: true 10 | presamp_n: 1024 11 | noise: 0.01 12 | resamp: false 13 | rotate: 180 14 | 15 | subsamp: true 16 | n1: 1024 17 | n2_min: 512 18 | n2_max: 512 19 | centralize: false -------------------------------------------------------------------------------- /configs/test/reg_ideal.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: out/reg/cfg.yaml 2 | 3 | testing: 4 | out_dir: test_ideal 5 | model_file: model_best.pt 6 | 7 | data: 8 | test: 9 | reg: true 10 | presamp_n: 1024 11 | noise: 0.0 12 | resamp: false 13 | rotate: 180 14 | 15 | subsamp: false 16 | # n1: 1024 17 | # n2_min: 500 18 | # n2_max: 1200 19 | centralize: false -------------------------------------------------------------------------------- /configs/test/reg_noisy.yaml: -------------------------------------------------------------------------------- 1 | inherit_from: out/reg/cfg.yaml 2 | 3 | testing: 4 | out_dir: test_noisy 5 | model_file: model_best.pt 6 | 7 | data: 8 | test: 9 | reg: true 10 | presamp_n: 1024 11 | noise: 0.01 12 | resamp: false 13 | rotate: 180 14 | 15 | subsamp: false 16 | # n1: 1024 17 | # n2_min: 500 18 | # n2_max: 1200 19 | centralize: false -------------------------------------------------------------------------------- /utils/libsimplify/__init__.py: -------------------------------------------------------------------------------- 1 | from .simplify_mesh import ( 2 | mesh_simplify 3 | ) 4 | import trimesh 5 | 6 | 7 | def simplify_mesh(mesh, f_target=10000, agressiveness=7.): 8 | vertices = mesh.vertices 9 | faces = mesh.faces 10 | 11 | vertices, faces = mesh_simplify(vertices, faces, f_target, agressiveness) 12 | 13 | mesh_simplified = trimesh.Trimesh(vertices, faces, process=False) 14 | 15 | return mesh_simplified 16 | -------------------------------------------------------------------------------- /utils/libmcubes/pywrapper.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _PYWRAPPER_H 3 | #define _PYWRAPPER_H 4 | 5 | #include 6 | #include "pyarraymodule.h" 7 | 8 | #include 9 | 10 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue); 11 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue); 12 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue); 13 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 14 | int numx, int numy, int numz, PyObject* f, double isovalue); 15 | 16 | #endif // _PYWRAPPER_H 17 | -------------------------------------------------------------------------------- /utils/libmise/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mise import MISE 3 | import time 4 | 5 | t0 = time.time() 6 | extractor = MISE(1, 2, 0.) 7 | 8 | p = extractor.query() 9 | i = 0 10 | 11 | while p.shape[0] != 0: 12 | print(i) 13 | print(p) 14 | v = 2 * (p.sum(axis=-1) > 2).astype(np.float64) - 1 15 | extractor.update(p, v) 16 | p = extractor.query() 17 | i += 1 18 | if (i >= 8): 19 | break 20 | 21 | print(extractor.to_dense()) 22 | # p, v = extractor.get_points() 23 | # print(p) 24 | # print(v) 25 | print('Total time: %f' % (time.time() - t0)) 26 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributions as dist 4 | from . import encoder_latent, decoder 5 | from .occnet import OccupancyNetwork 6 | 7 | # Encoder latent dictionary 8 | encoder_latent_dict = { 9 | 'simple': encoder_latent.Encoder, 10 | } 11 | 12 | # Decoder dictionary 13 | decoder_dict = { 14 | 'simple': decoder.Decoder, 15 | 'cbatchnorm': decoder.DecoderCBatchNorm, 16 | 'cbatchnorm2': decoder.DecoderCBatchNorm2, 17 | 'batchnorm': decoder.DecoderBatchNorm, 18 | 'cbatchnorm_noresnet': decoder.DecoderCBatchNormNoResnet, 19 | 'cbatchnorm_vn': decoder.VNDecoderCBatchNorm 20 | } 21 | -------------------------------------------------------------------------------- /encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | conv, pix2mesh_cond, pointnet, 3 | psgn_cond, r2n2, voxels, 4 | ) 5 | 6 | 7 | encoder_dict = { 8 | 'simple_conv': conv.ConvEncoder, 9 | 'resnet18': conv.Resnet18, 10 | 'resnet34': conv.Resnet34, 11 | 'resnet50': conv.Resnet50, 12 | 'resnet101': conv.Resnet101, 13 | 'r2n2_simple': r2n2.SimpleConv, 14 | 'r2n2_resnet': r2n2.Resnet, 15 | 'pointnet_simple': pointnet.SimplePointnet, 16 | 'pointnet_resnet': pointnet.ResnetPointnet, 17 | 'psgn_cond': psgn_cond.PCGN_Cond, 18 | 'voxel_simple': voxels.VoxelEncoder, 19 | 'pixel2mesh_cond': pix2mesh_cond.Pix2mesh_Cond, 20 | 'pointnet_resnet_vn': pointnet.VNResnetPointnet, 21 | } 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EquivReg 2 | Official repo for CoRL 2021 paper **Correspondence-Free Point Cloud Registration with SO(3)-Equivariant Implicit Shape Representations** [(link)](https://proceedings.mlr.press/v164/zhu22b.html) 3 | 4 | ## Environment 5 | This repo has the same environment as the [occupancy network repo](https://github.com/autonomousvision/occupancy_networks). The code is also developed based on that repo. 6 | ## Dataset 7 | The preprocessed ModelNet40 dataset can be downloaded at this [Google Drive link](https://drive.google.com/file/d/1XU62rCk-S9OB_Hn7Z7I0D9aUmFuHCBpz/view?usp=share_link). It is processed by this [repo](https://github.com/davidstutz/mesh-fusion) to obtain water-tight meshes and occupancy value for points in the space, which are not available in the original ModelNet40 dataset (mentioned in the OccNet repo). Extract the files and create a symbolic link named `ModelNet40_install` under the root of this repo. 8 | 9 | ## Training and testing 10 | Examples are given in the files `run_train.sh` and `run_test.sh`. 11 | 12 | ## Citation 13 | If this work is helpful for your research, please consider citing our work: 14 | ``` 15 | @inproceedings{zhu2022correspondence, 16 | title={Correspondence-free point cloud registration with SO (3)-equivariant implicit shape representations}, 17 | author={Zhu, Minghan and Ghaffari, Maani and Peng, Huei}, 18 | booktitle={Conference on Robot Learning}, 19 | pages={1412--1422}, 20 | year={2022}, 21 | organization={PMLR} 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /register_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import torch.distributions as dist 3 | import logging 4 | 5 | import numpy as np 6 | 7 | def solve_R(f1, f2): 8 | """f1 and f2: (b*)m*3 9 | only work for batch_size=1 10 | """ 11 | S = torch.matmul(f1.transpose(-1, -2), f2) # 3*3 12 | U, sigma, V = torch.svd(S) 13 | R = torch.matmul(V, U.transpose(-1, -2)) 14 | det = torch.det(R) 15 | # logging.info(R) 16 | diag_1 = torch.tensor([1, 1, 0], device=R.device, dtype=R.dtype) 17 | diag_2 = torch.tensor([0, 0, 1], device=R.device, dtype=R.dtype) 18 | det_mat = torch.diag(diag_1 + diag_2 * det) 19 | 20 | # det_mat = torch.eye(3, device=R.device, dtype=R.dtype) 21 | # det_mat[2, 2] = det 22 | 23 | det_mat = det_mat.unsqueeze(0) 24 | # logging.info(det_mat) 25 | R = torch.matmul(V, torch.matmul(det_mat, U.transpose(-1, -2))) 26 | logging.debug(f'det(R)={det}') 27 | # logging.info(V.shape) 28 | 29 | return R 30 | 31 | def angle_of_R(R): 32 | # logging.info("R_diff", R_diff) 33 | cos_angle_diff = (torch.diagonal(R, dim1=-2, dim2=-1).sum(-1) - 1) / 2 34 | # logging.info("cos_angle_diff", cos_angle_diff) 35 | cos_angle_diff = torch.clamp(cos_angle_diff, -1, 1) 36 | angle_diff = torch.acos(cos_angle_diff) 37 | angle_diff = angle_diff / np.pi * 180 38 | return angle_diff 39 | 40 | def angle_diff_func(R1, R2): 41 | R_diff = torch.matmul(torch.inverse(R1), R2) 42 | angle_diff = angle_of_R(R_diff) 43 | return angle_diff 44 | -------------------------------------------------------------------------------- /utils/libmcubes/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2015, P. M. Neila 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /utils/libmcubes/mcubes.pyx: -------------------------------------------------------------------------------- 1 | 2 | # distutils: language = c++ 3 | # cython: embedsignature = True 4 | 5 | # from libcpp.vector cimport vector 6 | import numpy as np 7 | 8 | # Define PY_ARRAY_UNIQUE_SYMBOL 9 | cdef extern from "pyarray_symbol.h": 10 | pass 11 | 12 | cimport numpy as np 13 | 14 | np.import_array() 15 | 16 | cdef extern from "pywrapper.h": 17 | cdef object c_marching_cubes "marching_cubes"(np.ndarray, double) except + 18 | cdef object c_marching_cubes2 "marching_cubes2"(np.ndarray, double) except + 19 | cdef object c_marching_cubes3 "marching_cubes3"(np.ndarray, double) except + 20 | cdef object c_marching_cubes_func "marching_cubes_func"(tuple, tuple, int, int, int, object, double) except + 21 | 22 | def marching_cubes(np.ndarray volume, float isovalue): 23 | 24 | verts, faces = c_marching_cubes(volume, isovalue) 25 | verts.shape = (-1, 3) 26 | faces.shape = (-1, 3) 27 | return verts, faces 28 | 29 | def marching_cubes2(np.ndarray volume, float isovalue): 30 | 31 | verts, faces = c_marching_cubes2(volume, isovalue) 32 | verts.shape = (-1, 3) 33 | faces.shape = (-1, 3) 34 | return verts, faces 35 | 36 | def marching_cubes3(np.ndarray volume, float isovalue): 37 | 38 | verts, faces = c_marching_cubes3(volume, isovalue) 39 | verts.shape = (-1, 3) 40 | faces.shape = (-1, 3) 41 | return verts, faces 42 | 43 | def marching_cubes_func(tuple lower, tuple upper, int numx, int numy, int numz, object f, double isovalue): 44 | 45 | verts, faces = c_marching_cubes_func(lower, upper, numx, numy, numz, f, isovalue) 46 | verts.shape = (-1, 3) 47 | faces.shape = (-1, 3) 48 | return verts, faces 49 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | training: 3 | out_dir: out/default 4 | lr: 1.0e-4 5 | lr_schedule: null 6 | testing: 7 | out_dir: test 8 | model_file: model_best.pt 9 | 10 | data: 11 | input: 12 | path: ModelNet40_install 13 | pointcloud_file: pointcloud.npz 14 | T_file: null 15 | input_bench: 16 | path: ModelNet40_benchmark 17 | pointcloud_file_1: pcl_1.npy 18 | pointcloud_file_2: pcl_2.npy 19 | T21_file: R21.npz 20 | occ: 21 | points_file: points.npz 22 | points_subsample: 1024 23 | points_unpackbits: true 24 | points_iou_file: points.npz 25 | voxels_file: null 26 | train: 27 | reg: true 28 | presamp_n: 2048 29 | noise: 0.01 30 | resamp: false 31 | rotate: 180 32 | 33 | subsamp: true 34 | n1: 1024 35 | n2_min: 400 36 | n2_max: 1200 37 | centralize: false # true 38 | val: {} 39 | test: 40 | reg_benchmark: false 41 | centralize: false # 42 | vis: 43 | reg: false 44 | split: val 45 | 46 | dataloader: 47 | train: 48 | batch_size: 10 49 | num_workers: 10 50 | val: 51 | batch_size: 10 52 | num_workers: 5 53 | vis: 54 | batch_size: 12 55 | 56 | trainer: 57 | angloss_w: 10 58 | closs_w: 0 59 | occloss_w: 1 60 | cos_loss: false 61 | cos_mse: false 62 | threshold: 0.5 # 0.2 63 | 64 | tester: {} 65 | 66 | model: 67 | encoder: pointnet_resnet_vn 68 | decoder: cbatchnorm_vn 69 | encoder_latent: null 70 | z_dim: 0 71 | c_dim: 513 72 | encoder_kwargs: 73 | hidden_dim: 1026 74 | ball_radius: 0. 75 | pooling: mean 76 | init_lrf: false 77 | lrf_cross: false 78 | n_knn: 20 79 | global_relu: false 80 | decoder_kwargs: {} 81 | encoder_latent_kwargs: {} 82 | 83 | checkpoint: 84 | model_selection_metric: iou 85 | model_selection_mode: maximize 86 | callback: 87 | print_every: 200 88 | visualize_every: 0 89 | validate_every: 10000 90 | checkpoint_every: 50000 91 | autosave_every: 10000 -------------------------------------------------------------------------------- /utils/libmcubes/exporter.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def export_obj(vertices, triangles, filename): 6 | """ 7 | Exports a mesh in the (.obj) format. 8 | """ 9 | 10 | with open(filename, 'w') as fh: 11 | 12 | for v in vertices: 13 | fh.write("v {} {} {}\n".format(*v)) 14 | 15 | for f in triangles: 16 | fh.write("f {} {} {}\n".format(*(f + 1))) 17 | 18 | 19 | def export_off(vertices, triangles, filename): 20 | """ 21 | Exports a mesh in the (.off) format. 22 | """ 23 | 24 | with open(filename, 'w') as fh: 25 | fh.write('OFF\n') 26 | fh.write('{} {} 0\n'.format(len(vertices), len(triangles))) 27 | 28 | for v in vertices: 29 | fh.write("{} {} {}\n".format(*v)) 30 | 31 | for f in triangles: 32 | fh.write("3 {} {} {}\n".format(*f)) 33 | 34 | 35 | def export_mesh(vertices, triangles, filename, mesh_name="mcubes_mesh"): 36 | """ 37 | Exports a mesh in the COLLADA (.dae) format. 38 | 39 | Needs PyCollada (https://github.com/pycollada/pycollada). 40 | """ 41 | 42 | import collada 43 | 44 | mesh = collada.Collada() 45 | 46 | vert_src = collada.source.FloatSource("verts-array", vertices, ('X','Y','Z')) 47 | geom = collada.geometry.Geometry(mesh, "geometry0", mesh_name, [vert_src]) 48 | 49 | input_list = collada.source.InputList() 50 | input_list.addInput(0, 'VERTEX', "#verts-array") 51 | 52 | triset = geom.createTriangleSet(np.copy(triangles), input_list, "") 53 | geom.primitives.append(triset) 54 | mesh.geometries.append(geom) 55 | 56 | geomnode = collada.scene.GeometryNode(geom, []) 57 | node = collada.scene.Node(mesh_name, children=[geomnode]) 58 | 59 | myscene = collada.scene.Scene("mcubes_scene", [node]) 60 | mesh.scenes.append(myscene) 61 | mesh.scene = myscene 62 | 63 | mesh.write(filename) 64 | -------------------------------------------------------------------------------- /utils/libmcubes/README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | PyMCubes 3 | ======== 4 | 5 | PyMCubes is an implementation of the marching cubes algorithm to extract 6 | isosurfaces from volumetric data. The volumetric data can be given as a 7 | three-dimensional NumPy array or as a Python function ``f(x, y, z)``. The first 8 | option is much faster, but it requires more memory and becomes unfeasible for 9 | very large volumes. 10 | 11 | PyMCubes also provides a function to export the results of the marching cubes as 12 | COLLADA ``(.dae)`` files. This requires the 13 | `PyCollada `_ library. 14 | 15 | Installation 16 | ============ 17 | 18 | Just as any standard Python package, clone or download the project 19 | and run:: 20 | 21 | $ cd path/to/PyMCubes 22 | $ python setup.py build 23 | $ python setup.py install 24 | 25 | If you do not have write permission on the directory of Python packages, 26 | install with the ``--user`` option:: 27 | 28 | $ python setup.py install --user 29 | 30 | Example 31 | ======= 32 | 33 | The following example creates a data volume with spherical isosurfaces and 34 | extracts one of them (i.e., a sphere) with PyMCubes. The result is exported as 35 | ``sphere.dae``:: 36 | 37 | >>> import numpy as np 38 | >>> import mcubes 39 | 40 | # Create a data volume (30 x 30 x 30) 41 | >>> X, Y, Z = np.mgrid[:30, :30, :30] 42 | >>> u = (X-15)**2 + (Y-15)**2 + (Z-15)**2 - 8**2 43 | 44 | # Extract the 0-isosurface 45 | >>> vertices, triangles = mcubes.marching_cubes(u, 0) 46 | 47 | # Export the result to sphere.dae 48 | >>> mcubes.export_mesh(vertices, triangles, "sphere.dae", "MySphere") 49 | 50 | The second example is very similar to the first one, but it uses a function 51 | to represent the volume instead of a NumPy array:: 52 | 53 | >>> import numpy as np 54 | >>> import mcubes 55 | 56 | # Create the volume 57 | >>> f = lambda x, y, z: x**2 + y**2 + z**2 58 | 59 | # Extract the 16-isosurface 60 | >>> vertices, triangles = mcubes.marching_cubes_func((-10,-10,-10), (10,10,10), 61 | ... 100, 100, 100, f, 16) 62 | 63 | # Export the result to sphere2.dae 64 | >>> mcubes.export_mesh(vertices, triangles, "sphere2.dae", "MySphere") 65 | -------------------------------------------------------------------------------- /models/encoder_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Max Pooling operation 7 | def maxpool(x, dim=-1, keepdim=False): 8 | out, _ = x.max(dim=dim, keepdim=keepdim) 9 | return out 10 | 11 | 12 | class Encoder(nn.Module): 13 | ''' Latent encoder class. 14 | 15 | It encodes the input points and returns mean and standard deviation for the 16 | posterior Gaussian distribution. 17 | 18 | Args: 19 | z_dim (int): dimension if output code z 20 | c_dim (int): dimension of latent conditioned code c 21 | dim (int): input dimension 22 | leaky (bool): whether to use leaky ReLUs 23 | ''' 24 | def __init__(self, z_dim=128, c_dim=128, dim=3, leaky=False): 25 | super().__init__() 26 | self.z_dim = z_dim 27 | self.c_dim = c_dim 28 | 29 | # Submodules 30 | self.fc_pos = nn.Linear(dim, 128) 31 | 32 | if c_dim != 0: 33 | self.fc_c = nn.Linear(c_dim, 128) 34 | 35 | self.fc_0 = nn.Linear(1, 128) 36 | self.fc_1 = nn.Linear(128, 128) 37 | self.fc_2 = nn.Linear(256, 128) 38 | self.fc_3 = nn.Linear(256, 128) 39 | self.fc_mean = nn.Linear(128, z_dim) 40 | self.fc_logstd = nn.Linear(128, z_dim) 41 | 42 | if not leaky: 43 | self.actvn = F.relu 44 | self.pool = maxpool 45 | else: 46 | self.actvn = lambda x: F.leaky_relu(x, 0.2) 47 | self.pool = torch.mean 48 | 49 | def forward(self, p, x, c=None, **kwargs): 50 | batch_size, T, D = p.size() 51 | 52 | # output size: B x T X F 53 | net = self.fc_0(x.unsqueeze(-1)) 54 | net = net + self.fc_pos(p) 55 | 56 | if self.c_dim != 0: 57 | net = net + self.fc_c(c).unsqueeze(1) 58 | 59 | net = self.fc_1(self.actvn(net)) 60 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 61 | net = torch.cat([net, pooled], dim=2) 62 | 63 | net = self.fc_2(self.actvn(net)) 64 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 65 | net = torch.cat([net, pooled], dim=2) 66 | 67 | net = self.fc_3(self.actvn(net)) 68 | # Reduce 69 | # to B x F 70 | net = self.pool(net, dim=1) 71 | 72 | mean = self.fc_mean(net) 73 | logstd = self.fc_logstd(net) 74 | 75 | return mean, logstd 76 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import logging 3 | 4 | class Callback: 5 | def __init__(self, freq): 6 | self.freq = freq 7 | 8 | def __call__(self, it, *args: Any, **kwds: Any): 9 | if it % self.freq == 0: 10 | return self.do(it, *args, **kwds) 11 | 12 | def do(self, it ,*args: Any, **kwds: Any): 13 | raise NotImplementedError 14 | 15 | class PrintCallback(Callback): 16 | def do(self, it, epoch_it, loss, d_loss): 17 | txt = '[Epoch %02d] it=%03d, loss=%.4f'% (epoch_it, it, loss) 18 | for key in d_loss: 19 | txt = txt + ", %s: %.5f"%(key, d_loss[key]) 20 | logging.info(txt) 21 | 22 | class VisualizeCallback(Callback): 23 | def __init__(self, freq, trainer, vis_loader): 24 | super().__init__(freq) 25 | self.trainer = trainer 26 | self.vis_loader = vis_loader 27 | self.vis_iter = iter(self.vis_loader) 28 | 29 | def do(self, *args, **kwds): 30 | logging.info('Visualizing') 31 | try: 32 | batch = next(self.vis_iter) 33 | except StopIteration: 34 | logging.info('Finished a loop of the visualization dataset. ') 35 | self.vis_iter = iter(self.vis_loader) 36 | batch = next(self.vis_iter) 37 | 38 | self.trainer.visualize(batch) 39 | 40 | class CheckpointsaveCallback(Callback): 41 | def __init__(self, freq, checkpoint_io): 42 | super().__init__(freq) 43 | self.checkpoint_io = checkpoint_io 44 | 45 | def do(self, it, epoch_it, *args, **kwds): 46 | self.checkpoint_io.save_process(it=it, epoch_it=epoch_it) 47 | 48 | class AutosaveCallback(Callback): 49 | def __init__(self, freq, checkpoint_io): 50 | super().__init__(freq) 51 | self.checkpoint_io = checkpoint_io 52 | 53 | def do(self, it, epoch_it, *args, **kwds): 54 | logging.info('Autosave latest checkpoint') 55 | self.checkpoint_io.save_latest(it=it, epoch_it=epoch_it) 56 | 57 | class ValidationCallback(Callback): 58 | def __init__(self, freq, checkpoint_io, trainer, val_loader, writer, 59 | *args, **kwds): 60 | super().__init__(freq) 61 | self.checkpoint_io = checkpoint_io 62 | self.trainer = trainer 63 | self.val_loader = val_loader 64 | self.writer = writer 65 | 66 | def do(self, it, epoch_it, *args, **kwds): 67 | eval_dict = self.trainer.evaluate(self.val_loader) 68 | 69 | for k, v in eval_dict.items(): 70 | self.writer.add_scalar('val/%s' % k, v, it) 71 | 72 | self.checkpoint_io.save_if_best(eval_dict, it=it, epoch_it=epoch_it) -------------------------------------------------------------------------------- /metricrecord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | class Record: 6 | def __init__(self, name, num=10, highest=True) -> None: 7 | self.name = name 8 | self.num = num 9 | self.highest = highest 10 | self.vals = [] 11 | self.items = [] 12 | 13 | def update(self, val, item): 14 | if len(self.vals) < self.num: 15 | # Add the item and value, and then sort the list 16 | self.vals.append(val) 17 | self.items.append(item) 18 | self.sort() 19 | else: 20 | if self.highest: 21 | if val > self.vals[-1]: 22 | # Replace the smallest value and item, and then sort the list 23 | self.vals[-1] = val 24 | self.items[-1] = item 25 | self.sort() 26 | else: 27 | if val < self.vals[-1]: 28 | # Replace the largest value and item, and then sort the list 29 | self.vals[-1] = val 30 | self.items[-1] = item 31 | self.sort() 32 | 33 | def sort(self): 34 | sorted_indices = sorted(range(len(self.vals)), key=lambda i: self.vals[i], reverse=self.highest) 35 | self.vals = [self.vals[i] for i in sorted_indices] 36 | self.items = [self.items[i] for i in sorted_indices] 37 | # combined = list(zip(self.vals, self.items)) # zip returns an iterable object, each element is a tuple 38 | # combined.sort(reverse=self.highest) 39 | # self.vals, self.items = map(list, zip(*combined)) # map apply list() to every tuple returned by zip 40 | 41 | def __str__(self) -> str: 42 | return f"{self.name}: {list(zip(self.items, self.vals))}" 43 | 44 | class Metric: 45 | def __init__(self, name) -> None: 46 | self.name = name 47 | self.val = 0 48 | self.count = 0 49 | def update(self, val): 50 | if isinstance(val, torch.Tensor): 51 | val = val.item() 52 | self.val += val 53 | self.count += 1 54 | def avg(self): 55 | return self.val / max(1, self.count) 56 | def __str__(self) -> str: 57 | return f"{self.name}: {self.avg()} (avg over {self.count})" 58 | 59 | if __name__ == '__main__': 60 | 61 | # Example usage: 62 | record = Record("Top Scores", num=5, highest=False) 63 | scores = [100, 50, 75, 120, 80, 60, 110] 64 | names = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace"] 65 | 66 | for name, score in zip(names, scores): 67 | record.update(score, name) 68 | 69 | print(record) -------------------------------------------------------------------------------- /utils/libsimplify/simplify_mesh.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | from libcpp.vector cimport vector 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | 7 | cdef extern from "Simplify.h": 8 | cdef struct vec3f: 9 | double x, y, z 10 | 11 | cdef cppclass SymetricMatrix: 12 | SymetricMatrix() except + 13 | 14 | 15 | cdef extern from "Simplify.h" namespace "Simplify": 16 | cdef struct Triangle: 17 | int v[3] 18 | double err[4] 19 | int deleted, dirty, attr 20 | vec3f uvs[3] 21 | int material 22 | 23 | cdef struct Vertex: 24 | vec3f p 25 | int tstart, tcount 26 | SymetricMatrix q 27 | int border 28 | 29 | cdef vector[Triangle] triangles 30 | cdef vector[Vertex] vertices 31 | cdef void simplify_mesh(int, double) 32 | 33 | 34 | cpdef mesh_simplify(double[:, ::1] vertices_in, long[:, ::1] triangles_in, 35 | int f_target, double agressiveness=7.) except +: 36 | vertices.clear() 37 | triangles.clear() 38 | 39 | # Read in vertices and triangles 40 | cdef Vertex v 41 | for iv in range(vertices_in.shape[0]): 42 | v = Vertex() 43 | v.p.x = vertices_in[iv, 0] 44 | v.p.y = vertices_in[iv, 1] 45 | v.p.z = vertices_in[iv, 2] 46 | vertices.push_back(v) 47 | 48 | cdef Triangle t 49 | for it in range(triangles_in.shape[0]): 50 | t = Triangle() 51 | t.v[0] = triangles_in[it, 0] 52 | t.v[1] = triangles_in[it, 1] 53 | t.v[2] = triangles_in[it, 2] 54 | triangles.push_back(t) 55 | 56 | # Simplify 57 | # print('Simplify...') 58 | simplify_mesh(f_target, agressiveness) 59 | 60 | # Only use triangles that are not deleted 61 | cdef vector[Triangle] triangles_notdel 62 | triangles_notdel.reserve(triangles.size()) 63 | 64 | for t in triangles: 65 | if not t.deleted: 66 | triangles_notdel.push_back(t) 67 | 68 | # Read out triangles 69 | vertices_out = np.empty((vertices.size(), 3), dtype=np.float64) 70 | triangles_out = np.empty((triangles_notdel.size(), 3), dtype=np.int64) 71 | 72 | cdef double[:, :] vertices_out_view = vertices_out 73 | cdef long[:, :] triangles_out_view = triangles_out 74 | 75 | for iv in range(vertices.size()): 76 | vertices_out_view[iv, 0] = vertices[iv].p.x 77 | vertices_out_view[iv, 1] = vertices[iv].p.y 78 | vertices_out_view[iv, 2] = vertices[iv].p.z 79 | 80 | for it in range(triangles_notdel.size()): 81 | triangles_out_view[it, 0] = triangles_notdel[it].v[0] 82 | triangles_out_view[it, 1] = triangles_notdel[it].v[1] 83 | triangles_out_view[it, 2] = triangles_notdel[it].v[2] 84 | 85 | # Clear vertices and triangles 86 | vertices.clear() 87 | triangles.clear() 88 | 89 | return vertices_out, triangles_out -------------------------------------------------------------------------------- /encoders/pix2mesh_cond.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Pix2mesh_Cond(nn.Module): 5 | r''' Conditioning Network proposed in the authors' Pixel2Mesh implementation. 6 | 7 | The network consists of several 2D convolution layers, and several of the 8 | intermediate feature maps are returned to features for the image 9 | projection layer of the encoder network. 10 | ''' 11 | def __init__(self, c_dim=512, return_feature_maps=True): 12 | r''' Initialisation. 13 | 14 | Args: 15 | c_dim (int): channels of the final output 16 | return_feature_maps (bool): whether intermediate feature maps 17 | should be returned 18 | ''' 19 | super().__init__() 20 | actvn = nn.ReLU() 21 | self.return_feature_maps = return_feature_maps 22 | num_fm = int(c_dim/32) 23 | if num_fm != 16: 24 | raise ValueError('Pixel2Mesh requires a fixed c_dim of 512!') 25 | 26 | self.block_1 = nn.Sequential( 27 | nn.Conv2d(3, num_fm, 3, stride=1, padding=1), actvn, 28 | nn.Conv2d(num_fm, num_fm, 3, stride=1, padding=1), actvn, 29 | nn.Conv2d(num_fm, num_fm*2, 3, stride=2, padding=1), actvn, 30 | nn.Conv2d(num_fm*2, num_fm*2, 3, stride=1, padding=1), actvn, 31 | nn.Conv2d(num_fm*2, num_fm*2, 3, stride=1, padding=1), actvn, 32 | nn.Conv2d(num_fm*2, num_fm*4, 3, stride=2, padding=1), actvn, 33 | nn.Conv2d(num_fm*4, num_fm*4, 3, stride=1, padding=1), actvn, 34 | nn.Conv2d(num_fm*4, num_fm*4, 3, stride=1, padding=1), actvn) 35 | 36 | self.block_2 = nn.Sequential( 37 | nn.Conv2d(num_fm*4, num_fm*8, 3, stride=2, padding=1), actvn, 38 | nn.Conv2d(num_fm*8, num_fm*8, 3, stride=1, padding=1), actvn, 39 | nn.Conv2d(num_fm*8, num_fm*8, 3, stride=1, padding=1), actvn) 40 | 41 | self.block_3 = nn.Sequential( 42 | nn.Conv2d(num_fm*8, num_fm*16, 5, stride=2, padding=2), actvn, 43 | nn.Conv2d(num_fm*16, num_fm*16, 3, stride=1, padding=1), actvn, 44 | nn.Conv2d(num_fm*16, num_fm*16, 3, stride=1, padding=1), actvn) 45 | 46 | self.block_4 = nn.Sequential( 47 | nn.Conv2d(num_fm*16, num_fm*32, 5, stride=2, padding=2), actvn, 48 | nn.Conv2d(num_fm*32, num_fm*32, 3, stride=1, padding=1), actvn, 49 | nn.Conv2d(num_fm*32, num_fm*32, 3, stride=1, padding=1), actvn, 50 | nn.Conv2d(num_fm*32, num_fm*32, 3, stride=1, padding=1), actvn, 51 | ) 52 | 53 | def forward(self, x): 54 | # x has size 224 x 224 55 | x_0 = self.block_1(x) # 64 x 56 x 56 56 | x_1 = self.block_2(x_0) # 128 x 28 x 28 57 | x_2 = self.block_3(x_1) # 256 x 14 x 14 58 | x_3 = self.block_4(x_2) # 512 x 7 x 7 59 | 60 | if self.return_feature_maps: 61 | return x_0, x_1, x_2, x_3 62 | return x_3 63 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import os 8 | from torch.utils.tensorboard import SummaryWriter 9 | import logging 10 | 11 | import misc 12 | import config 13 | 14 | def get_args(): 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Train a 3D reconstruction model.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | parser.add_argument('--exit-after', type=int, default=-1, 22 | help='Checkpoint and exit after specified number of iterations' 23 | 'with exit code 2.') 24 | args = parser.parse_args() 25 | return args 26 | 27 | if __name__ == "__main__": 28 | 29 | args = get_args() 30 | 31 | cfg = misc.load_config(args.config, 'configs/default.yaml') 32 | 33 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 34 | device = torch.device("cuda" if is_cuda else "cpu") 35 | 36 | ### configure logger 37 | out_dir = config.cfg_f_out(cfg) 38 | 39 | ### configure dataset 40 | train_dataset, val_dataset, train_loader, val_loader, vis_loader, duo_loader = config.cfg_dataloader(cfg) 41 | 42 | ### configure model 43 | model = config.cfg_model(cfg, device) 44 | 45 | ### configure optimizer, lr scheduler, and loss functions 46 | trainer, optimizer, lr_scheduler = config.cfg_trainer(cfg, device, model) 47 | 48 | ### configure checkpoints 49 | checkpoint_io, epoch_it, it = config.cfg_checkpoint(cfg, out_dir, model, optimizer, lr_scheduler) 50 | writer = SummaryWriter(os.path.join(out_dir, 'logs')) 51 | 52 | ### configure callbacks 53 | callback_list, callback_dict = config.cfg_callbacks(cfg, trainer, vis_loader, val_loader, checkpoint_io, writer) 54 | 55 | # Shorthands 56 | 57 | # Iteration on epochs 58 | while True: 59 | epoch_it += 1 60 | train_iter = iter(train_loader) 61 | 62 | while True: 63 | try: 64 | batch = next(train_iter) 65 | except StopIteration: 66 | break 67 | 68 | for key, value in batch.items(): 69 | if isinstance(value, torch.Tensor): 70 | batch[key] = value.to(device) 71 | 72 | it += 1 73 | loss, d_loss = trainer.train_step(batch) 74 | writer.add_scalar('train/loss', loss, it) 75 | for key in d_loss: 76 | writer.add_scalar('train/{}'.format(key), d_loss[key], it) 77 | 78 | for callback in callback_list: 79 | callback_dict[callback](it=it, epoch_it=epoch_it, loss=loss, d_loss=d_loss) 80 | 81 | 82 | if args.exit_after > 0 and it > args.exit_after: 83 | logging.info(f'Exiting at {it} iterations.') 84 | break -------------------------------------------------------------------------------- /encoders/voxels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VoxelEncoder(nn.Module): 7 | ''' 3D-convolutional encoder network for voxel input. 8 | 9 | Args: 10 | dim (int): input dimension 11 | c_dim (int): output dimension 12 | ''' 13 | 14 | def __init__(self, dim=3, c_dim=128): 15 | super().__init__() 16 | self.actvn = F.relu 17 | 18 | self.conv_in = nn.Conv3d(1, 32, 3, padding=1) 19 | 20 | self.conv_0 = nn.Conv3d(32, 64, 3, padding=1, stride=2) 21 | self.conv_1 = nn.Conv3d(64, 128, 3, padding=1, stride=2) 22 | self.conv_2 = nn.Conv3d(128, 256, 3, padding=1, stride=2) 23 | self.conv_3 = nn.Conv3d(256, 512, 3, padding=1, stride=2) 24 | self.fc = nn.Linear(512 * 2 * 2 * 2, c_dim) 25 | 26 | def forward(self, x): 27 | batch_size = x.size(0) 28 | 29 | x = x.unsqueeze(1) 30 | net = self.conv_in(x) 31 | net = self.conv_0(self.actvn(net)) 32 | net = self.conv_1(self.actvn(net)) 33 | net = self.conv_2(self.actvn(net)) 34 | net = self.conv_3(self.actvn(net)) 35 | 36 | hidden = net.view(batch_size, 512 * 2 * 2 * 2) 37 | c = self.fc(self.actvn(hidden)) 38 | 39 | return c 40 | 41 | 42 | class CoordVoxelEncoder(nn.Module): 43 | ''' 3D-convolutional encoder network for voxel input. 44 | 45 | It additional concatenates the coordinate data. 46 | 47 | Args: 48 | dim (int): input dimension 49 | c_dim (int): output dimension 50 | ''' 51 | 52 | def __init__(self, dim=3, c_dim=128): 53 | super().__init__() 54 | self.actvn = F.relu 55 | 56 | self.conv_in = nn.Conv3d(4, 32, 3, padding=1) 57 | 58 | self.conv_0 = nn.Conv3d(32, 64, 3, padding=1, stride=2) 59 | self.conv_1 = nn.Conv3d(64, 128, 3, padding=1, stride=2) 60 | self.conv_2 = nn.Conv3d(128, 256, 3, padding=1, stride=2) 61 | self.conv_3 = nn.Conv3d(256, 512, 3, padding=1, stride=2) 62 | self.fc = nn.Linear(512 * 2 * 2 * 2, c_dim) 63 | 64 | def forward(self, x): 65 | batch_size = x.size(0) 66 | device = x.device 67 | 68 | coord1 = torch.linspace(-0.5, 0.5, x.size(1)).to(device) 69 | coord2 = torch.linspace(-0.5, 0.5, x.size(2)).to(device) 70 | coord3 = torch.linspace(-0.5, 0.5, x.size(3)).to(device) 71 | 72 | coord1 = coord1.view(1, -1, 1, 1).expand_as(x) 73 | coord2 = coord2.view(1, 1, -1, 1).expand_as(x) 74 | coord3 = coord3.view(1, 1, 1, -1).expand_as(x) 75 | 76 | coords = torch.stack([coord1, coord2, coord3], dim=1) 77 | 78 | x = x.unsqueeze(1) 79 | net = torch.cat([x, coords], dim=1) 80 | net = self.conv_in(net) 81 | net = self.conv_0(self.actvn(net)) 82 | net = self.conv_1(self.actvn(net)) 83 | net = self.conv_2(self.actvn(net)) 84 | net = self.conv_3(self.actvn(net)) 85 | 86 | hidden = net.view(batch_size, 512 * 2 * 2 * 2) 87 | c = self.fc(self.actvn(hidden)) 88 | 89 | return c 90 | -------------------------------------------------------------------------------- /encoders/psgn_cond.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class PCGN_Cond(nn.Module): 5 | r''' Point Set Generation Network encoding network. 6 | 7 | The PSGN conditioning network from the original publication consists of 8 | several 2D convolution layers. The intermediate outputs from some layers 9 | are used as additional input to the encoder network, similar to U-Net. 10 | 11 | Args: 12 | c_dim (int): output dimension of the latent embedding 13 | ''' 14 | def __init__(self, c_dim=512): 15 | super().__init__() 16 | actvn = nn.ReLU() 17 | num_fm = int(c_dim/32) 18 | 19 | self.conv_block1 = nn.Sequential( 20 | nn.Conv2d(3, num_fm, 3, 1, 1), actvn, 21 | nn.Conv2d(num_fm, num_fm, 3, 1, 1), actvn) 22 | self.conv_block2 = nn.Sequential( 23 | nn.Conv2d(num_fm, num_fm*2, 3, 2, 1), actvn, 24 | nn.Conv2d(num_fm*2, num_fm*2, 3, 1, 1), actvn, 25 | nn.Conv2d(num_fm*2, num_fm*2, 3, 1, 1), actvn) 26 | self.conv_block3 = nn.Sequential( 27 | nn.Conv2d(num_fm*2, num_fm*4, 3, 2, 1), actvn, 28 | nn.Conv2d(num_fm*4, num_fm*4, 3, 1, 1), actvn, 29 | nn.Conv2d(num_fm*4, num_fm*4, 3, 1, 1), actvn) 30 | self.conv_block4 = nn.Sequential( 31 | nn.Conv2d(num_fm*4, num_fm*8, 3, 2, 1), actvn, 32 | nn.Conv2d(num_fm*8, num_fm*8, 3, 1, 1), actvn, 33 | nn.Conv2d(num_fm*8, num_fm*8, 3, 1, 1), actvn) 34 | self.conv_block5 = nn.Sequential( 35 | nn.Conv2d(num_fm*8, num_fm*16, 3, 2, 1), actvn, 36 | nn.Conv2d(num_fm*16, num_fm*16, 3, 1, 1), actvn, 37 | nn.Conv2d(num_fm*16, num_fm*16, 3, 1, 1), actvn) 38 | self.conv_block6 = nn.Sequential( 39 | nn.Conv2d(num_fm*16, num_fm*32, 3, 2, 1), actvn, 40 | nn.Conv2d(num_fm*32, num_fm*32, 3, 1, 1), actvn, 41 | nn.Conv2d(num_fm*32, num_fm*32, 3, 1, 1), actvn, 42 | nn.Conv2d(num_fm*32, num_fm*32, 3, 1, 1), actvn) 43 | self.conv_block7 = nn.Sequential( 44 | nn.Conv2d(num_fm*32, num_fm*32, 5, 2, 2), actvn) 45 | 46 | self.trans_conv1 = nn.Conv2d(num_fm*8, num_fm*4, 3, 1, 1) 47 | self.trans_conv2 = nn.Conv2d(num_fm*16, num_fm*8, 3, 1, 1) 48 | self.trans_conv3 = nn.Conv2d(num_fm*32, num_fm*16, 3, 1, 1) 49 | 50 | def forward(self, x, return_feature_maps=True): 51 | r''' Performs a forward pass through the network. 52 | 53 | Args: 54 | x (tensor): input data 55 | return_feature_maps (bool): whether intermediate feature maps 56 | should be returned 57 | ''' 58 | feature_maps = [] 59 | 60 | x = self.conv_block1(x) 61 | x = self.conv_block2(x) 62 | x = self.conv_block3(x) 63 | x = self.conv_block4(x) 64 | 65 | feature_maps.append(self.trans_conv1(x)) 66 | 67 | x = self.conv_block5(x) 68 | feature_maps.append(self.trans_conv2(x)) 69 | 70 | x = self.conv_block6(x) 71 | feature_maps.append(self.trans_conv3(x)) 72 | 73 | x = self.conv_block7(x) 74 | 75 | if return_feature_maps: 76 | return x, feature_maps 77 | return x 78 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import logging 4 | 5 | def load_config(path, default_path=None): 6 | ''' Loads config file. 7 | 8 | Args: 9 | path (str): path to config file 10 | default_path (bool): whether to use default path 11 | ''' 12 | # Load configuration from file itself 13 | with open(path, 'r') as f: 14 | cfg_specific = yaml.load(f, Loader=yaml.SafeLoader) 15 | 16 | # Check if we should inherit from a config 17 | inherit_from = cfg_specific.get('inherit_from') 18 | 19 | # If yes, load this config first as default 20 | if inherit_from is not None: 21 | cfg = load_config(inherit_from, default_path) 22 | # If no, use the default_path 23 | elif default_path is not None: 24 | with open(default_path, 'r') as f: 25 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 26 | else: 27 | cfg = dict() 28 | 29 | # update cfg using cfg_specific 30 | update_recursive(cfg, cfg_specific) 31 | 32 | return cfg 33 | 34 | def update_recursive(dict1, dict2): 35 | ''' Update two config dictionaries recursively. 36 | 37 | Args: 38 | dict1 (dict): first dictionary to be updated 39 | dict2 (dict): second dictionary which entries should be used 40 | 41 | ''' 42 | for k, v in dict2.items(): 43 | if k not in dict1: 44 | dict1[k] = dict() 45 | if isinstance(v, dict): 46 | update_recursive(dict1[k], v) 47 | else: 48 | dict1[k] = v 49 | 50 | def setup_logging(out_dir): 51 | # If modules imported before this has already configured basicConfig, we cannot set it here again. 52 | # Therefore we reload it. 53 | # Since python 3.8, can skip the reload and set force=True for basicConfig. 54 | # https://stackoverflow.com/questions/20240464/python-logging-file-is-not-working-when-using-logging-basicconfig 55 | from importlib import reload # python 2.x don't need to import reload, use it directly 56 | reload(logging) 57 | 58 | # Set up the logging format and output path 59 | level = logging.INFO 60 | format = '%(asctime)s %(message)s' 61 | datefmt = '%m-%d %H:%M:%S' 62 | logfile = os.path.join(out_dir, 'msgs.log') 63 | handlers = [logging.FileHandler(logfile), logging.StreamHandler()] 64 | 65 | logging.basicConfig(level = level, format = format, datefmt=datefmt, handlers = handlers) 66 | logging.info('Hey, logging is written to {}!'.format(logfile)) 67 | return 68 | 69 | # def setup_logging(out_dir): 70 | # # Set up the logging format and output path 71 | # level = logging.INFO 72 | # format = '%(asctime)s %(message)s' 73 | # datefmt = '%m-%d %H:%M:%S' 74 | 75 | # formatter = logging.Formatter(format, datefmt) 76 | # handler_file = logging.FileHandler(os.path.join(out_dir, 'msgs.log')) 77 | # handler_file.setFormatter(formatter) 78 | # handler_stream = logging.StreamHandler() 79 | # handler_stream.setFormatter(formatter) 80 | 81 | # logger = logging.getLogger() 82 | # logger.addHandler(handler_file) 83 | # logger.addHandler(handler_stream) 84 | # logger.setLevel(level) 85 | 86 | # logger.info('Hey, logging is written to {}!'.format(os.path.join(out_dir, 'msgs.log'))) 87 | # return -------------------------------------------------------------------------------- /encoders/r2n2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torch.nn.functional as F 3 | from common import normalize_imagenet 4 | 5 | 6 | class SimpleConv(nn.Module): 7 | ''' 3D Recurrent Reconstruction Neural Network (3D-R2-N2) encoder network. 8 | 9 | Args: 10 | c_dim: output dimension 11 | ''' 12 | 13 | def __init__(self, c_dim=1024): 14 | super().__init__() 15 | actvn = nn.LeakyReLU() 16 | pooling = nn.MaxPool2d(2, padding=1) 17 | self.convnet = nn.Sequential( 18 | nn.Conv2d(3, 96, 7, padding=3), 19 | pooling, actvn, 20 | nn.Conv2d(96, 128, 3, padding=1), 21 | pooling, actvn, 22 | nn.Conv2d(128, 256, 3, padding=1), 23 | pooling, actvn, 24 | nn.Conv2d(256, 256, 3, padding=1), 25 | pooling, actvn, 26 | nn.Conv2d(256, 256, 3, padding=1), 27 | pooling, actvn, 28 | nn.Conv2d(256, 256, 3, padding=1), 29 | pooling, actvn, 30 | ) 31 | self.fc_out = nn.Linear(256*3*3, c_dim) 32 | 33 | def forward(self, x): 34 | batch_size = x.size(0) 35 | 36 | net = normalize_imagenet(x) 37 | net = self.convnet(net) 38 | net = net.view(batch_size, 256*3*3) 39 | out = self.fc_out(net) 40 | 41 | return out 42 | 43 | 44 | class Resnet(nn.Module): 45 | ''' 3D Recurrent Reconstruction Neural Network (3D-R2-N2) ResNet-based 46 | encoder network. 47 | 48 | It is the ResNet variant of the previous encoder.s 49 | 50 | Args: 51 | c_dim: output dimension 52 | ''' 53 | 54 | def __init__(self, c_dim=1024): 55 | super().__init__() 56 | actvn = nn.LeakyReLU() 57 | pooling = nn.MaxPool2d(2, padding=1) 58 | self.convnet = nn.Sequential( 59 | nn.Conv2d(3, 96, 7, padding=3), 60 | actvn, 61 | nn.Conv2d(96, 96, 3, padding=1), 62 | actvn, pooling, 63 | ResnetBlock(96, 128), 64 | pooling, 65 | ResnetBlock(128, 256), 66 | pooling, 67 | ResnetBlock(256, 256), 68 | pooling, 69 | ResnetBlock(256, 256), 70 | pooling, 71 | ResnetBlock(256, 256), 72 | pooling, 73 | ) 74 | self.fc_out = nn.Linear(256*3*3, c_dim) 75 | 76 | def forward(self, x): 77 | batch_size = x.size(0) 78 | 79 | net = normalize_imagenet(x) 80 | net = self.convnet(net) 81 | net = net.view(batch_size, 256*3*3) 82 | out = self.fc_out(net) 83 | 84 | return out 85 | 86 | 87 | class ResnetBlock(nn.Module): 88 | ''' ResNet block class. 89 | 90 | Args: 91 | f_in (int): input dimension 92 | f_out (int): output dimension 93 | ''' 94 | 95 | def __init__(self, f_in, f_out): 96 | super().__init__() 97 | actvn = nn.LeakyReLU() 98 | self.convnet = nn.Sequential( 99 | nn.Conv2d(f_in, f_out, 3, padding=1), 100 | actvn, 101 | nn.Conv2d(f_out, f_out, 3, padding=1), 102 | actvn, 103 | ) 104 | self.shortcut = nn.Conv2d(f_in, f_out, 1) 105 | 106 | def forward(self, x): 107 | out = self.convnet(x) + self.shortcut(x) 108 | return out 109 | -------------------------------------------------------------------------------- /field_tfs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class SubsamplePoints(object): 5 | ''' Points subsampling transformation class. 6 | 7 | It subsamples the points data. 8 | 9 | Args: 10 | N (int): number of points to be subsampled 11 | ''' 12 | def __init__(self, N): 13 | self.N = N 14 | 15 | def __call__(self, data): 16 | ''' Calls the transformation. 17 | 18 | Args: 19 | data (dictionary): data dictionary 20 | ''' 21 | points = data[None] 22 | occ = data['occ'] 23 | 24 | data_out = data.copy() 25 | if isinstance(self.N, int): 26 | idx = np.random.randint(points.shape[0], size=self.N) 27 | data_out.update({ 28 | None: points[idx, :], 29 | 'occ': occ[idx], 30 | }) 31 | else: 32 | Nt_out, Nt_in = self.N 33 | occ_binary = (occ >= 0.5) 34 | points0 = points[~occ_binary] 35 | points1 = points[occ_binary] 36 | 37 | idx0 = np.random.randint(points0.shape[0], size=Nt_out) 38 | idx1 = np.random.randint(points1.shape[0], size=Nt_in) 39 | 40 | points0 = points0[idx0, :] 41 | points1 = points1[idx1, :] 42 | points = np.concatenate([points0, points1], axis=0) 43 | 44 | occ0 = np.zeros(Nt_out, dtype=np.float32) 45 | occ1 = np.ones(Nt_in, dtype=np.float32) 46 | occ = np.concatenate([occ0, occ1], axis=0) 47 | 48 | volume = occ_binary.sum() / len(occ_binary) 49 | volume = volume.astype(np.float32) 50 | 51 | data_out.update({ 52 | None: points, 53 | 'occ': occ, 54 | 'volume': volume, 55 | }) 56 | return data_out 57 | 58 | class SubsamplePointcloud(object): 59 | ''' Point cloud subsampling transformation class. 60 | 61 | It subsamples the point cloud data. 62 | 63 | Args: 64 | N (int): number of points to be subsampled 65 | ''' 66 | def __init__(self, N): 67 | self.N = N 68 | 69 | def __call__(self, data): 70 | ''' Calls the transformation. 71 | 72 | Args: 73 | data (dict): dictionary with None and 'normals' keys 74 | ''' 75 | data_out = data.copy() 76 | points = data[None] 77 | normals = data['normals'] 78 | 79 | indices = np.random.randint(points.shape[0], size=self.N) 80 | data_out[None] = points[indices, :] 81 | data_out['normals'] = normals[indices, :] 82 | 83 | return data_out 84 | 85 | class PointcloudNoise(object): 86 | ''' Point cloud noise transformation class. 87 | 88 | It adds noise to point cloud data. 89 | 90 | Args: 91 | stddev (int): standard deviation 92 | ''' 93 | 94 | def __init__(self, stddev): 95 | self.stddev = stddev 96 | 97 | def __call__(self, data): 98 | ''' Calls the transformation. 99 | 100 | Args: 101 | data (dictionary): dictionary with None and 'normals' keys 102 | ''' 103 | data_out = data.copy() 104 | points = data[None] 105 | noise = self.stddev * np.random.randn(*points.shape) 106 | noise = noise.astype(np.float32) 107 | data_out[None] = points + noise 108 | return data_out -------------------------------------------------------------------------------- /generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch import autograd 4 | import numpy as np 5 | from tqdm import trange 6 | import trimesh 7 | from utils import libmcubes 8 | from common import make_3d_grid 9 | from utils.libsimplify import simplify_mesh 10 | from utils.libmise import MISE 11 | import time 12 | import logging 13 | 14 | # from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations, axis_angle_to_matrix 15 | 16 | from transforms import SubSamplePairBatchIP, CentralizePairBatchIP, RotatePairBatchIP 17 | class Generator3D(object): 18 | ''' Generator class for Occupancy Networks. 19 | 20 | It provides functions to generate the final mesh as well refining options. 21 | 22 | Args: 23 | model (nn.Module): trained Occupancy Network model 24 | points_batch_size (int): batch size for points evaluation 25 | threshold (float): threshold value 26 | refinement_step (int): number of refinement steps 27 | device (device): pytorch device 28 | resolution_0 (int): start resolution for MISE 29 | upsampling steps (int): number of upsampling steps 30 | with_normals (bool): whether normals should be estimated 31 | padding (float): how much padding should be used for MISE 32 | use_sampling (bool): whether z should be sampled 33 | simplify_nfaces (int): number of faces the mesh should be simplified to 34 | preprocessor (nn.Module): preprocessor for inputs 35 | ''' 36 | 37 | def __init__(self, model, points_batch_size=100000, 38 | threshold=0.5, refinement_step=0, device=None, 39 | resolution_0=16, upsampling_steps=3, 40 | with_normals=False, padding=0.1, use_sampling=False, 41 | simplify_nfaces=None, 42 | preprocessor=None, 43 | rotate=-1, 44 | noise=0, 45 | centralize=False, 46 | n1=0, n2=0, subsamp=True, reg_benchmark=False, transform_test=None, **kwargs, 47 | ): 48 | self.model = model.to(device) 49 | self.points_batch_size = points_batch_size 50 | self.refinement_step = refinement_step 51 | self.threshold = threshold 52 | self.device = device 53 | self.resolution_0 = resolution_0 54 | self.upsampling_steps = upsampling_steps 55 | self.with_normals = with_normals 56 | self.padding = padding 57 | self.use_sampling = use_sampling 58 | self.simplify_nfaces = simplify_nfaces 59 | self.preprocessor = preprocessor 60 | 61 | self.transform_test = transform_test 62 | 63 | # # self.rotate = rotate 64 | # # self.noise = noise # noise only effective when not sampling different points 65 | # self.subsamp = subsamp 66 | # self.sub_op = SubSamplePairBatchIP(n1, n2, n2, device) if subsamp else None 67 | # # self.rotate_op = RotatePairBatchIP() 68 | # self.centralize = centralize 69 | # self.ctr_op = CentralizePairBatchIP() if centralize else None 70 | 71 | def generate_latent_conditioned(self, data): 72 | self.model.eval() 73 | # device = self.device 74 | # stats_dict = {} 75 | 76 | if self.transform_test is not None: 77 | self.transform_test(data) 78 | 79 | inputs = data['inputs'] 80 | inputs_2 = data['inputs_2'] 81 | 82 | input_max = torch.max(torch.abs(inputs)) 83 | norm_max = torch.max(torch.norm(inputs, dim=-1)) 84 | logging.debug(f"max inf norm, max 2 norm, {input_max}, {norm_max}") 85 | 86 | # rot_d = {} 87 | # rot_d['angles'] = data['T21.deg'] 88 | # # rot_d['trot'] = trot 89 | # rot_d['rotmats'] = data['T21'] 90 | # # rot_d['t'] = t 91 | 92 | # Encode inputs 93 | # t0 = time.time() 94 | with torch.no_grad(): 95 | c = self.model.encode_inputs(inputs) 96 | c_rot = self.model.encode_inputs(inputs_2) 97 | 98 | return c, c_rot 99 | -------------------------------------------------------------------------------- /models/legacy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers import ResnetBlockFC, AffineLayer 5 | 6 | 7 | class VoxelDecoder(nn.Module): 8 | def __init__(self, dim=3, z_dim=128, c_dim=128, hidden_size=128): 9 | super().__init__() 10 | self.c_dim = c_dim 11 | self.z_dim = z_dim 12 | # Submodules 13 | self.actvn = F.relu 14 | # 3D decoder 15 | self.fc_in = nn.Linear(c_dim + z_dim, 256*4*4*4) 16 | self.convtrp_0 = nn.ConvTranspose3d(256, 128, 3, stride=2, 17 | padding=1, output_padding=1) 18 | self.convtrp_1 = nn.ConvTranspose3d(128, 64, 3, stride=2, 19 | padding=1, output_padding=1) 20 | self.convtrp_2 = nn.ConvTranspose3d(64, 32, 3, stride=2, 21 | padding=1, output_padding=1) 22 | # Fully connected decoder 23 | self.z_dim = z_dim 24 | if not z_dim == 0: 25 | self.fc_z = nn.Linear(z_dim, hidden_size) 26 | self.fc_f = nn.Linear(32, hidden_size) 27 | self.fc_c = nn.Linear(c_dim, hidden_size) 28 | self.fc_p = nn.Linear(dim, hidden_size) 29 | 30 | self.block0 = ResnetBlockFC(hidden_size, hidden_size) 31 | self.block1 = ResnetBlockFC(hidden_size, hidden_size) 32 | self.fc_out = nn.Linear(hidden_size, 1) 33 | 34 | def forward(self, p, z, c, **kwargs): 35 | batch_size = c.size(0) 36 | 37 | if self.z_dim != 0: 38 | net = torch.cat([z, c], dim=1) 39 | else: 40 | net = c 41 | 42 | net = self.fc_in(net) 43 | net = net.view(batch_size, 256, 4, 4, 4) 44 | net = self.convtrp_0(self.actvn(net)) 45 | net = self.convtrp_1(self.actvn(net)) 46 | net = self.convtrp_2(self.actvn(net)) 47 | 48 | net = F.grid_sample( 49 | net, 2*p.unsqueeze(1).unsqueeze(1), padding_mode='border') 50 | net = net.squeeze(2).squeeze(2).transpose(1, 2) 51 | net = self.fc_f(self.actvn(net)) 52 | 53 | net_p = self.fc_p(p) 54 | net = net + net_p 55 | 56 | if self.z_dim != 0: 57 | net_z = self.fc_z(z).unsqueeze(1) 58 | net = net + net_z 59 | 60 | if self.c_dim != 0: 61 | net_c = self.fc_c(c).unsqueeze(1) 62 | net = net + net_c 63 | 64 | net = self.block0(net) 65 | net = self.block1(net) 66 | 67 | out = self.fc_out(self.actvn(net)) 68 | out = out.squeeze(-1) 69 | 70 | return out 71 | 72 | 73 | class FeatureDecoder(nn.Module): 74 | def __init__(self, dim=3, z_dim=128, c_dim=128, hidden_size=256): 75 | super().__init__() 76 | self.z_dim = z_dim 77 | self.c_dim = c_dim 78 | self.dim = dim 79 | 80 | self.actvn = nn.ReLU() 81 | 82 | self.affine = AffineLayer(c_dim, dim) 83 | if not z_dim == 0: 84 | self.fc_z = nn.Linear(z_dim, hidden_size) 85 | self.fc_p1 = nn.Linear(dim, hidden_size) 86 | self.fc_p2 = nn.Linear(dim, hidden_size) 87 | 88 | self.fc_c1 = nn.Linear(c_dim, hidden_size) 89 | self.fc_c2 = nn.Linear(c_dim, hidden_size) 90 | 91 | self.block0 = ResnetBlockFC(hidden_size, hidden_size) 92 | self.block1 = ResnetBlockFC(hidden_size, hidden_size) 93 | self.block2 = ResnetBlockFC(hidden_size, hidden_size) 94 | self.block3 = ResnetBlockFC(hidden_size, hidden_size) 95 | 96 | self.fc_out = nn.Linear(hidden_size, 1) 97 | 98 | def forward(self, p, z, c, **kwargs): 99 | batch_size, T, D = p.size() 100 | 101 | c1 = c.view(batch_size, self.c_dim, -1).max(dim=2)[0] 102 | Ap = self.affine(c1, p) 103 | Ap2 = Ap[:, :, :2] / (Ap[:, :, 2:].abs() + 1e-5) 104 | 105 | c2 = F.grid_sample(c, 2*Ap2.unsqueeze(1), padding_mode='border') 106 | c2 = c2.squeeze(2).transpose(1, 2) 107 | 108 | net = self.fc_p1(p) + self.fc_p2(Ap) 109 | 110 | if self.z_dim != 0: 111 | net_z = self.fc_z(z).unsqueeze(1) 112 | net = net + net_z 113 | 114 | net_c = self.fc_c2(c2) + self.fc_c1(c1).unsqueeze(1) 115 | net = net + net_c 116 | 117 | net = self.block0(net) 118 | net = self.block1(net) 119 | net = self.block2(net) 120 | net = self.block3(net) 121 | 122 | out = self.fc_out(self.actvn(net)) 123 | out = out.squeeze(-1) 124 | 125 | return out -------------------------------------------------------------------------------- /models/occnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributions as dist 4 | 5 | class OccupancyNetwork(nn.Module): 6 | ''' Occupancy Network class. 7 | 8 | Args: 9 | decoder (nn.Module): decoder network 10 | encoder (nn.Module): encoder network 11 | encoder_latent (nn.Module): latent encoder network 12 | p0_z (dist): prior distribution for latent code z 13 | device (device): torch device 14 | ''' 15 | 16 | def __init__(self, decoder, encoder=None, encoder_latent=None, p0_z=None, 17 | device=None): 18 | super().__init__() 19 | if p0_z is None: 20 | p0_z = dist.Normal(torch.tensor([]), torch.tensor([])) 21 | 22 | self.decoder = decoder.to(device) 23 | 24 | if encoder_latent is not None: 25 | self.encoder_latent = encoder_latent.to(device) 26 | else: 27 | self.encoder_latent = None 28 | 29 | if encoder is not None: 30 | self.encoder = encoder.to(device) 31 | else: 32 | self.encoder = None 33 | 34 | self._device = device 35 | self.p0_z = p0_z 36 | 37 | def forward(self, p, inputs, sample=True, **kwargs): 38 | ''' Performs a forward pass through the network. 39 | 40 | Args: 41 | p (tensor): sampled points 42 | inputs (tensor): conditioning input 43 | sample (bool): whether to sample for z 44 | ''' 45 | batch_size = p.size(0) 46 | c = self.encode_inputs(inputs) 47 | z = self.get_z_from_prior((batch_size,), sample=sample) 48 | p_r = self.decode(p, z, c, **kwargs) 49 | return p_r 50 | 51 | def compute_elbo(self, p, occ, inputs, **kwargs): 52 | ''' Computes the expectation lower bound. 53 | 54 | Args: 55 | p (tensor): sampled points 56 | occ (tensor): occupancy values for p 57 | inputs (tensor): conditioning input 58 | ''' 59 | c = self.encode_inputs(inputs) 60 | q_z = self.infer_z(p, occ, c, **kwargs) 61 | z = q_z.rsample() 62 | p_r = self.decode(p, z, c, **kwargs) 63 | 64 | rec_error = -p_r.log_prob(occ).sum(dim=-1) 65 | kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1) 66 | elbo = -rec_error - kl 67 | 68 | return elbo, rec_error, kl 69 | 70 | def encode_inputs(self, inputs): 71 | ''' Encodes the input. 72 | 73 | Args: 74 | input (tensor): the input 75 | ''' 76 | 77 | if self.encoder is not None: 78 | c = self.encoder(inputs) 79 | else: 80 | # Return inputs? 81 | c = torch.empty(inputs.size(0), 0) 82 | 83 | return c 84 | 85 | def decode(self, p, z, c, **kwargs): 86 | ''' Returns occupancy probabilities for the sampled points. 87 | 88 | Args: 89 | p (tensor): points 90 | z (tensor): latent code z 91 | c (tensor): latent conditioned code c 92 | ''' 93 | 94 | logits = self.decoder(p, z, c, **kwargs) 95 | p_r = dist.Bernoulli(logits=logits) 96 | return p_r 97 | 98 | def infer_z(self, p, occ, c, **kwargs): 99 | ''' Infers z. 100 | 101 | Args: 102 | p (tensor): points tensor 103 | occ (tensor): occupancy values for occ 104 | c (tensor): latent conditioned code c 105 | ''' 106 | if self.encoder_latent is not None: 107 | mean_z, logstd_z = self.encoder_latent(p, occ, c, **kwargs) 108 | else: 109 | batch_size = p.size(0) 110 | mean_z = torch.empty(batch_size, 0).to(self._device) 111 | logstd_z = torch.empty(batch_size, 0).to(self._device) 112 | 113 | q_z = dist.Normal(mean_z, torch.exp(logstd_z)) 114 | return q_z 115 | 116 | def get_z_from_prior(self, size=torch.Size([]), sample=True): 117 | ''' Returns z from prior distribution. 118 | 119 | Args: 120 | size (Size): size of z 121 | sample (bool): whether to sample 122 | ''' 123 | if sample: 124 | z = self.p0_z.sample(size).to(self._device) 125 | else: 126 | z = self.p0_z.mean.to(self._device) 127 | z = z.expand(*size, *z.size()) 128 | 129 | return z 130 | 131 | def to(self, device): 132 | ''' Puts the model to the device. 133 | 134 | Args: 135 | device (device): pytorch device 136 | ''' 137 | model = super().to(device) 138 | model._device = device 139 | return model 140 | -------------------------------------------------------------------------------- /se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transfomation group and corresponding Lie algebra. """ 2 | import torch 3 | from sinc import sinc1, sinc2, sinc3 4 | import so3 5 | 6 | def twist_prod(x, y): 7 | x_ = x.view(-1, 6) 8 | y_ = y.view(-1, 6) 9 | 10 | xw, xv = x_[:, 0:3], x_[:, 3:6] 11 | yw, yv = y_[:, 0:3], y_[:, 3:6] 12 | 13 | zw = so3.cross_prod(xw, yw) 14 | zv = so3.cross_prod(xw, yv) + so3.cross_prod(xv, yw) 15 | 16 | z = torch.cat((zw, zv), dim=1) 17 | 18 | return z.view_as(x) 19 | 20 | def liebracket(x, y): 21 | return twist_prod(x, y) 22 | 23 | 24 | def mat(x): 25 | # size: [*, 6] -> [*, 4, 4] 26 | x_ = x.view(-1, 6) 27 | w1, w2, w3 = x_[:, 0], x_[:, 1], x_[:, 2] 28 | v1, v2, v3 = x_[:, 3], x_[:, 4], x_[:, 5] 29 | O = torch.zeros_like(w1) 30 | 31 | X = torch.stack(( 32 | torch.stack(( O, -w3, w2, v1), dim=1), 33 | torch.stack(( w3, O, -w1, v2), dim=1), 34 | torch.stack((-w2, w1, O, v3), dim=1), 35 | torch.stack(( O, O, O, O), dim=1)), dim=1) 36 | return X.view(*(x.size()[0:-1]), 4, 4) 37 | 38 | def vec(X): 39 | X_ = X.view(-1, 4, 4) 40 | w1, w2, w3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 41 | v1, v2, v3 = X_[:, 0, 3], X_[:, 1, 3], X_[:, 2, 3] 42 | x = torch.stack((w1, w2, w3, v1, v2, v3), dim=1) 43 | return x.view(*X.size()[0:-2], 6) 44 | 45 | def genvec(): 46 | return torch.eye(6) 47 | 48 | def genmat(): 49 | return mat(genvec()) 50 | 51 | def exp(x): 52 | x_ = x.view(-1, 6) 53 | w, v = x_[:, 0:3], x_[:, 3:6] 54 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 55 | W = so3.mat(w) 56 | S = W.bmm(W) 57 | I = torch.eye(3).to(w) 58 | 59 | # Rodrigues' rotation formula. 60 | #R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 61 | # = eye(3) + sinc1(t)*W + sinc2(t)*S 62 | R = I + sinc1(t)*W + sinc2(t)*S 63 | 64 | #V = sinc1(t)*eye(3) + sinc2(t)*W + sinc3(t)*(w*w') 65 | # = eye(3) + sinc2(t)*W + sinc3(t)*S 66 | V = I + sinc2(t)*W + sinc3(t)*S 67 | 68 | p = V.bmm(v.contiguous().view(-1, 3, 1)) 69 | 70 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(x_.size(0), 1, 1).to(x) 71 | Rp = torch.cat((R, p), dim=2) 72 | g = torch.cat((Rp, z), dim=1) 73 | 74 | return g.view(*(x.size()[0:-1]), 4, 4) 75 | 76 | def inverse(g): 77 | g_ = g.view(-1, 4, 4) 78 | R = g_[:, 0:3, 0:3] 79 | p = g_[:, 0:3, 3] 80 | Q = R.transpose(1, 2) 81 | q = -Q.matmul(p.unsqueeze(-1)) 82 | 83 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(g_.size(0), 1, 1).to(g) 84 | Qq = torch.cat((Q, q), dim=2) 85 | ig = torch.cat((Qq, z), dim=1) 86 | 87 | return ig.view(*(g.size()[0:-2]), 4, 4) 88 | 89 | 90 | def log(g): 91 | g_ = g.view(-1, 4, 4) 92 | R = g_[:, 0:3, 0:3] 93 | p = g_[:, 0:3, 3] 94 | 95 | w = so3.log(R) 96 | H = so3.inv_vecs_Xg_ig(w) 97 | v = H.bmm(p.contiguous().view(-1, 3, 1)).view(-1, 3) 98 | 99 | x = torch.cat((w, v), dim=1) 100 | return x.view(*(g.size()[0:-2]), 6) 101 | 102 | def transform(g, a): 103 | # g : SE(3), * x 4 x 4 104 | # a : R^3, * x 3[x N] 105 | g_ = g.view(-1, 4, 4) 106 | R = g_[:, 0:3, 0:3].contiguous().view(*(g.size()[0:-2]), 3, 3) 107 | p = g_[:, 0:3, 3].contiguous().view(*(g.size()[0:-2]), 3) 108 | if len(g.size()) == len(a.size()): 109 | b = R.matmul(a) + p.unsqueeze(-1) 110 | else: 111 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p 112 | return b 113 | 114 | def group_prod(g, h): 115 | # g, h : SE(3) 116 | g1 = g.matmul(h) 117 | return g1 118 | 119 | 120 | class ExpMap(torch.autograd.Function): 121 | """ Exp: se(3) -> SE(3) 122 | """ 123 | @staticmethod 124 | def forward(ctx, x): 125 | """ Exp: R^6 -> M(4), 126 | size: [B, 6] -> [B, 4, 4], 127 | or [B, 1, 6] -> [B, 1, 4, 4] 128 | """ 129 | ctx.save_for_backward(x) 130 | g = exp(x) 131 | return g 132 | 133 | @staticmethod 134 | def backward(ctx, grad_output): 135 | x, = ctx.saved_tensors 136 | g = exp(x) 137 | gen_k = genmat().to(x) 138 | 139 | # Let z = f(g) = f(exp(x)) 140 | # dz = df/dgij * dgij/dxk * dxk 141 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 142 | # = df/dgij * [gen_k*g]_ij * dxk 143 | 144 | dg = gen_k.matmul(g.view(-1, 1, 4, 4)) 145 | # (k, i, j) 146 | dg = dg.to(grad_output) 147 | 148 | go = grad_output.contiguous().view(-1, 1, 4, 4) 149 | dd = go * dg 150 | grad_input = dd.sum(-1).sum(-1) 151 | 152 | return grad_input 153 | 154 | Exp = ExpMap.apply 155 | 156 | 157 | #EOF 158 | -------------------------------------------------------------------------------- /utils/libmcubes/pyarraymodule.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _EXTMODULE_H 3 | #define _EXTMODULE_H 4 | 5 | #include 6 | #include 7 | 8 | // #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 9 | #define PY_ARRAY_UNIQUE_SYMBOL mcubes_PyArray_API 10 | #define NO_IMPORT_ARRAY 11 | #include "numpy/arrayobject.h" 12 | 13 | #include 14 | 15 | template 16 | struct numpy_typemap; 17 | 18 | #define define_numpy_type(ctype, dtype) \ 19 | template<> \ 20 | struct numpy_typemap \ 21 | {static const int type = dtype;}; 22 | 23 | define_numpy_type(bool, NPY_BOOL); 24 | define_numpy_type(char, NPY_BYTE); 25 | define_numpy_type(short, NPY_SHORT); 26 | define_numpy_type(int, NPY_INT); 27 | define_numpy_type(long, NPY_LONG); 28 | define_numpy_type(long long, NPY_LONGLONG); 29 | define_numpy_type(unsigned char, NPY_UBYTE); 30 | define_numpy_type(unsigned short, NPY_USHORT); 31 | define_numpy_type(unsigned int, NPY_UINT); 32 | define_numpy_type(unsigned long, NPY_ULONG); 33 | define_numpy_type(unsigned long long, NPY_ULONGLONG); 34 | define_numpy_type(float, NPY_FLOAT); 35 | define_numpy_type(double, NPY_DOUBLE); 36 | define_numpy_type(long double, NPY_LONGDOUBLE); 37 | define_numpy_type(std::complex, NPY_CFLOAT); 38 | define_numpy_type(std::complex, NPY_CDOUBLE); 39 | define_numpy_type(std::complex, NPY_CLONGDOUBLE); 40 | 41 | template 42 | T PyArray_SafeGet(const PyArrayObject* aobj, const npy_intp* indaux) 43 | { 44 | // HORROR. 45 | npy_intp* ind = const_cast(indaux); 46 | void* ptr = PyArray_GetPtr(const_cast(aobj), ind); 47 | switch(PyArray_TYPE(aobj)) 48 | { 49 | case NPY_BOOL: 50 | return static_cast(*reinterpret_cast(ptr)); 51 | case NPY_BYTE: 52 | return static_cast(*reinterpret_cast(ptr)); 53 | case NPY_SHORT: 54 | return static_cast(*reinterpret_cast(ptr)); 55 | case NPY_INT: 56 | return static_cast(*reinterpret_cast(ptr)); 57 | case NPY_LONG: 58 | return static_cast(*reinterpret_cast(ptr)); 59 | case NPY_LONGLONG: 60 | return static_cast(*reinterpret_cast(ptr)); 61 | case NPY_UBYTE: 62 | return static_cast(*reinterpret_cast(ptr)); 63 | case NPY_USHORT: 64 | return static_cast(*reinterpret_cast(ptr)); 65 | case NPY_UINT: 66 | return static_cast(*reinterpret_cast(ptr)); 67 | case NPY_ULONG: 68 | return static_cast(*reinterpret_cast(ptr)); 69 | case NPY_ULONGLONG: 70 | return static_cast(*reinterpret_cast(ptr)); 71 | case NPY_FLOAT: 72 | return static_cast(*reinterpret_cast(ptr)); 73 | case NPY_DOUBLE: 74 | return static_cast(*reinterpret_cast(ptr)); 75 | case NPY_LONGDOUBLE: 76 | return static_cast(*reinterpret_cast(ptr)); 77 | default: 78 | throw std::runtime_error("data type not supported"); 79 | } 80 | } 81 | 82 | template 83 | T PyArray_SafeSet(PyArrayObject* aobj, const npy_intp* indaux, const T& value) 84 | { 85 | // HORROR. 86 | npy_intp* ind = const_cast(indaux); 87 | void* ptr = PyArray_GetPtr(aobj, ind); 88 | switch(PyArray_TYPE(aobj)) 89 | { 90 | case NPY_BOOL: 91 | *reinterpret_cast(ptr) = static_cast(value); 92 | break; 93 | case NPY_BYTE: 94 | *reinterpret_cast(ptr) = static_cast(value); 95 | break; 96 | case NPY_SHORT: 97 | *reinterpret_cast(ptr) = static_cast(value); 98 | break; 99 | case NPY_INT: 100 | *reinterpret_cast(ptr) = static_cast(value); 101 | break; 102 | case NPY_LONG: 103 | *reinterpret_cast(ptr) = static_cast(value); 104 | break; 105 | case NPY_LONGLONG: 106 | *reinterpret_cast(ptr) = static_cast(value); 107 | break; 108 | case NPY_UBYTE: 109 | *reinterpret_cast(ptr) = static_cast(value); 110 | break; 111 | case NPY_USHORT: 112 | *reinterpret_cast(ptr) = static_cast(value); 113 | break; 114 | case NPY_UINT: 115 | *reinterpret_cast(ptr) = static_cast(value); 116 | break; 117 | case NPY_ULONG: 118 | *reinterpret_cast(ptr) = static_cast(value); 119 | break; 120 | case NPY_ULONGLONG: 121 | *reinterpret_cast(ptr) = static_cast(value); 122 | break; 123 | case NPY_FLOAT: 124 | *reinterpret_cast(ptr) = static_cast(value); 125 | break; 126 | case NPY_DOUBLE: 127 | *reinterpret_cast(ptr) = static_cast(value); 128 | break; 129 | case NPY_LONGDOUBLE: 130 | *reinterpret_cast(ptr) = static_cast(value); 131 | break; 132 | default: 133 | throw std::runtime_error("data type not supported"); 134 | } 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | import logging 6 | 7 | class CheckpointIO(object): 8 | ''' CheckpointIO class. 9 | 10 | It handles saving and loading checkpoints. 11 | 12 | Args: 13 | checkpoint_dir (str): path where checkpoints are saved 14 | ''' 15 | def __init__(self, model, optimizer=None, lr_scheduler=None, checkpoint_dir='./chkpts'): 16 | self.model = model 17 | self.optimizer = optimizer 18 | self.lr_scheduler = lr_scheduler 19 | # self.module_dict = kwargs 20 | self.checkpoint_dir = checkpoint_dir 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def set_selection_criteria(self, model_selection_metric, model_selection_sign, metric_val_best=None): 25 | self.model_selection_metric = model_selection_metric 26 | self.model_selection_sign = model_selection_sign 27 | self.metric_val_best = metric_val_best 28 | 29 | def save_if_best(self, eval_dict, it, epoch_it, **kwargs): 30 | metric_val = eval_dict[self.model_selection_metric] 31 | logging.info('Validation metric (%s): %.4f' 32 | % (self.model_selection_metric, metric_val)) 33 | 34 | if self.model_selection_sign * (metric_val - self.metric_val_best) > 0: 35 | self.metric_val_best = metric_val 36 | logging.info('New best model (loss %.4f)' % self.metric_val_best) 37 | self.save('model_best.pt', loss_val_best=self.metric_val_best, it=it, epoch_it=epoch_it, **kwargs) 38 | 39 | def save_latest(self, it, epoch_it, **kwargs): 40 | logging.info('Saving checkpoint model.pt') 41 | self.save('model.pt', loss_val_best=self.metric_val_best, it=it, epoch_it=epoch_it, **kwargs) 42 | 43 | def save_process(self, it, epoch_it, **kwargs): 44 | logging.info('Backup checkpoint model_%d.pt'%it) 45 | self.save('model_%d.pt'%it, loss_val_best=self.metric_val_best, it=it, epoch_it=epoch_it, **kwargs) 46 | # def register_modules(self, **kwargs): 47 | # ''' Registers modules in current module dictionary. 48 | # ''' 49 | # self.module_dict.update(kwargs) 50 | 51 | def save(self, filename, **kwargs): 52 | ''' Saves the current module dictionary. 53 | 54 | Args: 55 | filename (str): name of output file 56 | ''' 57 | if not os.path.isabs(filename): 58 | filename = os.path.join(self.checkpoint_dir, filename) 59 | 60 | outdict = kwargs 61 | outdict['model'] = self.model.state_dict() 62 | if self.optimizer is not None: 63 | outdict['optimizer'] = self.optimizer.state_dict() 64 | if self.lr_scheduler is not None: 65 | outdict['lr_scheduler'] = self.lr_scheduler.state_dict() 66 | 67 | torch.save(outdict, filename) 68 | 69 | def load(self, filename): 70 | '''Loads a module dictionary from local file or url. 71 | 72 | Args: 73 | filename (str): name of saved module dictionary 74 | ''' 75 | if is_url(filename): 76 | return self.load_url(filename) 77 | else: 78 | return self.load_file(filename) 79 | 80 | def load_file(self, filename): 81 | '''Loads a module dictionary from file. 82 | 83 | Args: 84 | filename (str): name of saved module dictionary 85 | ''' 86 | 87 | if not os.path.isabs(filename): 88 | filename = os.path.join(self.checkpoint_dir, filename) 89 | 90 | if os.path.exists(filename): 91 | print(filename) 92 | print('=> Loading checkpoint from local file...') 93 | state_dict = torch.load(filename, map_location=torch.device('cpu')) 94 | scalars = self.load_state_dict(state_dict) 95 | return scalars 96 | else: 97 | raise FileExistsError 98 | 99 | def load_url(self, url): 100 | '''Load a module dictionary from url. 101 | 102 | Args: 103 | url (str): url to saved model 104 | ''' 105 | print(url) 106 | print('=> Loading checkpoint from url...') 107 | state_dict = model_zoo.load_url(url, progress=True) 108 | scalars = self.load_state_dict(state_dict) 109 | return scalars 110 | 111 | def load_state_dict(self, state_dict): 112 | '''Parse state_dict of model and return scalars. 113 | 114 | Args: 115 | state_dict (dict): State dict of model 116 | ''' 117 | self.model.load_state_dict(state_dict.pop('model')) 118 | if self.optimizer is not None: 119 | try: 120 | self.optimizer.load_state_dict(state_dict.pop('optimizer')) 121 | except Exception as e: 122 | logging.warn('Cannot find optimizer in checkpoint: {}'.format(e)) 123 | if self.lr_scheduler is not None: 124 | try: 125 | self.lr_scheduler.load_state_dict(state_dict.pop('lr_scheduler')) 126 | except Exception as e: 127 | logging.warn('Cannot find lr_scheduler in checkpoint: {}'.format(e)) 128 | 129 | return state_dict 130 | 131 | def is_url(url): 132 | scheme = urllib.parse.urlparse(url).scheme 133 | return scheme in ('http', 'https') -------------------------------------------------------------------------------- /encoders/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torch.nn.functional as F 3 | from torchvision import models 4 | from common import normalize_imagenet 5 | 6 | 7 | class ConvEncoder(nn.Module): 8 | r''' Simple convolutional encoder network. 9 | 10 | It consists of 5 convolutional layers, each downsampling the input by a 11 | factor of 2, and a final fully-connected layer projecting the output to 12 | c_dim dimenions. 13 | 14 | Args: 15 | c_dim (int): output dimension of latent embedding 16 | ''' 17 | 18 | def __init__(self, c_dim=128): 19 | super().__init__() 20 | self.conv0 = nn.Conv2d(3, 32, 3, stride=2) 21 | self.conv1 = nn.Conv2d(32, 64, 3, stride=2) 22 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2) 23 | self.conv3 = nn.Conv2d(128, 256, 3, stride=2) 24 | self.conv4 = nn.Conv2d(256, 512, 3, stride=2) 25 | self.fc_out = nn.Linear(512, c_dim) 26 | self.actvn = nn.ReLU() 27 | 28 | def forward(self, x): 29 | batch_size = x.size(0) 30 | 31 | net = self.conv0(x) 32 | net = self.conv1(self.actvn(net)) 33 | net = self.conv2(self.actvn(net)) 34 | net = self.conv3(self.actvn(net)) 35 | net = self.conv4(self.actvn(net)) 36 | net = net.view(batch_size, 512, -1).mean(2) 37 | out = self.fc_out(self.actvn(net)) 38 | 39 | return out 40 | 41 | 42 | class Resnet18(nn.Module): 43 | r''' ResNet-18 encoder network for image input. 44 | Args: 45 | c_dim (int): output dimension of the latent embedding 46 | normalize (bool): whether the input images should be normalized 47 | use_linear (bool): whether a final linear layer should be used 48 | ''' 49 | 50 | def __init__(self, c_dim, normalize=True, use_linear=True): 51 | super().__init__() 52 | self.normalize = normalize 53 | self.use_linear = use_linear 54 | self.features = models.resnet18(pretrained=True) 55 | self.features.fc = nn.Sequential() 56 | if use_linear: 57 | self.fc = nn.Linear(512, c_dim) 58 | elif c_dim == 512: 59 | self.fc = nn.Sequential() 60 | else: 61 | raise ValueError('c_dim must be 512 if use_linear is False') 62 | 63 | def forward(self, x): 64 | if self.normalize: 65 | x = normalize_imagenet(x) 66 | net = self.features(x) 67 | out = self.fc(net) 68 | return out 69 | 70 | 71 | class Resnet34(nn.Module): 72 | r''' ResNet-34 encoder network. 73 | 74 | Args: 75 | c_dim (int): output dimension of the latent embedding 76 | normalize (bool): whether the input images should be normalized 77 | use_linear (bool): whether a final linear layer should be used 78 | ''' 79 | 80 | def __init__(self, c_dim, normalize=True, use_linear=True): 81 | super().__init__() 82 | self.normalize = normalize 83 | self.use_linear = use_linear 84 | self.features = models.resnet34(pretrained=True) 85 | self.features.fc = nn.Sequential() 86 | if use_linear: 87 | self.fc = nn.Linear(512, c_dim) 88 | elif c_dim == 512: 89 | self.fc = nn.Sequential() 90 | else: 91 | raise ValueError('c_dim must be 512 if use_linear is False') 92 | 93 | def forward(self, x): 94 | if self.normalize: 95 | x = normalize_imagenet(x) 96 | net = self.features(x) 97 | out = self.fc(net) 98 | return out 99 | 100 | 101 | class Resnet50(nn.Module): 102 | r''' ResNet-50 encoder network. 103 | 104 | Args: 105 | c_dim (int): output dimension of the latent embedding 106 | normalize (bool): whether the input images should be normalized 107 | use_linear (bool): whether a final linear layer should be used 108 | ''' 109 | 110 | def __init__(self, c_dim, normalize=True, use_linear=True): 111 | super().__init__() 112 | self.normalize = normalize 113 | self.use_linear = use_linear 114 | self.features = models.resnet50(pretrained=True) 115 | self.features.fc = nn.Sequential() 116 | if use_linear: 117 | self.fc = nn.Linear(2048, c_dim) 118 | elif c_dim == 2048: 119 | self.fc = nn.Sequential() 120 | else: 121 | raise ValueError('c_dim must be 2048 if use_linear is False') 122 | 123 | def forward(self, x): 124 | if self.normalize: 125 | x = normalize_imagenet(x) 126 | net = self.features(x) 127 | out = self.fc(net) 128 | return out 129 | 130 | 131 | class Resnet101(nn.Module): 132 | r''' ResNet-101 encoder network. 133 | Args: 134 | c_dim (int): output dimension of the latent embedding 135 | normalize (bool): whether the input images should be normalized 136 | use_linear (bool): whether a final linear layer should be used 137 | ''' 138 | 139 | def __init__(self, c_dim, normalize=True, use_linear=True): 140 | super().__init__() 141 | self.normalize = normalize 142 | self.use_linear = use_linear 143 | self.features = models.resnet50(pretrained=True) 144 | self.features.fc = nn.Sequential() 145 | if use_linear: 146 | self.fc = nn.Linear(2048, c_dim) 147 | elif c_dim == 2048: 148 | self.fc = nn.Sequential() 149 | else: 150 | raise ValueError('c_dim must be 2048 if use_linear is False') 151 | 152 | def forward(self, x): 153 | if self.normalize: 154 | x = normalize_imagenet(x) 155 | net = self.features(x) 156 | out = self.fc(net) 157 | return out 158 | -------------------------------------------------------------------------------- /so3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rotation group and corresponding Lie algebra """ 2 | import torch 3 | import sinc 4 | from sinc import sinc1, sinc2, sinc3 5 | 6 | 7 | def cross_prod(x, y): 8 | z = torch.cross(x.view(-1, 3), y.view(-1, 3), dim=1).view_as(x) 9 | return z 10 | 11 | def liebracket(x, y): 12 | return cross_prod(x, y) 13 | 14 | def mat(x): 15 | # size: [*, 3] -> [*, 3, 3] 16 | x_ = x.view(-1, 3) 17 | x1, x2, x3 = x_[:, 0], x_[:, 1], x_[:, 2] 18 | O = torch.zeros_like(x1) 19 | 20 | X = torch.stack(( 21 | torch.stack((O, -x3, x2), dim=1), 22 | torch.stack((x3, O, -x1), dim=1), 23 | torch.stack((-x2, x1, O), dim=1)), dim=1) 24 | return X.view(*(x.size()[0:-1]), 3, 3) 25 | 26 | def vec(X): 27 | X_ = X.view(-1, 3, 3) 28 | x1, x2, x3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 29 | x = torch.stack((x1, x2, x3), dim=1) 30 | return x.view(*X.size()[0:-2], 3) 31 | 32 | def genvec(): 33 | return torch.eye(3) 34 | 35 | def genmat(): 36 | return mat(genvec()) 37 | 38 | def RodriguesRotation(x): 39 | # for autograd 40 | w = x.view(-1, 3) 41 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 42 | W = mat(w) 43 | S = W.bmm(W) 44 | I = torch.eye(3).to(w) 45 | 46 | # Rodrigues' rotation formula. 47 | #R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 48 | #R = eye(3) + sinc1(t)*W + sinc2(t)*S 49 | 50 | R = I + sinc.Sinc1(t)*W + sinc.Sinc2(t)*S 51 | 52 | return R.view(*(x.size()[0:-1]), 3, 3) 53 | 54 | def exp(x): 55 | w = x.view(-1, 3) 56 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 57 | W = mat(w) 58 | S = W.bmm(W) 59 | I = torch.eye(3).to(w) 60 | 61 | # Rodrigues' rotation formula. 62 | #R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 63 | #R = eye(3) + sinc1(t)*W + sinc2(t)*S 64 | 65 | R = I + sinc1(t)*W + sinc2(t)*S 66 | 67 | return R.view(*(x.size()[0:-1]), 3, 3) 68 | 69 | def inverse(g): 70 | R = g.view(-1, 3, 3) 71 | Rt = R.transpose(1, 2) 72 | return Rt.view_as(g) 73 | 74 | def btrace(X): 75 | # batch-trace: [B, N, N] -> [B] 76 | n = X.size(-1) 77 | X_ = X.view(-1, n, n) 78 | tr = torch.zeros(X_.size(0)).to(X) 79 | for i in range(tr.size(0)): 80 | m = X_[i, :, :] 81 | tr[i] = torch.trace(m) 82 | return tr.view(*(X.size()[0:-2])) 83 | 84 | def log(g): 85 | eps = 1.0e-7 86 | R = g.view(-1, 3, 3) 87 | tr = btrace(R) 88 | c = (tr - 1) / 2 89 | t = torch.acos(c) 90 | sc = sinc1(t) 91 | idx0 = (torch.abs(sc) <= eps) 92 | idx1 = (torch.abs(sc) > eps) 93 | sc = sc.view(-1, 1, 1) 94 | 95 | X = torch.zeros_like(R) 96 | if idx1.any(): 97 | X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2*sc[idx1]) 98 | 99 | if idx0.any(): 100 | # t[idx0] == math.pi 101 | t2 = t[idx0] ** 2 102 | A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2 103 | aw1 = torch.sqrt(A[:, 0, 0]) 104 | aw2 = torch.sqrt(A[:, 1, 1]) 105 | aw3 = torch.sqrt(A[:, 2, 2]) 106 | sgn_3 = torch.sign(A[:, 0, 2]) 107 | sgn_3[sgn_3 == 0] = 1 108 | sgn_23 = torch.sign(A[:, 1, 2]) 109 | sgn_23[sgn_23 == 0] = 1 110 | sgn_2 = sgn_23 * sgn_3 111 | w1 = aw1 112 | w2 = aw2 * sgn_2 113 | w3 = aw3 * sgn_3 114 | w = torch.stack((w1, w2, w3), dim=-1) 115 | W = mat(w) 116 | X[idx0] = W 117 | 118 | x = vec(X.view_as(g)) 119 | return x 120 | 121 | def transform(g, a): 122 | # g in SO(3): * x 3 x 3 123 | # a in R^3: * x 3[x N] 124 | if len(g.size()) == len(a.size()): 125 | b = g.matmul(a) 126 | else: 127 | b = g.matmul(a.unsqueeze(-1)).squeeze(-1) 128 | return b 129 | 130 | def group_prod(g, h): 131 | # g, h : SO(3) 132 | g1 = g.matmul(h) 133 | return g1 134 | 135 | 136 | 137 | def vecs_Xg_ig(x): 138 | """ Vi = vec(dg/dxi * inv(g)), where g = exp(x) 139 | (== [Ad(exp(x))] * vecs_ig_Xg(x)) 140 | """ 141 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 142 | X = mat(x) 143 | S = X.bmm(X) 144 | #B = x.view(-1,3,1).bmm(x.view(-1,1,3)) # B = x*x' 145 | I = torch.eye(3).to(X) 146 | 147 | #V = sinc1(t)*eye(3) + sinc2(t)*X + sinc3(t)*B 148 | #V = eye(3) + sinc2(t)*X + sinc3(t)*S 149 | 150 | V = I + sinc2(t)*X + sinc3(t)*S 151 | 152 | return V.view(*(x.size()[0:-1]), 3, 3) 153 | 154 | def inv_vecs_Xg_ig(x): 155 | """ H = inv(vecs_Xg_ig(x)) """ 156 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 157 | X = mat(x) 158 | S = X.bmm(X) 159 | I = torch.eye(3).to(x) 160 | 161 | e = 0.01 162 | eta = torch.zeros_like(t) 163 | s = (t < e) 164 | c = (s == 0) 165 | t2 = t[s] ** 2 166 | eta[s] = ((t2/40 + 1)*t2/42 + 1)*t2/720 + 1/12 # O(t**8) 167 | eta[c] = (1 - (t[c]/2) / torch.tan(t[c]/2)) / (t[c]**2) 168 | 169 | H = I - 1/2*X + eta*S 170 | return H.view(*(x.size()[0:-1]), 3, 3) 171 | 172 | 173 | class ExpMap(torch.autograd.Function): 174 | """ Exp: so(3) -> SO(3) 175 | """ 176 | @staticmethod 177 | def forward(ctx, x): 178 | """ Exp: R^3 -> M(3), 179 | size: [B, 3] -> [B, 3, 3], 180 | or [B, 1, 3] -> [B, 1, 3, 3] 181 | """ 182 | ctx.save_for_backward(x) 183 | g = exp(x) 184 | return g 185 | 186 | @staticmethod 187 | def backward(ctx, grad_output): 188 | x, = ctx.saved_tensors 189 | g = exp(x) 190 | gen_k = genmat().to(x) 191 | #gen_1 = gen_k[0, :, :] 192 | #gen_2 = gen_k[1, :, :] 193 | #gen_3 = gen_k[2, :, :] 194 | 195 | # Let z = f(g) = f(exp(x)) 196 | # dz = df/dgij * dgij/dxk * dxk 197 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 198 | # = df/dgij * [gen_k*g]_ij * dxk 199 | 200 | dg = gen_k.matmul(g.view(-1, 1, 3, 3)) 201 | # (k, i, j) 202 | dg = dg.to(grad_output) 203 | 204 | go = grad_output.contiguous().view(-1, 1, 3, 3) 205 | dd = go * dg 206 | grad_input = dd.sum(-1).sum(-1) 207 | 208 | return grad_input 209 | 210 | Exp = ExpMap.apply 211 | 212 | 213 | #EOF 214 | -------------------------------------------------------------------------------- /sinc.py: -------------------------------------------------------------------------------- 1 | """ sinc(t) := sin(t) / t """ 2 | import torch 3 | from torch import sin, cos 4 | 5 | def sinc1(t): 6 | """ sinc1: t -> sin(t)/t """ 7 | e = 0.01 8 | r = torch.zeros_like(t) 9 | a = torch.abs(t) 10 | 11 | s = a < e 12 | c = (s == 0) 13 | t2 = t[s] ** 2 14 | r[s] = 1 - t2/6*(1 - t2/20*(1 - t2/42)) # Taylor series O(t^8) 15 | r[c] = sin(t[c]) / t[c] 16 | 17 | return r 18 | 19 | def sinc1_dt(t): 20 | """ d/dt(sinc1) """ 21 | e = 0.01 22 | r = torch.zeros_like(t) 23 | a = torch.abs(t) 24 | 25 | s = a < e 26 | c = (s == 0) 27 | t2 = t ** 2 28 | r[s] = -t[s]/3*(1 - t2[s]/10*(1 - t2[s]/28*(1 - t2[s]/54))) # Taylor series O(t^8) 29 | r[c] = cos(t[c])/t[c] - sin(t[c])/t2[c] 30 | 31 | return r 32 | 33 | def sinc1_dt_rt(t): 34 | """ d/dt(sinc1) / t """ 35 | e = 0.01 36 | r = torch.zeros_like(t) 37 | a = torch.abs(t) 38 | 39 | s = a < e 40 | c = (s == 0) 41 | t2 = t ** 2 42 | r[s] = -1/3*(1 - t2[s]/10*(1 - t2[s]/28*(1 - t2[s]/54))) # Taylor series O(t^8) 43 | r[c] = (cos(t[c]) / t[c] - sin(t[c]) / t2[c]) / t[c] 44 | 45 | return r 46 | 47 | 48 | def rsinc1(t): 49 | """ rsinc1: t -> t/sinc1(t) """ 50 | e = 0.01 51 | r = torch.zeros_like(t) 52 | a = torch.abs(t) 53 | 54 | s = a < e 55 | c = (s == 0) 56 | t2 = t[s] ** 2 57 | r[s] = (((31*t2)/42 + 7)*t2/60 + 1)*t2/6 + 1 # Taylor series O(t^8) 58 | r[c] = t[c] / sin(t[c]) 59 | 60 | return r 61 | 62 | def rsinc1_dt(t): 63 | """ d/dt(rsinc1) """ 64 | e = 0.01 65 | r = torch.zeros_like(t) 66 | a = torch.abs(t) 67 | 68 | s = a < e 69 | c = (s == 0) 70 | t2 = t[s] ** 2 71 | r[s] = ((((127*t2)/30 + 31)*t2/28 + 7)*t2/30 + 1)*t[s]/3 # Taylor series O(t^8) 72 | r[c] = 1/sin(t[c]) - (t[c]*cos(t[c]))/(sin(t[c])*sin(t[c])) 73 | 74 | return r 75 | 76 | def rsinc1_dt_csc(t): 77 | """ d/dt(rsinc1) / sin(t) """ 78 | e = 0.01 79 | r = torch.zeros_like(t) 80 | a = torch.abs(t) 81 | 82 | s = a < e 83 | c = (s == 0) 84 | t2 = t[s] ** 2 85 | r[s] = t2*(t2*((4*t2)/675 + 2/63) + 2/15) + 1/3 # Taylor series O(t^8) 86 | r[c] = (1/sin(t[c]) - (t[c]*cos(t[c]))/(sin(t[c])*sin(t[c]))) / sin(t[c]) 87 | 88 | return r 89 | 90 | 91 | def sinc2(t): 92 | """ sinc2: t -> (1 - cos(t)) / (t**2) """ 93 | e = 0.01 94 | r = torch.zeros_like(t) 95 | a = torch.abs(t) 96 | 97 | s = a < e 98 | c = (s == 0) 99 | t2 = t ** 2 100 | r[s] = 1/2*(1-t2[s]/12*(1-t2[s]/30*(1-t2[s]/56))) # Taylor series O(t^8) 101 | r[c] = (1-cos(t[c]))/t2[c] 102 | 103 | return r 104 | 105 | def sinc2_dt(t): 106 | """ d/dt(sinc2) """ 107 | e = 0.01 108 | r = torch.zeros_like(t) 109 | a = torch.abs(t) 110 | 111 | s = a < e 112 | c = (s == 0) 113 | t2 = t ** 2 114 | r[s] = -t[s]/12*(1 - t2[s]/5*(1.0/3 - t2[s]/56*(1.0/2 - t2[s]/135))) # Taylor series O(t^8) 115 | r[c] = sin(t[c])/t2[c] - 2*(1-cos(t[c]))/(t2[c]*t[c]) 116 | 117 | return r 118 | 119 | 120 | def sinc3(t): 121 | """ sinc3: t -> (t - sin(t)) / (t**3) """ 122 | e = 0.01 123 | r = torch.zeros_like(t) 124 | a = torch.abs(t) 125 | 126 | s = a < e 127 | c = (s == 0) 128 | t2 = t[s] ** 2 129 | r[s] = 1/6*(1-t2/20*(1-t2/42*(1-t2/72))) # Taylor series O(t^8) 130 | r[c] = (t[c]-sin(t[c]))/(t[c]**3) 131 | 132 | return r 133 | 134 | def sinc3_dt(t): 135 | """ d/dt(sinc3) """ 136 | e = 0.01 137 | r = torch.zeros_like(t) 138 | a = torch.abs(t) 139 | 140 | s = a < e 141 | c = (s == 0) 142 | t2 = t[s] ** 2 143 | r[s] = -t[s]/60*(1 - t2/21*(1 - t2/24*(1.0/2 - t2/165))) # Taylor series O(t^8) 144 | r[c] = (3*sin(t[c]) - t[c]*(cos(t[c]) + 2))/(t[c]**4) 145 | 146 | return r 147 | 148 | 149 | def sinc4(t): 150 | """ sinc4: t -> 1/t^2 * (1/2 - sinc2(t)) 151 | = 1/t^2 * (1/2 - (1 - cos(t))/t^2) 152 | """ 153 | e = 0.01 154 | r = torch.zeros_like(t) 155 | a = torch.abs(t) 156 | 157 | s = a < e 158 | c = (s == 0) 159 | t2 = t ** 2 160 | r[s] = 1/24*(1-t2/30*(1-t2/56*(1-t2/90))) # Taylor series O(t^8) 161 | r[c] = (0.5 - (1 - cos(t))/t2) / t2 162 | 163 | 164 | class Sinc1_autograd(torch.autograd.Function): 165 | @staticmethod 166 | def forward(ctx, theta): 167 | ctx.save_for_backward(theta) 168 | return sinc1(theta) 169 | 170 | @staticmethod 171 | def backward(ctx, grad_output): 172 | theta, = ctx.saved_tensors 173 | grad_theta = None 174 | if ctx.needs_input_grad[0]: 175 | grad_theta = grad_output * sinc1_dt(theta).to(grad_output) 176 | return grad_theta 177 | 178 | Sinc1 = Sinc1_autograd.apply 179 | 180 | class RSinc1_autograd(torch.autograd.Function): 181 | @staticmethod 182 | def forward(ctx, theta): 183 | ctx.save_for_backward(theta) 184 | return rsinc1(theta) 185 | 186 | @staticmethod 187 | def backward(ctx, grad_output): 188 | theta, = ctx.saved_tensors 189 | grad_theta = None 190 | if ctx.needs_input_grad[0]: 191 | grad_theta = grad_output * rsinc1_dt(theta).to(grad_output) 192 | return grad_theta 193 | 194 | RSinc1 = RSinc1_autograd.apply 195 | 196 | class Sinc2_autograd(torch.autograd.Function): 197 | @staticmethod 198 | def forward(ctx, theta): 199 | ctx.save_for_backward(theta) 200 | return sinc2(theta) 201 | 202 | @staticmethod 203 | def backward(ctx, grad_output): 204 | theta, = ctx.saved_tensors 205 | grad_theta = None 206 | if ctx.needs_input_grad[0]: 207 | grad_theta = grad_output * sinc2_dt(theta).to(grad_output) 208 | return grad_theta 209 | 210 | Sinc2 = Sinc2_autograd.apply 211 | 212 | class Sinc3_autograd(torch.autograd.Function): 213 | @staticmethod 214 | def forward(ctx, theta): 215 | ctx.save_for_backward(theta) 216 | return sinc3(theta) 217 | 218 | @staticmethod 219 | def backward(ctx, grad_output): 220 | theta, = ctx.saved_tensors 221 | grad_theta = None 222 | if ctx.needs_input_grad[0]: 223 | grad_theta = grad_output * sinc3_dt(theta).to(grad_output) 224 | return grad_theta 225 | 226 | Sinc3 = Sinc3_autograd.apply 227 | 228 | -------------------------------------------------------------------------------- /fmr_transforms.py: -------------------------------------------------------------------------------- 1 | """ gives some transform methods for 3d points """ 2 | import math 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | import so3 8 | import se3 9 | 10 | import numpy as np 11 | 12 | class Mesh2Points: 13 | def __init__(self): 14 | pass 15 | 16 | def __call__(self, mesh): 17 | mesh = mesh.clone() 18 | v = mesh.vertex_array 19 | return torch.from_numpy(v).type(dtype=torch.float) 20 | 21 | 22 | class OnUnitSphere: 23 | def __init__(self, zero_mean=False): 24 | self.zero_mean = zero_mean 25 | 26 | def __call__(self, tensor): 27 | if self.zero_mean: 28 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 29 | v = tensor - m 30 | else: 31 | v = tensor 32 | nn = v.norm(p=2, dim=1) # [N, D] -> [N] 33 | nmax = torch.max(nn) 34 | return v / nmax 35 | 36 | 37 | class OnUnitCube: 38 | def __init__(self): 39 | pass 40 | 41 | def method1(self, tensor): 42 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 43 | v = tensor - m 44 | s = torch.max(v.abs()) 45 | v = v / s * 0.5 46 | return v 47 | 48 | def method2(self, tensor, spec=None): 49 | if spec is not None: 50 | s, m = spec 51 | v = tensor / s 52 | return v - m 53 | else: 54 | c = torch.max(tensor, dim=0)[0] - torch.min(tensor, dim=0)[0] # [N, D] -> [D] 55 | s = torch.max(c) # -> scalar 56 | v = tensor / s 57 | m = v.mean(dim=0, keepdim=True) 58 | return v - m, (s, m) 59 | 60 | def __call__(self, tensor, spec=None): 61 | # return self.method1(tensor) 62 | return self.method2(tensor, spec) 63 | 64 | 65 | class Resampler: 66 | """ [N, D] -> [M, D] """ 67 | 68 | def __init__(self, num): 69 | self.num = num 70 | 71 | def __call__(self, tensor): 72 | num_points, dim_p = tensor.size() 73 | out = torch.zeros(self.num, dim_p).to(tensor) 74 | 75 | selected = 0 76 | while selected < self.num: 77 | remainder = self.num - selected 78 | idx = torch.randperm(num_points) 79 | sel = min(remainder, num_points) 80 | val = tensor[idx[:sel]] 81 | out[selected:(selected + sel)] = val 82 | selected += sel 83 | return out 84 | 85 | 86 | class RandomTranslate: 87 | def __init__(self, mag=None, randomly=True): 88 | self.mag = 1.0 if mag is None else mag 89 | self.randomly = randomly 90 | self.igt = None 91 | 92 | def __call__(self, tensor): 93 | # tensor: [N, 3] 94 | amp = torch.rand(1) if self.randomly else 1.0 95 | t = torch.randn(1, 3).to(tensor) 96 | t = t / t.norm(p=2, dim=1, keepdim=True) * amp * self.mag 97 | 98 | g = torch.eye(4).to(tensor) 99 | g[0:3, 3] = t[0, :] 100 | self.igt = g # [4, 4] 101 | 102 | p1 = tensor + t 103 | return p1 104 | 105 | 106 | class RandomRotator: 107 | def __init__(self, mag=None, randomly=True): 108 | self.mag = math.pi if mag is None else mag 109 | self.randomly = randomly 110 | self.igt = None 111 | 112 | def __call__(self, tensor): 113 | # tensor: [N, 3] 114 | amp = torch.rand(1) if self.randomly else 1.0 115 | w = torch.randn(1, 3) 116 | w = w / w.norm(p=2, dim=1, keepdim=True) * amp * self.mag 117 | 118 | g = so3.exp(w).to(tensor) # [1, 3, 3] 119 | self.igt = g.squeeze(0) # [3, 3] 120 | 121 | p1 = so3.transform(g, tensor) # [1, 3, 3] x [N, 3] -> [N, 3] 122 | return p1 123 | 124 | 125 | class RandomRotatorZ: 126 | def __init__(self): 127 | self.mag = 2 * math.pi 128 | 129 | def __call__(self, tensor): 130 | # tensor: [N, 3] 131 | w = torch.Tensor([0, 0, 1]).view(1, 3) * torch.rand(1) * self.mag 132 | 133 | g = so3.exp(w).to(tensor) # [1, 3, 3] 134 | 135 | p1 = so3.transform(g, tensor) 136 | return p1 137 | 138 | 139 | class RandomJitter: 140 | """ generate perturbations """ 141 | 142 | def __init__(self, scale=0.01, clip=0.05): 143 | self.scale = scale 144 | self.clip = clip 145 | self.e = None 146 | 147 | def jitter(self, tensor): 148 | noise = torch.zeros_like(tensor).to(tensor) # [N, 3] 149 | noise.normal_(0, self.scale) 150 | noise.clamp_(-self.clip, self.clip) 151 | self.e = noise 152 | return tensor.add(noise) 153 | 154 | def __call__(self, tensor): 155 | return self.jitter(tensor) 156 | 157 | 158 | class RandomTransformSE3: 159 | """ rigid motion """ 160 | 161 | def __init__(self, mag=1, mag_randomly=False, mag_trans=1): 162 | self.mag = mag * np.pi / 180 # Minghan: use deg as input! 163 | self.randomly = mag_randomly 164 | self.mag_trans = mag_trans 165 | 166 | self.gt = None 167 | self.igt = None 168 | 169 | def generate_transform(self): 170 | # return: a twist-vector 171 | amp = self.mag 172 | if self.randomly: 173 | amp = torch.rand(1, 1) * self.mag 174 | x = torch.randn(1, 3) 175 | x = x / x.norm(p=2, dim=1, keepdim=True) * amp 176 | 177 | amp = self.mag_trans 178 | if self.randomly: 179 | amp = torch.rand(1, 1) * self.mag_trans 180 | y = torch.randn(1, 3) 181 | y = y / y.norm(p=2, dim=1, keepdim=True) * amp 182 | 183 | x = torch.cat([x, y], dim=1) 184 | 185 | # x = torch.randn(1, 6) 186 | # x = x / x.norm(p=2, dim=1, keepdim=True) * amp 187 | 188 | '''a = torch.rand(3) 189 | a = a * math.pi 190 | b = torch.zeros(1, 6) 191 | b[:, 0:3] = a 192 | x = x+b 193 | ''' 194 | return x # [1, 6] 195 | 196 | def apply_transform(self, p0, x): 197 | # p0: [N, 3] 198 | # x: [1, 6] 199 | g = se3.exp(x).to(p0) # [1, 4, 4] 200 | gt = se3.exp(-x).to(p0) # [1, 4, 4] 201 | 202 | p1 = se3.transform(g, p0) 203 | self.gt = gt.squeeze(0) # gt: p1 -> p0 204 | self.igt = g.squeeze(0) # igt: p0 -> p1 205 | return p1 206 | 207 | def transform(self, tensor): 208 | x = self.generate_transform() 209 | return self.apply_transform(tensor, x) 210 | 211 | def __call__(self, tensor): 212 | return self.transform(tensor) 213 | 214 | # EOF 215 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | import misc 10 | import config 11 | from checkpoints import CheckpointIO 12 | from metricrecord import Metric, Record 13 | 14 | from transforms import apply_rot 15 | from register_utils import * 16 | 17 | def get_args(): 18 | # Arguments 19 | parser = argparse.ArgumentParser( 20 | description='Train a 3D reconstruction model.' 21 | ) 22 | parser.add_argument('config', type=str, help='Path to config file.') 23 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 24 | args = parser.parse_args() 25 | return args 26 | 27 | if __name__ == "__main__": 28 | args = get_args() 29 | 30 | cfg = misc.load_config(args.config, 'configs/default.yaml') 31 | 32 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 33 | device = torch.device("cuda" if is_cuda else "cpu") 34 | 35 | out_dir, gen_dir = config.cfg_f_out_test(cfg) 36 | 37 | dataset = config.cfg_dataset(cfg, 'test') 38 | test_loader = torch.utils.data.DataLoader( 39 | dataset, batch_size=1, num_workers=3, shuffle=False) 40 | 41 | model = config.cfg_model(cfg, device) 42 | 43 | checkpointio = CheckpointIO(model, checkpoint_dir=out_dir) 44 | checkpointio.load(cfg['testing']['model_file']) 45 | 46 | generator = config.cfg_generator(cfg, device, model) 47 | 48 | metric_dict = dict() 49 | metric_keys = ['angle', 'angle_180', 'angle_90-', 'angle_90+', 'rmse', 'rmse_tp'] 50 | for key in metric_keys: 51 | metric_dict[key] = Metric(key) 52 | 53 | worst_dict = dict() 54 | worst_dict['angle'] = Record('angle', 10, True) 55 | worst_dict['angle_180'] = Record('angle_180', 10, True) 56 | best_dict = dict() 57 | best_dict['angle'] = Record('angle', 10, False) 58 | 59 | # evaluate 60 | for it, data in enumerate(tqdm(test_loader)): 61 | 62 | for key, value in data.items(): 63 | if isinstance(value, torch.Tensor): 64 | data[key] = value.to(device) 65 | 66 | out_1, out_2_rot = generator.generate_latent_conditioned(data) 67 | batch_size = out_1.shape[0] 68 | out_1 = out_1.reshape(batch_size, -1, 3) 69 | out_2_rot = out_2_rot.reshape(batch_size, -1, 3) 70 | 71 | out_1 = out_1.to(torch.float64) 72 | out_2_rot = out_2_rot.to(torch.float64) 73 | 74 | R_est = solve_R(out_1, out_2_rot) 75 | 76 | ##### angular error 77 | R_gt = data['T21'].to(torch.float64) 78 | rotdeg = data['T21.deg'] 79 | 80 | angle_diff = angle_diff_func(R_est, R_gt) 81 | angle_diff = torch.abs(angle_diff).max().item() 82 | rotdeg = torch.abs(rotdeg).max().item() 83 | logging.debug("angle_diff (res, gt) %.4f %.4f"%(angle_diff, rotdeg ) ) 84 | 85 | ##### equivariant embedding error 86 | out_1_rot = apply_rot(R_gt, out_1) 87 | out_1_rot_est = apply_rot(R_est, out_1) 88 | 89 | diff_ori = out_1 - out_2_rot 90 | diff_gt = out_1_rot - out_2_rot 91 | diff_est = out_1_rot_est - out_2_rot 92 | 93 | diff_gt_inf = torch.abs(diff_gt).max().item() 94 | diff_ori_inf = torch.abs(diff_ori).max().item() 95 | diff_est_inf = torch.abs(diff_est).max().item() 96 | 97 | diff_gt_l2 = torch.norm(diff_gt, dim=2).mean().item() # == torch.sqrt((diff_gt**2).sum()) 98 | diff_ori_l2 = torch.norm(diff_ori, dim=2).mean().item() 99 | diff_est_l2 = torch.norm(diff_est, dim=2).mean().item() 100 | 101 | logging.debug("feat_diff_Linf (ori, est, gt) %.4f %.4f %.4f"%(diff_ori_inf, diff_est_inf, diff_gt_inf) ) 102 | logging.debug("feat_diff_L2 (ori, est, gt) %.4f %.4f %.4f"%(diff_ori_l2, diff_est_l2, diff_gt_l2) ) 103 | 104 | ### input alignment error 105 | input_1 = data['inputs'] 106 | input_2_rot = data['inputs_2'] 107 | input_1_rot = apply_rot(R_gt, input_1) 108 | 109 | # ### known issue: gpu 64 == cpu 32 == cpu 64 == numpy 32, but is slightly different from gpu 32 110 | # input_1_rot_64 = apply_rot(R_gt.to(torch.float64), input_1.to(torch.float64)).to(torch.float32) 111 | # input_1_rot_cpu = apply_rot(R_gt.cpu(), input_1.cpu()) 112 | # input_1_rot_cpu_64 = apply_rot(R_gt.cpu().to(torch.float64), input_1.cpu().to(torch.float64)).to(torch.float32) 113 | # input_1_rot_np = np.matmul(R_gt.detach().cpu().numpy()[0], input_1.detach().cpu().numpy()[0].swapaxes(-1, -2))[None, ...].swapaxes(-1, -2) 114 | 115 | input_1_rot_est = apply_rot(R_est, input_1) 116 | input_2 = apply_rot(torch.inverse(R_gt), input_2_rot) 117 | 118 | if input_1.shape == input_2_rot.shape: 119 | diff_pts_rmse_ori = torch.norm(input_1 - input_2_rot, dim=2).mean() 120 | diff_pts_rmse_gt = torch.norm(input_1_rot - input_2_rot, dim=2).mean() 121 | diff_pts_rmse_est = torch.norm(input_1_rot_est - input_2_rot, dim=2).mean() 122 | logging.debug("diff_pts_rmse (ori, est, gt) %.4f %.4f %.4f"%(diff_pts_rmse_ori, diff_pts_rmse_gt, diff_pts_rmse_est) ) 123 | 124 | # update the metrics 125 | metric_dict['angle'].update(angle_diff) 126 | if angle_diff > 90: 127 | metric_dict['angle_180'].update(180-angle_diff) 128 | metric_dict['angle_90+'].update(180-angle_diff) 129 | else: 130 | metric_dict['angle_180'].update(angle_diff) 131 | metric_dict['angle_90-'].update(angle_diff) 132 | 133 | if input_1.shape == input_2_rot.shape: 134 | metric_dict['rmse'].update(diff_pts_rmse_est) 135 | if diff_pts_rmse_est < 0.2: 136 | metric_dict['rmse_tp'].update(diff_pts_rmse_est) 137 | 138 | ds_cur = dict() 139 | ds_cur['pcl_1'] = input_1.squeeze(0).cpu().numpy() # N*3 140 | ds_cur['pcl_1_rot_est'] = input_1_rot_est.squeeze(0).cpu().numpy() # N*3 141 | ds_cur['pcl_1_rot_gt'] = input_1_rot.squeeze(0).cpu().numpy() # N*3 142 | ds_cur['pcl_2'] = input_2_rot.squeeze(0).cpu().numpy() # N*3 143 | # ds_cur['category_name'] = str(category_id) 144 | 145 | worst_dict['angle'].update(angle_diff, ds_cur) 146 | if angle_diff > 90: 147 | worst_dict['angle_180'].update(180-angle_diff, ds_cur) 148 | else: 149 | worst_dict['angle_180'].update(angle_diff, ds_cur) 150 | best_dict['angle'].update(angle_diff, ds_cur) 151 | 152 | # break 153 | 154 | for key in metric_dict: 155 | logging.info(metric_dict[key]) -------------------------------------------------------------------------------- /utils/libkdtree/README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://travis-ci.org/storpipfugl/pykdtree.svg?branch=master 2 | :target: https://travis-ci.org/storpipfugl/pykdtree 3 | .. image:: https://ci.appveyor.com/api/projects/status/ubo92368ktt2d25g/branch/master 4 | :target: https://ci.appveyor.com/project/storpipfugl/pykdtree 5 | 6 | ======== 7 | pykdtree 8 | ======== 9 | 10 | Objective 11 | --------- 12 | pykdtree is a kd-tree implementation for fast nearest neighbour search in Python. 13 | The aim is to be the fastest implementation around for common use cases (low dimensions and low number of neighbours) for both tree construction and queries. 14 | 15 | The implementation is based on scipy.spatial.cKDTree and libANN by combining the best features from both and focus on implementation efficiency. 16 | 17 | The interface is similar to that of scipy.spatial.cKDTree except only Euclidean distance measure is supported. 18 | 19 | Queries are optionally multithreaded using OpenMP. 20 | 21 | Installation 22 | ------------ 23 | Default build of pykdtree with OpenMP enabled queries using libgomp 24 | 25 | .. code-block:: bash 26 | 27 | $ cd 28 | $ python setup.py install 29 | 30 | If it fails with undefined compiler flags or you want to use another OpenMP implementation please modify setup.py at the indicated point to match your system. 31 | 32 | Building without OpenMP support is controlled by the USE_OMP environment variable 33 | 34 | .. code-block:: bash 35 | 36 | $ cd 37 | $ export USE_OMP=0 38 | $ python setup.py install 39 | 40 | Note evironment variables are by default not exported when using sudo so in this case do 41 | 42 | .. code-block:: bash 43 | 44 | $ USE_OMP=0 sudo -E python setup.py install 45 | 46 | Usage 47 | ----- 48 | The usage of pykdtree is similar to scipy.spatial.cKDTree so for now refer to its documentation 49 | 50 | >>> from pykdtree.kdtree import KDTree 51 | >>> kd_tree = KDTree(data_pts) 52 | >>> dist, idx = kd_tree.query(query_pts, k=8) 53 | 54 | The number of threads to be used in OpenMP enabled queries can be controlled with the standard OpenMP environment variable OMP_NUM_THREADS. 55 | 56 | The **leafsize** argument (number of data points per leaf) for the tree creation can be used to control the memory overhead of the kd-tree. pykdtree uses a default **leafsize=16**. 57 | Increasing **leafsize** will reduce the memory overhead and construction time but increase query time. 58 | 59 | pykdtree accepts data in double precision (numpy.float64) or single precision (numpy.float32) floating point. If data of another type is used an internal copy in double precision is made resulting in a memory overhead. If the kd-tree is constructed on single precision data the query points must be single precision as well. 60 | 61 | Benchmarks 62 | ---------- 63 | Comparison with scipy.spatial.cKDTree and libANN. This benchmark is on geospatial 3D data with 10053632 data points and 4276224 query points. The results are indexed relative to the construction time of scipy.spatial.cKDTree. A leafsize of 10 (scipy.spatial.cKDTree default) is used. 64 | 65 | Note: libANN is *not* thread safe. In this benchmark libANN is compiled with "-O3 -funroll-loops -ffast-math -fprefetch-loop-arrays" in order to achieve optimum performance. 66 | 67 | ================== ===================== ====== ======== ================== 68 | Operation scipy.spatial.cKDTree libANN pykdtree pykdtree 4 threads 69 | ------------------ --------------------- ------ -------- ------------------ 70 | 71 | Construction 100 304 96 96 72 | 73 | query 1 neighbour 1267 294 223 70 74 | 75 | Total 1 neighbour 1367 598 319 166 76 | 77 | query 8 neighbours 2193 625 449 143 78 | 79 | Total 8 neighbours 2293 929 545 293 80 | ================== ===================== ====== ======== ================== 81 | 82 | Looking at the combined construction and query this gives the following performance improvement relative to scipy.spatial.cKDTree 83 | 84 | ========== ====== ======== ================== 85 | Neighbours libANN pykdtree pykdtree 4 threads 86 | ---------- ------ -------- ------------------ 87 | 1 129% 329% 723% 88 | 89 | 8 147% 320% 682% 90 | ========== ====== ======== ================== 91 | 92 | Note: mileage will vary with the dataset at hand and computer architecture. 93 | 94 | Test 95 | ---- 96 | Run the unit tests using nosetest 97 | 98 | .. code-block:: bash 99 | 100 | $ cd 101 | $ python setup.py nosetests 102 | 103 | Installing on AppVeyor 104 | ---------------------- 105 | 106 | Pykdtree requires the "stdint.h" header file which is not available on certain 107 | versions of Windows or certain Windows compilers including those on the 108 | continuous integration platform AppVeyor. To get around this the header file(s) 109 | can be downloaded and placed in the correct "include" directory. This can 110 | be done by adding the `anaconda/missing-headers.ps1` script to your repository 111 | and running it the install step of `appveyor.yml`: 112 | 113 | # install missing headers that aren't included with MSVC 2008 114 | # https://github.com/omnia-md/conda-recipes/pull/524 115 | - "powershell ./appveyor/missing-headers.ps1" 116 | 117 | In addition to this, AppVeyor does not support OpenMP so this feature must be 118 | turned off by adding the following to `appveyor.yml` in the 119 | `environment` section: 120 | 121 | environment: 122 | global: 123 | # Don't build with openmp because it isn't supported in appveyor's compilers 124 | USE_OMP: "0" 125 | 126 | Changelog 127 | --------- 128 | v1.3.1 : Fix masking in the "query" method introduced in 1.3.0 129 | 130 | v1.3.0 : Keyword argument "mask" added to "query" method. OpenMP compilation now works for MS Visual Studio compiler 131 | 132 | v1.2.2 : Build process fixes 133 | 134 | v1.2.1 : Fixed OpenMP thread safety issue introduced in v1.2.0 135 | 136 | v1.2.0 : 64 and 32 bit MSVC Windows support added 137 | 138 | v1.1.1 : Same as v1.1 release due to incorrect pypi release 139 | 140 | v1.1 : Build process improvements. Add data attribute to kdtree class for scipy interface compatibility 141 | 142 | v1.0 : Switched license from GPLv3 to LGPLv3 143 | 144 | v0.3 : Avoid zipping of installed egg 145 | 146 | v0.2 : Reduced memory footprint. Can now handle single precision data internally avoiding copy conversion to double precision. Default leafsize changed from 10 to 16 as this reduces the memory footprint and makes it a cache line multiplum (negligible if any query performance observed in benchmarks). Reduced memory allocation for leaf nodes. Applied patch for building on OS X. 147 | 148 | v0.1 : Initial version. 149 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from torchvision.utils import save_image 5 | import common 6 | try: 7 | import open3d as o3d 8 | except: 9 | print("Warning: failed to import open3d, some function may not be used. ") 10 | 11 | 12 | def visualize_data(data, data_type, out_file, data2=None, info=None, c1=None, c2=None, show=False, s1=5, s2=5): 13 | r''' Visualizes the data with regard to its type. 14 | 15 | Args: 16 | data (tensor): batch of data 17 | data_type (string): data type (img, voxels or pointcloud) 18 | out_file (string): output file 19 | ''' 20 | if data_type == 'img': 21 | if data.dim() == 3: 22 | data = data.unsqueeze(0) 23 | save_image(data, out_file, nrow=4) 24 | elif data_type == 'voxels': 25 | visualize_voxels(data, out_file=out_file) 26 | elif data_type == 'pointcloud': 27 | visualize_pointcloud(data, out_file=out_file, points2=data2, info=info, c1=c1, c2=c2, show=show, s1=s1, s2=s2) 28 | # display_open3d(data) 29 | elif data_type is None or data_type == 'idx': 30 | pass 31 | else: 32 | raise ValueError('Invalid data_type "%s"' % data_type) 33 | 34 | 35 | def visualize_voxels(voxels, out_file=None, show=False): 36 | r''' Visualizes voxel data. 37 | 38 | Args: 39 | voxels (tensor): voxel data 40 | out_file (string): output file 41 | show (bool): whether the plot should be shown 42 | ''' 43 | # Use numpy 44 | voxels = np.asarray(voxels) 45 | # Create plot 46 | fig = plt.figure() 47 | ax = fig.gca(projection=Axes3D.name) 48 | voxels = voxels.transpose(2, 0, 1) 49 | ax.voxels(voxels, edgecolor='k') 50 | ax.set_xlabel('Z') 51 | ax.set_ylabel('X') 52 | ax.set_zlabel('Y') 53 | ax.view_init(elev=30, azim=45) 54 | if out_file is not None: 55 | plt.savefig(out_file) 56 | if show: 57 | plt.show() 58 | plt.close(fig) 59 | 60 | 61 | def display_open3d(template, source=None, transformed_source=None): 62 | to_vis = [] 63 | template_ = o3d.geometry.PointCloud() 64 | template_.points = o3d.utility.Vector3dVector(template) 65 | template_.paint_uniform_color([1, 0, 0]) 66 | to_vis.append(template_) 67 | if source is not None: 68 | source_ = o3d.geometry.PointCloud() 69 | source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0])) 70 | source_.paint_uniform_color([0, 1, 0]) 71 | to_vis.append(source_) 72 | if transformed_source is not None: 73 | transformed_source_ = o3d.geometry.PointCloud() 74 | transformed_source_.points = o3d.utility.Vector3dVector(transformed_source) 75 | transformed_source_.paint_uniform_color([0, 0, 1]) 76 | to_vis.append(transformed_source_) 77 | o3d.visualization.draw_geometries(to_vis) 78 | # o3d.visualization.draw_geometries([template_, source_, transformed_source_]) 79 | 80 | def hat(v): 81 | mat = np.array([[0, -v[2], v[1]], 82 | [v[2], 0, -v[0]], 83 | [-v[1], v[0], 0]]) 84 | return mat 85 | def visualize_feat_as_vec_field(points, feature, idx=list(range(10)), rotmat=None, 86 | out_file=None, show=False, size=7): 87 | for ii in idx: 88 | out_1_feat_samp = feature[ii] # 3-vector 89 | mat_1 = hat(out_1_feat_samp) 90 | field_1 = points.dot(mat_1.T) 91 | visualize_pointcloud(points, field_1, out_file, show, s1=size) 92 | 93 | if rotmat is not None: 94 | field_1 = field_1.dot(rotmat) 95 | points = points.dot(rotmat) 96 | visualize_pointcloud(points, field_1, out_file, show, s1=size) 97 | return 98 | 99 | def visualize_pointcloud(points, normals=None, 100 | out_file=None, show=False, 101 | points2=None, info=None, 102 | c1=None, c2=None, cm1='viridis', cm2='viridis', s1=5, s2=5): 103 | r''' Visualizes point cloud data. 104 | 105 | Args: 106 | points (tensor): point data 107 | normals (tensor): normal data (if existing) 108 | out_file (string): output file 109 | show (bool): whether the plot should be shown 110 | ''' 111 | # Use numpy 112 | points = np.asarray(points) 113 | # Create plot 114 | fig = plt.figure() 115 | ax = fig.gca(projection=Axes3D.name) 116 | if c1 is not None: 117 | cmap1 = plt.get_cmap(cm1) # viridis, magma 118 | ax.scatter(points[:, 2], points[:, 0], points[:, 1], s=s1, c=c1, cmap=cmap1) 119 | else: 120 | ax.scatter(points[:, 2], points[:, 0], points[:, 1]) 121 | if points2 is not None: 122 | if c2 is not None: 123 | cmap2 = plt.get_cmap(cm2) 124 | ax.scatter(points2[:, 2], points2[:, 0], points2[:, 1], s=s2, c=c2, cmap=cmap2, marker='^') 125 | else: 126 | ax.scatter(points2[:, 2], points2[:, 0], points2[:, 1], 'r') 127 | if normals is not None: 128 | ax.quiver( 129 | points[:, 2], points[:, 0], points[:, 1], 130 | normals[:, 2], normals[:, 0], normals[:, 1], 131 | length=0.8, color='gray', linewidth=0.8 132 | ) 133 | ax.set_xlabel('Z') 134 | ax.set_ylabel('X') 135 | ax.set_zlabel('Y') 136 | ax.set_xlim(-0.5, 0.5) 137 | ax.set_ylim(-0.5, 0.5) 138 | ax.set_zlim(-0.5, 0.5) 139 | ax.view_init(elev=30, azim=45) 140 | if info is not None: 141 | plt.title("{}".format(info)) 142 | 143 | if out_file is not None: 144 | plt.savefig(out_file) 145 | if show: 146 | plt.show() 147 | plt.close(fig) 148 | 149 | 150 | def visualise_projection( 151 | self, points, world_mat, camera_mat, img, output_file='out.png'): 152 | r''' Visualizes the transformation and projection to image plane. 153 | 154 | The first points of the batch are transformed and projected to the 155 | respective image. After performing the relevant transformations, the 156 | visualization is saved in the provided output_file path. 157 | 158 | Arguments: 159 | points (tensor): batch of point cloud points 160 | world_mat (tensor): batch of matrices to rotate pc to camera-based 161 | coordinates 162 | camera_mat (tensor): batch of camera matrices to project to 2D image 163 | plane 164 | img (tensor): tensor of batch GT image files 165 | output_file (string): where the output should be saved 166 | ''' 167 | points_transformed = common.transform_points(points, world_mat) 168 | points_img = common.project_to_camera(points_transformed, camera_mat) 169 | pimg2 = points_img[0].detach().cpu().numpy() 170 | image = img[0].cpu().numpy() 171 | plt.imshow(image.transpose(1, 2, 0)) 172 | plt.plot( 173 | (pimg2[:, 0] + 1)*image.shape[1]/2, 174 | (pimg2[:, 1] + 1) * image.shape[2]/2, 'x') 175 | plt.savefig(output_file) 176 | -------------------------------------------------------------------------------- /fields.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils.binvox_rw as binvox_rw 4 | import logging 5 | 6 | class Field(object): 7 | ''' Data fields class. 8 | ''' 9 | 10 | def load(self, data_path): 11 | ''' Loads a data point. 12 | 13 | Args: 14 | data_path (str): path to data file 15 | idx (int): index of data point 16 | category (int): index of category 17 | ''' 18 | raise NotImplementedError 19 | 20 | def check_complete(self, files): 21 | ''' Checks if set is complete. 22 | 23 | Args: 24 | files: files 25 | ''' 26 | raise NotImplementedError 27 | 28 | class PointsField(Field): 29 | ''' Point Field. 30 | 31 | It provides the field to load point data. This is used for the points 32 | randomly sampled in the bounding volume of the 3D shape. 33 | 34 | Args: 35 | file_name (str): file name 36 | transform (list): list of transformations which will be applied to the 37 | points tensor 38 | with_transforms (bool): whether scaling and rotation data should be 39 | provided 40 | 41 | ''' 42 | def __init__(self, file_name, transform=None, unpackbits=False): 43 | self.file_name = file_name 44 | self.transform = transform 45 | self.unpackbits = unpackbits 46 | 47 | def load(self, model_path): 48 | 49 | # Load data 50 | file_path = os.path.join(model_path, self.file_name) 51 | points_dict = np.load(file_path) 52 | 53 | # Points 54 | points = points_dict['points'] 55 | # Break symmetry if given in float16: 56 | if points.dtype == np.float16: 57 | points = points.astype(np.float32) 58 | points += 1e-4 * np.random.randn(*points.shape) 59 | else: 60 | points = points.astype(np.float32) 61 | 62 | # Occupancies 63 | occupancies = points_dict['occupancies'] 64 | if self.unpackbits: 65 | occupancies = np.unpackbits(occupancies)[:points.shape[0]] 66 | occupancies = occupancies.astype(np.float32) 67 | 68 | # Output dict 69 | data = { 70 | None: points, 71 | 'occ': occupancies, 72 | } 73 | 74 | if self.transform is not None: 75 | data = self.transform(data) 76 | 77 | return data 78 | 79 | class VoxelsField(Field): 80 | ''' Voxel field class. 81 | 82 | It provides the class used for voxel-based data. 83 | 84 | Args: 85 | file_name (str): file name 86 | transform (list): list of transformations applied to data points 87 | ''' 88 | def __init__(self, file_name, transform=None): 89 | self.file_name = file_name 90 | self.transform = transform 91 | 92 | def load(self, model_path): 93 | 94 | file_path = os.path.join(model_path, self.file_name) 95 | 96 | with open(file_path, 'rb') as f: 97 | voxels = binvox_rw.read_as_3d_array(f) 98 | voxels = voxels.data.astype(np.float32) 99 | 100 | if self.transform is not None: 101 | voxels = self.transform(voxels) 102 | 103 | return voxels 104 | 105 | def check_complete(self, files): 106 | 107 | complete = (self.file_name in files) 108 | return complete 109 | 110 | class PointCloudField(Field): 111 | ''' Point cloud field. 112 | 113 | It provides the field used for point cloud data. These are the points 114 | randomly sampled on the mesh. 115 | 116 | Args: 117 | file_name (str): file name 118 | transform (list): list of transformations applied to data points 119 | with_transforms (bool): whether scaling and rotation dat should be 120 | provided 121 | ''' 122 | def __init__(self, file_name, transform=None): 123 | self.file_name = file_name 124 | self.transform = transform 125 | 126 | def load_dict(self, pointcloud_dict): 127 | 128 | points = pointcloud_dict['points'].astype(np.float32) 129 | normals = pointcloud_dict['normals'].astype(np.float32) 130 | 131 | # print("points.shape", points.shape) # 100000, 3 132 | data = { 133 | None: points, 134 | 'normals': normals, 135 | } 136 | 137 | if self.transform is not None: 138 | data = self.transform(data) 139 | 140 | return data 141 | 142 | def load_array(self, pointcloud_array): 143 | data = {None: pointcloud_array} 144 | 145 | if self.transform is not None: 146 | data = self.transform(data) 147 | 148 | return data 149 | 150 | def load(self, model_path): 151 | 152 | file_path = os.path.join(model_path, self.file_name) 153 | 154 | pointcloud_dict = np.load(file_path) 155 | 156 | if isinstance(pointcloud_dict, np.lib.npyio.NpzFile): 157 | data = self.load_dict(pointcloud_dict) 158 | elif isinstance(pointcloud_dict, np.ndarray): 159 | data = self.load_array(pointcloud_dict) 160 | else: 161 | raise ValueError('pointcloud file content {} unexpected: {}'.format(type(pointcloud_dict), file_path)) 162 | 163 | return data 164 | 165 | def check_complete(self, files): 166 | 167 | complete = (self.file_name in files) 168 | return complete 169 | 170 | class RotationField(Field): 171 | '''It provides the field used for a rotation transformation. 172 | When benchmarking registration performance, 173 | it is useful to have fixed initial transformations instead of random ones. 174 | ''' 175 | def __init__(self, file_name) -> None: 176 | super().__init__() 177 | self.file_name = file_name 178 | 179 | def load(self, model_path): 180 | 181 | file_path = os.path.join(model_path, self.file_name) 182 | data = np.load(file_path) 183 | assert 'T' in data and 'deg' in data, data 184 | data_out = { 185 | None: data['T'], 186 | 'deg': data['deg'], 187 | } 188 | 189 | return data_out 190 | 191 | class TransformationField(Field): 192 | '''It provides the field used for a rigid body transformation. 193 | When benchmarking registration performance, 194 | it is useful to have fixed initial transformations instead of random ones. 195 | ''' 196 | def __init__(self, file_name) -> None: 197 | super().__init__() 198 | self.file_name = file_name 199 | 200 | def load(self, model_path): 201 | 202 | file_path = os.path.join(model_path, self.file_name) 203 | T = np.load(file_path) 204 | return T 205 | 206 | class IndexField(Field): 207 | ''' Basic index field.''' 208 | 209 | def check_complete(self, files): 210 | 211 | return True 212 | 213 | class CategoryField(Field): 214 | ''' Basic category field.''' 215 | 216 | def check_complete(self, files): 217 | ''' Check if field is complete. 218 | 219 | Args: 220 | files: files 221 | ''' 222 | return True -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import so3, se3 4 | import torch 5 | import random 6 | 7 | def apply_transformation(T, pts): 8 | '''rotmat: ?*4*4, pts: ?*N*3''' 9 | assert pts.shape[-1] == 3, pts.shape 10 | if pts.ndim == 1: 11 | pts = pts.unsqueeze(0) 12 | assert pts.ndim == T.ndim, "{} {}".format(pts.shape, T.shape) 13 | 14 | T = T.to(pts) 15 | pts_trans = se3.transform(T, pts.transpose(-1, -2)).transpose(-1, -2) 16 | return pts_trans 17 | 18 | # def apply_rot(rotmat, pts): 19 | # '''rotmat: ?*3*3, pts: N*3''' 20 | # assert pts.ndim == 2 and pts.shape[1] == 3, pts.shape 21 | # if rotmat.ndim == 2: 22 | # rotmat = rotmat.unsqueeze(0) 23 | # else: 24 | # assert rotmat.shape[0] == pts.shape[0] 25 | # rotmat = rotmat.to(pts) 26 | # pts_rot = so3.transform(rotmat, pts) # [1 or N,3,3] x [N,3] -> [N,3] 27 | # return pts_rot 28 | 29 | def apply_rot(rotmat, pts): 30 | '''rotmat: ?*3*3, pts: ?*N*3''' 31 | assert pts.shape[-1] == 3, pts.shape 32 | if pts.ndim == 1: 33 | pts = pts.unsqueeze(0) 34 | assert pts.ndim == rotmat.ndim, "{} {}".format(pts.shape, rotmat.shape) 35 | 36 | rotmat = rotmat.to(pts) 37 | pts_rot = so3.transform(rotmat, pts.transpose(-1, -2)).transpose(-1, -2) 38 | return pts_rot 39 | 40 | def gen_randrot(mag_max=None, mag_random=True): 41 | # tensor: [N, 3] 42 | mag_max = 180 if mag_max is None else mag_max 43 | amp = torch.rand(1) if mag_random else 1.0 44 | deg = amp * mag_max 45 | w = torch.randn(1, 3) 46 | w = w / w.norm(p=2, dim=1, keepdim=True) * deg * np.pi / 180 47 | 48 | g = so3.exp(w) # [1, 3, 3] 49 | g = g.squeeze(0) # [3, 3] 50 | return g, deg 51 | 52 | # g = so3.exp(w).to(tensor) # [1, 3, 3] 53 | # p1 = so3.transform(g, tensor) # [1, 3, 3] x [N, 3] -> [N, 3] 54 | # return p1 55 | 56 | def totensor_inplace(data): 57 | for key, value in data.items(): 58 | if isinstance(value, np.ndarray): 59 | data[key] = torch.from_numpy(value) 60 | return data 61 | 62 | class CentralizeBatchIP(object): 63 | '''In-place centralization transform for a batch of PairedDataset data''' 64 | def __init__(self) -> None: 65 | super().__init__() 66 | # self.device = device 67 | 68 | def __call__(self, data): 69 | # device = self.device 70 | inputs = data['inputs'] 71 | inputs_mean = inputs.mean(dim=1, keepdim=True) 72 | data['inputs'] = inputs - inputs_mean 73 | 74 | if 'points' in data: 75 | data['points'] = data['points'] - inputs_mean.to(data['points']) 76 | 77 | return data 78 | 79 | ### Transforms for DualDataset 80 | class CentralizePairBatchIP(object): 81 | '''In-place centralization transform for a batch of PairedDataset data''' 82 | def __init__(self) -> None: 83 | super().__init__() 84 | # self.device = device 85 | 86 | def __call__(self, data): 87 | # device = self.device 88 | inputs = data['inputs'] 89 | inputs_2 = data['inputs_2'] 90 | inputs_mean = inputs.mean(dim=1, keepdim=True) 91 | data['inputs'] = inputs - inputs_mean 92 | inputs_2_mean = inputs_2.mean(dim=1, keepdim=True) 93 | data['inputs_2'] = inputs_2 - inputs_2_mean 94 | 95 | if 'points' in data: 96 | data['points'] = data['points'] - inputs_mean.to(data['points']) 97 | data['points_2'] = data['points_2'] - inputs_2_mean.to(data['points']) 98 | 99 | return data 100 | 101 | class RotateBatchIP(object): 102 | def __init__(self) -> None: 103 | super().__init__() 104 | # self.device = device 105 | 106 | def __call__(self, data): 107 | # device = self.device 108 | 109 | data['inputs'] = apply_rot(data['T'], data['inputs']) 110 | if 'points' in data: 111 | data['points'] = apply_rot(data['T'], data['points']) 112 | return 113 | 114 | class RotatePairBatchIP(object): 115 | def __init__(self) -> None: 116 | super().__init__() 117 | # self.device = device 118 | 119 | def __call__(self, data): 120 | # device = self.device 121 | 122 | data['inputs_2'] = apply_rot(data['T21'], data['inputs_2']) 123 | if 'points' in data: 124 | data['points_2'] = apply_rot(data['T21'], data['points_2']) 125 | return 126 | 127 | def noise_pts(pts, stddev): 128 | noise = stddev * torch.randn(*pts.shape, dtype=pts.dtype, device=pts.device) 129 | pts = pts + noise 130 | return pts 131 | 132 | class NoisePairBatchIP(object): 133 | def __init__(self, stddev, device=None) -> None: 134 | super().__init__() 135 | self.stddev = stddev 136 | self.device = device 137 | 138 | def __call__(self, data): 139 | device = self.device if self.device is not None else data['inputs'].device 140 | 141 | inputs = data.get('inputs').to(device) 142 | inputs_2 = data.get('inputs_2', inputs.clone()).to(device) 143 | 144 | inputs = noise_pts(inputs, self.stddev) 145 | inputs_2 = noise_pts(inputs_2, self.stddev) 146 | 147 | data['inputs'] = inputs 148 | data['inputs_2'] = inputs_2 149 | return 150 | 151 | def subsample(pts, n): 152 | n_pts_in = pts.shape[1] 153 | 154 | if n <= 0: 155 | return pts 156 | elif n < n_pts_in: 157 | idx_sample = torch.randperm(n_pts_in)[:n] 158 | pts = pts[:, idx_sample] 159 | return pts 160 | elif n == n_pts_in: 161 | return pts 162 | else: 163 | raise ValueError("n=%d is more than the size of the point cloud %d"%(n, n_pts_in)) 164 | 165 | class SubSampleBatchIP(object): 166 | '''In-place subsampling transform for a batch of PairedDataset data''' 167 | def __init__(self, n2_min, n2_max, device) -> None: 168 | super().__init__() 169 | self.n2_min = n2_min 170 | self.n2_max = n2_max 171 | self.device = device 172 | 173 | def __call__(self, data): 174 | device = self.device 175 | 176 | inputs = data.get('inputs') 177 | n2 = random.randint(self.n2_min, self.n2_max) 178 | inputs = subsample(inputs, n2) 179 | 180 | data['inputs'] = inputs.to(device) 181 | return data 182 | 183 | class SubSamplePairBatchIP(object): 184 | '''In-place subsampling transform for a batch of PairedDataset data''' 185 | def __init__(self, n1, n2_min, n2_max, device) -> None: 186 | super().__init__() 187 | self.n1 = n1 188 | self.n2_min = n2_min 189 | self.n2_max = n2_max 190 | 191 | self.device = device 192 | 193 | def __call__(self, data): 194 | device = self.device 195 | 196 | inputs = data.get('inputs') 197 | inputs_2 = data.get('inputs_2') #, inputs.clone()) 198 | 199 | assert inputs.ndim == 3 and inputs.shape[1] == inputs_2.shape[1], "{}, {}".format(inputs.shape, inputs_2.shape) 200 | inputs = subsample(inputs, self.n1) 201 | 202 | # if 'inputs_2' not in data: 203 | # n2 = random.randint(self.n2_min, self.n2_max) 204 | # inputs_2 = subsample(inputs_2, n2) 205 | # else: 206 | # inputs_2 = subsample(inputs_2, self.n1) # for 7scenes, avoid too few points 207 | n2 = random.randint(self.n2_min, self.n2_max) 208 | inputs_2 = subsample(inputs_2, n2) 209 | 210 | data['inputs'] = inputs.to(device) 211 | data['inputs_2'] = inputs_2.to(device) 212 | return data -------------------------------------------------------------------------------- /utils/libkdtree/LICENSE.txt: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007, 2015 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /utils/libmcubes/pywrapper.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "pywrapper.h" 3 | 4 | #include "marchingcubes.h" 5 | 6 | #include 7 | 8 | struct PythonToCFunc 9 | { 10 | PyObject* func; 11 | PythonToCFunc(PyObject* func) {this->func = func;} 12 | double operator()(double x, double y, double z) 13 | { 14 | PyObject* res = PyObject_CallFunction(func, "(d,d,d)", x, y, z); // py::extract(func(x,y,z)); 15 | if(res == NULL) 16 | return 0.0; 17 | 18 | double result = PyFloat_AsDouble(res); 19 | Py_DECREF(res); 20 | return result; 21 | } 22 | }; 23 | 24 | PyObject* marching_cubes_func(PyObject* lower, PyObject* upper, 25 | int numx, int numy, int numz, PyObject* f, double isovalue) 26 | { 27 | std::vector vertices; 28 | std::vector polygons; 29 | 30 | // Copy the lower and upper coordinates to a C array. 31 | double lower_[3]; 32 | double upper_[3]; 33 | for(int i=0; i<3; ++i) 34 | { 35 | PyObject* l = PySequence_GetItem(lower, i); 36 | if(l == NULL) 37 | throw std::runtime_error("error"); 38 | PyObject* u = PySequence_GetItem(upper, i); 39 | if(u == NULL) 40 | { 41 | Py_DECREF(l); 42 | throw std::runtime_error("error"); 43 | } 44 | 45 | lower_[i] = PyFloat_AsDouble(l); 46 | upper_[i] = PyFloat_AsDouble(u); 47 | 48 | Py_DECREF(l); 49 | Py_DECREF(u); 50 | if(lower_[i]==-1.0 || upper_[i]==-1.0) 51 | { 52 | if(PyErr_Occurred()) 53 | throw std::runtime_error("error"); 54 | } 55 | } 56 | 57 | // Marching cubes. 58 | mc::marching_cubes(lower_, upper_, numx, numy, numz, PythonToCFunc(f), isovalue, vertices, polygons); 59 | 60 | // Copy the result to two Python ndarrays. 61 | npy_intp size_vertices = vertices.size(); 62 | npy_intp size_polygons = polygons.size(); 63 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 64 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 65 | 66 | std::vector::const_iterator it = vertices.begin(); 67 | for(int i=0; it!=vertices.end(); ++i, ++it) 68 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 69 | std::vector::const_iterator it2 = polygons.begin(); 70 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 71 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 72 | 73 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 74 | Py_XDECREF(verticesarr); 75 | Py_XDECREF(polygonsarr); 76 | return res; 77 | } 78 | 79 | struct PyArrayToCFunc 80 | { 81 | PyArrayObject* arr; 82 | PyArrayToCFunc(PyArrayObject* arr) {this->arr = arr;} 83 | double operator()(int x, int y, int z) 84 | { 85 | npy_intp c[3] = {x,y,z}; 86 | return PyArray_SafeGet(arr, c); 87 | } 88 | }; 89 | 90 | PyObject* marching_cubes(PyArrayObject* arr, double isovalue) 91 | { 92 | if(PyArray_NDIM(arr) != 3) 93 | throw std::runtime_error("Only three-dimensional arrays are supported."); 94 | 95 | // Prepare data. 96 | npy_intp* shape = PyArray_DIMS(arr); 97 | double lower[3] = {0,0,0}; 98 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 99 | long numx = upper[0] - lower[0] + 1; 100 | long numy = upper[1] - lower[1] + 1; 101 | long numz = upper[2] - lower[2] + 1; 102 | std::vector vertices; 103 | std::vector polygons; 104 | 105 | // Marching cubes. 106 | mc::marching_cubes(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 107 | vertices, polygons); 108 | 109 | // Copy the result to two Python ndarrays. 110 | npy_intp size_vertices = vertices.size(); 111 | npy_intp size_polygons = polygons.size(); 112 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 113 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 114 | 115 | std::vector::const_iterator it = vertices.begin(); 116 | for(int i=0; it!=vertices.end(); ++i, ++it) 117 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 118 | std::vector::const_iterator it2 = polygons.begin(); 119 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 120 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 121 | 122 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 123 | Py_XDECREF(verticesarr); 124 | Py_XDECREF(polygonsarr); 125 | 126 | return res; 127 | } 128 | 129 | PyObject* marching_cubes2(PyArrayObject* arr, double isovalue) 130 | { 131 | if(PyArray_NDIM(arr) != 3) 132 | throw std::runtime_error("Only three-dimensional arrays are supported."); 133 | 134 | // Prepare data. 135 | npy_intp* shape = PyArray_DIMS(arr); 136 | double lower[3] = {0,0,0}; 137 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 138 | long numx = upper[0] - lower[0] + 1; 139 | long numy = upper[1] - lower[1] + 1; 140 | long numz = upper[2] - lower[2] + 1; 141 | std::vector vertices; 142 | std::vector polygons; 143 | 144 | // Marching cubes. 145 | mc::marching_cubes2(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 146 | vertices, polygons); 147 | 148 | // Copy the result to two Python ndarrays. 149 | npy_intp size_vertices = vertices.size(); 150 | npy_intp size_polygons = polygons.size(); 151 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 152 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 153 | 154 | std::vector::const_iterator it = vertices.begin(); 155 | for(int i=0; it!=vertices.end(); ++i, ++it) 156 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 157 | std::vector::const_iterator it2 = polygons.begin(); 158 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 159 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 160 | 161 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 162 | Py_XDECREF(verticesarr); 163 | Py_XDECREF(polygonsarr); 164 | 165 | return res; 166 | } 167 | 168 | PyObject* marching_cubes3(PyArrayObject* arr, double isovalue) 169 | { 170 | if(PyArray_NDIM(arr) != 3) 171 | throw std::runtime_error("Only three-dimensional arrays are supported."); 172 | 173 | // Prepare data. 174 | npy_intp* shape = PyArray_DIMS(arr); 175 | double lower[3] = {0,0,0}; 176 | double upper[3] = {shape[0]-1, shape[1]-1, shape[2]-1}; 177 | long numx = upper[0] - lower[0] + 1; 178 | long numy = upper[1] - lower[1] + 1; 179 | long numz = upper[2] - lower[2] + 1; 180 | std::vector vertices; 181 | std::vector polygons; 182 | 183 | // Marching cubes. 184 | mc::marching_cubes3(lower, upper, numx, numy, numz, PyArrayToCFunc(arr), isovalue, 185 | vertices, polygons); 186 | 187 | // Copy the result to two Python ndarrays. 188 | npy_intp size_vertices = vertices.size(); 189 | npy_intp size_polygons = polygons.size(); 190 | PyArrayObject* verticesarr = reinterpret_cast(PyArray_SimpleNew(1, &size_vertices, PyArray_DOUBLE)); 191 | PyArrayObject* polygonsarr = reinterpret_cast(PyArray_SimpleNew(1, &size_polygons, PyArray_ULONG)); 192 | 193 | std::vector::const_iterator it = vertices.begin(); 194 | for(int i=0; it!=vertices.end(); ++i, ++it) 195 | *reinterpret_cast(PyArray_GETPTR1(verticesarr, i)) = *it; 196 | std::vector::const_iterator it2 = polygons.begin(); 197 | for(int i=0; it2!=polygons.end(); ++i, ++it2) 198 | *reinterpret_cast(PyArray_GETPTR1(polygonsarr, i)) = *it2; 199 | 200 | PyObject* res = Py_BuildValue("(O,O)", verticesarr, polygonsarr); 201 | Py_XDECREF(verticesarr); 202 | Py_XDECREF(polygonsarr); 203 | 204 | return res; 205 | } -------------------------------------------------------------------------------- /encoders/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers import ResnetBlockFC, VNResnetBlockFC 4 | from layers_vn import VNLinear, VNLeakyReLU, VNBatchNorm, VNLinearLeakyReLU, VNMaxPool, get_graph_feature_cross, mean_pool, get_graph_feature_lrf 5 | 6 | import torch.autograd.profiler as profiler 7 | 8 | 9 | def maxpool(x, dim=-1, keepdim=False): 10 | out, _ = x.max(dim=dim, keepdim=keepdim) 11 | return out 12 | 13 | 14 | class SimplePointnet(nn.Module): 15 | ''' PointNet-based encoder network. 16 | 17 | Args: 18 | c_dim (int): dimension of latent code c 19 | dim (int): input points dimension 20 | hidden_dim (int): hidden dimension of the network 21 | ''' 22 | 23 | def __init__(self, c_dim=128, dim=3, hidden_dim=128): 24 | super().__init__() 25 | self.c_dim = c_dim 26 | 27 | self.fc_pos = nn.Linear(dim, 2*hidden_dim) 28 | self.fc_0 = nn.Linear(2*hidden_dim, hidden_dim) 29 | self.fc_1 = nn.Linear(2*hidden_dim, hidden_dim) 30 | self.fc_2 = nn.Linear(2*hidden_dim, hidden_dim) 31 | self.fc_3 = nn.Linear(2*hidden_dim, hidden_dim) 32 | self.fc_c = nn.Linear(hidden_dim, c_dim) 33 | 34 | self.actvn = nn.ReLU() 35 | self.pool = maxpool 36 | 37 | def forward(self, p): 38 | batch_size, T, D = p.size() 39 | 40 | # output size: B x T X F 41 | net = self.fc_pos(p) 42 | net = self.fc_0(self.actvn(net)) 43 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 44 | net = torch.cat([net, pooled], dim=2) 45 | 46 | net = self.fc_1(self.actvn(net)) 47 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 48 | net = torch.cat([net, pooled], dim=2) 49 | 50 | net = self.fc_2(self.actvn(net)) 51 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 52 | net = torch.cat([net, pooled], dim=2) 53 | 54 | net = self.fc_3(self.actvn(net)) 55 | 56 | # Recude to B x F 57 | net = self.pool(net, dim=1) 58 | 59 | c = self.fc_c(self.actvn(net)) 60 | 61 | return c 62 | 63 | 64 | class ResnetPointnet(nn.Module): 65 | ''' PointNet-based encoder network with ResNet blocks. 66 | 67 | Args: 68 | c_dim (int): dimension of latent code c 69 | dim (int): input points dimension 70 | hidden_dim (int): hidden dimension of the network 71 | ''' 72 | 73 | def __init__(self, c_dim=128, dim=3, hidden_dim=128): 74 | super().__init__() 75 | self.c_dim = c_dim 76 | 77 | self.fc_pos = nn.Linear(dim, 2*hidden_dim) 78 | self.block_0 = ResnetBlockFC(2*hidden_dim, hidden_dim) 79 | self.block_1 = ResnetBlockFC(2*hidden_dim, hidden_dim) 80 | self.block_2 = ResnetBlockFC(2*hidden_dim, hidden_dim) 81 | self.block_3 = ResnetBlockFC(2*hidden_dim, hidden_dim) 82 | self.block_4 = ResnetBlockFC(2*hidden_dim, hidden_dim) 83 | self.fc_c = nn.Linear(hidden_dim, c_dim) 84 | 85 | self.actvn = nn.ReLU() 86 | self.pool = maxpool 87 | 88 | def forward(self, p): 89 | batch_size, T, D = p.size() 90 | 91 | # output size: B x T X F 92 | net = self.fc_pos(p) 93 | net = self.block_0(net) 94 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 95 | net = torch.cat([net, pooled], dim=2) 96 | 97 | net = self.block_1(net) 98 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 99 | net = torch.cat([net, pooled], dim=2) 100 | 101 | net = self.block_2(net) 102 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 103 | net = torch.cat([net, pooled], dim=2) 104 | 105 | net = self.block_3(net) 106 | pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 107 | net = torch.cat([net, pooled], dim=2) 108 | 109 | net = self.block_4(net) 110 | 111 | # Recude to B x F 112 | net = self.pool(net, dim=1) 113 | 114 | c = self.fc_c(self.actvn(net)) 115 | 116 | return c 117 | class VNResnetPointnet(nn.Module): 118 | ''' PointNet-based encoder network with ResNet blocks. 119 | 120 | Args: 121 | c_dim (int): dimension of latent code c 122 | dim (int): input points dimension 123 | hidden_dim (int): hidden dimension of the network 124 | ''' 125 | 126 | def __init__(self, c_dim=128, dim=3, hidden_dim=128, pooling='mean', ball_radius=0, 127 | init_lrf=False, lrf_cross=False, n_knn=20, global_relu=False): 128 | super().__init__() 129 | self.c_dim = c_dim 130 | 131 | self.init_lrf = init_lrf 132 | self.lrf_cross = lrf_cross 133 | self.global_relu = global_relu 134 | if self.init_lrf: 135 | if lrf_cross: 136 | self.fc_pos = VNLinear(3, 2*hidden_dim // 3) 137 | else: 138 | self.fc_pos = VNLinear(4, 2*hidden_dim // 3) 139 | else: 140 | # self.fc_pos = nn.Linear(dim, 2*hidden_dim) 141 | self.fc_pos = VNLinear(3, 2*hidden_dim // 3) 142 | 143 | self.block_0 = VNResnetBlockFC(2*hidden_dim//3, hidden_dim//3, global_relu=global_relu) 144 | self.block_1 = VNResnetBlockFC(2*hidden_dim//3, hidden_dim//3, global_relu=global_relu) 145 | self.block_2 = VNResnetBlockFC(2*hidden_dim//3, hidden_dim//3, global_relu=global_relu) 146 | self.block_3 = VNResnetBlockFC(2*hidden_dim//3, hidden_dim//3, global_relu=global_relu) 147 | self.block_4 = VNResnetBlockFC(2*hidden_dim//3, hidden_dim//3, global_relu=global_relu) 148 | # self.fc_c = nn.Linear(hidden_dim, c_dim) 149 | self.fc_c = VNLinear(hidden_dim//3, c_dim//3) 150 | 151 | # self.actvn = nn.ReLU() 152 | self.actvn = VNLeakyReLU(hidden_dim//3, share_nonlinearity=False, negative_slope=0, global_relu=global_relu) 153 | 154 | # self.pool = maxpool 155 | self.pooling = pooling 156 | if self.pooling == 'max': 157 | self.pool = VNMaxPool(2*hidden_dim//3) 158 | self.pool_0 = VNMaxPool(hidden_dim//3) 159 | self.pool_1 = VNMaxPool(hidden_dim//3) 160 | self.pool_2 = VNMaxPool(hidden_dim//3) 161 | self.pool_3 = VNMaxPool(hidden_dim//3) 162 | self.pool_4 = VNMaxPool(hidden_dim//3) 163 | 164 | elif self.pooling == 'mean': 165 | self.pool = mean_pool 166 | self.pool_0 = mean_pool 167 | self.pool_1 = mean_pool 168 | self.pool_2 = mean_pool 169 | self.pool_3 = mean_pool 170 | self.pool_4 = mean_pool 171 | 172 | self.n_knn = n_knn 173 | self.ball_radius = ball_radius 174 | 175 | def forward(self, p): 176 | batch_size, T, D = p.size() 177 | 178 | p = p.transpose(1, 2) # B*dimension*npoints 179 | p = p.unsqueeze(1) 180 | # print("0", x.shape) # B, 1, 3, N 181 | if self.init_lrf: 182 | 183 | # with profiler.profile() as prof: 184 | with torch.no_grad(): 185 | feat = get_graph_feature_lrf(p, k=self.n_knn, ball_radius=self.ball_radius, lrf_cross=self.lrf_cross) # B*4*3*N 186 | # print(prof.key_averages().table(row_limit=5)) 187 | 188 | net = self.fc_pos(feat) # B, F, 3, N 189 | else: 190 | feat = get_graph_feature_cross(p, k=self.n_knn, ball_radius=self.ball_radius) 191 | # print("1", feat.shape) # B, 3, 3, N, knn 192 | x = self.fc_pos(feat) 193 | # print("2", x.shape) # B, F, 3, N, knn 194 | net = self.pool(x) 195 | # print("3", x.shape) # B, F, 3, N 196 | 197 | net = self.block_0(net) 198 | # print("block_0 net.shape", net.shape) # [B, 342(hidden_dim//3), 3, 1024(N)] 199 | # pooled = self.pool(net, dim=3, keepdim=True).expand(net.size()) 200 | pooled = self.pool_0(net, keepdim=True) 201 | # print("pooled.shape", pooled.shape) 202 | pooled = pooled.expand(net.size()) 203 | net = torch.cat([net, pooled], dim=1) 204 | 205 | net = self.block_1(net) 206 | # print("block_1 net.shape", net.shape) 207 | # pooled = self.pool(net, dim=3, keepdim=True).expand(net.size()) 208 | pooled = self.pool_1(net, keepdim=True).expand(net.size()) 209 | net = torch.cat([net, pooled], dim=1) 210 | 211 | net = self.block_2(net) 212 | # print("block_2 net.shape", net.shape) 213 | # pooled = self.pool(net, dim=3, keepdim=True).expand(net.size()) 214 | pooled = self.pool_2(net, keepdim=True).expand(net.size()) 215 | net = torch.cat([net, pooled], dim=1) 216 | 217 | net = self.block_3(net) 218 | # print("block_3 net.shape", net.shape) 219 | # pooled = self.pool(net, dim=3, keepdim=True).expand(net.size()) 220 | pooled = self.pool_3(net, keepdim=True).expand(net.size()) 221 | net = torch.cat([net, pooled], dim=1) 222 | 223 | net = self.block_4(net) # B*F*3*N 224 | 225 | # Reduce to B x F x 3 226 | # net = self.pool(net, dim=3) 227 | net = self.pool_4(net) 228 | # print("block_4 net.shape", net.shape) 229 | 230 | c = self.fc_c(self.actvn(net)) 231 | 232 | c = torch.flatten(c, 1) # B*F 233 | 234 | return c 235 | 236 | # # output size: B x T X F 237 | # net = self.fc_pos(p) 238 | 239 | # net = self.block_0(net) 240 | # pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 241 | # net = torch.cat([net, pooled], dim=2) 242 | 243 | # net = self.block_1(net) 244 | # pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 245 | # net = torch.cat([net, pooled], dim=2) 246 | 247 | # net = self.block_2(net) 248 | # pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 249 | # net = torch.cat([net, pooled], dim=2) 250 | 251 | # net = self.block_3(net) 252 | # pooled = self.pool(net, dim=1, keepdim=True).expand(net.size()) 253 | # net = torch.cat([net, pooled], dim=2) 254 | 255 | # net = self.block_4(net) 256 | 257 | # # Recude to B x F 258 | # net = self.pool(net, dim=1) 259 | 260 | # c = self.fc_c(self.actvn(net)) 261 | 262 | # return c 263 | -------------------------------------------------------------------------------- /utils/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | # Modified by Christopher B. Choy 18 | # for python 3 support 19 | 20 | """ 21 | Binvox to Numpy and back. 22 | 23 | 24 | >>> import numpy as np 25 | >>> import binvox_rw 26 | >>> with open('chair.binvox', 'rb') as f: 27 | ... m1 = binvox_rw.read_as_3d_array(f) 28 | ... 29 | >>> m1.dims 30 | [32, 32, 32] 31 | >>> m1.scale 32 | 41.133000000000003 33 | >>> m1.translate 34 | [0.0, 0.0, 0.0] 35 | >>> with open('chair_out.binvox', 'wb') as f: 36 | ... m1.write(f) 37 | ... 38 | >>> with open('chair_out.binvox', 'rb') as f: 39 | ... m2 = binvox_rw.read_as_3d_array(f) 40 | ... 41 | >>> m1.dims==m2.dims 42 | True 43 | >>> m1.scale==m2.scale 44 | True 45 | >>> m1.translate==m2.translate 46 | True 47 | >>> np.all(m1.data==m2.data) 48 | True 49 | 50 | >>> with open('chair.binvox', 'rb') as f: 51 | ... md = binvox_rw.read_as_3d_array(f) 52 | ... 53 | >>> with open('chair.binvox', 'rb') as f: 54 | ... ms = binvox_rw.read_as_coord_array(f) 55 | ... 56 | >>> data_ds = binvox_rw.dense_to_sparse(md.data) 57 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) 58 | >>> np.all(data_sd==md.data) 59 | True 60 | >>> # the ordering of elements returned by numpy.nonzero changes with axis 61 | >>> # ordering, so to compare for equality we first lexically sort the voxels. 62 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) 63 | True 64 | """ 65 | 66 | import numpy as np 67 | 68 | class Voxels(object): 69 | """ Holds a binvox model. 70 | data is either a three-dimensional numpy boolean array (dense representation) 71 | or a two-dimensional numpy float array (coordinate representation). 72 | 73 | dims, translate and scale are the model metadata. 74 | 75 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 76 | 77 | scale and translate relate the voxels to the original model coordinates. 78 | 79 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 80 | 81 | x_n = (i+.5)/dims[0] 82 | y_n = (j+.5)/dims[1] 83 | z_n = (k+.5)/dims[2] 84 | x = scale*x_n + translate[0] 85 | y = scale*y_n + translate[1] 86 | z = scale*z_n + translate[2] 87 | 88 | """ 89 | 90 | def __init__(self, data, dims, translate, scale, axis_order): 91 | self.data = data 92 | self.dims = dims 93 | self.translate = translate 94 | self.scale = scale 95 | assert (axis_order in ('xzy', 'xyz')) 96 | self.axis_order = axis_order 97 | 98 | def clone(self): 99 | data = self.data.copy() 100 | dims = self.dims[:] 101 | translate = self.translate[:] 102 | return Voxels(data, dims, translate, self.scale, self.axis_order) 103 | 104 | def write(self, fp): 105 | write(self, fp) 106 | 107 | def read_header(fp): 108 | """ Read binvox header. Mostly meant for internal use. 109 | """ 110 | line = fp.readline().strip() 111 | if not line.startswith(b'#binvox'): 112 | raise IOError('Not a binvox file') 113 | dims = [int(i) for i in fp.readline().strip().split(b' ')[1:]] 114 | translate = [float(i) for i in fp.readline().strip().split(b' ')[1:]] 115 | scale = [float(i) for i in fp.readline().strip().split(b' ')[1:]][0] 116 | line = fp.readline() 117 | return dims, translate, scale 118 | 119 | def read_as_3d_array(fp, fix_coords=True): 120 | """ Read binary binvox format as array. 121 | 122 | Returns the model with accompanying metadata. 123 | 124 | Voxels are stored in a three-dimensional numpy array, which is simple and 125 | direct, but may use a lot of memory for large models. (Storage requirements 126 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 127 | boolean arrays use a byte per element). 128 | 129 | Doesn't do any checks on input except for the '#binvox' line. 130 | """ 131 | dims, translate, scale = read_header(fp) 132 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 133 | # if just using reshape() on the raw data: 134 | # indexing the array as array[i,j,k], the indices map into the 135 | # coords as: 136 | # i -> x 137 | # j -> z 138 | # k -> y 139 | # if fix_coords is true, then data is rearranged so that 140 | # mapping is 141 | # i -> x 142 | # j -> y 143 | # k -> z 144 | values, counts = raw_data[::2], raw_data[1::2] 145 | data = np.repeat(values, counts).astype(np.bool) 146 | data = data.reshape(dims) 147 | if fix_coords: 148 | # xzy to xyz TODO the right thing 149 | data = np.transpose(data, (0, 2, 1)) 150 | axis_order = 'xyz' 151 | else: 152 | axis_order = 'xzy' 153 | return Voxels(data, dims, translate, scale, axis_order) 154 | 155 | 156 | def read_as_coord_array(fp, fix_coords=True): 157 | """ Read binary binvox format as coordinates. 158 | 159 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 160 | 3 x N array where N is the number of nonzero voxels. Each column 161 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 162 | of the voxel. (The odd ordering is due to the way binvox format lays out 163 | data). Note that coordinates refer to the binvox voxels, without any 164 | scaling or translation. 165 | 166 | Use this to save memory if your model is very sparse (mostly empty). 167 | 168 | Doesn't do any checks on input except for the '#binvox' line. 169 | """ 170 | dims, translate, scale = read_header(fp) 171 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 172 | 173 | values, counts = raw_data[::2], raw_data[1::2] 174 | 175 | sz = np.prod(dims) 176 | index, end_index = 0, 0 177 | end_indices = np.cumsum(counts) 178 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 179 | 180 | values = values.astype(np.bool) 181 | indices = indices[values] 182 | end_indices = end_indices[values] 183 | 184 | nz_voxels = [] 185 | for index, end_index in zip(indices, end_indices): 186 | nz_voxels.extend(range(index, end_index)) 187 | nz_voxels = np.array(nz_voxels) 188 | # TODO are these dims correct? 189 | # according to docs, 190 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 191 | 192 | x = nz_voxels / (dims[0]*dims[1]) 193 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 194 | z = zwpy / dims[0] 195 | y = zwpy % dims[0] 196 | if fix_coords: 197 | data = np.vstack((x, y, z)) 198 | axis_order = 'xyz' 199 | else: 200 | data = np.vstack((x, z, y)) 201 | axis_order = 'xzy' 202 | 203 | #return Voxels(data, dims, translate, scale, axis_order) 204 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 205 | 206 | def dense_to_sparse(voxel_data, dtype=np.int): 207 | """ From dense representation to sparse (coordinate) representation. 208 | No coordinate reordering. 209 | """ 210 | if voxel_data.ndim!=3: 211 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 212 | return np.asarray(np.nonzero(voxel_data), dtype) 213 | 214 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 215 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 216 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 217 | if np.isscalar(dims): 218 | dims = [dims]*3 219 | dims = np.atleast_2d(dims).T 220 | # truncate to integers 221 | xyz = voxel_data.astype(np.int) 222 | # discard voxels that fall outside dims 223 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 224 | xyz = xyz[:,valid_ix] 225 | out = np.zeros(dims.flatten(), dtype=dtype) 226 | out[tuple(xyz)] = True 227 | return out 228 | 229 | #def get_linear_index(x, y, z, dims): 230 | #""" Assuming xzy order. (y increasing fastest. 231 | #TODO ensure this is right when dims are not all same 232 | #""" 233 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 234 | 235 | def write(voxel_model, fp): 236 | """ Write binary binvox format. 237 | 238 | Note that when saving a model in sparse (coordinate) format, it is first 239 | converted to dense format. 240 | 241 | Doesn't check if the model is 'sane'. 242 | 243 | """ 244 | if voxel_model.data.ndim==2: 245 | # TODO avoid conversion to dense 246 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 247 | else: 248 | dense_voxel_data = voxel_model.data 249 | 250 | fp.write('#binvox 1\n') 251 | fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') 252 | fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') 253 | fp.write('scale '+str(voxel_model.scale)+'\n') 254 | fp.write('data\n') 255 | if not voxel_model.axis_order in ('xzy', 'xyz'): 256 | raise ValueError('Unsupported voxel model axis order') 257 | 258 | if voxel_model.axis_order=='xzy': 259 | voxels_flat = dense_voxel_data.flatten() 260 | elif voxel_model.axis_order=='xyz': 261 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 262 | 263 | # keep a sort of state machine for writing run length encoding 264 | state = voxels_flat[0] 265 | ctr = 0 266 | for c in voxels_flat: 267 | if c==state: 268 | ctr += 1 269 | # if ctr hits max, dump 270 | if ctr==255: 271 | fp.write(chr(state)) 272 | fp.write(chr(ctr)) 273 | ctr = 0 274 | else: 275 | # if switch state, dump 276 | fp.write(chr(state)) 277 | fp.write(chr(ctr)) 278 | state = c 279 | ctr = 1 280 | # flush out remainders 281 | if ctr > 0: 282 | fp.write(chr(state)) 283 | fp.write(chr(ctr)) 284 | 285 | if __name__ == '__main__': 286 | import doctest 287 | doctest.testmod() 288 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # import multiprocessing 2 | import torch 3 | from utils.libkdtree import KDTree 4 | import numpy as np 5 | 6 | 7 | def compute_iou(occ1, occ2): 8 | ''' Computes the Intersection over Union (IoU) value for two sets of 9 | occupancy values. 10 | 11 | Args: 12 | occ1 (tensor): first set of occupancy values 13 | occ2 (tensor): second set of occupancy values 14 | ''' 15 | occ1 = np.asarray(occ1) 16 | occ2 = np.asarray(occ2) 17 | 18 | # Put all data in second dimension 19 | # Also works for 1-dimensional data 20 | if occ1.ndim >= 2: 21 | occ1 = occ1.reshape(occ1.shape[0], -1) 22 | if occ2.ndim >= 2: 23 | occ2 = occ2.reshape(occ2.shape[0], -1) 24 | 25 | # Convert to boolean values 26 | occ1 = (occ1 >= 0.5) 27 | occ2 = (occ2 >= 0.5) 28 | 29 | # Compute IOU 30 | area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) 31 | area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) 32 | 33 | iou = (area_intersect / area_union) 34 | 35 | return iou 36 | 37 | 38 | def chamfer_distance(points1, points2, use_kdtree=True, give_id=False): 39 | ''' Returns the chamfer distance for the sets of points. 40 | 41 | Args: 42 | points1 (numpy array): first point set 43 | points2 (numpy array): second point set 44 | use_kdtree (bool): whether to use a kdtree 45 | give_id (bool): whether to return the IDs of nearest points 46 | ''' 47 | if use_kdtree: 48 | return chamfer_distance_kdtree(points1, points2, give_id=give_id) 49 | else: 50 | return chamfer_distance_naive(points1, points2) 51 | 52 | 53 | def chamfer_distance_naive(points1, points2): 54 | ''' Naive implementation of the Chamfer distance. 55 | 56 | Args: 57 | points1 (numpy array): first point set 58 | points2 (numpy array): second point set 59 | ''' 60 | assert(points1.size() == points2.size()) 61 | batch_size, T, _ = points1.size() 62 | 63 | points1 = points1.view(batch_size, T, 1, 3) 64 | points2 = points2.view(batch_size, 1, T, 3) 65 | 66 | distances = (points1 - points2).pow(2).sum(-1) 67 | 68 | chamfer1 = distances.min(dim=1)[0].mean(dim=1) 69 | chamfer2 = distances.min(dim=2)[0].mean(dim=1) 70 | 71 | chamfer = chamfer1 + chamfer2 72 | return chamfer 73 | 74 | 75 | def chamfer_distance_kdtree(points1, points2, give_id=False): 76 | ''' KD-tree based implementation of the Chamfer distance. 77 | 78 | Args: 79 | points1 (numpy array): first point set 80 | points2 (numpy array): second point set 81 | give_id (bool): whether to return the IDs of the nearest points 82 | ''' 83 | # Points have size batch_size x T x 3 84 | batch_size = points1.size(0) 85 | 86 | # First convert points to numpy 87 | points1_np = points1.detach().cpu().numpy() 88 | points2_np = points2.detach().cpu().numpy() 89 | 90 | # Get list of nearest neighbors indieces 91 | idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np) 92 | idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device) 93 | # Expands it as batch_size x 1 x 3 94 | idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1) 95 | 96 | # Get list of nearest neighbors indieces 97 | idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np) 98 | idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device) 99 | # Expands it as batch_size x T x 3 100 | idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2) 101 | 102 | # Compute nearest neighbors in points2 to points in points1 103 | # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k] 104 | points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand) 105 | 106 | # Compute nearest neighbors in points1 to points in points2 107 | # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k] 108 | points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand) 109 | 110 | # Compute chamfer distance 111 | chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1) 112 | chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1) 113 | 114 | # Take sum 115 | chamfer = chamfer1 + chamfer2 116 | 117 | # If required, also return nearest neighbors 118 | if give_id: 119 | return chamfer1, chamfer2, idx_nn_12, idx_nn_21 120 | 121 | return chamfer 122 | 123 | 124 | def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): 125 | ''' Returns the nearest neighbors for point sets batchwise. 126 | 127 | Args: 128 | points_src (numpy array): source points 129 | points_tgt (numpy array): target points 130 | k (int): number of nearest neighbors to return 131 | ''' 132 | indices = [] 133 | distances = [] 134 | 135 | for (p1, p2) in zip(points_src, points_tgt): 136 | kdtree = KDTree(p2) 137 | dist, idx = kdtree.query(p1, k=k) 138 | indices.append(idx) 139 | distances.append(dist) 140 | 141 | return indices, distances 142 | 143 | 144 | def normalize_imagenet(x): 145 | ''' Normalize input images according to ImageNet standards. 146 | 147 | Args: 148 | x (tensor): input images 149 | ''' 150 | x = x.clone() 151 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 152 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 153 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 154 | return x 155 | 156 | 157 | def make_3d_grid(bb_min, bb_max, shape): 158 | ''' Makes a 3D grid. 159 | 160 | Args: 161 | bb_min (tuple): bounding box minimum 162 | bb_max (tuple): bounding box maximum 163 | shape (tuple): output shape 164 | ''' 165 | size = shape[0] * shape[1] * shape[2] 166 | 167 | pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) 168 | pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) 169 | pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) 170 | 171 | pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) 172 | pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) 173 | pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) 174 | p = torch.stack([pxs, pys, pzs], dim=1) 175 | 176 | return p 177 | 178 | 179 | def transform_points(points, transform): 180 | ''' Transforms points with regard to passed camera information. 181 | 182 | Args: 183 | points (tensor): points tensor 184 | transform (tensor): transformation matrices 185 | ''' 186 | assert(points.size(2) == 3) 187 | assert(transform.size(1) == 3) 188 | assert(points.size(0) == transform.size(0)) 189 | 190 | if transform.size(2) == 4: 191 | R = transform[:, :, :3] 192 | t = transform[:, :, 3:] 193 | points_out = points @ R.transpose(1, 2) + t.transpose(1, 2) 194 | elif transform.size(2) == 3: 195 | K = transform 196 | points_out = points @ K.transpose(1, 2) 197 | 198 | return points_out 199 | 200 | 201 | def b_inv(b_mat): 202 | ''' Performs batch matrix inversion. 203 | 204 | Arguments: 205 | b_mat: the batch of matrices that should be inverted 206 | ''' 207 | 208 | eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat) 209 | b_inv, _ = torch.gesv(eye, b_mat) 210 | return b_inv 211 | 212 | 213 | def transform_points_back(points, transform): 214 | ''' Inverts the transformation. 215 | 216 | Args: 217 | points (tensor): points tensor 218 | transform (tensor): transformation matrices 219 | ''' 220 | assert(points.size(2) == 3) 221 | assert(transform.size(1) == 3) 222 | assert(points.size(0) == transform.size(0)) 223 | 224 | if transform.size(2) == 4: 225 | R = transform[:, :, :3] 226 | t = transform[:, :, 3:] 227 | points_out = points - t.transpose(1, 2) 228 | points_out = points_out @ b_inv(R.transpose(1, 2)) 229 | elif transform.size(2) == 3: 230 | K = transform 231 | points_out = points @ b_inv(K.transpose(1, 2)) 232 | 233 | return points_out 234 | 235 | 236 | def project_to_camera(points, transform): 237 | ''' Projects points to the camera plane. 238 | 239 | Args: 240 | points (tensor): points tensor 241 | transform (tensor): transformation matrices 242 | ''' 243 | p_camera = transform_points(points, transform) 244 | p_camera = p_camera[..., :2] / p_camera[..., 2:] 245 | return p_camera 246 | 247 | 248 | def get_camera_args(data, loc_field=None, scale_field=None, device=None): 249 | ''' Returns dictionary of camera arguments. 250 | 251 | Args: 252 | data (dict): data dictionary 253 | loc_field (str): name of location field 254 | scale_field (str): name of scale field 255 | device (device): pytorch device 256 | ''' 257 | Rt = data['inputs.world_mat'].to(device) 258 | K = data['inputs.camera_mat'].to(device) 259 | 260 | if loc_field is not None: 261 | loc = data[loc_field].to(device) 262 | else: 263 | loc = torch.zeros(K.size(0), 3, device=K.device, dtype=K.dtype) 264 | 265 | if scale_field is not None: 266 | scale = data[scale_field].to(device) 267 | else: 268 | scale = torch.zeros(K.size(0), device=K.device, dtype=K.dtype) 269 | 270 | Rt = fix_Rt_camera(Rt, loc, scale) 271 | K = fix_K_camera(K, img_size=137.) 272 | kwargs = {'Rt': Rt, 'K': K} 273 | return kwargs 274 | 275 | 276 | def fix_Rt_camera(Rt, loc, scale): 277 | ''' Fixes Rt camera matrix. 278 | 279 | Args: 280 | Rt (tensor): Rt camera matrix 281 | loc (tensor): location 282 | scale (float): scale 283 | ''' 284 | # Rt is B x 3 x 4 285 | # loc is B x 3 and scale is B 286 | batch_size = Rt.size(0) 287 | R = Rt[:, :, :3] 288 | t = Rt[:, :, 3:] 289 | 290 | scale = scale.view(batch_size, 1, 1) 291 | R_new = R * scale 292 | t_new = t + R @ loc.unsqueeze(2) 293 | 294 | Rt_new = torch.cat([R_new, t_new], dim=2) 295 | 296 | assert(Rt_new.size() == (batch_size, 3, 4)) 297 | return Rt_new 298 | 299 | 300 | def fix_K_camera(K, img_size=137): 301 | """Fix camera projection matrix. 302 | 303 | This changes a camera projection matrix that maps to 304 | [0, img_size] x [0, img_size] to one that maps to [-1, 1] x [-1, 1]. 305 | 306 | Args: 307 | K (np.ndarray): Camera projection matrix. 308 | img_size (float): Size of image plane K projects to. 309 | """ 310 | # Unscale and recenter 311 | scale_mat = torch.tensor([ 312 | [2./img_size, 0, -1], 313 | [0, 2./img_size, -1], 314 | [0, 0, 1.], 315 | ], device=K.device, dtype=K.dtype) 316 | K_new = scale_mat.view(1, 3, 3) @ K 317 | return K_new 318 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers_vn import VNLinear, VNLeakyReLU, VNBatchNorm, VNLinearLeakyReLU, VNMaxPool 4 | 5 | # VN version of resnet 6 | class VNResnetBlockFC(nn.Module): 7 | ''' Fully connected ResNet Block class. 8 | 9 | Args: 10 | size_in (int): input dimension 11 | size_out (int): output dimension 12 | size_h (int): hidden dimension 13 | ''' 14 | 15 | def __init__(self, size_in, size_out=None, size_h=None, share_nonlinearity=False, negative_slope=0, 16 | use_batchnorm=False, global_relu=False): 17 | super().__init__() 18 | # Attributes 19 | if size_out is None: 20 | size_out = size_in 21 | 22 | if size_h is None: 23 | size_h = min(size_in, size_out) 24 | 25 | self.global_relu = global_relu 26 | self.size_in = size_in 27 | self.size_h = size_h 28 | self.size_out = size_out 29 | # Submodules 30 | self.actvn = VNLeakyReLU(self.size_in, share_nonlinearity=share_nonlinearity, negative_slope=negative_slope, global_relu=global_relu) 31 | self.fc_0_actvn = VNLinearLeakyReLU(self.size_in, self.size_h, dim=4, share_nonlinearity=share_nonlinearity, 32 | negative_slope=negative_slope, use_batchnorm=use_batchnorm, global_relu=global_relu) 33 | # self.fc_0 = VNLinear(self.size_in, self.size_h) 34 | self.fc_1 = VNLinear(self.size_h, self.size_out) 35 | 36 | if size_in == size_out: 37 | self.shortcut = None 38 | else: 39 | self.shortcut = VNLinear(self.size_in, self.size_out) 40 | # Initialization 41 | nn.init.zeros_(self.fc_1.map_to_feat.weight) 42 | 43 | def forward(self, x): 44 | 45 | x_act = self.actvn(x) 46 | x_0 = self.fc_0_actvn(x_act) 47 | dx = self.fc_1(x_0) 48 | 49 | # net = self.fc_0(self.actvn(x)) 50 | # dx = self.fc_1(self.actvn(net)) 51 | 52 | if self.shortcut is not None: 53 | x_s = self.shortcut(x) 54 | else: 55 | x_s = x 56 | 57 | return x_s + dx 58 | 59 | 60 | # Resnet Blocks 61 | class ResnetBlockFC(nn.Module): 62 | ''' Fully connected ResNet Block class. 63 | 64 | Args: 65 | size_in (int): input dimension 66 | size_out (int): output dimension 67 | size_h (int): hidden dimension 68 | ''' 69 | 70 | def __init__(self, size_in, size_out=None, size_h=None): 71 | super().__init__() 72 | # Attributes 73 | if size_out is None: 74 | size_out = size_in 75 | 76 | if size_h is None: 77 | size_h = min(size_in, size_out) 78 | 79 | self.size_in = size_in 80 | self.size_h = size_h 81 | self.size_out = size_out 82 | # Submodules 83 | self.fc_0 = nn.Linear(size_in, size_h) 84 | self.fc_1 = nn.Linear(size_h, size_out) 85 | self.actvn = nn.ReLU() 86 | 87 | if size_in == size_out: 88 | self.shortcut = None 89 | else: 90 | self.shortcut = nn.Linear(size_in, size_out, bias=False) 91 | # Initialization 92 | nn.init.zeros_(self.fc_1.weight) 93 | 94 | def forward(self, x): 95 | net = self.fc_0(self.actvn(x)) 96 | dx = self.fc_1(self.actvn(net)) 97 | 98 | if self.shortcut is not None: 99 | x_s = self.shortcut(x) 100 | else: 101 | x_s = x 102 | 103 | return x_s + dx 104 | 105 | 106 | class CResnetBlockConv1d(nn.Module): 107 | ''' Conditional batch normalization-based Resnet block class. 108 | 109 | Args: 110 | c_dim (int): dimension of latend conditioned code c 111 | size_in (int): input dimension 112 | size_out (int): output dimension 113 | size_h (int): hidden dimension 114 | norm_method (str): normalization method 115 | legacy (bool): whether to use legacy blocks 116 | ''' 117 | 118 | def __init__(self, c_dim, size_in, size_h=None, size_out=None, 119 | norm_method='batch_norm', legacy=False): 120 | super().__init__() 121 | # Attributes 122 | if size_h is None: 123 | size_h = size_in 124 | if size_out is None: 125 | size_out = size_in 126 | 127 | self.size_in = size_in 128 | self.size_h = size_h 129 | self.size_out = size_out 130 | # Submodules 131 | if not legacy: 132 | self.bn_0 = CBatchNorm1d( 133 | c_dim, size_in, norm_method=norm_method) 134 | self.bn_1 = CBatchNorm1d( 135 | c_dim, size_h, norm_method=norm_method) 136 | else: 137 | self.bn_0 = CBatchNorm1d_legacy( 138 | c_dim, size_in, norm_method=norm_method) 139 | self.bn_1 = CBatchNorm1d_legacy( 140 | c_dim, size_h, norm_method=norm_method) 141 | 142 | self.fc_0 = nn.Conv1d(size_in, size_h, 1) 143 | self.fc_1 = nn.Conv1d(size_h, size_out, 1) 144 | self.actvn = nn.ReLU() 145 | 146 | if size_in == size_out: 147 | self.shortcut = None 148 | else: 149 | self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) 150 | # Initialization 151 | nn.init.zeros_(self.fc_1.weight) 152 | 153 | def forward(self, x, c): 154 | net = self.fc_0(self.actvn(self.bn_0(x, c))) 155 | dx = self.fc_1(self.actvn(self.bn_1(net, c))) 156 | 157 | if self.shortcut is not None: 158 | x_s = self.shortcut(x) 159 | else: 160 | x_s = x 161 | 162 | return x_s + dx 163 | 164 | 165 | class ResnetBlockConv1d(nn.Module): 166 | ''' 1D-Convolutional ResNet block class. 167 | 168 | Args: 169 | size_in (int): input dimension 170 | size_out (int): output dimension 171 | size_h (int): hidden dimension 172 | ''' 173 | 174 | def __init__(self, size_in, size_h=None, size_out=None): 175 | super().__init__() 176 | # Attributes 177 | if size_h is None: 178 | size_h = size_in 179 | if size_out is None: 180 | size_out = size_in 181 | 182 | self.size_in = size_in 183 | self.size_h = size_h 184 | self.size_out = size_out 185 | # Submodules 186 | self.bn_0 = nn.BatchNorm1d(size_in) 187 | self.bn_1 = nn.BatchNorm1d(size_h) 188 | 189 | self.fc_0 = nn.Conv1d(size_in, size_h, 1) 190 | self.fc_1 = nn.Conv1d(size_h, size_out, 1) 191 | self.actvn = nn.ReLU() 192 | 193 | if size_in == size_out: 194 | self.shortcut = None 195 | else: 196 | self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) 197 | 198 | # Initialization 199 | nn.init.zeros_(self.fc_1.weight) 200 | 201 | def forward(self, x): 202 | net = self.fc_0(self.actvn(self.bn_0(x))) 203 | dx = self.fc_1(self.actvn(self.bn_1(net))) 204 | 205 | if self.shortcut is not None: 206 | x_s = self.shortcut(x) 207 | else: 208 | x_s = x 209 | 210 | return x_s + dx 211 | 212 | 213 | # Utility modules 214 | class AffineLayer(nn.Module): 215 | ''' Affine layer class. 216 | 217 | Args: 218 | c_dim (tensor): dimension of latent conditioned code c 219 | dim (int): input dimension 220 | ''' 221 | 222 | def __init__(self, c_dim, dim=3): 223 | super().__init__() 224 | self.c_dim = c_dim 225 | self.dim = dim 226 | # Submodules 227 | self.fc_A = nn.Linear(c_dim, dim * dim) 228 | self.fc_b = nn.Linear(c_dim, dim) 229 | self.reset_parameters() 230 | 231 | def reset_parameters(self): 232 | nn.init.zeros_(self.fc_A.weight) 233 | nn.init.zeros_(self.fc_b.weight) 234 | with torch.no_grad(): 235 | self.fc_A.bias.copy_(torch.eye(3).view(-1)) 236 | self.fc_b.bias.copy_(torch.tensor([0., 0., 2.])) 237 | 238 | def forward(self, x, p): 239 | assert(x.size(0) == p.size(0)) 240 | assert(p.size(2) == self.dim) 241 | batch_size = x.size(0) 242 | A = self.fc_A(x).view(batch_size, 3, 3) 243 | b = self.fc_b(x).view(batch_size, 1, 3) 244 | out = p @ A + b 245 | return out 246 | 247 | 248 | class CBatchNorm1d(nn.Module): 249 | ''' Conditional batch normalization layer class. 250 | 251 | Args: 252 | c_dim (int): dimension of latent conditioned code c 253 | f_dim (int): feature dimension 254 | norm_method (str): normalization method 255 | ''' 256 | 257 | def __init__(self, c_dim, f_dim, norm_method='batch_norm'): 258 | super().__init__() 259 | self.c_dim = c_dim 260 | self.f_dim = f_dim 261 | self.norm_method = norm_method 262 | # Submodules 263 | self.conv_gamma = nn.Conv1d(c_dim, f_dim, 1) 264 | self.conv_beta = nn.Conv1d(c_dim, f_dim, 1) 265 | if norm_method == 'batch_norm': 266 | self.bn = nn.BatchNorm1d(f_dim, affine=False) 267 | elif norm_method == 'instance_norm': 268 | self.bn = nn.InstanceNorm1d(f_dim, affine=False) 269 | elif norm_method == 'group_norm': 270 | self.bn = nn.GroupNorm1d(f_dim, affine=False) 271 | else: 272 | raise ValueError('Invalid normalization method!') 273 | self.reset_parameters() 274 | 275 | def reset_parameters(self): 276 | nn.init.zeros_(self.conv_gamma.weight) 277 | nn.init.zeros_(self.conv_beta.weight) 278 | nn.init.ones_(self.conv_gamma.bias) 279 | nn.init.zeros_(self.conv_beta.bias) 280 | 281 | def forward(self, x, c): 282 | assert(x.size(0) == c.size(0)) 283 | assert(c.size(1) == self.c_dim) 284 | 285 | # c is assumed to be of size batch_size x c_dim x T 286 | if len(c.size()) == 2: 287 | c = c.unsqueeze(2) 288 | 289 | # Affine mapping 290 | gamma = self.conv_gamma(c) 291 | beta = self.conv_beta(c) 292 | 293 | # Batchnorm 294 | net = self.bn(x) 295 | out = gamma * net + beta 296 | 297 | return out 298 | 299 | 300 | class CBatchNorm1d_legacy(nn.Module): 301 | ''' Conditional batch normalization legacy layer class. 302 | 303 | Args: 304 | c_dim (int): dimension of latent conditioned code c 305 | f_dim (int): feature dimension 306 | norm_method (str): normalization method 307 | ''' 308 | 309 | def __init__(self, c_dim, f_dim, norm_method='batch_norm'): 310 | super().__init__() 311 | self.c_dim = c_dim 312 | self.f_dim = f_dim 313 | self.norm_method = norm_method 314 | # Submodules 315 | self.fc_gamma = nn.Linear(c_dim, f_dim) 316 | self.fc_beta = nn.Linear(c_dim, f_dim) 317 | if norm_method == 'batch_norm': 318 | self.bn = nn.BatchNorm1d(f_dim, affine=False) 319 | elif norm_method == 'instance_norm': 320 | self.bn = nn.InstanceNorm1d(f_dim, affine=False) 321 | elif norm_method == 'group_norm': 322 | self.bn = nn.GroupNorm1d(f_dim, affine=False) 323 | else: 324 | raise ValueError('Invalid normalization method!') 325 | self.reset_parameters() 326 | 327 | def reset_parameters(self): 328 | nn.init.zeros_(self.fc_gamma.weight) 329 | nn.init.zeros_(self.fc_beta.weight) 330 | nn.init.ones_(self.fc_gamma.bias) 331 | nn.init.zeros_(self.fc_beta.bias) 332 | 333 | def forward(self, x, c): 334 | batch_size = x.size(0) 335 | # Affine mapping 336 | gamma = self.fc_gamma(c) 337 | beta = self.fc_beta(c) 338 | gamma = gamma.view(batch_size, self.f_dim, 1) 339 | beta = beta.view(batch_size, self.f_dim, 1) 340 | # Batchnorm 341 | net = self.bn(x) 342 | out = gamma * net + beta 343 | 344 | return out 345 | -------------------------------------------------------------------------------- /utils/libkdtree/pykdtree/kdtree.pyx: -------------------------------------------------------------------------------- 1 | #pykdtree, Fast kd-tree implementation with OpenMP-enabled queries 2 | # 3 | #Copyright (C) 2013 - present Esben S. Nielsen 4 | # 5 | # This program is free software: you can redistribute it and/or modify it under 6 | # the terms of the GNU Lesser General Public License as published by the Free 7 | # Software Foundation, either version 3 of the License, or 8 | #(at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, but WITHOUT 11 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 12 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 13 | # details. 14 | # 15 | # You should have received a copy of the GNU Lesser General Public License along 16 | # with this program. If not, see . 17 | 18 | import numpy as np 19 | cimport numpy as np 20 | from libc.stdint cimport uint32_t, int8_t, uint8_t 21 | cimport cython 22 | 23 | 24 | # Node structure 25 | cdef struct node_float: 26 | float cut_val 27 | int8_t cut_dim 28 | uint32_t start_idx 29 | uint32_t n 30 | float cut_bounds_lv 31 | float cut_bounds_hv 32 | node_float *left_child 33 | node_float *right_child 34 | 35 | cdef struct tree_float: 36 | float *bbox 37 | int8_t no_dims 38 | uint32_t *pidx 39 | node_float *root 40 | 41 | cdef struct node_double: 42 | double cut_val 43 | int8_t cut_dim 44 | uint32_t start_idx 45 | uint32_t n 46 | double cut_bounds_lv 47 | double cut_bounds_hv 48 | node_double *left_child 49 | node_double *right_child 50 | 51 | cdef struct tree_double: 52 | double *bbox 53 | int8_t no_dims 54 | uint32_t *pidx 55 | node_double *root 56 | 57 | cdef extern tree_float* construct_tree_float(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil 58 | cdef extern void search_tree_float(tree_float *kdtree, float *pa, float *point_coords, uint32_t num_points, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists) nogil 59 | cdef extern void delete_tree_float(tree_float *kdtree) 60 | 61 | cdef extern tree_double* construct_tree_double(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil 62 | cdef extern void search_tree_double(tree_double *kdtree, double *pa, double *point_coords, uint32_t num_points, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists) nogil 63 | cdef extern void delete_tree_double(tree_double *kdtree) 64 | 65 | cdef class KDTree: 66 | """kd-tree for fast nearest-neighbour lookup. 67 | The interface is made to resemble the scipy.spatial kd-tree except 68 | only Euclidean distance measure is supported. 69 | 70 | :Parameters: 71 | data_pts : numpy array 72 | Data points with shape (n , dims) 73 | leafsize : int, optional 74 | Maximum number of data points in tree leaf 75 | """ 76 | 77 | cdef tree_float *_kdtree_float 78 | cdef tree_double *_kdtree_double 79 | cdef readonly np.ndarray data_pts 80 | cdef readonly np.ndarray data 81 | cdef float *_data_pts_data_float 82 | cdef double *_data_pts_data_double 83 | cdef readonly uint32_t n 84 | cdef readonly int8_t ndim 85 | cdef readonly uint32_t leafsize 86 | 87 | def __cinit__(KDTree self): 88 | self._kdtree_float = NULL 89 | self._kdtree_double = NULL 90 | 91 | def __init__(KDTree self, np.ndarray data_pts not None, int leafsize=16): 92 | 93 | # Check arguments 94 | if leafsize < 1: 95 | raise ValueError('leafsize must be greater than zero') 96 | 97 | # Get data content 98 | cdef np.ndarray[float, ndim=1] data_array_float 99 | cdef np.ndarray[double, ndim=1] data_array_double 100 | 101 | if data_pts.dtype == np.float32: 102 | data_array_float = np.ascontiguousarray(data_pts.ravel(), dtype=np.float32) 103 | self._data_pts_data_float = data_array_float.data 104 | self.data_pts = data_array_float 105 | else: 106 | data_array_double = np.ascontiguousarray(data_pts.ravel(), dtype=np.float64) 107 | self._data_pts_data_double = data_array_double.data 108 | self.data_pts = data_array_double 109 | 110 | # scipy interface compatibility 111 | self.data = self.data_pts 112 | 113 | # Get tree info 114 | self.n = data_pts.shape[0] 115 | self.leafsize = leafsize 116 | if data_pts.ndim == 1: 117 | self.ndim = 1 118 | else: 119 | self.ndim = data_pts.shape[1] 120 | 121 | # Release GIL and construct tree 122 | if data_pts.dtype == np.float32: 123 | with nogil: 124 | self._kdtree_float = construct_tree_float(self._data_pts_data_float, self.ndim, 125 | self.n, self.leafsize) 126 | else: 127 | with nogil: 128 | self._kdtree_double = construct_tree_double(self._data_pts_data_double, self.ndim, 129 | self.n, self.leafsize) 130 | 131 | 132 | def query(KDTree self, np.ndarray query_pts not None, k=1, eps=0, 133 | distance_upper_bound=None, sqr_dists=False, mask=None): 134 | """Query the kd-tree for nearest neighbors 135 | 136 | :Parameters: 137 | query_pts : numpy array 138 | Query points with shape (m, dims) 139 | k : int 140 | The number of nearest neighbours to return 141 | eps : non-negative float 142 | Return approximate nearest neighbours; the k-th returned value 143 | is guaranteed to be no further than (1 + eps) times the distance 144 | to the real k-th nearest neighbour 145 | distance_upper_bound : non-negative float 146 | Return only neighbors within this distance. 147 | This is used to prune tree searches. 148 | sqr_dists : bool, optional 149 | Internally pykdtree works with squared distances. 150 | Determines if the squared or Euclidean distances are returned. 151 | mask : numpy array, optional 152 | Array of booleans where neighbors are considered invalid and 153 | should not be returned. A mask value of True represents an 154 | invalid pixel. Mask should have shape (n,) to match data points. 155 | By default all points are considered valid. 156 | 157 | """ 158 | 159 | # Check arguments 160 | if k < 1: 161 | raise ValueError('Number of neighbours must be greater than zero') 162 | elif eps < 0: 163 | raise ValueError('eps must be non-negative') 164 | elif distance_upper_bound is not None: 165 | if distance_upper_bound < 0: 166 | raise ValueError('distance_upper_bound must be non negative') 167 | 168 | # Check dimensions 169 | if query_pts.ndim == 1: 170 | q_ndim = 1 171 | else: 172 | q_ndim = query_pts.shape[1] 173 | 174 | if self.ndim != q_ndim: 175 | raise ValueError('Data and query points must have same dimensions') 176 | 177 | if self.data_pts.dtype == np.float32 and query_pts.dtype != np.float32: 178 | raise TypeError('Type mismatch. query points must be of type float32 when data points are of type float32') 179 | 180 | # Get query info 181 | cdef uint32_t num_qpoints = query_pts.shape[0] 182 | cdef uint32_t num_n = k 183 | cdef np.ndarray[uint32_t, ndim=1] closest_idxs = np.empty(num_qpoints * k, dtype=np.uint32) 184 | cdef np.ndarray[float, ndim=1] closest_dists_float 185 | cdef np.ndarray[double, ndim=1] closest_dists_double 186 | 187 | 188 | # Set up return arrays 189 | cdef uint32_t *closest_idxs_data = closest_idxs.data 190 | cdef float *closest_dists_data_float 191 | cdef double *closest_dists_data_double 192 | 193 | # Get query points data 194 | cdef np.ndarray[float, ndim=1] query_array_float 195 | cdef np.ndarray[double, ndim=1] query_array_double 196 | cdef float *query_array_data_float 197 | cdef double *query_array_data_double 198 | cdef np.ndarray[np.uint8_t, ndim=1] query_mask 199 | cdef np.uint8_t *query_mask_data 200 | 201 | if mask is not None and mask.size != self.n: 202 | raise ValueError('Mask must have the same size as data points') 203 | elif mask is not None: 204 | query_mask = np.ascontiguousarray(mask.ravel(), dtype=np.uint8) 205 | query_mask_data = query_mask.data 206 | else: 207 | query_mask_data = NULL 208 | 209 | 210 | if query_pts.dtype == np.float32 and self.data_pts.dtype == np.float32: 211 | closest_dists_float = np.empty(num_qpoints * k, dtype=np.float32) 212 | closest_dists = closest_dists_float 213 | closest_dists_data_float = closest_dists_float.data 214 | query_array_float = np.ascontiguousarray(query_pts.ravel(), dtype=np.float32) 215 | query_array_data_float = query_array_float.data 216 | else: 217 | closest_dists_double = np.empty(num_qpoints * k, dtype=np.float64) 218 | closest_dists = closest_dists_double 219 | closest_dists_data_double = closest_dists_double.data 220 | query_array_double = np.ascontiguousarray(query_pts.ravel(), dtype=np.float64) 221 | query_array_data_double = query_array_double.data 222 | 223 | # Setup distance_upper_bound 224 | cdef float dub_float 225 | cdef double dub_double 226 | if distance_upper_bound is None: 227 | if self.data_pts.dtype == np.float32: 228 | dub_float = np.finfo(np.float32).max 229 | else: 230 | dub_double = np.finfo(np.float64).max 231 | else: 232 | if self.data_pts.dtype == np.float32: 233 | dub_float = (distance_upper_bound * distance_upper_bound) 234 | else: 235 | dub_double = (distance_upper_bound * distance_upper_bound) 236 | 237 | # Set epsilon 238 | cdef double epsilon_float = eps 239 | cdef double epsilon_double = eps 240 | 241 | # Release GIL and query tree 242 | if self.data_pts.dtype == np.float32: 243 | with nogil: 244 | search_tree_float(self._kdtree_float, self._data_pts_data_float, 245 | query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, 246 | query_mask_data, closest_idxs_data, closest_dists_data_float) 247 | 248 | else: 249 | with nogil: 250 | search_tree_double(self._kdtree_double, self._data_pts_data_double, 251 | query_array_data_double, num_qpoints, num_n, dub_double, epsilon_double, 252 | query_mask_data, closest_idxs_data, closest_dists_data_double) 253 | 254 | # Shape result 255 | if k > 1: 256 | closest_dists_res = closest_dists.reshape(num_qpoints, k) 257 | closest_idxs_res = closest_idxs.reshape(num_qpoints, k) 258 | else: 259 | closest_dists_res = closest_dists 260 | closest_idxs_res = closest_idxs 261 | 262 | if distance_upper_bound is not None: # Mark out of bounds results 263 | if self.data_pts.dtype == np.float32: 264 | idx_out = (closest_dists_res >= dub_float) 265 | else: 266 | idx_out = (closest_dists_res >= dub_double) 267 | 268 | closest_dists_res[idx_out] = np.Inf 269 | closest_idxs_res[idx_out] = self.n 270 | 271 | if not sqr_dists: # Return actual cartesian distances 272 | closest_dists_res = np.sqrt(closest_dists_res) 273 | 274 | return closest_dists_res, closest_idxs_res 275 | 276 | def __dealloc__(KDTree self): 277 | if self._kdtree_float != NULL: 278 | delete_tree_float(self._kdtree_float) 279 | elif self._kdtree_double != NULL: 280 | delete_tree_double(self._kdtree_double) 281 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import logging 4 | import torch 5 | from torch.utils.data import Dataset, dataloader 6 | import numpy as np 7 | 8 | import fields 9 | import transforms, fmr_transforms 10 | 11 | # logger = logging.getLogger(__name__) 12 | 13 | class Shapes3dDataset(Dataset): 14 | ''' 3D Shapes dataset class. 15 | ''' 16 | 17 | def __init__(self, dataset_folder, fields, split, rot_magmax=None 18 | # categories=None 19 | ): 20 | ''' Initialization of the the 3D shape dataset. 21 | 22 | Args: 23 | dataset_folder (str): dataset folder 24 | fields (dict): dictionary of fields 25 | split (str): which split is used 26 | categories (list): list of categories to use 27 | no_except (bool): no exception 28 | transform (callable): transformation applied to data points 29 | ''' 30 | # Attributes 31 | self.dataset_folder = dataset_folder 32 | self.fields = fields 33 | self.rot_magmax = rot_magmax 34 | self.rotate_op = transforms.RotateBatchIP() if rot_magmax is not None and rot_magmax > 0 else None 35 | # self.no_except = no_except 36 | # self.transform = transform 37 | 38 | # If categories is None, use all subfolders 39 | # if categories is None: 40 | categories = os.listdir(dataset_folder) 41 | categories = [c for c in categories 42 | if os.path.isdir(os.path.join(dataset_folder, c))] 43 | 44 | # Read metadata file 45 | metadata_file = os.path.join(dataset_folder, 'metadata.yaml') 46 | 47 | ### for ModelNet40, there is no metadata_file 48 | if os.path.exists(metadata_file): 49 | with open(metadata_file, 'r') as f: 50 | self.metadata = yaml.load(f) 51 | else: 52 | self.metadata = { 53 | c: {'id': c, 'name': 'n/a'} for c in categories 54 | } 55 | 56 | # Assign an index to each category 57 | for c_idx, c in enumerate(categories): 58 | self.metadata[c]['idx'] = c_idx 59 | ### self.metadata: 60 | ### For ModelNet40: each category has an id (category folder name == category name), name (n/a), and idx (a number) 61 | ### For ShapeNet: each category has an id (category folder name), name (category names), and idx (a number) 62 | 63 | # Get all models 64 | self.models = [] 65 | for c_idx, c in enumerate(categories): 66 | subpath = os.path.join(dataset_folder, c) 67 | if not os.path.isdir(subpath): 68 | logging.warning('Category %s does not exist in dataset.' % c) 69 | 70 | split_file = os.path.join(subpath, split + '.lst') 71 | if os.path.exists(split_file): 72 | with open(split_file, 'r') as f: 73 | models_c = f.read().split('\n') 74 | else: 75 | models_c = os.listdir(subpath) 76 | 77 | self.models += [ 78 | {'category': c, 'model': m} # (category folder name, model folder name) 79 | for m in models_c 80 | ] 81 | ### self.models: 82 | ### Each model category has an id (foldername == category name), name (n/a), and idx (a number) 83 | 84 | def __len__(self): 85 | ''' Returns the length of the dataset. 86 | ''' 87 | return len(self.models) 88 | 89 | def load_field(self, data, idx, field_name, field): 90 | 91 | category = self.models[idx]['category'] 92 | model = self.models[idx]['model'] 93 | # dataset_folder = self.bench_input_folder if 'inputs' in field_name or 'T21' in field_name else self.dataset_folder 94 | dataset_folder = self.dataset_folder 95 | model_path = os.path.join(dataset_folder, category, model) 96 | 97 | try: 98 | field_data = field.load(model_path) 99 | except Exception as e: 100 | raise ValueError('Error occured when loading field %s of model path %s: %s' 101 | % (field_name, model_path, e)) 102 | 103 | if isinstance(field_data, dict): 104 | for k, v in field_data.items(): 105 | if k is None: 106 | data[field_name] = v 107 | else: 108 | data['%s.%s' % (field_name, k)] = v 109 | else: 110 | data[field_name] = field_data 111 | 112 | return 113 | 114 | def __getitem__(self, idx): 115 | data = {} 116 | for field_name, field in self.fields.items(): 117 | if isinstance(field, fields.IndexField): 118 | data[field_name] = idx 119 | elif isinstance(field, fields.CategoryField): 120 | category = self.models[idx]['category'] 121 | c_idx = self.metadata[category]['idx'] 122 | data[field_name] = c_idx 123 | else: 124 | self.load_field(data, idx, field_name, field) 125 | 126 | # if self.transform is not None: 127 | # data = self.transform(data) 128 | 129 | transforms.totensor_inplace(data) 130 | if self.rotate_op is not None: 131 | rotmat, deg = transforms.gen_randrot(self.rot_magmax) 132 | data['T'] = rotmat 133 | data['T.deg'] = deg 134 | self.rotate_op(data) 135 | 136 | return data 137 | 138 | def get_model_dict(self, idx): 139 | return self.models[idx] 140 | 141 | class PairedDataset(Dataset): 142 | '''Given a dataset that spit one sample at a time, return a pair, 143 | so that they are related by a rigid body transformation corrupted by some noise 144 | (e.g. Gaussian noise, resampling, density difference). 145 | 146 | For training: 147 | Max number resampling is per instance. 148 | Gaussian noise is per instance. 149 | Generation of rigid body transformation is per pair. 150 | Resampling for randomness in number of points is per batch. (so that we can have randomness in number of points but consistent in a batch) 151 | Application of rigid body transformation is per batch. (so that the operation of transformation is on the same device as registration) 152 | Optional centering is per batch. 153 | 154 | For testing: 155 | No resampling or Gaussian noise. 156 | Load Rigid body transformation from file, per pair. ''' 157 | def __init__(self, dataset, rot_magmax=None, duo_mode=False, reg_benchmark_mode=False, resamp_mode=True, pcl_noise=None) -> None: 158 | ''' 159 | Args: 160 | dataset: the Dataset object that give one instance as a time 161 | duo_mode: if True, the output pair is formed using two instances from different indices 162 | (reg_mode: True when we use this PairedDataset. 163 | default_mode: inputs and inputs_2 from the same file 164 | reg_benchmark_mode: inputs and inputs_2 from different files of the same indices, 165 | need a transformation for a pair. [done outside of this dataset] 166 | duo_mode: inputs and inputs_2 from the same file of different indices, 167 | need a transformation for each index. 168 | duo_benchmark_mode? 169 | ) 170 | transform: if not None, each pair goes through the transforms specified 171 | ''' 172 | super().__init__() 173 | self.dataset = dataset 174 | self.rot_magmax = rot_magmax 175 | self.duo_mode = duo_mode 176 | self.reg_benchmark_mode = reg_benchmark_mode 177 | self.resamp_mode = resamp_mode 178 | self.pcl_noise = pcl_noise 179 | self.noise_op = transforms.NoisePairBatchIP(pcl_noise) if pcl_noise > 0 else None 180 | self.unitcube_op = fmr_transforms.OnUnitCube() 181 | self.rotate_op = transforms.RotatePairBatchIP() 182 | 183 | def __len__(self): 184 | return len(self.dataset) 185 | def get_model_dict(self, idx): 186 | return self.dataset.get_model_dict(idx) 187 | 188 | def duo_load(self, idx, data): 189 | '''Load the data of adjacent index. ''' 190 | if idx < len(self)-1: 191 | idx_2 = idx + 1 192 | else: 193 | idx_2 = idx - 1 194 | data_2 = self.dataset[idx_2] 195 | transforms.totensor_inplace(data_2) 196 | for key, value in data_2: 197 | dotidx = key.find('.') 198 | key_2 = key + '_2' if dotidx == -1 else key[:dotidx] + '_2' + key[dotidx:] 199 | data[key_2] = value 200 | return 201 | 202 | def duo_preprocess(self, data): 203 | '''Put inputs and inputs_2 in the same reference frame (the frame of inputs_2), 204 | and then put both into a unit cube together (using the same spec). 205 | ''' 206 | ### Put inputs and inputs_2 in the same reference frame 207 | T01 = data['T'] 208 | T02 = data['T_2'] 209 | T21 = torch.matmul(torch.inverse(T02), T01) 210 | data['inputs_rawT'] = data['inputs'].clone() 211 | data['inputs'] = transforms.apply_transformation(T21, data['inputs']) 212 | if 'points' in data: 213 | data['points'] = transforms.apply_transformation(T21, data['points']) 214 | ### Centralize inputs and inputs_2 to unit cube using the same spec 215 | data['inputs'], spec = self.unitcube_op(data['inputs']) 216 | data['inputs_2'], _ = self.unitcube_op(data['inputs_2'], spec) 217 | if 'points' in data: 218 | data['points'], _ = self.unitcube_op(data['points'], spec) 219 | data['points_2'], _ = self.unitcube_op(data['points_2'], spec) 220 | return 221 | 222 | def __getitem__(self, idx): 223 | data = self.dataset[idx] 224 | transforms.totensor_inplace(data) 225 | 226 | if self.reg_benchmark_mode: 227 | assert 'inputs_2' in data and 'T21' in data, "{}".format(list(data.keys())) 228 | 229 | elif self.duo_mode: 230 | self.duo_load(idx, data) 231 | self.duo_preprocess(data) 232 | assert 'inputs_2' in data and 'T21' not in data 233 | 234 | else: 235 | ### when not duo_mode or reg_benchmark_mode (which both gives 'inputs_2' item) 236 | if self.resamp_mode: 237 | self.dataset.load_field(data, idx, 'inputs_2', self.dataset.fields['inputs']) 238 | else: 239 | data['inputs_2'] = data['inputs'].clone() 240 | data['points_2'] = data['points'].clone() 241 | transforms.totensor_inplace(data) 242 | assert 'inputs_2' in data and 'T21' not in data 243 | 244 | if 'T21' not in data: 245 | rotmat, deg = transforms.gen_randrot(self.rot_magmax) 246 | data['T21'] = rotmat 247 | data['T21.deg'] = deg 248 | 249 | if self.noise_op is not None: 250 | self.noise_op(data) 251 | 252 | self.rotate_op(data) 253 | # ### rotate one of the pair 254 | # data['inputs_2'] = apply_rot(data['T21'], data['inputs']) 255 | 256 | # data['inputs_3'] = apply_rot(data['T21'], data['inputs']) 257 | # diff_pts_rmse = torch.norm(data['inputs_3'] - data['inputs_2'], dim=1).mean() 258 | # diff_pts_rmse1 = torch.norm(data['inputs_2'], dim=1).mean() 259 | # diff_pts_rmse2 = torch.norm(data['inputs_3'], dim=1).mean() 260 | # logging.info("diff_pts_rmse dataset %.4f %.4f %.4f"%(diff_pts_rmse.item(), diff_pts_rmse1.item(), diff_pts_rmse2.item() ) ) 261 | # logging.info("data['inputs'].dtype {} shape {}, device {}".format(data['inputs'].dtype, data['inputs'].shape, data['inputs'].device)) 262 | # logging.info("data['inputs_2'].dtype {} shape {}, device {}".format(data['inputs_2'].dtype, data['inputs_2'].shape, data['inputs_2'].device)) 263 | 264 | # if 'points' in data: 265 | # data['points_2'] = apply_rot(data['T21'], data['points']) 266 | return data 267 | 268 | 269 | def collate_remove_none(batch): 270 | ''' Collater that puts each data field into a tensor with outer dimension 271 | batch size. 272 | 273 | Args: 274 | batch: batch 275 | ''' 276 | 277 | batch = list(filter(lambda x: x is not None, batch)) 278 | return dataloader.default_collate(batch) 279 | 280 | 281 | def worker_init_fn(worker_id): 282 | ''' Worker init function to ensure true randomness. 283 | ''' 284 | random_data = os.urandom(4) 285 | base_seed = int.from_bytes(random_data, byteorder="big") 286 | np.random.seed(base_seed + worker_id) --------------------------------------------------------------------------------