├── .gitignore ├── LICENSE ├── README.md ├── config ├── d-faust_splits.csv └── dfaust_32_prims.yaml ├── environment.yaml ├── hierarchical_primitives ├── __init__.py ├── common │ ├── __init__.py │ ├── base.py │ ├── dataset.py │ ├── mesh.py │ ├── model_factory.py │ └── parse_splits.py ├── equal_distance_sampler_sq.py ├── external │ ├── __init__.py │ └── libmesh │ │ ├── README.md │ │ ├── __init__.py │ │ ├── inside_mesh.py │ │ ├── triangle_hash.cpp │ │ └── triangle_hash.pyx ├── fast_sampler │ ├── __init__.py │ ├── _sampler.c │ ├── _sampler.pyx │ ├── sampling.cpp │ └── sampling.hpp ├── losses │ ├── __init__.py │ ├── chamfer_loss.py │ ├── common.py │ ├── coverage.py │ ├── implicit_surface_loss.py │ ├── implicit_surface_loss_with_partition.py │ ├── loss_functions.py │ └── regularizers.py ├── mesh.py ├── networks │ ├── __init__.py │ ├── base.py │ ├── feature_extractors.py │ ├── primitive_layer.py │ ├── primitive_parameters.py │ ├── probability.py │ ├── qos.py │ ├── rotation.py │ ├── shape.py │ ├── sharpness.py │ ├── simple_constant_sq.py │ ├── size.py │ ├── translation.py │ └── utils.py ├── primitives.py ├── sample_points_on_primitive.py └── utils │ ├── __init__.py │ ├── filter_sqs.py │ ├── metrics.py │ ├── progbar.py │ ├── sq_mesh.py │ ├── stats_logger.py │ ├── value_registry.py │ └── visualization_utils.py ├── img ├── chair.gif ├── human_punching.gif └── teaser.png ├── scripts ├── arguments.py ├── compute_metrics.py ├── evaluate.py ├── evaluate_to_db.py ├── generate_surface_samples.py ├── output_logger.py ├── render_dfaust.py ├── single_mesh_from_primitives.py ├── train_network.py ├── training_utils.py ├── utils.py ├── visualization_utils.py ├── visualize_parsing_tree.py ├── visualize_partition.py └── visualize_predictions.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build 3 | dist 4 | *.egg-info 5 | *.so 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2020 Despoina Paschalidou , Luc van Gool , Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Unsupervised Hierarchical Part Decomposition of 3D Objects from a Single RGB Image 2 | 3 | This repository contains the code that accompanies our CVPR 2020 paper 4 | [Learning Unsupervised Hierarchical Part Decomposition of 3D Objects from a Single RGB Image](https://superquadrics.com/hierarchical-primitives.html) 5 | 6 | ![Teaser](img/teaser.png) 7 | 8 | You can find detailed usage instructions for training your own models and using our pretrained models below. 9 | 10 | If you found this work influential or helpful for your research, please consider citing 11 | 12 | ``` 13 | @Inproceedings{Paschalidou2020CVPR, 14 | title = {Learning Unsupervised Hierarchical Part Decomposition of 3D Objects from a Single RGB Image}, 15 | author = {Paschalidou, Despoina and Luc van Gool and Geiger, Andreas}, 16 | booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, 17 | year = {2020} 18 | } 19 | ``` 20 | 21 | ## Installation & Dependencies 22 | 23 | Our codebase has the following dependencies: 24 | 25 | - [numpy](https://numpy.org/doc/stable/user/install.html) 26 | - [cython](https://cython.readthedocs.io/en/latest/src/quickstart/build.html) 27 | - [pyquaternion](http://kieranwynn.github.io/pyquaternion/) 28 | - [pyyaml](https://pyyaml.org/wiki/PyYAMLDocumentation) 29 | - [pykdtree](https://github.com/storpipfugl/pykdtree) 30 | - [torch && torchvision](https://pytorch.org/get-started/locally/) 31 | - [trimesh](https://github.com/mikedh/trimesh) 32 | 33 | For the visualizations, we use [simple-3dviz](http://simple-3dviz.com), which 34 | is our easy-to-use library for visualizing 3D data using Python and ModernGL and 35 | [matplotlib](https://matplotlib.org/) for the colormaps. Note that 36 | [simple-3dviz](http://simple-3dviz.com) provides a lightweight and easy-to-use 37 | scene viewer using [wxpython](https://www.wxpython.org/). If you wish you use 38 | our scripts for visualizing the reconstructed primitives, you will need to also 39 | install [wxpython](https://anaconda.org/anaconda/wxpython). 40 | 41 | The simplest way to make sure that you have all dependencies in place is to use 42 | [conda](https://docs.conda.io/projects/conda/en/4.6.1/index.html). You can 43 | create a conda environment called ```hierarchical_primitives``` using 44 | ``` 45 | conda env create -f environment.yaml 46 | conda activate hierarchical_primitives 47 | ``` 48 | 49 | Next compile the extenstion modules. You can do this via 50 | ``` 51 | python setup.py build_ext --inplace 52 | pip install -e . 53 | ``` 54 | 55 | ## Usage 56 | 57 | As soon as you have installed all dependencies you can now start training new 58 | models from scratch, evaluate our pre-trained models and visualize the 59 | recovered primitives using one of our pre-trained models. 60 | 61 | ### Reconstruction 62 | To visualize the predicted primitives using a trained model, we provide the 63 | ``visualize_predictions.py`` script. In particular, it performs the forward 64 | pass and visualizes the predicted primitives using 65 | [simple-3dviz](https://simple-3dviz.com/). To execute it simply run 66 | To run the ``visualize_predictions.py`` script you need to run 67 | ``` 68 | python visualize_predictions.py path_to_config_yaml path_to_output_dir --weight_file path_to_weight_file --model_tag MODEL_TAG --from_fit 69 | ``` 70 | where the argument ``--weight_file`` specifies the path to a trained model and 71 | the argument ``--model_tag`` defines the model_tag of the input to be 72 | reconstructed. 73 | 74 | ### Hierarchy Reconstruction 75 | 76 | ### Training 77 | Finally, to train a new network from scratch, we provide the 78 | ``train_network.py`` script. To execute this script, you need to specify the 79 | path to the configuration file you wish to use and the path to the output 80 | directory, where the trained models and the training statistics will be saved. 81 | Namely, to train a new model from scratch, you simply need to run 82 | ``` 83 | python train_network.py path_to_config_yaml path_to_output_dir 84 | ``` 85 | Note tha it is also possible to start from a previously trained model by 86 | specifying the ``--weight_file`` argument, which should contain the path to a 87 | previously trained model. Furthermore, by using the arguments `--model_tag` and 88 | ``--category_tag``, you can also train your network on a particular model (e.g. 89 | a specific plane, car, human etc.) or a specific object category (e.g. planes, 90 | chairs etc.). 91 | 92 | Also make sure to update the ``dataset_directory`` argument in the provided 93 | config file based on the path where your dataset is stored. 94 | 95 | ## Contribution 96 | 97 | Contributions such as bug fixes, bug reports, suggestions etc. are more than 98 | welcome and should be submitted in the form of new issues and/or pull requests 99 | on Github. 100 | 101 | ## License 102 | 103 | Our code is released under the MIT license which practically allows anyone to do anything with it. 104 | MIT license found in the LICENSE file. 105 | 106 | ## Relevant Research 107 | 108 | Below we list some papers that are relevant to our work. 109 | 110 | **Ours:** 111 | - Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks [pdf](https://arxiv.org/pdf/2103.10429.pdf) 112 | - Learning Unsupervised Hierarchical Part Decomposition of 3D Objects from a Single RGB Image [pdf](https://paschalidoud.github.io/) 113 | - Superquadrics Revisited: Learning 3D Shape Parsing beyond Cuboids [pdf](https://arxiv.org/pdf/1904.09970.pdf) [blog](https://autonomousvision.github.io/superquadrics-revisited/) 114 | 115 | **By Others:** 116 | - Learning Shape Abstractions by Assembling Volumetric Primitives [pdf](https://arxiv.org/pdf/1612.00404.pdf) 117 | - 3D-PRNN: Generating Shape Primitives with Recurrent Neural Networks [pdf](https://arxiv.org/abs/1708.01648.pdf) 118 | - Im2Struct: Recovering 3D Shape Structure From a Single RGB Image [pdf](http://openaccess.thecvf.com/content_cvpr_2018/html/Niu_Im2Struct_Recovering_3D_CVPR_2018_paper.pdf) 119 | - Learning shape templates with structured implicit functions [pdf](https://arxiv.org/abs/1904.06447) 120 | - CvxNet: Learnable Convex Decomposition [pdf](https://arxiv.org/abs/1909.05736) 121 | 122 | Below we also list some more papers that are more closely related to superquadrics 123 | - Equal-Distance Sampling of Supercllipse Models [pdf](https://pdfs.semanticscholar.org/3e6f/f812b392f9eb70915b3c16e7bfbd57df379d.pdf) 124 | - Revisiting Superquadric Fitting: A Numerically Stable Formulation [link](https://ieeexplore.ieee.org/document/8128485) 125 | - Segmentation and Recovery of Superquadric Models using Convolutional Neural Networks [pdf](https://arxiv.org/abs/2001.10504) 126 | -------------------------------------------------------------------------------- /config/dfaust_32_prims.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_factory: ImageDatasetWithOccupancyLabels 3 | n_points_in_mesh: 10000 4 | n_points_on_mesh: 1000 5 | equal: true 6 | n_primitives: 6 7 | normalize: false 8 | random_view: false 9 | dataset_directory: "/media/paschalidoud/MyData/onet_data/Humans/D-FAUST/" 10 | splits_file: "../config/d-faust_splits.csv" 11 | dataset_type: "dynamic_faust" 12 | mesh_folder: "mesh_seq" 13 | points_folder: "points_seq" 14 | surface_points_folder: "surface_points_seq" 15 | renderings_folder: "renderings-downsampled" 16 | 17 | feature_extractor: 18 | type: resnet18 19 | freeze_bn: true 20 | primitive_network: space_partitioner 21 | structure_layer: 22 | - translations:att_translation 23 | - constant:sq 24 | - probs:all_ones 25 | - sharpness:constant_sharpness 26 | primitive_layer: 27 | - shapes:att_sq 28 | - sizes:att_size 29 | - translations:att_translation 30 | - rotations:att_rotation 31 | - qos:att_qos 32 | - probs:all_ones 33 | - sharpness:constant_sharpness 34 | loss_type: cluster_coverage 35 | loss: 36 | sharpness: 10.0 37 | epochs: 1500 38 | steps_per_epoch: 500 39 | lr: 0.0001 40 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: hierarchical_primitives 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - cython=0.29.21 8 | - numpy=1.19.1 9 | - pyyaml=5.3.1 10 | - pykdtree=1.3.4 11 | - python=3.8 12 | - pytorch=1.6.0 13 | - torchvision=0.7.0 14 | - pillow=7.2.0 15 | - trimesh=3.8.10 16 | - matplotlib=3.3.1 17 | - wxpython=4.0.7 18 | - pip: 19 | - simple_3dviz==0.2.1 20 | - pyquaternion==0.9.9 21 | -------------------------------------------------------------------------------- /hierarchical_primitives/__init__.py: -------------------------------------------------------------------------------- 1 | """hierarchical_primitives is a model used for representing objects as a binary 2 | tree of primitives.""" 3 | 4 | __author__ = "Despoina Paschalidou" 5 | __copyright__ = "Copyright (c) 2020 Max Planck Institute for Intelligent Systems" 6 | __license__ = "MIT" 7 | __maintainer__ = "Despoina Paschalidou" 8 | __email__ = "despoina.paschalidou@tue.mpg.de" 9 | __url__ = "http://superquadrics.com/" 10 | __version__ = "0.1" 11 | 12 | -------------------------------------------------------------------------------- /hierarchical_primitives/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/hierarchical_primitives/common/__init__.py -------------------------------------------------------------------------------- /hierarchical_primitives/common/base.py: -------------------------------------------------------------------------------- 1 | from .dataset import dataset_factory 2 | from .model_factory import DatasetBuilder 3 | from .parse_splits import ShapeNetSplitsBuilder, DynamicFaustSplitsBuilder, \ 4 | CSVSplitsBuilder 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | def splits_factory(dataset_type): 10 | return { 11 | "shapenet": ShapeNetSplitsBuilder, 12 | "dynamic_faust": DynamicFaustSplitsBuilder, 13 | }[dataset_type] 14 | 15 | 16 | def build_dataset( 17 | config, 18 | model_tags, 19 | category_tags, 20 | keep_splits, 21 | random_subset=1.0, 22 | cache_size=0 23 | ): 24 | dataset_directory = config["data"]["dataset_directory"] 25 | dataset_type = config["data"]["dataset_type"] 26 | train_test_splits_file = config["data"]["splits_file"] 27 | # Create a dataset instance to generate the samples for training 28 | dataset = dataset_factory( 29 | config["data"]["dataset_factory"], 30 | (DatasetBuilder(config) 31 | .with_dataset(dataset_type) 32 | .filter_train_test( 33 | splits_factory(dataset_type)(train_test_splits_file), 34 | keep_splits 35 | ) 36 | .filter_category_tags(category_tags) 37 | .filter_tags(model_tags) 38 | .random_subset(random_subset) 39 | .build(dataset_directory)) 40 | ) 41 | return dataset 42 | 43 | 44 | def build_dataloader( 45 | config, 46 | model_tags, 47 | category_tags, 48 | split, 49 | batch_size, 50 | n_processes, 51 | random_subset=1.0, 52 | cache_size=0 53 | ): 54 | dataset = build_dataset( 55 | config, 56 | model_tags, 57 | category_tags, 58 | split, 59 | random_subset=random_subset, 60 | cache_size=cache_size, 61 | ) 62 | print("Dataset has {} elements".format(len(dataset))) 63 | 64 | dataloader = DataLoader( 65 | dataset, 66 | batch_size=batch_size, 67 | num_workers=n_processes, 68 | shuffle=True 69 | ) 70 | 71 | return dataloader 72 | -------------------------------------------------------------------------------- /hierarchical_primitives/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import Compose 5 | from torchvision.transforms import Normalize as TorchNormalize 6 | 7 | 8 | class BaseDataset(Dataset): 9 | """Dataset is a wrapper for all datasets we have 10 | """ 11 | def __init__(self, dataset_object, transform=None): 12 | """ 13 | Arguments: 14 | --------- 15 | dataset_object: a dataset object that can be either ShapeNetObject 16 | or SurrealBodiesObject 17 | transform: Callable that applies a transform to a sample 18 | """ 19 | self._dataset_object = dataset_object 20 | self._transform = transform 21 | 22 | def __len__(self): 23 | return len(self._dataset_object) 24 | 25 | def __getitem__(self, idx): 26 | datapoint = self._get_item_inner(idx) 27 | if self._transform: 28 | datapoint = self._transform(datapoint) 29 | 30 | return datapoint 31 | 32 | def get_random_datapoint(self): 33 | return self.__getitem__(np.random.choice(len(self))) 34 | 35 | def _get_item_inner(self, idx): 36 | raise NotImplementedError() 37 | 38 | @property 39 | def shapes(self): 40 | raise NotImplementedError() 41 | 42 | @property 43 | def internal_dataset_object(self): 44 | return self._dataset_object 45 | 46 | 47 | class ImageInput(BaseDataset): 48 | def __init__(self, dataset_object, transform=None): 49 | super(ImageInput, self).__init__(dataset_object, transform) 50 | self._image_size = None 51 | 52 | @property 53 | def image_size(self): 54 | if self._image_size is None: 55 | self._image_size = self._dataset_object[0].get_image(0).shape 56 | self._image_size = (self._image_size[2],) + self._image_size[:2] 57 | return self._image_size 58 | 59 | def _get_item_inner(self, idx): 60 | return np.transpose(self._dataset_object[idx].random_image, (2, 0, 1)) 61 | 62 | @property 63 | def shapes(self): 64 | return [self.image_size] 65 | 66 | 67 | class PointsOnMesh(BaseDataset): 68 | """Get random points on the surface of a mesh.""" 69 | def __init__(self, dataset_object, transform=None): 70 | super(PointsOnMesh, self).__init__(dataset_object, transform) 71 | 72 | def _get_item_inner(self, idx): 73 | return self._dataset_object[idx].sample_faces() 74 | 75 | @property 76 | def shapes(self): 77 | return [(self._n_points_from_mesh, 6)] 78 | 79 | 80 | class PointsAndLabels(BaseDataset): 81 | """Get random points in the bbox and label them inside or outside.""" 82 | def __init__(self, dataset_object, transform=None): 83 | super(PointsAndLabels, self).__init__(dataset_object, transform) 84 | 85 | def _get_item_inner(self, idx): 86 | return self._dataset_object[idx].sample_points() 87 | 88 | @property 89 | def shapes(self): 90 | return [ 91 | (self._n_points_from_mesh, 3), 92 | (self._n_points_from_mesh, 1), 93 | (self._n_points_from_mesh, 1) 94 | ] 95 | 96 | 97 | class DatasetCollection(BaseDataset): 98 | """Implement a pytorch Dataset with a list of BaseDataset objects.""" 99 | def __init__(self, *datasets): 100 | super(DatasetCollection, self).__init__( 101 | datasets[0]._dataset_object, 102 | None 103 | ) 104 | self._datasets = datasets 105 | 106 | def _get_item_inner(self, idx): 107 | def flatten(x): 108 | if not isinstance(x, (tuple, list)): 109 | return [x] 110 | return [ 111 | xij 112 | for xi in x 113 | for xij in flatten(xi) 114 | ] 115 | 116 | return flatten([d[idx] for d in self._datasets]) 117 | 118 | @property 119 | def shapes(self): 120 | return sum( 121 | (d.shapes for d in self._datasets), 122 | [] # initializer 123 | ) 124 | 125 | 126 | class DatasetWithTags(BaseDataset): 127 | """Implement a Dataset with tags.""" 128 | def __init__(self, dataset): 129 | super(DatasetWithTags, self).__init__( 130 | dataset._dataset_object, None 131 | ) 132 | self._dataset = dataset 133 | 134 | def _get_item_inner(self, idx): 135 | return self._dataset[idx] + [self._dataset._dataset_object[idx].tag] 136 | 137 | 138 | class ToTensor(object): 139 | """Convert ndarrays in sample to Tensors.""" 140 | def __call__(self, x): 141 | if isinstance(x, (tuple, list)): 142 | return [torch.from_numpy(xi) for xi in x] 143 | return torch.from_numpy(x) 144 | 145 | 146 | class Normalize(object): 147 | """Normalize image based based on ImageNet.""" 148 | def __call__(self, x): 149 | X = x.float() 150 | 151 | normalize = TorchNormalize( 152 | mean=[0.485, 0.456, 0.406], 153 | std=[0.229, 0.224, 0.225] 154 | ) 155 | X = X.float() / 255.0 156 | return normalize(X) 157 | 158 | 159 | def dataset_factory(name, dataset_object): 160 | to_tensor = ToTensor() 161 | norm = Normalize() 162 | on_surface = PointsOnMesh(dataset_object, transform=to_tensor) 163 | in_bbox = PointsAndLabels(dataset_object, transform=to_tensor) 164 | image_input = ImageInput( 165 | dataset_object, 166 | transform=Compose([to_tensor, norm]) 167 | ) 168 | 169 | return { 170 | "ImageDataset": DatasetCollection(image_input, on_surface), 171 | "ImageDatasetWithOccupancyLabels": DatasetCollection( 172 | image_input, 173 | in_bbox 174 | ), 175 | "ImageDatasetForChamferAndIOU": DatasetCollection( 176 | image_input, 177 | in_bbox, 178 | on_surface 179 | ) 180 | }[name] 181 | -------------------------------------------------------------------------------- /hierarchical_primitives/common/mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | 5 | class Mesh(object): 6 | def __init__(self, mesh, normalize=False): 7 | self.mesh = mesh 8 | # Normalize points such that they are in the unit cube 9 | if normalize: 10 | bbox = self.mesh.bounding_box.bounds 11 | # Compute location and scale 12 | loc = (bbox[0] + bbox[1]) / 2 13 | scale = (bbox[1] - bbox[0]).max() # / (1 - 0.05) 14 | 15 | # Transform input mesh 16 | self.mesh.apply_translation(-loc) 17 | self.mesh.apply_scale(1 / scale) 18 | 19 | # Make sure that the input meshes are watertight 20 | assert self.mesh.is_watertight 21 | 22 | self._vertices = None 23 | self._vertex_normals = None 24 | self._faces = None 25 | self._face_normals = None 26 | 27 | @property 28 | def vertices(self): 29 | if self._vertices is None: 30 | self._vertices = np.array(self.mesh.vertices) 31 | return self._vertices 32 | 33 | @property 34 | def vertex_normals(self): 35 | if self._vertex_normals is None: 36 | self._vertex_normals = np.array(self.mesh.vertex_normals) 37 | return self._vertex_normals 38 | 39 | @property 40 | def faces(self): 41 | if self._faces is None: 42 | self._faces = np.array(self.mesh.faces) 43 | return self._faces 44 | 45 | @property 46 | def face_normals(self): 47 | if self._face_normals is None: 48 | self._face_normals = np.array(self.mesh.face_normals) 49 | return self._face_normals 50 | 51 | def sample_faces(self, N=10000): 52 | P, t = trimesh.sample.sample_surface(self.mesh, N) 53 | return np.hstack([ 54 | P, self.face_normals[t, :] 55 | ]) 56 | 57 | @classmethod 58 | def from_file(cls, filename, normalize): 59 | return cls(trimesh.load(filename, process=False), normalize) 60 | 61 | 62 | def read_mesh_file(filename, normalize): 63 | return Mesh.from_file(filename, normalize) 64 | -------------------------------------------------------------------------------- /hierarchical_primitives/common/parse_splits.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | 4 | 5 | class SplitsBuilder(object): 6 | def __init__(self, train_test_splits_file): 7 | self._train_test_splits_file = train_test_splits_file 8 | self._splits = {} 9 | 10 | def train_split(self): 11 | return self.splits["train"] 12 | 13 | def test_split(self): 14 | return self.splits["test"] 15 | 16 | def val_split(self): 17 | return self.splits["val"] 18 | 19 | def _parse_train_test_splits_file(self): 20 | with open(self._train_test_splits_file, "r") as f: 21 | data = [row for row in csv.reader(f)] 22 | return np.array(data) 23 | 24 | def get_splits(self, keep_splits=["train, val"]): 25 | if not isinstance(keep_splits, list): 26 | keep_splits = [keep_splits] 27 | # Return only the split 28 | s = [] 29 | for ks in keep_splits: 30 | s.extend(self._parse_split_file()[ks]) 31 | return s 32 | 33 | 34 | class CSVSplitsBuilder(SplitsBuilder): 35 | def _parse_split_file(self): 36 | if not self._splits: 37 | data = self._parse_train_test_splits_file() 38 | for s in ["train", "test", "val"]: 39 | self._splits[s] = [r[0] for r in data if r[1] == s] 40 | return self._splits 41 | 42 | 43 | class DynamicFaustSplitsBuilder(SplitsBuilder): 44 | def _parse_split_file(self): 45 | if not self._splits: 46 | data = self._parse_train_test_splits_file() 47 | header = data[0] 48 | for s in ["train", "test", "val"]: 49 | # Only keep the data for the current split 50 | d = data[data[:, -1] == s] 51 | tags = [ 52 | "{}:{}".format(oi, mi) 53 | for oi, mi in zip(d[:, 0], d[:, 1]) 54 | ] 55 | self._splits[s] = tags 56 | 57 | return self._splits 58 | 59 | 60 | class ShapeNetSplitsBuilder(SplitsBuilder): 61 | def _parse_split_file(self): 62 | if not self._splits: 63 | data = self._parse_train_test_splits_file() 64 | header = data[0] 65 | for s in ["train", "test", "val"]: 66 | # Only keep the data for the current split 67 | d = data[data[:, -1] == s] 68 | tags = [ 69 | "{}:{}".format(oi, mi) 70 | for oi, mi in zip(d[:, 1], d[:, 3]) 71 | ] 72 | self._splits[s] = tags 73 | 74 | return self._splits 75 | -------------------------------------------------------------------------------- /hierarchical_primitives/equal_distance_sampler_sq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | 6 | from .fast_sampler import fast_sample, fast_sample_on_batch 7 | 8 | 9 | class EqualDistanceSamplerSQ(object): 10 | 11 | def __init__(self, n_samples): 12 | self._n_samples = n_samples 13 | 14 | @property 15 | def n_samples(self): 16 | return self._n_samples 17 | 18 | def sample(self, **kwargs): 19 | return fast_sample( 20 | a1=kwargs.get("a1", 0.25), 21 | a2=kwargs.get("a2", 0.25), 22 | a3=kwargs.get("a3", 0.25), 23 | e1=kwargs.get("eps1", 0.25), 24 | e2=kwargs.get("eps2", 0.25), 25 | N=self.n_samples 26 | ) 27 | 28 | def sample_on_batch(self, shapes, epsilons): 29 | return fast_sample_on_batch( 30 | shapes, 31 | epsilons, 32 | self.n_samples 33 | ) 34 | 35 | 36 | def fexp(x, p): 37 | return np.sign(x)*(np.abs(x)**p) 38 | 39 | 40 | def sq_surface(a1, a2, a3, e1, e2, eta, omega): 41 | x = a1 * fexp(np.cos(eta), e1) * fexp(np.cos(omega), e2) 42 | y = a2 * fexp(np.cos(eta), e1) * fexp(np.sin(omega), e2) 43 | z = a3 * fexp(np.sin(eta), e1) 44 | return x, y, z 45 | 46 | def bending(x, y, z, a, gamma): 47 | b = np.arctan2(y, x) 48 | r = np.sqrt(x**2 + y**2) * np.cos(a-b) 49 | inv_k = gamma / z 50 | R = inv_k - (inv_k - r) * np.cos(gamma) 51 | x = x + (R-r)*np.cos(a) 52 | y = y + (R-r)*np.sin(a) 53 | z = (inv_k - r)*np.sin(gamma) 54 | return x, y, z 55 | 56 | def bending_inv(x, y, z, a, gamma): 57 | R = np.sqrt(x**2 + y**2) 58 | t1 = np.arctan2(y, x) 59 | R = R * np.cos(a - t1) 60 | inv_k = gamma / z 61 | t2 = inv_k - R 62 | r = inv_k - np.sqrt(z**2 + t2**2) 63 | gamma = np.arctan2(z, t2) 64 | 65 | x = x - (R-r)*np.cos(a) 66 | y = y - (R-r)*np.sin(a) 67 | z = inv_k * gamma 68 | return x, y, z 69 | 70 | 71 | def visualize_points_on_sq_mesh(e, **kwargs): 72 | print(kwargs) 73 | e1 = kwargs.get("eps1", 0.25) 74 | e2 = kwargs.get("eps2", 0.25) 75 | a1 = kwargs.get("a1", 0.25) 76 | a2 = kwargs.get("a2", 0.25) 77 | a3 = kwargs.get("a3", 0.25) 78 | Kx = kwargs.get("Kx", 0.0) 79 | Ky = kwargs.get("Ky", 0.0) 80 | a = kwargs.get("a", 0.1) 81 | gamma = kwargs.get("gamma", 0.1) 82 | 83 | shapes = np.array([[[a1, a2, a3]]], dtype=np.float32) 84 | epsilons = np.array([[[e1, e2]]], dtype=np.float32) 85 | etas, omegas = e.sample_on_batch(shapes, epsilons) 86 | x, y, z = sq_surface(a1, a2, a3, e1, e2, etas.ravel(), omegas.ravel()) 87 | 88 | # Apply tapering 89 | # fx = Kx * z / a3 + 1 90 | # fy = Ky * z / a3 + 1 91 | # fz = 1 92 | 93 | # x = x * fx 94 | # y = y * fy 95 | # z = z * fz 96 | x, y, z = bending_inv(x, y, z, a, gamma) 97 | 98 | import matplotlib.pyplot as plt 99 | from mpl_toolkits.mplot3d import Axes3D 100 | fig = plt.figure() 101 | ax = fig.add_subplot(111, projection='3d') 102 | ax.scatter(x, y, z) 103 | # ax.set_xlim([-1.25, 1.25]) 104 | # ax.set_ylim([-1.25, 1.25]) 105 | # ax.set_zlim([-1.25, 1.25]) 106 | plt.show() 107 | 108 | 109 | if __name__ == "__main__": 110 | e = EqualDistanceSamplerSQ(600) 111 | # etas, omegas = e.sample(**{ 112 | # 'a1': 0.2074118, 113 | # 'a2': 0.0926611, 114 | # 'a3': 0.2323654, 115 | # 'eps1': 0.20715195, 116 | # 'eps2': 1.6855394 117 | # }) 118 | visualize_points_on_sq_mesh(e, **{ 119 | #'a1': 0.2074118, 120 | #'a2': 0.0926611, 121 | #'a3': 0.2323654, 122 | 'a1': 0.15, 123 | 'a2': 0.15, 124 | 'a3': 0.35, 125 | 'eps1': 0.20715195, 126 | 'eps2': 1.3855394, 127 | 'Kx': 0.0, 128 | 'Ky': 0.0, 129 | 'a': 1.0, 130 | 'gamma': 1.0 131 | # 'k': 0.01 132 | }) 133 | -------------------------------------------------------------------------------- /hierarchical_primitives/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/hierarchical_primitives/external/__init__.py -------------------------------------------------------------------------------- /hierarchical_primitives/external/libmesh/README.md: -------------------------------------------------------------------------------- 1 | ## License 2 | 3 | License for source code corresponding to: 4 | 5 | Mescheder, Lars and Oechsle, Michael and Niemeyer, Michael and Nowozin, Sebastian and Geiger, Andreas. **Occupancy Networks: Learning 3D Reconstruction in Function Space**, CVPR 2019 6 | 7 | Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 14 | -------------------------------------------------------------------------------- /hierarchical_primitives/external/libmesh/__init__.py: -------------------------------------------------------------------------------- 1 | from .inside_mesh import ( 2 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 3 | ) 4 | 5 | __all__ = [ 6 | check_mesh_contains, MeshIntersector, TriangleIntersector2d 7 | ] 8 | -------------------------------------------------------------------------------- /hierarchical_primitives/external/libmesh/inside_mesh.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/autonomousvision/occupancy_networks/blob/ddb2908f96de9c0c5a30c093f2a701878ffc1f4a/im2mesh/utils/libmesh/inside_mesh.py 2 | """ 3 | import numpy as np 4 | 5 | from .triangle_hash import TriangleHash as _TriangleHash 6 | 7 | 8 | def check_mesh_contains(mesh, points, hash_resolution=512): 9 | intersector = MeshIntersector(mesh, hash_resolution) 10 | contains = intersector.query(points) 11 | return contains 12 | 13 | 14 | class MeshIntersector: 15 | def __init__(self, mesh, resolution=512): 16 | triangles = mesh.vertices[mesh.faces].astype(np.float64) 17 | n_tri = triangles.shape[0] 18 | 19 | self.resolution = resolution 20 | self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) 21 | self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) 22 | # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 23 | self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) 24 | self.translate = 0.5 - self.scale * self.bbox_min 25 | 26 | self._triangles = triangles = self.rescale(triangles) 27 | # assert(np.allclose(triangles.reshape(-1, 3).min(0), 0.5)) 28 | # assert(np.allclose(triangles.reshape(-1, 3).max(0), resolution - 0.5)) 29 | 30 | triangles2d = triangles[:, :, :2] 31 | self._tri_intersector2d = TriangleIntersector2d( 32 | triangles2d, resolution) 33 | 34 | def query(self, points): 35 | # Rescale points 36 | points = self.rescale(points) 37 | 38 | # placeholder result with no hits we'll fill in later 39 | contains = np.zeros(len(points), dtype=np.bool) 40 | 41 | # cull points outside of the axis aligned bounding box 42 | # this avoids running ray tests unless points are close 43 | inside_aabb = np.all( 44 | (0 <= points) & (points <= self.resolution), axis=1) 45 | if not inside_aabb.any(): 46 | return contains 47 | 48 | # Only consider points inside bounding box 49 | mask = inside_aabb 50 | points = points[mask] 51 | 52 | # Compute intersection depth and check order 53 | points_indices, tri_indices = self._tri_intersector2d.query(points[:, :2]) 54 | 55 | triangles_intersect = self._triangles[tri_indices] 56 | points_intersect = points[points_indices] 57 | 58 | depth_intersect, abs_n_2 = self.compute_intersection_depth( 59 | points_intersect, triangles_intersect) 60 | 61 | # Count number of intersections in both directions 62 | smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 63 | bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 64 | points_indices_0 = points_indices[smaller_depth] 65 | points_indices_1 = points_indices[bigger_depth] 66 | 67 | nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) 68 | nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) 69 | 70 | # Check if point contained in mesh 71 | contains1 = (np.mod(nintersect0, 2) == 1) 72 | contains2 = (np.mod(nintersect1, 2) == 1) 73 | if (contains1 != contains2).any(): 74 | print('Warning: contains1 != contains2 for some points.') 75 | contains[mask] = (contains1 & contains2) 76 | return contains 77 | 78 | def compute_intersection_depth(self, points, triangles): 79 | t1 = triangles[:, 0, :] 80 | t2 = triangles[:, 1, :] 81 | t3 = triangles[:, 2, :] 82 | 83 | v1 = t3 - t1 84 | v2 = t2 - t1 85 | # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) 86 | # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True) 87 | 88 | normals = np.cross(v1, v2) 89 | alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) 90 | 91 | n_2 = normals[:, 2] 92 | t1_2 = t1[:, 2] 93 | s_n_2 = np.sign(n_2) 94 | abs_n_2 = np.abs(n_2) 95 | 96 | mask = (abs_n_2 != 0) 97 | 98 | depth_intersect = np.full(points.shape[0], np.nan) 99 | depth_intersect[mask] = \ 100 | t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] 101 | 102 | # Test the depth: 103 | # TODO: remove and put into tests 104 | # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1) 105 | # alpha = (normals * t1).sum(-1) 106 | # mask = (depth_intersect == depth_intersect) 107 | # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1), 108 | # alpha[mask])) 109 | return depth_intersect, abs_n_2 110 | 111 | def rescale(self, array): 112 | array = self.scale * array + self.translate 113 | return array 114 | 115 | 116 | class TriangleIntersector2d: 117 | def __init__(self, triangles, resolution=128): 118 | self.triangles = triangles 119 | self.tri_hash = _TriangleHash(triangles, resolution) 120 | 121 | def query(self, points): 122 | point_indices, tri_indices = self.tri_hash.query(points) 123 | point_indices = np.array(point_indices, dtype=np.int64) 124 | tri_indices = np.array(tri_indices, dtype=np.int64) 125 | points = points[point_indices] 126 | triangles = self.triangles[tri_indices] 127 | mask = self.check_triangles(points, triangles) 128 | point_indices = point_indices[mask] 129 | tri_indices = tri_indices[mask] 130 | return point_indices, tri_indices 131 | 132 | def check_triangles(self, points, triangles): 133 | contains = np.zeros(points.shape[0], dtype=np.bool) 134 | A = triangles[:, :2] - triangles[:, 2:] 135 | A = A.transpose([0, 2, 1]) 136 | y = points - triangles[:, 2] 137 | 138 | detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] 139 | 140 | mask = (np.abs(detA) != 0.) 141 | A = A[mask] 142 | y = y[mask] 143 | detA = detA[mask] 144 | 145 | s_detA = np.sign(detA) 146 | abs_detA = np.abs(detA) 147 | 148 | u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA 149 | v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA 150 | 151 | sum_uv = u + v 152 | contains[mask] = ( 153 | (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) 154 | & (0 < sum_uv) & (sum_uv < abs_detA) 155 | ) 156 | return contains 157 | -------------------------------------------------------------------------------- /hierarchical_primitives/external/libmesh/triangle_hash.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language=c++ 2 | import numpy as np 3 | cimport numpy as np 4 | cimport cython 5 | from libcpp.vector cimport vector 6 | from libc.math cimport floor, ceil 7 | 8 | cdef class TriangleHash: 9 | cdef vector[vector[int]] spatial_hash 10 | cdef int resolution 11 | 12 | def __cinit__(self, double[:, :, :] triangles, int resolution): 13 | self.spatial_hash.resize(resolution * resolution) 14 | self.resolution = resolution 15 | self._build_hash(triangles) 16 | 17 | @cython.boundscheck(False) # Deactivate bounds checking 18 | @cython.wraparound(False) # Deactivate negative indexing. 19 | cdef int _build_hash(self, double[:, :, :] triangles): 20 | assert(triangles.shape[1] == 3) 21 | assert(triangles.shape[2] == 2) 22 | 23 | cdef int n_tri = triangles.shape[0] 24 | cdef int bbox_min[2] 25 | cdef int bbox_max[2] 26 | 27 | cdef int i_tri, j, x, y 28 | cdef int spatial_idx 29 | 30 | for i_tri in range(n_tri): 31 | # Compute bounding box 32 | for j in range(2): 33 | bbox_min[j] = min( 34 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 35 | ) 36 | bbox_max[j] = max( 37 | triangles[i_tri, 0, j], triangles[i_tri, 1, j], triangles[i_tri, 2, j] 38 | ) 39 | bbox_min[j] = min(max(bbox_min[j], 0), self.resolution - 1) 40 | bbox_max[j] = min(max(bbox_max[j], 0), self.resolution - 1) 41 | 42 | # Find all voxels where bounding box intersects 43 | for x in range(bbox_min[0], bbox_max[0] + 1): 44 | for y in range(bbox_min[1], bbox_max[1] + 1): 45 | spatial_idx = self.resolution * x + y 46 | self.spatial_hash[spatial_idx].push_back(i_tri) 47 | 48 | @cython.boundscheck(False) # Deactivate bounds checking 49 | @cython.wraparound(False) # Deactivate negative indexing. 50 | cpdef query(self, double[:, :] points): 51 | assert(points.shape[1] == 2) 52 | cdef int n_points = points.shape[0] 53 | 54 | cdef vector[int] points_indices 55 | cdef vector[int] tri_indices 56 | # cdef int[:] points_indices_np 57 | # cdef int[:] tri_indices_np 58 | 59 | cdef int i_point, k, x, y 60 | cdef int spatial_idx 61 | 62 | for i_point in range(n_points): 63 | x = int(points[i_point, 0]) 64 | y = int(points[i_point, 1]) 65 | if not (0 <= x < self.resolution and 0 <= y < self.resolution): 66 | continue 67 | 68 | spatial_idx = self.resolution * x + y 69 | for i_tri in self.spatial_hash[spatial_idx]: 70 | points_indices.push_back(i_point) 71 | tri_indices.push_back(i_tri) 72 | 73 | points_indices_np = np.zeros(points_indices.size(), dtype=np.int32) 74 | tri_indices_np = np.zeros(tri_indices.size(), dtype=np.int32) 75 | 76 | cdef int[:] points_indices_view = points_indices_np 77 | cdef int[:] tri_indices_view = tri_indices_np 78 | 79 | for k in range(points_indices.size()): 80 | points_indices_view[k] = points_indices[k] 81 | 82 | for k in range(tri_indices.size()): 83 | tri_indices_view[k] = tri_indices[k] 84 | 85 | return points_indices_np, tri_indices_np 86 | -------------------------------------------------------------------------------- /hierarchical_primitives/fast_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._sampler import step_eta, step_omega, collect_etas, collect_omegas, \ 3 | fast_sample, fast_sample_on_batch 4 | -------------------------------------------------------------------------------- /hierarchical_primitives/fast_sampler/sampling.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | extern "C" { 10 | #include "sampling.hpp" 11 | } 12 | 13 | 14 | const float pi = std::acos(-1); 15 | const float pi_2 = pi/2; 16 | 17 | 18 | class prng { 19 | public: 20 | prng(int seed) : gen(seed), dis(0, 1) {} 21 | float operator()() { 22 | return dis(gen); 23 | } 24 | 25 | private: 26 | std::mt19937 gen; 27 | std::uniform_real_distribution dis; 28 | }; 29 | 30 | 31 | struct recursion_params { 32 | float A[2]; 33 | float B[2]; 34 | float theta_a; 35 | float theta_b; 36 | int N; 37 | int offset; 38 | 39 | recursion_params( 40 | float a[2], 41 | float b[2], 42 | float t_a, 43 | float t_b, 44 | int n, 45 | int o 46 | ) { 47 | A[0] = a[0]; 48 | A[1] = a[1]; 49 | B[0] = b[0]; 50 | B[1] = b[1]; 51 | theta_a = t_a; 52 | theta_b = t_b; 53 | N = n; 54 | offset = o; 55 | } 56 | }; 57 | 58 | 59 | inline float fexp(float x, float p) { 60 | return std::copysign(std::pow(std::abs(x), p), x); 61 | } 62 | 63 | 64 | inline void xy(float theta, float a1, float a2, float e, float C[2]) { 65 | C[0] = a1 * fexp(std::cos(theta), e); 66 | C[1] = a2 * fexp(std::sin(theta), e); 67 | } 68 | 69 | inline float distance(float A[2], float B[2]) { 70 | float d1 = A[0]-B[0]; 71 | float d2 = A[1]-B[1]; 72 | return std::sqrt(d1*d1 + d2*d2); 73 | } 74 | 75 | 76 | void sample_superellipse_divide_conquer( 77 | float a1, 78 | float a2, 79 | float e, 80 | float theta_a, 81 | float theta_b, 82 | std::vector &buffer, 83 | std::vector &stack 84 | ) { 85 | float A[2], B[2], C[2], theta, dA, dB; 86 | int nA, nB; 87 | 88 | xy(theta_a, a1, a2, e, A); 89 | xy(theta_b, a1, a2, e, B); 90 | buffer[0] = theta_a; 91 | stack.emplace_back(A, B, theta_a, theta_b, buffer.size()-2, 1); 92 | 93 | while (stack.size() > 0) { 94 | recursion_params params = stack.back(); 95 | stack.pop_back(); 96 | 97 | if (params.N <= 0) { 98 | continue; 99 | } 100 | 101 | theta = (params.theta_a + params.theta_b)/2; 102 | xy(theta, a1, a2, e, C); 103 | dA = distance(params.A, C); 104 | dB = distance(C, params.B); 105 | nA = static_cast(std::round((dA/(dA+dB))*(params.N-1))); 106 | nB = params.N - nA - 1; 107 | 108 | buffer[nA+params.offset] = theta; 109 | 110 | stack.emplace_back( 111 | params.A, C, 112 | params.theta_a, theta, 113 | nA, 114 | params.offset 115 | ); 116 | stack.emplace_back( 117 | C, params.B, 118 | theta, params.theta_b, 119 | nB, 120 | params.offset + nA + 1 121 | ); 122 | } 123 | 124 | buffer[buffer.size()-1] = theta_b; 125 | } 126 | 127 | 128 | void sample_etas( 129 | std::function rand, 130 | float a1a2, 131 | float e1, 132 | std::vector &buffer, 133 | std::vector &cdf, 134 | float *etas, 135 | int N 136 | ) { 137 | const float smoothing = 0.001; 138 | float s; 139 | 140 | // Make the sampling distribution's CDF 141 | cdf[0] = smoothing; 142 | for (unsigned int i=1; i buffer(buffer_size); 172 | std::vector eta_cdf(buffer_size); 173 | std::vector stack; 174 | 175 | for (int b=0; b(rand()*buffer_size)]; 212 | } 213 | } 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /hierarchical_primitives/fast_sampler/sampling.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_HPP_ 2 | #define _SAMPLING_HPP_ 3 | 4 | 5 | void sample_on_batch( 6 | float *shapes, 7 | float *epsilons, 8 | float *etas, 9 | float *omegas, 10 | int B, 11 | int M, 12 | int N, 13 | int buffer_size, 14 | int seed 15 | ); 16 | 17 | 18 | #endif // _SAMPLING_HPP_ 19 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | 4 | from ..utils.value_registry import ValueRegistry 5 | from .implicit_surface_loss import implicit_surface_loss 6 | from .loss_functions import euclidean_dual_loss 7 | from .coverage import cluster_coverage_with_reconstruction 8 | from .implicit_surface_loss_with_partition import implicit_surface_loss_with_partition 9 | from .regularizers import get as _get_regularizer_function 10 | 11 | 12 | def _get_loss_function(loss, sampler, options): 13 | if loss == "euclidean_dual_loss": 14 | return partial( 15 | euclidean_dual_loss, 16 | options=options, 17 | sampler=sampler 18 | ) 19 | elif loss == "implicit_surface_loss": 20 | return partial( 21 | implicit_surface_loss, 22 | options=options 23 | ) 24 | elif loss == "implicit_surface_loss_with_chamfer_loss": 25 | return partial( 26 | implicit_surface_loss_with_chamfer_loss, 27 | options=options, 28 | sampler=sampler 29 | ) 30 | elif loss == "cluster_coverage": 31 | return partial( 32 | cluster_coverage_with_reconstruction, 33 | options=options 34 | ) 35 | elif loss == "implicit_surface_loss_with_partition": 36 | return partial( 37 | implicit_surface_loss_with_partition, 38 | options=options 39 | ) 40 | 41 | 42 | def get_loss(loss, regularizers, sampler, options): 43 | loss_fn = _get_loss_function(loss, sampler, options) 44 | regularizers = [ 45 | (_get_regularizer_function(regularizer, options), weight) 46 | for regularizer, weight in regularizers 47 | ] 48 | 49 | def inner(y_hat, y_target): 50 | ValueRegistry.get_instance("loss_intermediate_values").clear() 51 | loss = loss_fn(y_hat, y_target) 52 | for regularizer, weight in regularizers: 53 | loss = loss + weight*regularizer(y_hat) 54 | 55 | return loss 56 | 57 | return inner 58 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/chamfer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..primitives import deform, transform_to_world_coordinates_system, \ 4 | transform_to_primitives_centric_system, inside_outside_function 5 | 6 | 7 | def sample_points_on_predicted_shape( 8 | prim_params, 9 | sampler, 10 | sharpness, 11 | use_cuboid=False 12 | ): 13 | """ 14 | Arguments: 15 | ---------- 16 | prim_params: PrimitiveParameters object containing the predictions 17 | of the network 18 | """ 19 | # Declare some variables 20 | B = prim_params.batch_size 21 | M = prim_params.n_primitives 22 | S = sampler.n_samples 23 | 24 | probs = prim_params.probs 25 | translations = prim_params.translations_r 26 | rotations = prim_params.rotations_r 27 | alphas = prim_params.sizes_r 28 | epsilons = prim_params.shapes_r 29 | if prim_params.deformations is None: 30 | tapering_params = probs.new_zeros(B, M, 2) 31 | else: 32 | tapering_params = prim_params.tapering_params_r 33 | # Get the coordinates of the sampled points on the surfaces of the SQs, 34 | # with size BxMxSx3 35 | X_SQ, _ = sampler.sample_points_on_primitive(use_cuboid, alphas, epsilons) 36 | X_SQ = deform(X_SQ, alphas, tapering_params) 37 | # Make sure that everything has the right size 38 | assert X_SQ.shape == (B, M, S, 3) 39 | 40 | # Transform SQs to world coordinates 41 | X_SQ_world = transform_to_world_coordinates_system( 42 | X_SQ, 43 | translations, 44 | rotations 45 | ) 46 | # Make sure that everything has the right size 47 | assert X_SQ_world.shape == (B, M, S, 3) 48 | # Transform the points on the SQs to the other SQs 49 | X_SQ_transformed = transform_to_primitives_centric_system( 50 | X_SQ_world.view(B, M*S, 3), 51 | translations.view(B, M, 3), 52 | rotations.view(B, M, 4) 53 | ) 54 | assert X_SQ_transformed.shape == (B, M*S, M, 3) 55 | 56 | # Compute the inside outside function for every point on every primitive to 57 | # every other primitive 58 | F = inside_outside_function( 59 | X_SQ_transformed, 60 | alphas.detach(), # numerical reasons for the detach ;) 61 | epsilons.view(B, M, 2).detach() 62 | ) 63 | assert F.shape == (B, M*S, M) 64 | F = F.view(B, M, S, M) 65 | F = torch.sigmoid(sharpness*(1.0 - F)) 66 | f = torch.max(F, dim=-1)[0] 67 | assert f.shape == (B, M, S) 68 | isolevel = F.new_tensor(0.49) 69 | mask = f <= isolevel 70 | 71 | assert B == 1 72 | X_SQ_mask = X_SQ_world[mask] 73 | assert len(X_SQ_mask.shape) == 2 74 | assert X_SQ_mask.shape[1] == 3 75 | return X_SQ_mask 76 | 77 | 78 | def chamfer_loss( 79 | y_hat, 80 | y_target, 81 | sampler, 82 | options, 83 | use_l1=False 84 | ): 85 | """ 86 | Implement the loss function using the implicit surface function of SQs 87 | 88 | Arguments: 89 | ---------- 90 | y_hat: List of Tensors containing the predictions of the network 91 | y_target: Tensor with size BxNx6 with N points from the target object 92 | and their corresponding normals 93 | options: A dictionary with various options 94 | 95 | Returns: 96 | ------- 97 | the loss 98 | """ 99 | sharpness = options.get("sharpness", 5.0) 100 | use_cuboid = options.get("use_cuboid", False) 101 | 102 | gt_points = y_target[:, :, :3] 103 | assert gt_points.shape[-1] == 3 104 | 105 | # Declare some variables 106 | N = gt_points.shape[1] # number of points per sample 107 | 108 | X_SQ = sample_points_on_predicted_shape( 109 | y_hat, sampler, sharpness, use_cuboid 110 | ) 111 | V = torch.abs(X_SQ.unsqueeze(0) - gt_points[0].unsqueeze(1)) 112 | assert V.shape == (N, X_SQ.shape[0], 3) 113 | 114 | if use_l1: 115 | D = torch.sum(V, -1) 116 | else: 117 | D = torch.sum((V)**2, -1) 118 | 119 | D_pcl_to_prim = D.min(-1)[0].mean() 120 | D_prim_to_pcl = D.min(0)[0].mean() 121 | loss = D_pcl_to_prim + D_prim_to_pcl 122 | 123 | # Sum up the regularization terms 124 | return loss.mean() 125 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def loss_weights_on_depths(loss_type, start_value, end_value, N): 5 | start_value = float(start_value) 6 | end_value = float(end_value) 7 | 8 | if loss_type == "equal": 9 | return [start_value]*N 10 | elif loss_type == "linear": 11 | return np.linspace(start_value, end_value, N).tolist() 12 | 13 | 14 | def clustering_distance(clustering_type, dists, precision_matrix=None): 15 | if clustering_distance == "euclidean": 16 | return torch.sqrt(torch.einsum( 17 | "bnmj,bmjk,bnmk->bnm", 18 | [dists, percision_matrix, dists] 19 | )) 20 | elif clustering_distance == "mahalanobis": 21 | return dists.sum(-1) 22 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/coverage.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from .regularizers import overlapping_on_depths 5 | from ..networks.primitive_parameters import PrimitiveParameters 6 | from ..primitives import get_implicit_surface, _compute_accuracy_and_recall 7 | from ..utils.stats_logger import StatsLogger 8 | from ..utils.value_registry import ValueRegistry 9 | from ..utils.metrics import compute_iou 10 | 11 | 12 | def cluster_coverage_with_reconstruction(prim_params, y_target, options): 13 | def _coverage_inner(p, pparent, X, labels): 14 | M = p.n_primitives # number of primitives 15 | B, N, _ = X.shape 16 | translations = p.translations_r 17 | splits = 2 if M > 1 else 1 18 | assert labels.shape == (B, N, M//splits) 19 | 20 | # First assign points from the labels to each of the siblings 21 | dists = ((X.unsqueeze(2) - translations.unsqueeze(1))**2).sum(-1) 22 | assert dists.shape == (B, N, M) 23 | if M > 1: 24 | assign_left = (dists[:, :, ::2] < dists[:, :, 1::2]).float() 25 | assign_right = 1-assign_left 26 | assignments = torch.stack([ 27 | assign_left * labels, 28 | assign_right * labels 29 | ], dim=-1).view(B, N, M) 30 | else: 31 | assignments = labels 32 | assert assignments.shape == (B, N, M) 33 | assert assignments.sum(-1).max().item() == 1 34 | 35 | # Now compute the sum of squared distances as the loss 36 | loss = (dists * assignments).sum(-1).mean() 37 | 38 | return loss, assignments 39 | 40 | def _fit_shape_inner(pr, X, X_labels, X_weights): 41 | M = pr.n_primitives # number of primitives 42 | B, N, _ = X.shape 43 | 44 | assert X_labels.shape == (B, N, 1) 45 | assert X_weights.shape == (B, N, 1) 46 | 47 | translations = pr.translations_r 48 | rotations = pr.rotations_r 49 | alphas = pr.sizes_r 50 | epsilons = pr.shapes_r 51 | sharpness = pr.sharpness_r 52 | 53 | # Compute the implicit surface function for each primitive 54 | F, _ = get_implicit_surface( 55 | X, translations, rotations, alphas, epsilons, sharpness 56 | ) 57 | assert F.shape == (B, N, M) 58 | 59 | f = torch.max(F, dim=-1, keepdim=True)[0] 60 | sm = F.new_tensor(1e-6) 61 | t1 = torch.log(torch.max(f, sm)) 62 | t2 = torch.log(torch.max(1.0 - f, sm)) 63 | cross_entropy_loss = - X_labels * t1 - (1.0 - X_labels) * t2 64 | loss = X_weights * cross_entropy_loss 65 | 66 | return loss.mean(), F 67 | 68 | def _fit_parent_inner(p, X, labels, X_weights, F): 69 | M = p.n_primitives # number of primitives 70 | B, N, _ = X.shape 71 | 72 | assert labels.shape == (B, N, M) 73 | 74 | translations = p.translations_r 75 | rotations = p.rotations_r 76 | alphas = p.sizes_r 77 | epsilons = p.shapes_r 78 | sharpness = p.sharpness_r 79 | 80 | sm = F.new_tensor(1e-6) 81 | t1 = labels * torch.log(torch.max(F, sm)) 82 | t2 = (1-labels) * torch.log(torch.max(1-F, sm)) 83 | ce = - t1 - t2 84 | # 5 is a very important number that is necessary for the code to work!!! 85 | # Do not change!! (This is there to avoid having empty primitives :-)) 86 | loss_mask = (labels.sum(1, keepdim=True) > 5).float() 87 | loss = (loss_mask*ce*X_weights).mean() 88 | 89 | # Compute the quality of the current SQ 90 | target_iou = compute_iou( 91 | F.transpose(2, 1).reshape(-1, N), 92 | labels.transpose(2, 1).reshape(-1, N), 93 | average=False 94 | ).view(B, M).detach() 95 | mse_qos_loss = ((p.qos - target_iou)**2).mean() 96 | 97 | return loss, mse_qos_loss, F 98 | 99 | # Extract the arguments to local variables 100 | gt_points, gt_labels, gt_weights = y_target 101 | _, P = prim_params.space_partition 102 | sharpness = prim_params.sharpness_r 103 | 104 | # Compute the coverage loss given the partition 105 | labels = [gt_labels] 106 | coverage_loss = 0 107 | for i in range(len(P)): 108 | pcurrent = P[i] 109 | if i == 0: 110 | precision_m = gt_points.new_zeros(3, 3).fill_diagonal_(1).reshape( 111 | 1, 3, 3).repeat((gt_points.shape[0], 1, 1, 1) 112 | ) 113 | pparent = PrimitiveParameters.from_existing( 114 | PrimitiveParameters.empty(), 115 | precision_matrix=precision_m 116 | ) 117 | else: 118 | pparent = P[i-1] 119 | loss, next_labels = _coverage_inner( 120 | pcurrent, pparent, gt_points, labels[-1] 121 | ) 122 | labels.append(next_labels) 123 | coverage_loss = coverage_loss + loss 124 | 125 | F_intermediate = [] 126 | fit_loss = 0 127 | pr_loss = 0 128 | for pr, ps in zip(prim_params.fit, P): 129 | floss, F = _fit_shape_inner(pr, gt_points, gt_labels, gt_weights) 130 | fit_loss = fit_loss + 1e-1 * floss 131 | F_intermediate.append(F) 132 | 133 | # Compute the proximity loss between the centroids and the centers of the 134 | # primitives 135 | s_tr = ps.translations_r.detach() 136 | r_tr = pr.translations_r 137 | pr_loss = pr_loss + ((s_tr - r_tr)**2).sum(-1).mean() 138 | 139 | # Compute the disjoint loss between the siblings 140 | intermediates = ValueRegistry.get_instance("loss_intermediate_values") 141 | intermediates["F_intermediate"] = F_intermediate 142 | intermediates["labels"] = labels 143 | intermediates["gt_points"] = gt_points 144 | 145 | # Compute the quality of the reconstruction 146 | qos_loss = 0 147 | for i, pr in enumerate(prim_params.fit): 148 | floss, qloss, F = _fit_parent_inner( 149 | pr, gt_points, labels[i+1], gt_weights, F_intermediate[i] 150 | ) 151 | fit_loss = fit_loss + 1e-2 * floss 152 | qos_loss = qos_loss + 1e-3 * qloss 153 | 154 | # Compute some metrics to report during training 155 | F_leaves = F_intermediate[-1] 156 | iou = compute_iou( 157 | gt_labels.squeeze(-1), 158 | torch.max(F_leaves, dim=-1)[0] 159 | ) 160 | accuracy, positive_accuracy = _compute_accuracy_and_recall( 161 | F_leaves, 162 | F_leaves.new_ones(F_leaves.shape[0], F_leaves.shape[-1]), 163 | gt_labels, 164 | gt_weights 165 | ) 166 | 167 | stats = StatsLogger.instance() 168 | stats["losses.coverage"] = coverage_loss.item() 169 | stats["losses.fit"] = fit_loss.item() 170 | stats["losses.prox"] = pr_loss.item() 171 | stats["losses.qos"] = qos_loss.item() 172 | stats["metrics.iou"] = iou 173 | stats["metrics.accuracy"] = accuracy.item() 174 | stats["metrics.positive_accuracy"] = positive_accuracy.item() 175 | 176 | return coverage_loss + pr_loss + fit_loss + qos_loss 177 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/implicit_surface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .common import loss_weights_on_depths 4 | from .loss_functions import euclidean_dual_loss, _euclidean_dual_loss_impl 5 | from ..primitives import get_implicit_surface, _accuracy_and_recall 6 | 7 | from ..utils.metrics import compute_iou 8 | from ..utils.stats_logger import StatsLogger 9 | from ..utils.value_registry import ValueRegistry 10 | 11 | 12 | def _implicit_surface_loss_impl(X, X_weights, labels, probs, F): 13 | # Get the sizes: batch size (B), number of points (N) and number of 14 | # primitives 15 | B, N, M = F.shape 16 | assert X.shape == (B, N, 3) 17 | assert X_weights.shape == (B, N, 1) 18 | assert labels.shape == (B, N, 1) 19 | 20 | # Sort F in descending order 21 | f, idxs = torch.sort(F, dim=-1, descending=True) 22 | 23 | # Start by computing the cumulative product 24 | # Sort based on the indices 25 | probs = torch.cat([ 26 | probs[i].take(idxs[i]).unsqueeze(0) for i in range(len(idxs)) 27 | ]) 28 | neg_cumprod = torch.cumprod(1-probs, dim=-1) 29 | neg_cumprod = torch.cat( 30 | [neg_cumprod.new_ones((B, N, 1)), neg_cumprod[:, :, :-1].clone()], 31 | dim=-1 32 | ) 33 | 34 | # minprob[i, j, k] is the probability that for sample i and point j the 35 | # k-th primitive has the minimum loss 36 | minprob = probs.mul(neg_cumprod) 37 | 38 | intermediate_values = ValueRegistry.get_instance( 39 | "loss_intermediate_values" 40 | ) 41 | intermediate_values["F_sorted"] = f 42 | intermediate_values["minprob"] = minprob 43 | 44 | # Compute the classification loss using binary cross entropy loss 45 | sm = probs.new_tensor(1e-6) 46 | t1 = torch.log(torch.max(f, sm)) 47 | t2 = torch.log(torch.max(1.0 - f, sm)) 48 | cross_entropy_loss = - labels * t1 - (1.0 - labels) * t2 49 | cross_entropy_loss = X_weights * cross_entropy_loss 50 | loss = torch.einsum("ijk,ijk->i", [cross_entropy_loss, minprob]) 51 | loss = loss / N 52 | 53 | return loss 54 | 55 | 56 | def implicit_surface_loss(prim_params, y_target, options): 57 | """ 58 | Implement the loss function using the implicit surface function of SQs 59 | 60 | Arguments: 61 | ---------- 62 | prim_params: PrimitiveParameters object containing the predictions 63 | of the network 64 | y_target: A tensor of shape BxNx4 containing the points and occupancy 65 | labels concatenated in the last dimension [x_i; o_i] 66 | options: A dictionary with various options 67 | 68 | Returns: 69 | ------- 70 | the loss 71 | """ 72 | gt_points, gt_labels, gt_weights = y_target 73 | 74 | # Declare some variables 75 | B = gt_points.shape[0] # batch size 76 | N = gt_points.shape[1] # number of points per sample 77 | M = prim_params.n_primitives # number of primitives 78 | 79 | probs = prim_params.probs 80 | translations = prim_params.translations_r 81 | rotations = prim_params.rotations_r 82 | alphas = prim_params.sizes_r 83 | epsilons = prim_params.shapes_r 84 | sharpness = prim_params.sharpness_r 85 | 86 | # Compute the implicit surface function for each primitive 87 | F, X_transformed = get_implicit_surface( 88 | gt_points, translations, rotations, alphas, epsilons, sharpness 89 | ) 90 | intermediate_values = ValueRegistry.get_instance( 91 | "loss_intermediate_values" 92 | ) 93 | intermediate_values["F"] = F 94 | intermediate_values["X"] = X_transformed 95 | assert F.shape == (B, N, M) 96 | 97 | loss = _implicit_surface_loss_impl( 98 | gt_points, gt_weights, gt_labels, probs, F 99 | ) 100 | 101 | # Compute some metrics to report during training 102 | iou = compute_iou( 103 | gt_labels.squeeze(-1), 104 | torch.max(F, dim=-1)[0] 105 | ) 106 | 107 | accuracy, positive_accuracy = _accuracy_and_recall( 108 | intermediate_values["F_sorted"], 109 | intermediate_values["minprob"], 110 | gt_labels, 111 | gt_weights 112 | ) 113 | stats = StatsLogger.instance() 114 | stats["losses.reconstruction"] = loss.mean().item() 115 | stats["metrics.accuracy"] = accuracy.item() 116 | stats["metrics.positive_accuracy"] = positive_accuracy.item() 117 | stats["metrics.iou"] = iou 118 | stats["metrics.exp_n_prims"] = probs.sum(-1).mean().item() 119 | 120 | return loss.mean() 121 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/implicit_surface_loss_with_partition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..primitives import get_implicit_surface, _compute_accuracy_and_recall 4 | from ..utils.stats_logger import StatsLogger 5 | from ..utils.metrics import compute_iou 6 | 7 | 8 | def implicit_surface_loss_with_partition(prim_params, y_target, options): 9 | """ 10 | Arguments: 11 | ---------- 12 | prim_params: PrimitiveParameters object containing the predictions 13 | of the network 14 | y_target: A tensor of shape BxNx4 containing the points and occupancy 15 | labels concatenated in the last dimension [x_i; o_i] 16 | options: A dictionary with various options 17 | 18 | Returns: 19 | ------- 20 | the loss 21 | """ 22 | def _structure_loss_inner(p, X, X_labels): 23 | M = p.n_primitives # number of primitives 24 | B, N, _ = X.shape 25 | translations = p.translations_r 26 | 27 | # Compute the euclidean distance between the geometric centroids and 28 | # the target points 29 | dists = ((X.unsqueeze(2) - translations.unsqueeze(1))**2).sum(-1) 30 | assert dists.shape == (B, N, M) 31 | # For every point in the target points find the index of its closest 32 | # geometric centroid and assign it to this centroid. 33 | min_dists, idxs = torch.min(dists, dim=-1) 34 | X_labels_new = torch.eye(M, device=X.device)[idxs] * X_labels 35 | assert X_labels_new.shape == (B, N, M) 36 | assert X_labels_new.sum(-1).max().item() == 1 37 | 38 | # Now compute the sum of squared distances as the loss 39 | loss = (min_dists * X_labels.squeeze(-1)).mean() 40 | return loss, X_labels_new 41 | 42 | def _fit_part_inner(F, X, labels, X_weights): 43 | B, N, M = F.shape 44 | assert X_weights.shape == (B, N, 1) 45 | assert labels.shape == (B, N, M) 46 | 47 | # Now compute the loss 48 | sm = F.new_tensor(1e-6) 49 | t1 = torch.log(torch.max(F, sm)) 50 | t2 = torch.log(torch.max(1-F, sm)) 51 | loss = - labels * t1 - (1.0 - labels) * t2 52 | # 5 is a very important number that is necessary for the code to work!!! 53 | # Do not change!! (This is there to avoid having empty primitives :-)) 54 | loss_mask = (labels.sum(1, keepdim=True) > 5).float() 55 | loss = (loss_mask*loss*X_weights).mean() 56 | 57 | return loss.mean() 58 | 59 | def _fit_shape_inner(F, X, X_labels, X_weights): 60 | B, N, M = F.shape 61 | assert X_labels.shape == (B, N, 1) 62 | assert X_weights.shape == (B, N, 1) 63 | 64 | # Simply compute the cross entropy loss 65 | f = torch.max(F, dim=-1, keepdim=True)[0] 66 | sm = F.new_tensor(1e-6) 67 | t1 = torch.log(torch.max(f, sm)) 68 | t2 = torch.log(torch.max(1.0 - f, sm)) 69 | cross_entropy_loss = - X_labels * t1 - (1.0 - X_labels) * t2 70 | loss = X_weights * cross_entropy_loss 71 | 72 | return loss.mean(), f 73 | 74 | # Extract the arguments to local variables 75 | X, X_labels, X_weights = y_target 76 | _, P = prim_params.space_partition 77 | 78 | # Compute the structure loss and the assign the labels on the points given 79 | # the partition 80 | structure_loss, new_labels = _structure_loss_inner(P[-1], X, X_labels) 81 | 82 | # Declare some variables 83 | M = prim_params.n_primitives 84 | B, N, _ = X.shape 85 | translations = prim_params.translations_r 86 | rotations = prim_params.rotations_r 87 | alphas = prim_params.sizes_r 88 | epsilons = prim_params.shapes_r 89 | sharpness = prim_params.sharpness_r 90 | # Compute the implicit surface function for each primitive 91 | F, _ = get_implicit_surface( 92 | X, translations, rotations, alphas, epsilons, sharpness 93 | ) 94 | assert F.shape == (B, N, M) 95 | 96 | # Compute the geometry loss 97 | # Fit every primitive to the part of the object it represents 98 | fit_loss_parts = _fit_part_inner(F, X, new_labels, X_weights) 99 | 100 | fit_loss_shape, F_max = _fit_shape_inner(F, X, X_labels, X_weights) 101 | 102 | # Compute the proximity loss between the centroids and the centers of the 103 | # primitives 104 | s_tr = P[-1].translations_r.detach() 105 | r_tr = prim_params.translations_r 106 | prox_loss = ((s_tr - r_tr)**2).sum(-1).mean() 107 | 108 | # Compute some metrics and report during training 109 | iou = compute_iou( 110 | X_labels.squeeze(-1), F_max.squeeze(-1), X_weights.squeeze(-1) 111 | ) 112 | accuracy, positive_accuracy = _compute_accuracy_and_recall( 113 | F, 114 | F.new_ones(F.shape[0], F.shape[-1]), 115 | X_labels, 116 | X_weights 117 | ) 118 | 119 | stats = StatsLogger.instance() 120 | stats["losses.structure"] = structure_loss.item() 121 | stats["losses.fit_parts"] = fit_loss_parts.item() 122 | stats["losses.fit_shape"] = fit_loss_shape.item() 123 | stats["losses.prox"] = prox_loss.item() 124 | stats["metrics.iou"] = iou 125 | stats["metrics.accuracy"] = accuracy.item() 126 | stats["metrics.positive_accuracy"] = positive_accuracy.item() 127 | 128 | w1 = options["loss_weights"].get("structure_loss_weight", 0.0) 129 | w2 = options["loss_weights"].get("fit_shape_loss_weight", 0.0) 130 | w3 = options["loss_weights"].get("fit_parts_loss_weight", 0.0) 131 | w4 = options["loss_weights"].get("proximity_loss_weight", 0.0) 132 | 133 | loss = w1 * structure_loss + w2 * fit_loss_shape + w3 * fit_loss_parts 134 | loss = loss + w4 * prox_loss 135 | 136 | return loss 137 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from ..primitives import inside_outside_function, points_to_cuboid_distances, \ 6 | transform_to_primitives_centric_system, deform, \ 7 | ray_plane_intersections 8 | 9 | from ..utils.stats_logger import StatsLogger 10 | 11 | 12 | def _euclidean_dual_loss_impl( 13 | X, 14 | prim_params, 15 | sampler, 16 | use_chamfer=True, 17 | use_cuboid=False 18 | ): 19 | """ 20 | Arguments: 21 | ---------- 22 | X: Tensor of size BxNx3 containing the points sampled from the surface 23 | of the target mesh 24 | prim_params: PrimitiveParameters object containing the predictions 25 | of the network 26 | """ 27 | # Get some sizes: batch size (B), number of points (N), (M) number of 28 | # primitives and S the number of points sampled from the SQs 29 | B, N, _ = X.shape 30 | M = prim_params.n_primitives 31 | S = sampler.n_samples 32 | 33 | probs = prim_params.probs 34 | translations = prim_params.translations_r 35 | rotations = prim_params.rotations_r 36 | alphas = prim_params.sizes_r 37 | epsilons = prim_params.shapes_r 38 | if prim_params.deformations is None: 39 | # Initialize with zero deformations 40 | tapering_params = probs.new_zeros(B, M, 2) 41 | else: 42 | tapering_params = prim_params.deformations_r 43 | 44 | 45 | # Transform the 3D points from world-coordinates to primitive-centric 46 | # coordinates with size BxNxMx3 47 | X_transformed = transform_to_primitives_centric_system( 48 | X, 49 | translations, 50 | rotations 51 | ) 52 | # Get the coordinates of the sampled points on the surfaces of the SQs, 53 | # with size BxMxSx3 54 | X_SQ, normals = sampler.sample_points_on_primitive(use_cuboid, alphas, epsilons) 55 | X_SQ = deform(X_SQ, alphas, tapering_params) 56 | # Make the normals unit vectors 57 | normals_norm = normals.norm(dim=-1).view(B, M, S, 1) 58 | normals = normals / normals_norm 59 | 60 | # Make sure that everything has the right size 61 | assert X_SQ.shape == (B, M, S, 3) 62 | assert normals.shape == (B, M, S, 3) 63 | assert X_transformed.shape == (B, N, M, 3) 64 | # Make sure that the normals are unit vectors 65 | assert torch.sqrt(torch.sum(normals ** 2, -1)).sum() == B*M*S 66 | # Compute the pairwise Euclidean distances between points sampled on the 67 | # surface of the SQ (X_SQ) with points sampled on the surface of the target 68 | # object (X_transformed) 69 | V = (X_SQ.unsqueeze(3) - (X_transformed.permute(0, 2, 1, 3)).unsqueeze(2)) 70 | assert V.shape == (B, M, S, N, 3) 71 | D = torch.sum((V)**2, -1) 72 | 73 | cvrg_loss, inside = euclidean_coverage_loss( 74 | [probs, translations, rotations, alphas, epsilons, tapering_params], 75 | X_transformed, 76 | D, 77 | use_cuboid, 78 | use_chamfer 79 | ) 80 | assert inside is None or inside.shape == (B, N, M) 81 | 82 | cnst_loss = euclidean_consistency_loss( 83 | prim_params, 84 | V, 85 | normals, 86 | inside, 87 | D, 88 | use_chamfer 89 | ) 90 | 91 | return cvrg_loss, cnst_loss, X_SQ 92 | 93 | 94 | def euclidean_dual_loss( y_hat, y_target, sampler, options): 95 | """ 96 | Arguments: 97 | ---------- 98 | y_hat: PrimitiveParameters object containing the predictions 99 | of the network 100 | y_target: Tensor with size BxNx6 with the N points from the target 101 | object and their corresponding normals 102 | sampler: An object of either CuboidSampler or EqualDistanceSampler 103 | depending on the type of the primitive we are using 104 | options: A dictionary with various options 105 | 106 | Returns: 107 | -------- 108 | the loss 109 | """ 110 | use_cuboid = options.get("use_cuboid", False) 111 | use_chamfer = options.get("use_chamfer", False) 112 | loss_weights = options["loss_weights"] 113 | 114 | gt_points = y_target[:, :, :3] 115 | # Make sure that everything has the right shape 116 | assert gt_points.shape[-1] == 3 117 | 118 | cvrg_loss, cnst_loss, X_SQ = _euclidean_dual_loss_impl( 119 | gt_points, y_hat, sampler 120 | ) 121 | 122 | stats = StatsLogger.instance() 123 | stats["losses.cvrg"] = cvrg_loss.item() 124 | stats["losses.cnst"] = cnst_loss.item() 125 | stats["metrics.exp_n_prims"] = y_hat.probs.sum(-1).mean().item() 126 | 127 | w1 = loss_weights["coverage_loss_weight"] 128 | w2 = loss_weights["consistency_loss_weight"] 129 | return w1 * cvrg_loss + w2 * cnst_loss 130 | 131 | 132 | def euclidean_coverage_loss( 133 | y_hat, 134 | X_transformed, 135 | D, 136 | use_cuboid=False, 137 | use_chamfer=False 138 | ): 139 | """ 140 | Arguments: 141 | ---------- 142 | y_hat: List of Tensors containing the predictions of the network 143 | X_transformed: Tensor with size BxNxMx3 with the N points from the 144 | target object transformed in the M primitive-centric 145 | coordinate systems 146 | D: Tensor of size BxMxSxN that contains the pairwise distances between 147 | points on the surface of the SQ to the points on the target object 148 | """ 149 | # Declare some variables 150 | B = X_transformed.shape[0] # batch size 151 | N = X_transformed.shape[1] # number of points per sample 152 | M = X_transformed.shape[2] # number of primitives 153 | 154 | shapes = y_hat[3].view(B, M, 3) 155 | epsilons = y_hat[4].view(B, M, 2) 156 | probs = y_hat[0] 157 | 158 | # Get the relative position of points with respect to the SQs using the 159 | # inside-outside function 160 | F = shapes.new_tensor(0) 161 | inside = None 162 | if not use_chamfer: 163 | if use_cuboid: 164 | F = points_to_cuboid_distances(X_transformed, shapes) 165 | inside = F <= 0 166 | else: 167 | F = inside_outside_function( 168 | X_transformed, 169 | shapes, 170 | epsilons 171 | ) 172 | inside = F <= 1 173 | 174 | D = torch.min(D, 2)[0].permute(0, 2, 1) # size BxNxM 175 | assert D.shape == (B, N, M) 176 | 177 | if not use_chamfer: 178 | D[inside] = 0.0 179 | distances, idxs = torch.sort(D, dim=-1) 180 | 181 | # Start by computing the cumulative product 182 | # Sort based on the indices 183 | probs = torch.cat([ 184 | probs[i].take(idxs[i]).unsqueeze(0) for i in range(len(idxs)) 185 | ]) 186 | neg_cumprod = torch.cumprod(1-probs, dim=-1) 187 | neg_cumprod = torch.cat( 188 | [neg_cumprod.new_ones((B, N, 1)), neg_cumprod[:, :, :-1]], 189 | dim=-1 190 | ) 191 | 192 | # minprob[i, j, k] is the probability that for sample i and point j the 193 | # k-th primitive has the minimum loss 194 | minprob = probs.mul(neg_cumprod) 195 | 196 | loss = torch.einsum("ijk,ijk->", [distances, minprob]) 197 | loss = loss / B / N 198 | 199 | StatsLogger.instance()["F"] = F 200 | return loss, inside 201 | 202 | 203 | def euclidean_consistency_loss(y_hat, V, normals, inside, D, 204 | use_chamfer=False): 205 | """ 206 | Arguments: 207 | ---------- 208 | y_hat: List of Tensors containing the predictions of the network 209 | V: Tensor with size BxMxSxN3 with the vectors from the points on SQs to 210 | the points on the target's object surface. 211 | normals: Tensor with size BxMxSx3 with the normals at every sampled 212 | points on the surfaces of the M primitives 213 | inside: A mask containing 1 if a point is inside the corresponding 214 | shape 215 | D: Tensor of size BxMxSxN that contains the pairwise distances between 216 | points on the surface of the SQ to the points on the target object 217 | """ 218 | B = V.shape[0] # batch size 219 | M = V.shape[1] # number of primitives 220 | S = V.shape[2] # number of points sampled on the SQ 221 | N = V.shape[3] # number of points sampled on the target object 222 | probs = y_hat[0] 223 | 224 | assert D.shape == (B, M, S, N) 225 | 226 | # We need to compute the distance to the closest point from the target 227 | # object for every point S 228 | # min_D = D.min(-1)[0] # min_D has size BxMxS 229 | if not use_chamfer: 230 | outside = (1-inside).permute(0, 2, 1).unsqueeze(2).float() 231 | assert outside.shape == (B, M, 1, N) 232 | D = D + (outside*1e30) 233 | # Compute the minimum distances D, with size BxMxS 234 | D = D.min(-1)[0] 235 | D[D >= 1e30] = 0.0 236 | assert D.shape == (B, M, S) 237 | 238 | # Compute an approximate area of the superellipsoid as if it were an 239 | # ellipsoid 240 | shapes = y_hat[3].view(B, M, 3) 241 | area = 4 * np.pi * ( 242 | (shapes[:, :, 0] * shapes[:, :, 1])**1.6 / 3 + 243 | (shapes[:, :, 0] * shapes[:, :, 2])**1.6 / 3 + 244 | (shapes[:, :, 1] * shapes[:, :, 2])**1.6 / 3 245 | )**0.625 246 | area = M * area / area.sum(dim=-1, keepdim=True) 247 | 248 | # loss = torch.einsum("ij,ij,ij->", [torch.max(D, -1)[0], probs, volumes]) 249 | # loss = torch.einsum("ij,ij,ij->", [torch.mean(D, -1), probs, volumes]) 250 | # loss = torch.einsum("ij,ij->", [torch.max(D, -1)[0], probs]) 251 | loss = torch.einsum("ij,ij,ij->", [torch.mean(D, -1), probs, area]) 252 | loss = loss / B / M 253 | 254 | return loss 255 | -------------------------------------------------------------------------------- /hierarchical_primitives/losses/regularizers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | from ..utils.stats_logger import StatsLogger 6 | from ..utils.value_registry import ValueRegistry 7 | from ..primitives import sq_volumes 8 | 9 | 10 | def volumes(parameters): 11 | """Ensure that the primitives will be small""" 12 | volumes = sq_volumes(parameters) 13 | return volumes.mean() 14 | 15 | 16 | def sparsity(parameters, minimum_number_of_primitives, 17 | maximum_number_of_primitives, w1, w2, a1, a2): 18 | """Ensure that we have at least that many primitives in expectation""" 19 | expected_primitives = parameters[0].sum(-1) 20 | 21 | lower_bound = minimum_number_of_primitives - expected_primitives 22 | upper_bound = expected_primitives - maximum_number_of_primitives 23 | zero = expected_primitives.new_tensor(0) 24 | 25 | t1 = torch.max(lower_bound, zero) * lower_bound**a1 26 | t2 = torch.max(upper_bound, zero) * upper_bound**a2 27 | 28 | return (w1*t1 + w2*t2).mean() 29 | 30 | 31 | def entropy_bernoulli(parameters): 32 | """Minimize the entropy of each bernoulli variable pushing them to either 1 33 | or 0""" 34 | probs = parameters[0] 35 | sm = probs.new_tensor(1e-3) 36 | 37 | t1 = torch.log(torch.max(probs, sm)) 38 | t2 = torch.log(torch.max(1 - probs, sm)) 39 | 40 | return torch.mean((-probs * t1 - (1-probs) * t2).sum(-1)) 41 | 42 | 43 | def parsimony(parameters): 44 | """Penalize the use of more primitives""" 45 | expected_primitives = parameters[0].sum(-1) 46 | 47 | return expected_primitives.mean() 48 | 49 | 50 | def siblings_proximity(parameters, root_depth=1, maximum_distance=0.1): 51 | """Make sure that two primitives that have the same parent will also be 52 | close in space. 53 | """ 54 | _, P = parameters.space_partition 55 | zero = P[0].translations.new_tensor(0.0) 56 | max_depth = len(P) 57 | D = 0 58 | # Iterate over the depths 59 | for d in range(root_depth, max_depth): 60 | t = P[d].translations_r 61 | t1 = t[:, 0::2] 62 | t2 = t[:, 1::2] 63 | D = D + torch.max( 64 | torch.sqrt(torch.sum((t1-t2)**2, dim=-1)) - maximum_distance, 65 | zero 66 | ).sum() 67 | N = P[0].translations.new_tensor( 68 | (2**torch.arange(root_depth-1, max_depth-1)).sum() 69 | ).float() 70 | return D / N 71 | 72 | 73 | def overlapping(parameters, tau=2): 74 | """Make sure that at most tau primitives witll overlap 75 | """ 76 | intermediate_values = ValueRegistry.get_instance( 77 | "loss_intermediate_values" 78 | ) 79 | F = intermediate_values["F"] 80 | probs = parameters.probs 81 | # Make sure that everything has the right size 82 | assert probs.shape[0] == F.shape[0] 83 | assert probs.shape[1] == F.shape[2] 84 | zero = probs.new_tensor(0.0) 85 | 86 | # Only consider primitives that exist 87 | t = (probs.unsqueeze(1) * F).sum(dim=-1) 88 | return torch.max(zero, t - tau).mean() 89 | 90 | 91 | def overlapping_on_depths(parameters, tau=1): 92 | intermediate_values = ValueRegistry.get_instance( 93 | "loss_intermediate_values" 94 | ) 95 | F_intermediate = intermediate_values["F_intermediate"] 96 | _, P = parameters.space_partition 97 | zero = parameters.probs.new_tensor(0.0) 98 | 99 | reg_terms = [] 100 | for pcurrent, F in zip(P[1:], F_intermediate[1:]): 101 | probs = pcurrent.probs 102 | # Make sure that everything has the right size 103 | assert probs.shape[0] == F.shape[0] 104 | assert probs.shape[1] == F.shape[2] 105 | 106 | # Only consider primitives that exist 107 | mask = (F >= 0.5).float() 108 | t = (probs.unsqueeze(1) * mask * F).sum(dim=-1) 109 | reg_terms.append(torch.max(zero, t - tau).mean()/pcurrent.n_primitives) 110 | 111 | return sum(reg_terms) 112 | 113 | 114 | def get(regularizer, options): 115 | regs = { 116 | "parsimony": parsimony, 117 | "entropy_bernoulli": entropy_bernoulli, 118 | "overlapping": lambda y_hat: overlapping( 119 | y_hat, 120 | options.get("tau", 2) 121 | ), 122 | "overlapping_on_depths": overlapping_on_depths, 123 | "sparsity": lambda y_hat: sparsity( 124 | y_hat, 125 | options.get("minimum_number_of_primitives", 5), 126 | options.get("maximum_number_of_primitives", 5000), 127 | options.get("w1", 0.005), 128 | options.get("w2", 0.005), 129 | options.get("a1", 4.0), 130 | options.get("a2", 2.0) 131 | ), 132 | "proximity": proximity, 133 | "siblings_proximity": siblings_proximity, 134 | "volumes": volumes 135 | } 136 | 137 | def inner(y_hat): 138 | reg_value = regs[regularizer](y_hat) 139 | reg_key = "regularizers." + regularizer 140 | StatsLogger.instance()[reg_key] = reg_value.item() 141 | 142 | return reg_value 143 | 144 | return inner 145 | -------------------------------------------------------------------------------- /hierarchical_primitives/mesh.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import trimesh 4 | 5 | 6 | class Mesh(object): 7 | def __init__(self, mesh, normalize=False): 8 | self.mesh = mesh 9 | # Normalize points such that they are in the unit cube 10 | if normalize: 11 | bbox = self.mesh.bounding_box.bounds 12 | # Compute location and scale 13 | loc = (bbox[0] + bbox[1]) / 2 14 | scale = (bbox[1] - bbox[0]).max() # / (1 - 0.05) 15 | 16 | # Transform input mesh 17 | self.mesh.apply_translation(-loc) 18 | self.mesh.apply_scale(1 / scale) 19 | 20 | # Make sure that the input meshes are watertight 21 | assert self.mesh.is_watertight 22 | 23 | self._vertices = None 24 | self._vertex_normals = None 25 | self._faces = None 26 | self._face_normals = None 27 | 28 | @property 29 | def vertices(self): 30 | if self._vertices is None: 31 | self._vertices = np.array(self.mesh.vertices) 32 | return self._vertices 33 | 34 | @property 35 | def vertex_normals(self): 36 | if self._vertex_normals is None: 37 | self._vertex_normals = np.array(self.mesh.vertex_normals) 38 | return self._vertex_normals 39 | 40 | @property 41 | def faces(self): 42 | if self._faces is None: 43 | self._faces = np.array(self.mesh.faces) 44 | return self._faces 45 | 46 | @property 47 | def face_normals(self): 48 | if self._face_normals is None: 49 | self._face_normals = np.array(self.mesh.face_normals) 50 | return self._face_normals 51 | 52 | def sample_faces(self, N=10000): 53 | P, t = trimesh.sample.sample_surface(self.mesh, N) 54 | return np.hstack([ 55 | P, self.face_normals[t, :] 56 | ]) 57 | 58 | @classmethod 59 | def from_file(cls, filename, normalize): 60 | return cls(trimesh.load(filename, process=False), normalize) 61 | 62 | 63 | def read_mesh_file(filename, normalize): 64 | return Mesh.from_file(filename, normalize) 65 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/hierarchical_primitives/networks/__init__.py -------------------------------------------------------------------------------- /hierarchical_primitives/networks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | import yaml 6 | try: 7 | from yaml import CLoader as Loader 8 | except ImportError: 9 | from yaml import Loader 10 | 11 | from .feature_extractors import get_feature_extractor 12 | from .primitive_layer import get_primitive_network 13 | from .utils import FrozenBatchNorm2d 14 | from ..utils.value_registry import ValueRegistry 15 | 16 | 17 | class Network(nn.Module): 18 | """A module used to represent the general network architecture, which 19 | consists of a features extractor and a primitive layer. The 20 | features_extractor takes an input (it can be anything) and estimates a set 21 | of features. The primitive_layer, encodes these features to Mx13 22 | parameters, which correspond to the M primitives. 23 | """ 24 | def __init__(self, feature_extractor, primitive_layer): 25 | super(Network, self).__init__() 26 | self._feature_extractor = feature_extractor 27 | self._primitive_layer = primitive_layer 28 | 29 | def forward(self, X): 30 | return self._primitive_layer(self._feature_extractor(X)) 31 | 32 | 33 | class NetworkBuilder(object): 34 | def __init__(self, config): 35 | self._config = config 36 | 37 | @classmethod 38 | def from_yaml_file(cls, filepath): 39 | with open(filepath, "r") as f: 40 | config = yaml.load(f, Loader=Loader) 41 | return cls(config) 42 | 43 | @property 44 | def config(self): 45 | return self._config 46 | 47 | @property 48 | def feature_extractor(self): 49 | return get_feature_extractor( 50 | self.config["feature_extractor"]["type"], 51 | self.config["data"]["n_primitives"], 52 | self.config["feature_extractor"].get("freeze_bn", False) 53 | ) 54 | 55 | @property 56 | def primitive_layer(self): 57 | return get_primitive_network( 58 | self.config.get("primitive_network", "default"), 59 | self.feature_extractor, 60 | self.config 61 | ) 62 | 63 | @property 64 | def network(self): 65 | return Network(self.feature_extractor, self.primitive_layer) 66 | 67 | 68 | def train_on_batch( 69 | network, 70 | optimizer, 71 | loss_fn, 72 | X, 73 | y_target, 74 | current_epoch 75 | ): 76 | """Perform a forward and backward pass on a batch of samples and compute 77 | the loss and the primitive parameters. 78 | """ 79 | training_stats = ValueRegistry.get_instance("training_stats") 80 | training_stats["current_epoch"] = current_epoch 81 | optimizer.zero_grad() 82 | # Do the forward pass to predict the primitive_parameters 83 | y_hat = network(X) 84 | loss = loss_fn(y_hat, y_target) 85 | # Do the backpropagation 86 | loss.backward() 87 | nn.utils.clip_grad_norm_(network.parameters(), 1) 88 | # Do the update 89 | optimizer.step() 90 | 91 | return ( 92 | loss.item(), 93 | [x.data if hasattr(x, "data") else x for x in y_hat], 94 | ) 95 | 96 | 97 | def validate_on_batch( 98 | network, 99 | loss_fn, 100 | X, 101 | y_target 102 | ): 103 | """Perform a forward pass on a batch of samples and compute 104 | the loss and the metrics. 105 | """ 106 | # Do the forward pass to predict the primitive_parameters 107 | y_hat = network(X) 108 | loss = loss_fn(y_hat, y_target) 109 | return ( 110 | loss.item(), 111 | [x.data if hasattr(x, "data") else x for x in y_hat], 112 | ) 113 | 114 | 115 | def optimizer_factory(config, model): 116 | """Based on the input arguments create a suitable optimizer object 117 | """ 118 | params = model.parameters() 119 | 120 | optimizer = config["loss"].get("optimizer", "Adam") 121 | lr = config["loss"].get("lr", 1e-3) 122 | momentum = config["loss"].get("momentum", 0.9) 123 | 124 | if optimizer == "SGD": 125 | return optim.SGD(params, lr=lr, momentum=momentum) 126 | elif optimizer == "Adam": 127 | return optim.Adam(params, lr=lr) 128 | else: 129 | raise NotImplementedError() 130 | 131 | 132 | def build_network(config_file, weight_file, device="cpu"): 133 | network = NetworkBuilder.from_yaml_file(config_file).network 134 | # Move the network architecture to the device to be used 135 | network.to(device) 136 | # Check whether there is a weight file provided to continue training from 137 | if weight_file is not None: 138 | network.load_state_dict( 139 | torch.load(weight_file, map_location=device) 140 | ) 141 | return network 142 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | from .primitive_parameters import PrimitiveParameters 6 | from .utils import FrozenBatchNorm2d 7 | 8 | 9 | class BaseFeatureExtractor(nn.Module): 10 | """Hold some common functions among all feature extractor networks. 11 | """ 12 | @property 13 | def feature_shape(self): 14 | return self._feature_shape 15 | 16 | def forward(self, X): 17 | return self._feature_extractor(X) 18 | 19 | 20 | class TulsianiFeatures(BaseFeatureExtractor): 21 | """Build a variation of the feature extractor implemented in the volumetric 22 | primitives paper of Shubham Tulsiani. 23 | 24 | https://arxiv.org/pdf/1612.00404.pdf 25 | """ 26 | def __init__(self, freeze_bn): 27 | super(TulsianiFeatures, self).__init__() 28 | 29 | # Declare and initiliaze some useful variables 30 | n_filters = 4 31 | input_channels = 1 32 | encoder_layers = [] 33 | # Create an encoder using a stack of convolutions 34 | for i in range(5): 35 | encoder_layers.append( 36 | nn.Conv3d(input_channels, n_filters, kernel_size=3, padding=1) 37 | ) 38 | if not freeze_bn: 39 | encoder_layers.append(nn.BatchNorm3d(n_filters)) 40 | encoder_layers.append(nn.LeakyReLU(0.2, True)) 41 | encoder_layers.append(nn.MaxPool3d(kernel_size=2, stride=2)) 42 | 43 | input_channels = n_filters 44 | # Double the number of filters after every layer 45 | n_filters *= 2 46 | 47 | # Add the two fully connected layers 48 | input_channels = n_filters / 2 49 | n_filters = 100 50 | for i in range(2): 51 | encoder_layers.append(nn.Conv3d(input_channels, n_filters, 1)) 52 | #encoder_layers.append(nn.BatchNorm3d(n_filters)) 53 | encoder_layers.append(nn.LeakyReLU(0.2, True)) 54 | 55 | input_channels = n_filters 56 | 57 | self._feature_extractor = nn.Sequential(*encoder_layers[:-1]) 58 | self._feature_shape = n_filters 59 | 60 | def forward(self, X): 61 | return self._feature_extractor(X).view(X.shape[0], -1) 62 | 63 | 64 | class ResNet18(BaseFeatureExtractor): 65 | """Build a feature extractor using the pretrained ResNet18 architecture for 66 | image based inputs. 67 | """ 68 | def __init__(self, freeze_bn): 69 | super(ResNet18, self).__init__() 70 | self._feature_extractor = models.resnet18(pretrained=True) 71 | if freeze_bn: 72 | FrozenBatchNorm2d.freeze(self._feature_extractor) 73 | 74 | self._feature_extractor.fc = nn.Sequential( 75 | nn.Linear(512, 512), nn.ReLU(), 76 | nn.Linear(512, 512), nn.ReLU() 77 | ) 78 | self._feature_extractor.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 79 | self._feature_shape = 512 80 | 81 | 82 | def get_feature_extractor(name, n_primitives, freeze_bn=False): 83 | """Based on the name return the appropriate feature extractor""" 84 | return { 85 | "tulsiani": lambda: TulsianiFeatures(freeze_bn), 86 | "resnet18": lambda: ResNet18(freeze_bn) 87 | }[name]() 88 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/primitive_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .primitive_parameters import PrimitiveParameters 5 | from .probability import probs 6 | from .rotation import rotations 7 | from .shape import shapes 8 | from .size import sizes 9 | from .translation import translations 10 | from .qos import qos 11 | from .sharpness import sharpness 12 | from .simple_constant_sq import simple_constant_sq 13 | 14 | from ..primitives import compose_quaternions 15 | 16 | 17 | class _BasePrimitiveNetwork(nn.Module): 18 | """A simple way to reuse functions between primitive networks.""" 19 | def _get_inputs(self, X): 20 | if isinstance(X, tuple): 21 | X, p = X 22 | else: 23 | p = PrimitiveParameters.empty() 24 | 25 | return X, p 26 | 27 | 28 | class PrimitiveNetwork(_BasePrimitiveNetwork): 29 | """A PrimitiveNetwork creates a PrimitiveParameters object and passes it to 30 | a set of layers to fill it.""" 31 | def __init__(self, layers): 32 | super(PrimitiveNetwork, self).__init__() 33 | for i, l in enumerate(layers): 34 | self.add_module("layer{}".format(i), l) 35 | self._layers = layers 36 | 37 | def forward(self, X): 38 | X, p = self._get_inputs(X) 39 | F = [] 40 | for layer in self._layers: 41 | p = layer(X, p) 42 | F.append(p) 43 | 44 | return PrimitiveParameters.from_existing(F[-1], fit=F) 45 | 46 | 47 | class SpacePartitionerPrimitiveNetwork(_BasePrimitiveNetwork): 48 | """ 49 | Arguments 50 | --------- 51 | layers: list of nn.Module to predict the primitive parameters 52 | n_primitives: the depth of the binary tree that defines the number of 53 | leaves and hence the maximum number of primitives 54 | feature_shape: the dimensions of the features provided by the feature 55 | extractor 56 | """ 57 | def __init__(self, layers_fit, layers_partition, n_primitives, feature_shape): 58 | super(SpacePartitionerPrimitiveNetwork, self).__init__() 59 | for i, l in enumerate(layers_fit): 60 | self.add_module("layer_fit{}".format(i), l) 61 | for i, l in enumerate(layers_partition): 62 | self.add_module("layer_partition{}".format(i), l) 63 | 64 | self._layers_fit = layers_fit 65 | self._layers_partition = layers_partition 66 | 67 | self._partitioning = nn.Sequential( 68 | nn.Linear(feature_shape, 512), 69 | nn.ReLU(), 70 | nn.Linear(512, 2*feature_shape) 71 | ) 72 | 73 | self._n_primitives = n_primitives 74 | 75 | def forward(self, X): 76 | F, p = self._get_inputs(X) 77 | 78 | # Initialize some variables 79 | B, D = F.shape 80 | 81 | # Produce the partitioning of the feature 82 | C = [F.unsqueeze(1)] 83 | for d in range(1, self._n_primitives): 84 | C.append(self._partitioning(C[d-1]).view(B, -1, D)) 85 | 86 | # Predict primitives for each partition 87 | P = [] 88 | for Ci in C: 89 | pcurrent = p 90 | for layer in self._layers_partition: 91 | pcurrent = layer(Ci, pcurrent) 92 | P.append(pcurrent) 93 | 94 | F = [] 95 | for Ci in C: 96 | pcurrent = p 97 | for layer in self._layers_fit: 98 | pcurrent = layer(Ci, pcurrent) 99 | F.append(pcurrent) 100 | 101 | return PrimitiveParameters.from_existing( 102 | F[-1], 103 | space_partition=[C, P], 104 | fit=F 105 | ) 106 | 107 | 108 | class SingleDepthSpacePartitionerPrimitiveNetwork(_BasePrimitiveNetwork): 109 | """ 110 | Arguments: 111 | ---------- 112 | layers_fit: list of nn.Module to predict the primitive parameters 113 | layers_partition: list of nn.Modules to predict the space partitioning 114 | n_primitives: the depth of the binary tree that defines the number of 115 | leaves and hence the maximum number of primitives 116 | feature_shape: the dimensions of the features provided by the feature 117 | extractor 118 | """ 119 | def __init__( 120 | self, layers_fit, layers_partition, n_primitives, feature_shape 121 | ): 122 | super(SingleDepthSpacePartitionerPrimitiveNetwork, self).__init__() 123 | for i, l in enumerate(layers_fit): 124 | self.add_module("layer_fit{}".format(i), l) 125 | for i, l in enumerate(layers_partition): 126 | self.add_module("layer_partition{}".format(i), l) 127 | 128 | self._layers_fit = layers_fit 129 | self._layers_partition = layers_partition 130 | self._n_primitives = n_primitives 131 | 132 | def forward(self, X): 133 | X, p = self._get_inputs(X) 134 | 135 | # Predict primitives for each partition 136 | P = [] 137 | pcurrent = p 138 | for layer in self._layers_partition: 139 | pcurrent = layer(X, pcurrent) 140 | P.append(pcurrent) 141 | 142 | F = [] 143 | pcurrent = p 144 | for layer in self._layers_fit: 145 | pcurrent = layer(X, pcurrent) 146 | F.append(pcurrent) 147 | 148 | return PrimitiveParameters.from_existing( 149 | F[-1], 150 | space_partition=[X, P], 151 | fit=F 152 | ) 153 | 154 | 155 | class ConstantSQ(nn.Module): 156 | def forward(self, X, primitive_params): 157 | B, M, _ = X.shape 158 | rotations = X.new_zeros(B, M*4) 159 | rotations[:, ::4] = 1. 160 | return PrimitiveParameters.from_existing( 161 | primitive_params, 162 | sizes=X.new_ones(B, M*3)*0.1, 163 | shapes=X.new_ones(B, M*2), 164 | rotations=rotations 165 | ) 166 | 167 | 168 | def get_primitive_network(network, feature_extractor, config): 169 | n_primitives = config["data"]["n_primitives"] 170 | layers = config["primitive_layer"] 171 | 172 | return dict( 173 | default=lambda: PrimitiveNetwork( 174 | get_layer_instances(layers, feature_extractor, n_primitives, config) 175 | ), 176 | hierarchical=lambda: HierarchicalPrimitiveNetwork( 177 | get_layer_instances(layers, feature_extractor, n_primitives, config), 178 | n_primitives, 179 | feature_extractor.feature_shape 180 | ), 181 | space_partitioner=lambda: SpacePartitionerPrimitiveNetwork( 182 | get_layer_instances(layers, feature_extractor, n_primitives, config), 183 | get_layer_instances( 184 | config["structure_layer"], feature_extractor, n_primitives, config 185 | ), 186 | n_primitives, 187 | feature_extractor.feature_shape 188 | ), 189 | single_space_partitioner=lambda: SingleDepthSpacePartitionerPrimitiveNetwork( 190 | get_layer_instances(layers, feature_extractor, n_primitives, config), 191 | get_layer_instances( 192 | config["structure_layer"], feature_extractor, n_primitives, config 193 | ), 194 | n_primitives, 195 | feature_extractor.feature_shape 196 | ) 197 | )[network]() 198 | 199 | 200 | def get_layer_instances(layers, feature_extractor, n_primitives, config): 201 | factories = { 202 | "probs": probs, 203 | "rotations": rotations, 204 | "shapes": shapes, 205 | "sizes": sizes, 206 | "translations": translations, 207 | "qos": qos, 208 | "sharpness": sharpness, 209 | "simple_constant_sq": simple_constant_sq, 210 | "constant": lambda *args: ConstantSQ() 211 | } 212 | layer_instances = [] 213 | for name in layers: 214 | category, layer = name.split(":") 215 | layer_instances.append(factories[category]( 216 | layer, 217 | feature_extractor, 218 | n_primitives, 219 | config 220 | )) 221 | 222 | return layer_instances 223 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/primitive_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PrimitiveParameters(object): 5 | """Represents the \lambda_m.""" 6 | def __init__(self, probs, translations, rotations, sizes, shapes, 7 | space_partition=None, fit=None, qos=None, sharpness=None): 8 | self.probs = probs 9 | self.translations = translations 10 | self.rotations = rotations 11 | self.sizes = sizes 12 | self.shapes = shapes 13 | self.fit = fit 14 | self.space_partition = space_partition 15 | self.qos = qos 16 | self.sharpness = sharpness 17 | 18 | def __getattr__(self, name): 19 | if not name.endswith("_r"): 20 | raise AttributeError() 21 | 22 | prop = getattr(self, name[:-2]) 23 | if not torch.is_tensor(prop): 24 | raise AttributeError() 25 | 26 | return prop.view(self.batch_size, self.n_primitives, -1) 27 | 28 | @property 29 | def members(self): 30 | return ( 31 | self.probs, 32 | self.translations, 33 | self.rotations, 34 | self.sizes, 35 | self.shapes, 36 | self.space_partition, 37 | self.fit, 38 | self.qos, 39 | self.sharpness 40 | ) 41 | 42 | @property 43 | def batch_size(self): 44 | return self.sizes.shape[0] 45 | 46 | @property 47 | def n_primitives(self): 48 | return self.sizes.shape[1] // 3 49 | 50 | def __len__(self): 51 | return len(self.members) 52 | 53 | def __getitem__(self, i): 54 | return self.members[i] 55 | 56 | @classmethod 57 | def empty(cls): 58 | return cls( 59 | probs=None, 60 | translations=None, 61 | rotations=None, 62 | sizes=None, 63 | shapes=None, 64 | space_partition=None, 65 | fit=None, 66 | qos=None, 67 | sharpness=None 68 | ) 69 | 70 | @classmethod 71 | def from_existing(cls, other, **kwargs): 72 | params = dict() 73 | params["probs"] = other.probs 74 | params["translations"] = other.translations 75 | params["rotations"] = other.rotations 76 | params["sizes"] = other.sizes 77 | params["shapes"] = other.shapes 78 | params["space_partition"] = other.space_partition 79 | params["fit"] = other.fit 80 | params["qos"] = other.qos 81 | params["sharpness"] = other.sharpness 82 | for key, value in list(kwargs.items()): 83 | if key in params: 84 | params[key] = value 85 | return cls(**params) 86 | 87 | @classmethod 88 | def with_keys(cls, **kwargs): 89 | p = cls.empty() 90 | return cls.from_existing(p, **kwargs) 91 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/probability.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class Probability(nn.Module): 10 | """Use the features to predict the existence probabilities for all the 11 | primitives. 12 | 13 | The shape of the Probability tensor shoud be BxM, where B is the batch size 14 | and M is the number of primitives. 15 | """ 16 | def __init__(self, input_dims, n_primitives): 17 | super(Probability, self).__init__() 18 | self.fc = nn.Linear(input_dims, n_primitives) 19 | 20 | def forward(self, X, primitive_params): 21 | probs = torch.sigmoid(self.fc(X)) 22 | 23 | return PrimitiveParameters.from_existing( 24 | primitive_params, 25 | probs=probs 26 | ) 27 | 28 | 29 | class DeepProbability(nn.Module): 30 | """Use the features to predict the existence probabilities for all the 31 | primitives with a deeper architecture. 32 | 33 | The shape of the Probability tensor shoud be BxM, where B is the batch size 34 | and M is the number of primitives. 35 | """ 36 | def __init__(self, input_dims, n_primitives): 37 | super(DeepProbability, self).__init__() 38 | self.fc_0 = nn.Linear(input_dims, input_dims) 39 | self.nonlin_0 = nn.LeakyReLU(0.2, True) 40 | self.fc_1 = nn.Linear(input_dims, n_primitives) 41 | 42 | def forward(self, X, primitive_params): 43 | probs = torch.sigmoid( 44 | self.fc_1(self.nonlin_0(self.fc_0(X))) 45 | ) 46 | 47 | return PrimitiveParameters.from_existing( 48 | primitive_params, 49 | probs=probs 50 | ) 51 | 52 | 53 | class AttProbability(nn.Module): 54 | """Use the features to predict the existence probabilities for all the 55 | primitives.""" 56 | def __init__(self, input_dims): 57 | super(AttProbability, self).__init__() 58 | self.fc = nn.Linear(input_dims, 1) 59 | 60 | def forward(self, X, primitive_params): 61 | # Reshape it to BxM to be compatible with the rest 62 | probs = torch.sigmoid(self.fc(X)).squeeze(-1) 63 | 64 | return PrimitiveParameters.from_existing( 65 | primitive_params, 66 | probs=probs 67 | ) 68 | 69 | 70 | class AllOnes(nn.Module): 71 | """By default all primitives exist thus existence probabilities are 1.""" 72 | def forward(self, X, primitive_params): 73 | probs = X.new_ones((X.shape[0], primitive_params.n_primitives)) 74 | 75 | return PrimitiveParameters.from_existing( 76 | primitive_params, 77 | probs=probs 78 | ) 79 | 80 | 81 | class ProbabilityFromTransition(nn.Module): 82 | def __init__(self, n_primitives): 83 | super(ProbabilityFromTransition, self).__init__() 84 | self._n_primitives = n_primitives 85 | 86 | def forward(self, X, primitive_params): 87 | _transitions = primitive_params.transitions 88 | probs = transitions[:, :self._n_primitives, :self._n_primitives] 89 | 90 | return PrimitiveParameters.from_existing( 91 | primitive_params, 92 | probs=probs 93 | ) 94 | 95 | 96 | class TerminationProbability(nn.Module): 97 | """Use the features to predict the termination probabilities for all the 98 | primitives.""" 99 | def __init__(self, input_dims): 100 | super(TerminationProbability, self).__init__() 101 | self.fc = nn.Linear(input_dims, 1) 102 | 103 | def forward(self, X, primitive_params): 104 | ones = X.new_ones(X.shape[0], 1) 105 | probs = torch.sigmoid(self.fc(X)).squeeze(-1) 106 | probs = torch.cat([probs[:, :-1], ones], dim=1) 107 | 108 | return PrimitiveParameters.from_existing( 109 | primitive_params, 110 | termination_probs=probs 111 | ) 112 | 113 | 114 | def probs(name, fe, n_primitives, config): 115 | layers = dict( 116 | prob=partial(Probability, fe.feature_shape, n_primitives), 117 | all_ones=AllOnes, 118 | deep_prob=partial(DeepProbability, fe.feature_shape, n_primitives), 119 | att_prob=partial(AttProbability, fe.feature_shape), 120 | termination_prob=partial(TerminationProbability, fe.feature_shape), 121 | prob_from_transition=partial(ProbabilityFromTransition, n_primitives) 122 | ) 123 | 124 | return layers[name]() 125 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/qos.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class QualityOfSuperquadric(nn.Module): 10 | def __init__(self, input_dims, n_layers, hidden_units): 11 | super(QualityOfSuperquadric, self).__init__() 12 | # Keep the layers based on the n_layers 13 | l = [] 14 | in_features = input_dims 15 | for i in range(n_layers-1): 16 | l.append(nn.Linear(in_features, hidden_units)) 17 | l.append(nn.ReLU()) 18 | in_features = hidden_units 19 | l.append(nn.Linear(in_features, 1)) 20 | self.fc = nn.Sequential(*l) 21 | 22 | 23 | def forward(self, X, primitive_params): 24 | qos = torch.sigmoid(self.fc(X)).squeeze(-1) 25 | return PrimitiveParameters.from_existing( 26 | primitive_params, 27 | qos=qos 28 | ) 29 | 30 | def qos(name, fe, n_primitives, config): 31 | layers = dict( 32 | att_qos=partial( 33 | QualityOfSuperquadric, 34 | fe.feature_shape, 35 | n_layers=config["data"].get("n_layers", 1), 36 | hidden_units=config["data"].get("hidden_units", 128) 37 | ) 38 | ) 39 | return layers[name]() 40 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/rotation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class Rotation(nn.Module): 10 | """Use the features to predict the rotation as a quaternion for all 11 | primitives. 12 | 13 | The shape of the Rotation tensor should be BxM*4, where B is the batch size 14 | and M is the number of primitives. 15 | """ 16 | def __init__(self, input_dims, n_primitives): 17 | super(Rotation, self).__init__() 18 | self._n_primitives = n_primitives 19 | self.fc = nn.Linear(input_dims, n_primitives*4) 20 | 21 | def forward(self, X, primitive_params): 22 | quats = self.fc(X).view(X.shape[0], self._n_primitives, 4) 23 | # Apply an L2-normalization non-linearity to enforce the unit norm 24 | # constrain 25 | rotations = quats / torch.norm(quats, 2, -1, keepdim=True) 26 | rotations = rotations.view(X.shape[0], self._n_primitives*4) 27 | 28 | return PrimitiveParameters.from_existing( 29 | primitive_params, 30 | rotations=rotations 31 | ) 32 | 33 | 34 | class DeepRotation(nn.Module): 35 | """Use the features to predict the rotation as a quaternion for all 36 | primitives using a deeper architecture. 37 | 38 | The shape of the DeepRotation tensor should be BxM*4, where B is the batch 39 | size and M is the number of primitives. 40 | """ 41 | def __init__(self, input_dims, n_primitives): 42 | super(DeepRotation, self).__init__() 43 | self._n_primitives = n_primitives 44 | 45 | self.fc_0 = nn.Linear(input_dims, input_dims) 46 | self.nonlin_0 = nn.LeakyReLU(0.2, True) 47 | self.fc_1 = nn.Linear(input_dims, n_primitives*4) 48 | 49 | def forward(self, X, primitive_params): 50 | quats = self.fc_1( 51 | self.nonlin_0(sle.fc_0(X)) 52 | ).view(X.shape[0], self._n_primitives, 4) 53 | # Apply an L2-normalization non-linearity to enforce the unit norm 54 | # constrain 55 | rotations = quats / torch.norm(quats, 2, -1, keepdim=True) 56 | rotations = rotations.view(X.shape[0], self._n_primitives*4) 57 | 58 | return PrimitiveParameters.from_existing( 59 | primitive_params, 60 | rotations=rotations 61 | ) 62 | 63 | 64 | class AttRotation(nn.Module): 65 | """Use the features to predict the rotation for all primitives for an 66 | attention-based architecture. 67 | """ 68 | def __init__(self, input_dims, n_layers, hidden_units): 69 | super(AttRotation, self).__init__() 70 | 71 | # Keep the layers based on the n_layers 72 | l = [] 73 | in_features = input_dims 74 | for i in range(n_layers-1): 75 | l.append(nn.Linear(in_features, hidden_units)) 76 | l.append(nn.ReLU()) 77 | in_features = hidden_units 78 | l.append(nn.Linear(in_features, 4)) 79 | self.fc = nn.Sequential(*l) 80 | 81 | def forward(self, X, primitive_params): 82 | quats = self.fc(X) 83 | # Apply an L2-normalization non-linearity to enforce the unit norm 84 | # constrain 85 | rotations = quats / torch.norm(quats, 2, -1, keepdim=True) 86 | rotations = rotations.view(X.shape[0], -1) 87 | 88 | return PrimitiveParameters.from_existing( 89 | primitive_params, 90 | rotations=rotations 91 | ) 92 | 93 | 94 | class NoRotation(nn.Module): 95 | def __init__(self, n_primitives): 96 | super(NoRotation, self).__init__() 97 | self._n_primitives = n_primitives 98 | 99 | def forward(self, X, primitive_params): 100 | rotations = X.new_zeros(X.shape[0], self._n_primitives, 4) 101 | rotations[:, :, 0] = 1.0 102 | 103 | return PrimitiveParameters.from_existing( 104 | primitive_params, 105 | rotations=rotations.view(X.shape[0], self._n_primitives*4) 106 | ) 107 | 108 | 109 | def rotations(name, fe, n_primitives, config): 110 | layers = dict( 111 | default_rotation=partial(Rotation, fe.feature_shape, n_primitives), 112 | no_rotation=partial(NoRotation, n_primitives), 113 | deep_rotation=partial(DeepRotation, fe.feature_shape, n_primitives), 114 | att_rotation=partial( 115 | AttRotation, 116 | fe.feature_shape, 117 | n_layers=config["data"].get("n_layers", 1), 118 | hidden_units=config["data"].get("hidden_units", 128) 119 | ) 120 | ) 121 | return layers[name]() 122 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/shape.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class SQShape(nn.Module): 10 | """Use the features to predict the shape for all primitives. 11 | 12 | The shape of the SQShape tensor should be BxM*2, where B is the batch size 13 | and M is the number of primitives. 14 | """ 15 | def __init__(self, input_dims, n_primitives, min_e=0.4, max_e=1.1): 16 | super(SQShape, self).__init__() 17 | self.fc = nn.Linear(input_dims, n_primitives*2) 18 | self.max_e = max_e 19 | self.min_e = min_e 20 | 21 | def forward(self, X, primitive_params): 22 | shapes = torch.sigmoid(self.fc(X)) * self.max_e + self.min_e 23 | 24 | return PrimitiveParameters.from_existing( 25 | primitive_params, 26 | shapes=shapes 27 | ) 28 | 29 | 30 | class DeepSQShape(nn.Module): 31 | """Use the features to predict the shape for all primitives using a deeper 32 | architecture. 33 | 34 | The shape of the DeepSQShape tensor should be BxM*2, where B is the batch 35 | size and M is the number of primitives. 36 | """ 37 | def __init__(self, input_dims, n_primitives, min_e=0.4, max_e=1.1): 38 | super(DeepSQShape, self).__init__() 39 | self.fc_0 = nn.Linear(input_dims, input_dims) 40 | self.nonlin_0 = nn.LeakyReLU(0.2, True) 41 | self.fc_1 = nn.Linear(input_dims, n_primitives*2) 42 | self.max_e = max_e 43 | self.min_e = min_e 44 | 45 | def forward(self, X, primitive_params): 46 | shapes = torch.sigmoid( 47 | self.fc_1(self.nonlin_0(self.fc_0(X))) 48 | ) * self.max_e + self.min_e 49 | 50 | return PrimitiveParameters.from_existing( 51 | primitive_params, 52 | shapes=shapes 53 | ) 54 | 55 | 56 | class CubeShape(nn.Module): 57 | """By default all primitives are cubes, thus their shape is 0.25""" 58 | def __init__(self, n_primitives): 59 | super(CubeShape, self).__init__() 60 | self._n_primitives = n_primitives 61 | 62 | def forward(self, X, primitive_params): 63 | # Shapes should have shape BxM*2 64 | shapes = X.new_ones((X.shape[0], self._n_primitives*2)) * 0.25 65 | 66 | return PrimitiveParameters.from_existing( 67 | primitive_params, 68 | shapes=shapes 69 | ) 70 | 71 | 72 | class AttSQShape(nn.Module): 73 | """Use the features to predict the shape for all primitives. 74 | 75 | The shape of the AttentionSQShape tensor should be BxM*2, where B is the 76 | batch size and M is the number of primitives. 77 | """ 78 | def __init__( 79 | self, input_dims, n_layers, hidden_units, min_e=0.4, max_e=1.1 80 | ): 81 | super(AttSQShape, self).__init__() 82 | self.max_e = max_e 83 | self.min_e = min_e 84 | 85 | # Keep the layers based on the n_layers 86 | l = [] 87 | in_features = input_dims 88 | for i in range(n_layers-1): 89 | l.append(nn.Linear(in_features, hidden_units)) 90 | l.append(nn.ReLU()) 91 | in_features = hidden_units 92 | l.append(nn.Linear(in_features, 2)) 93 | self.fc = nn.Sequential(*l) 94 | 95 | def forward(self, X, primitive_params): 96 | shapes = torch.sigmoid(self.fc(X)) * self.max_e + self.min_e 97 | shapes = shapes.view(X.shape[0], -1) 98 | 99 | return PrimitiveParameters.from_existing( 100 | primitive_params, 101 | shapes=shapes 102 | ) 103 | 104 | 105 | def shapes(name, fe, n_primitives, config): 106 | layers = dict( 107 | sq=partial(SQShape, fe.feature_shape, n_primitives), 108 | cuboid=partial(CubeShape, n_primitives), 109 | deep_sq=partial(DeepSQShape, fe.feature_shape, n_primitives), 110 | att_sq=partial( 111 | AttSQShape, 112 | fe.feature_shape, 113 | n_layers=config["data"].get("n_layers", 1), 114 | hidden_units=config["data"].get("hidden_units", 128) 115 | ) 116 | ) 117 | return layers[name]() 118 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/sharpness.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class ConstantSharpness(nn.Module): 10 | def __init__(self, sharpness_value): 11 | super(ConstantSharpness, self).__init__() 12 | self.sv = sharpness_value 13 | 14 | def forward(self, X, primitive_params): 15 | sharpness = X.new_ones( 16 | (X.shape[0], primitive_params.n_primitives, 2) 17 | ) * X.new_tensor(self.sv) 18 | 19 | return PrimitiveParameters.from_existing( 20 | primitive_params, 21 | sharpness=sharpness 22 | ) 23 | 24 | 25 | class MixedConstantSharpness(nn.Module): 26 | def __init__(self, sharpness_value_pos, sharpness_value_neg): 27 | super(MixedConstantSharpness, self).__init__() 28 | self.sv_pos = sharpness_value_pos 29 | self.sv_neg = sharpness_value_neg 30 | 31 | def forward(self, X, primitive_params): 32 | sharpness = X.new_ones( 33 | (X.shape[0], primitive_params.n_primitives, 2) 34 | ) 35 | sharpness[:, :, 0] = self.sv_pos 36 | sharpness[:, :, 1] = self.sv_neg 37 | 38 | return PrimitiveParameters.from_existing( 39 | primitive_params, 40 | sharpness=sharpness 41 | ) 42 | 43 | 44 | class Sharpness(nn.Module): 45 | def __init__(self, input_dims, n_primitives, max_sv=10.0): 46 | super(Sharpness, self).__init__() 47 | self.max_sv = max_sv 48 | self.fc = nn.Linear(input_dims, n_primitives) 49 | 50 | def forward(self, X, primitive_params): 51 | sv = torch.sigmoid(self.fc(X)) * self.max_sv 52 | return PrimitiveParameters.from_existing( 53 | primitive_params, 54 | sharpness=sv.unsqueeze(-1).expand(X.shape[0], -1, 2) 55 | ) 56 | 57 | 58 | class MixedSharpness(nn.Module): 59 | def __init__( 60 | self, input_dims, n_primitives, max_sv_pos=10.0, max_sv_neg=10.0 61 | ): 62 | super(MixedSharpness, self).__init__() 63 | self.max_sv_pos = max_sv_pos 64 | self.max_sv_neg = max_sv_neg 65 | self.fc = nn.Linear(input_dims, 2*n_primitives) 66 | 67 | def forward(self, X, primitive_params): 68 | s = torch.sigmoid(self.fc(X)).view(X.shape[0], -1, 2) 69 | s = s * s.new_tensor([[[self.max_sv_pos, self.max_sv_neg]]]) 70 | 71 | return PrimitiveParameters.from_existing( 72 | primitive_params, 73 | sharpness=s 74 | ) 75 | 76 | 77 | def sharpness(name, fe, n_primitives, config): 78 | layers = dict( 79 | constant_sharpness=partial( 80 | ConstantSharpness, 81 | sharpness_value=config["loss"].get("sharpness", 10.0) 82 | ), 83 | mixed_constant_sharpness=partial( 84 | MixedConstantSharpness, 85 | sharpness_value_pos=config["loss"].get("sharpness", 10.0), 86 | sharpness_value_neg=config["loss"].get("sharpness_neg", 10.0) 87 | ), 88 | sharpness=partial( 89 | Sharpness, 90 | fe.feature_shape, 91 | n_primitives, 92 | max_sv=config["loss"].get("sharpness", 10.0) 93 | ), 94 | mixed_sharpness=partial( 95 | MixedSharpness, 96 | fe.feature_shape, 97 | n_primitives, 98 | max_sv_pos=config["loss"].get("sharpness", 10.0), 99 | max_sv_neg=config["loss"].get("sharpness_neg", 10.0) 100 | ), 101 | ) 102 | 103 | return layers[name]() 104 | 105 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/simple_constant_sq.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class SimpleConstantSQ(nn.Module): 10 | def __init__(self, n_primitives): 11 | super(SimpleConstantSQ, self).__init__() 12 | self._n_primitives = n_primitives 13 | 14 | def forward(self, X, primitive_params): 15 | B = X.shape[0] 16 | M = self._n_primitives 17 | rotations = X.new_zeros(B, M*4) 18 | rotations[:, ::4] = 1. 19 | return PrimitiveParameters.from_existing( 20 | primitive_params, 21 | sizes=X.new_ones(B, M*3)*0.1, 22 | shapes=X.new_ones(B, M*2), 23 | rotations=rotations 24 | ) 25 | 26 | def simple_constant_sq(name, fe, n_primitives, config): 27 | layers = dict( 28 | default_sq=partial(SimpleConstantSQ, n_primitives), 29 | ) 30 | return layers[name]() 31 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/size.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class Size(nn.Module): 10 | """Use the features to predict the size of all primitives. 11 | 12 | The shape of the Size tensor should be BxM*3, where B is the batch size and 13 | M is the number of primitives. 14 | """ 15 | def __init__(self, input_dims, n_primitives, min_a=0.005, max_a=0.5): 16 | super(Size, self).__init__() 17 | self._n_primitives = n_primitives 18 | self.fc = nn.Linear(input_dims, n_primitives*3) 19 | self.min_a = min_a 20 | self.max_a = max_a 21 | 22 | def forward(self, X, primitive_params): 23 | sizes = torch.sigmoid(self.fc(X)) * self.max_a + self.min_a 24 | 25 | return PrimitiveParameters.from_existing( 26 | primitive_params, 27 | sizes=sizes 28 | ) 29 | 30 | 31 | class DeepSize(nn.Module): 32 | """Use the features to predict the size of all primitives. 33 | 34 | The shape of the DeepSize tensor should be BxM*3, where B is the batch size 35 | and M is the number of primitives. 36 | """ 37 | def __init__(self, input_dims, n_primitives, min_a=0.005, max_a=0.5): 38 | super(DeepSize, self).__init__() 39 | self._n_primitives = n_primitives 40 | 41 | self.fc_0 = nn.Linear(input_dims, input_dims) 42 | self.nonlin_0 = nn.LeakyReLU(0.2, True) 43 | self.fc_1 = nn.Linear(input_dims, n_primitives*3) 44 | self.min_a = min_a 45 | self.max_a = max_a 46 | 47 | def forward(self, X, primitive_params): 48 | sizes = torch.sigmoid( 49 | self.fc_1(self.nonlin_0(self.fc_0(X))) 50 | ) * self.max_a + self.min_a 51 | 52 | return PrimitiveParameters.from_existing( 53 | primitive_params, 54 | sizes=sizes 55 | ) 56 | 57 | 58 | class AttSize(nn.Module): 59 | def __init__( 60 | self, input_dims, n_layers, hidden_units, min_a=0.005, max_a=0.5 61 | ): 62 | super(AttSize, self).__init__() 63 | self.min_a = min_a 64 | self.max_a = max_a 65 | 66 | # Keep the layers based on the n_layers 67 | l = [] 68 | in_features = input_dims 69 | for i in range(n_layers-1): 70 | l.append(nn.Linear(in_features, hidden_units)) 71 | l.append(nn.ReLU()) 72 | in_features = hidden_units 73 | l.append(nn.Linear(in_features, 3)) 74 | self.fc = nn.Sequential(*l) 75 | 76 | def forward(self, X, primitive_params): 77 | sizes = torch.sigmoid(self.fc(X)) * self.max_a + self.min_a 78 | sizes = sizes.view(X.shape[0], -1) 79 | 80 | return PrimitiveParameters.from_existing( 81 | primitive_params, 82 | sizes=sizes 83 | ) 84 | 85 | 86 | def sizes(name, fe, n_primitives, config): 87 | layers = dict( 88 | default_size=partial(Size, fe.feature_shape, n_primitives), 89 | deep_size=partial(DeepSize, fe.feature_shape, n_primitives), 90 | att_size=partial( 91 | AttSize, 92 | fe.feature_shape, 93 | n_layers=config["data"].get("n_layers", 1), 94 | hidden_units=config["data"].get("hidden_units", 128) 95 | ) 96 | ) 97 | 98 | return layers[name]() 99 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/translation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .primitive_parameters import PrimitiveParameters 7 | 8 | 9 | class Translation(nn.Module): 10 | """Use the features to predict the translation vectors for all primitives. 11 | 12 | The shape of the Translation tensor should be BxM*3, where B is the batch 13 | size and M is the number of primitives. 14 | """ 15 | def __init__(self, input_dims, n_primitives): 16 | super(Translation, self).__init__() 17 | self.fc = nn.Linear(input_dims, n_primitives*3) 18 | 19 | def forward(self, X, primitive_params): 20 | # Everything lies in the unit cube, so the maximum translation vector 21 | # is 0.51 22 | translations = torch.tanh(self.fc(X)) * 0.51 23 | 24 | return PrimitiveParameters.from_existing( 25 | primitive_params, 26 | translations=translations 27 | ) 28 | 29 | 30 | class DeepTranslation(nn.Module): 31 | """Use the features to predict the translation vectors for all primitives 32 | using a deeper architecture. 33 | 34 | The shape of the Translation tensor should be BxM*3, where B is the batch 35 | size and M is the number of primitives. 36 | """ 37 | def __init__(self, input_dims, n_primitives): 38 | super(DeepTranslation, self).__init__() 39 | self.fc_0 = nn.Linear(input_dims, input_dims) 40 | self.nonlin_0 = nn.LeakyReLU(0.2, True) 41 | self.fc_1 = nn.Linear(input_dims, n_primitives*3) 42 | 43 | def forward(self, X, primitive_params): 44 | # Everything lies in the unit cube, so the maximum translation vector 45 | # is 0.51 46 | translations = torch.tanh( 47 | self.fc_1(self.nonlin_0(self.fc_0(X))) 48 | ) * 0.51 49 | 50 | return PrimitiveParameters.from_existing( 51 | primitive_params, 52 | translations=translations 53 | ) 54 | 55 | 56 | class AttTranslation(nn.Module): 57 | def __init__(self, input_dims, n_layers, hidden_units): 58 | super(AttTranslation, self).__init__() 59 | 60 | # Keep the layers based on the n_layers 61 | l = [] 62 | in_features = input_dims 63 | for i in range(n_layers-1): 64 | l.append(nn.Linear(in_features, hidden_units)) 65 | l.append(nn.ReLU()) 66 | in_features = hidden_units 67 | l.append(nn.Linear(in_features, 3)) 68 | self.fc = nn.Sequential(*l) 69 | 70 | def forward(self, X, primitive_params): 71 | translations = torch.tanh(self.fc(X)) * 0.51 72 | 73 | # Reshape to BxM*3 74 | translations = translations.view(X.shape[0], -1) 75 | return PrimitiveParameters.from_existing( 76 | primitive_params, 77 | translations=translations 78 | ) 79 | 80 | 81 | class RelativeTranslation(nn.Module): 82 | def __init__(self, n_primitives): 83 | super(RelativeTranslation, self).__init__() 84 | self._n_primitives = n_primitives 85 | 86 | def forward(self, X, primitive_params): 87 | # Get the translations, the probs and the Pi_n tensors from the 88 | # primitive_params 89 | _translations = primitive_params.translations 90 | _probs = primitive_params.probs 91 | _Pi_n = primitive_params.Pi_n 92 | 93 | # Denote some variables for convenience 94 | B = X.shape[0] 95 | M = self._n_primitives 96 | 97 | # Compute the global translations from the local ones 98 | mask = X.new_tensor(1.) - torch.eye(M).to(X.device) 99 | g_translations = _translations.view(B, M, -1) 100 | for P in _Pi_n: 101 | g_translations = ( 102 | g_translations + 103 | torch.einsum( 104 | "ikc,ijk,ijl->ijl", 105 | [ 106 | _probs.unsqueeze(-1), 107 | P*mask, 108 | _translations.view(B, M, -1) 109 | ] 110 | ) 111 | ) 112 | 113 | return PrimitiveParameters.from_existing( 114 | primitive_params, 115 | translations=g_translations, 116 | local_translations=_translations 117 | ) 118 | 119 | 120 | class NoTranslation(nn.Module): 121 | """By default the translation tensor is set to 0.0.""" 122 | def __init__(self, n_primitives): 123 | super(NoTranslation, self).__init__() 124 | self._n_primitives = n_primitives 125 | 126 | def forward(self, X, primitive_params): 127 | translations = X.new_zeros(X.shape[0], self._n_primitives*3) 128 | 129 | return PrimitiveParameters.from_existing( 130 | primitive_params, 131 | translations=translations 132 | ) 133 | 134 | 135 | def translations(name, fe, n_primitives, config): 136 | layers = dict( 137 | default_translation=partial( 138 | Translation, 139 | fe.feature_shape, 140 | n_primitives 141 | ), 142 | deep_translation=partial( 143 | DeepTranslation, 144 | fe.feature_shape, 145 | n_primitives 146 | ), 147 | att_translation=partial( 148 | AttTranslation, 149 | fe.feature_shape, 150 | n_layers=config["data"].get("n_layers", 1), 151 | hidden_units=config["data"].get("hidden_units", 128) 152 | ), 153 | no_translation=partial(NoTranslation, n_primitives), 154 | relative_translation=partial(RelativeTranslation, n_primitives) 155 | ) 156 | return layers[name]() 157 | -------------------------------------------------------------------------------- /hierarchical_primitives/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """A BatchNorm2d wrapper for Pytorch's BatchNorm2d where the batch 8 | statictis are fixed. 9 | """ 10 | def __init__(self, num_features): 11 | super(FrozenBatchNorm2d, self).__init__() 12 | self.num_features = num_features 13 | self.register_parameter("weight", Parameter(torch.ones(num_features))) 14 | self.register_parameter("bias", Parameter(torch.zeros(num_features))) 15 | self.register_buffer("running_mean", torch.zeros(num_features)) 16 | self.register_buffer("running_var", torch.ones(num_features)) 17 | 18 | def extra_repr(self): 19 | return '{num_features}'.format(**self.__dict__) 20 | 21 | @classmethod 22 | def from_batch_norm(cls, bn): 23 | fbn = cls(bn.num_features) 24 | # Update the weight and biases based on the corresponding weights and 25 | # biases of the pre-trained bn layer 26 | with torch.no_grad(): 27 | fbn.weight[...] = bn.weight 28 | fbn.bias[...] = bn.bias 29 | fbn.running_mean[...] = bn.running_mean 30 | fbn.running_var[...] = bn.running_var + bn.eps 31 | return fbn 32 | 33 | @staticmethod 34 | def _getattr_nested(m, module_names): 35 | if len(module_names) == 1: 36 | return getattr(m, module_names[0]) 37 | else: 38 | return FrozenBatchNorm2d._getattr_nested( 39 | getattr(m, module_names[0]), module_names[1:] 40 | ) 41 | 42 | @staticmethod 43 | def freeze(m): 44 | for (name, layer) in m.named_modules(): 45 | if isinstance(layer, nn.BatchNorm2d): 46 | nest = name.split(".") 47 | if len(nest) == 1: 48 | setattr(m, name, FrozenBatchNorm2d.from_batch_norm(layer)) 49 | else: 50 | setattr( 51 | FrozenBatchNorm2d._getattr_nested(m, nest[:-1]), 52 | nest[-1], 53 | FrozenBatchNorm2d.from_batch_norm(layer) 54 | ) 55 | 56 | def forward(self, x): 57 | # Cast all fixed parameters to half() if necessary 58 | if x.dtype == torch.float16: 59 | self.weight = self.weight.half() 60 | self.bias = self.bias.half() 61 | self.running_mean = self.running_mean.half() 62 | self.running_var = self.running_var.half() 63 | 64 | scale = self.weight * self.running_var.rsqrt() 65 | bias = self.bias - self.running_mean * scale 66 | scale = scale.reshape(1, -1, 1, 1) 67 | bias = bias.reshape(1, -1, 1, 1) 68 | return x * scale + bias 69 | 70 | 71 | def freeze_network(network, freeze=False): 72 | if freeze: 73 | for p in network.parameters(): 74 | p.requires_grad = False 75 | return network 76 | -------------------------------------------------------------------------------- /hierarchical_primitives/sample_points_on_primitive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .equal_distance_sampler_sq import EqualDistanceSamplerSQ 4 | from .primitives import fexp 5 | 6 | 7 | def sample_uniformly_on_sq( 8 | alphas, 9 | epsilons, 10 | sq_sampler 11 | ): 12 | """ 13 | Given the sampling steps in the parametric space, we want to ge the actual 14 | 3D points on the sq. 15 | 16 | Arguments: 17 | ---------- 18 | alphas: Tensor with size BxMx3, containing the size along each 19 | axis for the M primitives 20 | epsilons: Tensor with size BxMx2, containing the shape along the 21 | latitude and the longitude for the M primitives 22 | 23 | Returns: 24 | --------- 25 | P: Tensor of size BxMxSx3 that contains S sampled points from the 26 | surface of each primitive 27 | N: Tensor of size BxMxSx3 that contains the normals of the S sampled 28 | points from the surface of each primitive 29 | """ 30 | # Allocate memory to store the sampling steps 31 | B = alphas.shape[0] # batch size 32 | M = alphas.shape[1] # number of primitives 33 | S = sq_sampler.n_samples 34 | 35 | etas, omegas = sq_sampler.sample_on_batch( 36 | alphas.detach().cpu().numpy(), 37 | epsilons.detach().cpu().numpy() 38 | ) 39 | # Make sure we don't get nan for gradients 40 | etas[etas == 0] += 1e-6 41 | omegas[omegas == 0] += 1e-6 42 | 43 | # Move to tensors 44 | etas = alphas.new_tensor(etas) 45 | omegas = alphas.new_tensor(omegas) 46 | 47 | # Make sure that all tensors have the right shape 48 | a1 = alphas[:, :, 0].unsqueeze(-1) # size BxMx1 49 | a2 = alphas[:, :, 1].unsqueeze(-1) # size BxMx1 50 | a3 = alphas[:, :, 2].unsqueeze(-1) # size BxMx1 51 | e1 = epsilons[:, :, 0].unsqueeze(-1) # size BxMx1 52 | e2 = epsilons[:, :, 1].unsqueeze(-1) # size BxMx1 53 | 54 | x = a1 * fexp(torch.cos(etas), e1) * fexp(torch.cos(omegas), e2) 55 | y = a2 * fexp(torch.cos(etas), e1) * fexp(torch.sin(omegas), e2) 56 | z = a3 * fexp(torch.sin(etas), e1) 57 | 58 | # Make sure we don't get INFs 59 | # x[torch.abs(x) <= 1e-9] = 1e-9 60 | # y[torch.abs(y) <= 1e-9] = 1e-9 61 | # z[torch.abs(z) <= 1e-9] = 1e-9 62 | x = ((x > 0).float() * 2 - 1) * torch.max(torch.abs(x), x.new_tensor(1e-6)) 63 | y = ((y > 0).float() * 2 - 1) * torch.max(torch.abs(y), x.new_tensor(1e-6)) 64 | z = ((z > 0).float() * 2 - 1) * torch.max(torch.abs(z), x.new_tensor(1e-6)) 65 | 66 | # Compute the normals of the SQs 67 | nx = (torch.cos(etas)**2) * (torch.cos(omegas)**2) / x 68 | ny = (torch.cos(etas)**2) * (torch.sin(omegas)**2) / y 69 | nz = (torch.sin(etas)**2) / z 70 | 71 | return torch.stack([x, y, z], -1), torch.stack([nx, ny, nz], -1) 72 | 73 | 74 | def sample_uniformly_on_cube(alphas, sampler): 75 | """ 76 | Given the sampling steps in the parametric space, we want to ge the actual 77 | 3D points on the surface of the cube. 78 | 79 | Arguments: 80 | ---------- 81 | alphas: Tensor with size BxMx3, containing the size along each 82 | axis for the M primitives 83 | 84 | Returns: 85 | --------- 86 | P: Tensor of size BxMxSx3 that contains S sampled points from the 87 | surface of each primitive 88 | """ 89 | # TODO: Make sure that this is the proper way to do this! 90 | # Check the device of the angles and move all the tensors to that device 91 | device = alphas.device 92 | 93 | # Allocate memory to store the sampling steps 94 | B = alphas.shape[0] # batch size 95 | M = alphas.shape[1] # number of primitives 96 | S = sampler.n_samples 97 | N = S/6 98 | 99 | X_SQ = torch.zeros(B, M, S, 3).to(device) 100 | 101 | for b in range(B): 102 | for m in range(M): 103 | x_max = alphas[b, m, 0] 104 | y_max = alphas[b, m, 1] 105 | z_max = alphas[b, m, 2] 106 | x_min = -x_max 107 | y_min = -y_max 108 | z_min = -z_max 109 | 110 | X_SQ[b, m] = torch.stack([ 111 | torch.stack([ 112 | torch.ones((N, 1)).to(device)*x_min, 113 | torch.rand(N, 1).to(device)*(y_max-y_min) + y_min, 114 | torch.rand(N, 1).to(device)*(z_max-z_min) + z_min 115 | ], dim=-1).squeeze(), 116 | torch.stack([ 117 | torch.ones((N, 1)).to(device)*x_max, 118 | torch.rand(N, 1).to(device)*(y_max-y_min) + y_min, 119 | torch.rand(N, 1).to(device)*(z_max-z_min) + z_min 120 | ], dim=-1).squeeze(), 121 | torch.stack([ 122 | torch.rand(N, 1).to(device)*(x_max-x_min) + x_min, 123 | torch.ones((N, 1)).to(device)*y_min, 124 | torch.rand(N, 1).to(device)*(z_max-z_min) + z_min 125 | ], dim=-1).squeeze(), 126 | torch.stack([ 127 | torch.rand(N, 1).to(device)*(x_max-x_min) + x_min, 128 | torch.ones((N, 1)).to(device)*y_max, 129 | torch.rand(N, 1).to(device)*(z_max-z_min) + z_min 130 | ], dim=-1).squeeze(), 131 | torch.stack([ 132 | torch.rand(N, 1).to(device)*(x_max-x_min) + x_min, 133 | torch.rand(N, 1).to(device)*(y_max-y_min) + y_min, 134 | torch.ones((N, 1)).to(device)*z_min, 135 | ], dim=-1).squeeze(), 136 | torch.stack([ 137 | torch.rand(N, 1).to(device)*(x_max-x_min) + x_min, 138 | torch.rand(N, 1).to(device)*(y_max-y_min) + y_min, 139 | torch.ones((N, 1)).to(device)*z_max, 140 | ], dim=-1).squeeze() 141 | ]).view(-1, 3) 142 | 143 | normals = X_SQ.new_zeros(X_SQ.shape) 144 | normals[:, :, 0*N:1*N, 0] = -1 145 | normals[:, :, 1*N:2*N, 0] = 1 146 | normals[:, :, 2*N:3*N, 1] = -1 147 | normals[:, :, 3*N:4*N, 1] = 1 148 | normals[:, :, 4*N:5*N, 2] = -1 149 | normals[:, :, 5*N:6*N, 2] = 1 150 | 151 | # make sure that X_SQ has the expected shape 152 | assert X_SQ.shape == (B, M, S, 3) 153 | return X_SQ, normals 154 | 155 | 156 | class CuboidSampler(object): 157 | def __init__(self, n_samples): 158 | self._n_samples = n_samples 159 | 160 | @property 161 | def n_samples(self): 162 | return self._n_samples 163 | 164 | def sample(self, a1, a2, a3): 165 | pass 166 | 167 | def sample_on_batch(self, shapes, epsilons): 168 | pass 169 | 170 | 171 | class PrimitiveSampler(object): 172 | def __init__(self, n_samples): 173 | self._n_samples = n_samples 174 | 175 | @property 176 | def n_samples(self): 177 | return self._n_samples 178 | 179 | def sample_points_on_primitive(self, use_cuboid, alphas, epsilons): 180 | if not use_cuboid: 181 | return sample_uniformly_on_sq( 182 | alphas, 183 | epsilons, 184 | EqualDistanceSamplerSQ(self._n_samples) 185 | ) 186 | else: 187 | return sample_uniformly_on_cube( 188 | alphas, 189 | CuboidSampler(self._n_samples) 190 | ) 191 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def compose(*funcs): 5 | """Compose any number of functions by passing the output of one as argument 6 | to the other. 7 | 8 | TODO: Do we need support for multiple outputs? For instance in the case of 9 | f(g(*args)), can g return multiple values? 10 | """ 11 | def inner(*args, **kwargs): 12 | r = funcs[-1](*args, **kwargs) 13 | for f in reversed(funcs[:-1]): 14 | r = f(r) 15 | return r 16 | 17 | return inner 18 | 19 | 20 | def ensure_parent_directory_exists(path_to_file): 21 | directory = os.path.dirname(path_to_file) 22 | if not os.path.exists(directory): 23 | os.makedirs(directory) 24 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/filter_sqs.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from ..networks.primitive_parameters import PrimitiveParameters 5 | from ..primitives import sq_volumes 6 | 7 | 8 | def always(*args): 9 | return True 10 | 11 | 12 | def qos_less(qos_th): 13 | """Split iff qos is less than qos_th.""" 14 | def inner(P, depth, idx): 15 | return P[depth].qos[0, idx] < qos_th 16 | return inner 17 | 18 | 19 | def volume_larger(vol_th): 20 | """Accepts iff the volume is larger that vol_th.""" 21 | def inner(P, depth, idx): 22 | return sq_volumes(P[depth])[0, idx] > vol_th 23 | return inner 24 | 25 | 26 | def filter_primitives(P, predicate_split, predicate_accept): 27 | """Compute the indices of the leaves selected based on the provided 28 | predicates for splitting and accepting a primitive.""" 29 | # P should only contain one primitive 30 | for p in P: 31 | assert len(p.sizes) == 1 32 | 33 | # Do a depth first search with filters for split and accept 34 | primitives = [] 35 | nodes = [(0, 0)] 36 | max_depth = len(P)-1 37 | while nodes: 38 | depth, idx = nodes.pop() 39 | if depth == max_depth or not predicate_split(P, depth, idx): 40 | if predicate_accept(P, depth, idx): 41 | primitives.append((depth, idx)) 42 | else: 43 | nodes.append((depth+1, 2*idx)) 44 | nodes.append((depth+1, 2*idx+1)) 45 | 46 | return primitives 47 | 48 | 49 | def primitive_parameters_from_indices(P, indices): 50 | B = 1 51 | M = len(indices) 52 | return PrimitiveParameters.with_keys( 53 | probs=torch.ones(B, M), 54 | translations=torch.stack([ 55 | P[depth].translations_r[:, idx] 56 | for depth, idx in indices 57 | ], dim=1).view(B, -1), 58 | rotations=torch.stack([ 59 | P[depth].rotations_r[:, idx] 60 | for depth, idx in indices 61 | ], dim=1).view(B, -1), 62 | sizes=torch.stack([ 63 | P[depth].sizes_r[:, idx] 64 | for depth, idx in indices 65 | ], dim=1).view(B, -1), 66 | shapes=torch.stack([ 67 | P[depth].shapes_r[:, idx] 68 | for depth, idx in indices 69 | ], dim=1).view(B, -1), 70 | sharpness=torch.stack([ 71 | P[depth].shapes_r[:, idx] 72 | for depth, idx in indices 73 | ], dim=1).view(B, -1) 74 | ) 75 | 76 | 77 | def get_primitives_indices(P): 78 | """Compute the indices of the leaves""" 79 | # P should only contain one primitive 80 | for p in P: 81 | assert len(p.sizes) == 1 82 | 83 | # Do a depth first search with filters for split and accept 84 | primitives = [] 85 | nodes = [(0, 0)] 86 | max_depth = len(P)-1 87 | while nodes: 88 | depth, idx = nodes.pop() 89 | if depth == max_depth: 90 | primitives.append((depth, idx)) 91 | else: 92 | nodes.append((depth+1, 2*idx)) 93 | nodes.append((depth+1, 2*idx+1)) 94 | 95 | return primitives 96 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def compute_iou(occ1, occ2, weights=None, average=True): 6 | """Compute the intersection over union (IoU) for two sets of occupancy 7 | values. 8 | 9 | Arguments: 10 | ---------- 11 | occ1: Tensor of size BxN containing the first set of occupancy values 12 | occ2: Tensor of size BxN containing the first set of occupancy values 13 | 14 | Returns: 15 | ------- 16 | the IoU 17 | """ 18 | if not torch.is_tensor(occ1): 19 | occ1 = torch.tensor(occ1) 20 | occ2 = torch.tensor(occ2) 21 | 22 | if weights is None: 23 | weights = occ1.new_ones(occ1.shape) 24 | 25 | assert len(occ1.shape) == 2 26 | assert occ1.shape == occ2.shape 27 | 28 | # Convert them to boolean 29 | occ1 = occ1 >= 0.5 30 | occ2 = occ2 >= 0.5 31 | 32 | # Compute IoU 33 | area_union = (occ1 | occ2).float() 34 | area_union = (weights * area_union).sum(dim=-1) 35 | area_union = torch.max(area_union.new_tensor(1.0), area_union) 36 | area_intersect = (occ1 & occ2).float() 37 | area_intersect = (weights * area_intersect).sum(dim=-1) 38 | iou = (area_intersect / area_union) 39 | 40 | if average: 41 | return iou.mean().item() 42 | else: 43 | return iou 44 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/progbar.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | class Progbar(object): 10 | """Displays a progress bar. 11 | # Arguments 12 | target: Total number of steps expected, None if unknown. 13 | width: Progress bar width on screen. 14 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 15 | stateful_metrics: Iterable of string names of metrics that 16 | should *not* be averaged over time. Metrics in this list 17 | will be displayed as-is. All others will be averaged 18 | by the progbar before display. 19 | interval: Minimum visual progress update interval (in seconds). 20 | """ 21 | 22 | def __init__(self, target, width=30, verbose=1, interval=0.05, 23 | stateful_metrics=None): 24 | self.target = target 25 | self.width = width 26 | self.verbose = verbose 27 | self.interval = interval 28 | if stateful_metrics: 29 | self.stateful_metrics = set(stateful_metrics) 30 | else: 31 | self.stateful_metrics = set() 32 | 33 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 34 | sys.stdout.isatty()) or 35 | 'ipykernel' in sys.modules) 36 | self._total_width = 0 37 | self._seen_so_far = 0 38 | self._values = collections.OrderedDict() 39 | self._start = time.time() 40 | self._last_update = 0 41 | 42 | def update(self, current, values=None): 43 | """Updates the progress bar. 44 | # Arguments 45 | current: Index of current step. 46 | values: List of tuples: 47 | `(name, value_for_last_step)`. 48 | If `name` is in `stateful_metrics`, 49 | `value_for_last_step` will be displayed as-is. 50 | Else, an average of the metric over time will be displayed. 51 | """ 52 | values = values or [] 53 | for k, v in values: 54 | if k not in self.stateful_metrics: 55 | if k not in self._values: 56 | self._values[k] = [v * (current - self._seen_so_far), 57 | current - self._seen_so_far] 58 | else: 59 | self._values[k][0] += v * (current - self._seen_so_far) 60 | self._values[k][1] += (current - self._seen_so_far) 61 | else: 62 | # Stateful metrics output a numeric value. This representation 63 | # means "take an average from a single value" but keeps the 64 | # numeric formatting. 65 | self._values[k] = [v, 1] 66 | self._seen_so_far = current 67 | 68 | now = time.time() 69 | info = ' - %.0fs' % (now - self._start) 70 | if self.verbose == 1: 71 | if (now - self._last_update < self.interval and 72 | self.target is not None and current < self.target): 73 | return 74 | 75 | prev_total_width = self._total_width 76 | if self._dynamic_display: 77 | sys.stdout.write('\b' * prev_total_width) 78 | sys.stdout.write('\r') 79 | else: 80 | sys.stdout.write('\n') 81 | 82 | if self.target is not None: 83 | numdigits = int(np.floor(np.log10(self.target))) + 1 84 | barstr = '%%%dd/%d [' % (numdigits, self.target) 85 | bar = barstr % current 86 | prog = float(current) / self.target 87 | prog_width = int(self.width * prog) 88 | if prog_width > 0: 89 | bar += ('=' * (prog_width - 1)) 90 | if current < self.target: 91 | bar += '>' 92 | else: 93 | bar += '=' 94 | bar += ('.' * (self.width - prog_width)) 95 | bar += ']' 96 | else: 97 | bar = '%7d/Unknown' % current 98 | 99 | self._total_width = len(bar) 100 | sys.stdout.write(bar) 101 | 102 | if current: 103 | time_per_unit = (now - self._start) / current 104 | else: 105 | time_per_unit = 0 106 | if self.target is not None and current < self.target: 107 | eta = time_per_unit * (self.target - current) 108 | if eta > 3600: 109 | eta_format = ('%d:%02d:%02d' % 110 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 111 | elif eta > 60: 112 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 113 | else: 114 | eta_format = '%ds' % eta 115 | 116 | info = ' - ETA: %s' % eta_format 117 | else: 118 | if time_per_unit >= 1: 119 | info += ' %.0fs/step' % time_per_unit 120 | elif time_per_unit >= 1e-3: 121 | info += ' %.0fms/step' % (time_per_unit * 1e3) 122 | else: 123 | info += ' %.0fus/step' % (time_per_unit * 1e6) 124 | 125 | for k in self._values: 126 | info += ' - %s:' % k 127 | if isinstance(self._values[k], list): 128 | avg = np.mean( 129 | self._values[k][0] / max(1, self._values[k][1])) 130 | if abs(avg) > 1e-3: 131 | info += ' %.4f' % avg 132 | else: 133 | info += ' %.4e' % avg 134 | else: 135 | info += ' %s' % self._values[k] 136 | 137 | self._total_width += len(info) 138 | if prev_total_width > self._total_width: 139 | info += (' ' * (prev_total_width - self._total_width)) 140 | 141 | if self.target is not None and current >= self.target: 142 | info += '\n' 143 | 144 | sys.stdout.write(info) 145 | sys.stdout.flush() 146 | 147 | elif self.verbose == 2: 148 | if self.target is None or current >= self.target: 149 | for k in self._values: 150 | info += ' - %s:' % k 151 | avg = np.mean( 152 | self._values[k][0] / max(1, self._values[k][1])) 153 | if avg > 1e-3: 154 | info += ' %.4f' % avg 155 | else: 156 | info += ' %.4e' % avg 157 | info += '\n' 158 | 159 | sys.stdout.write(info) 160 | sys.stdout.flush() 161 | 162 | self._last_update = now 163 | 164 | def add(self, n, values=None): 165 | self.update(self._seen_so_far + n, values) 166 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/sq_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import trimesh 5 | 6 | from ..primitives import transform_to_world_coordinates_system, \ 7 | transform_to_primitives_centric_system, inside_outside_function,\ 8 | quaternions_to_rotation_matrices 9 | 10 | 11 | def single_sq_mesh(alpha, epsilon, translation, rotation): 12 | """Create mesh for a superquadric with the provided primitive 13 | configuration. 14 | 15 | Arguments 16 | --------- 17 | alpha: Array of 3 sizes, along each axis 18 | epsilon: Array of 2 shapes, along each a 19 | translation: Array of 3 dimensional center 20 | rotation: Array of size 3x3 containing the rotations 21 | """ 22 | def fexp(x, p): 23 | return np.sign(x)*(np.abs(x)**p) 24 | 25 | def sq_surface(a1, a2, a3, e1, e2, eta, omega): 26 | x = a1 * fexp(np.cos(eta), e1) * fexp(np.cos(omega), e2) 27 | y = a2 * fexp(np.cos(eta), e1) * fexp(np.sin(omega), e2) 28 | z = a3 * fexp(np.sin(eta), e1) 29 | return x, y, z 30 | 31 | # triangulate the sphere to be used with the SQs 32 | eta = np.linspace(-np.pi/2, np.pi/2, 100, endpoint=True) 33 | omega = np.linspace(-np.pi, np.pi, 100, endpoint=True) 34 | triangles = [] 35 | for o1, o2 in zip(np.roll(omega, 1), omega): 36 | triangles.extend([ 37 | (eta[0], 0), 38 | (eta[1], o2), 39 | (eta[1], o1), 40 | ]) 41 | for e in range(1, len(eta)-2): 42 | for o1, o2 in zip(np.roll(omega, 1), omega): 43 | triangles.extend([ 44 | (eta[e], o1), 45 | (eta[e+1], o2), 46 | (eta[e+1], o1), 47 | (eta[e], o1), 48 | (eta[e], o2), 49 | (eta[e+1], o2), 50 | ]) 51 | for o1, o2 in zip(np.roll(omega, 1), omega): 52 | triangles.extend([ 53 | (eta[-1], 0), 54 | (eta[-2], o1), 55 | (eta[-2], o2), 56 | ]) 57 | triangles = np.array(triangles) 58 | eta, omega = triangles[:, 0], triangles[:, 1] 59 | 60 | # collect the pretriangulated vertices of each SQ 61 | vertices = [] 62 | a, e, t, R = list(map( 63 | np.asarray, 64 | [alpha, epsilon, translation, rotation] 65 | )) 66 | M, _ = a.shape # number of superquadrics 67 | assert R.shape == (M, 3, 3) 68 | assert t.shape == (M, 3) 69 | for i in range(M): 70 | a1, a2, a3 = a[i] 71 | e1, e2 = e[i] 72 | x, y, z = sq_surface(a1, a2, a3, e1, e2, eta, omega) 73 | # Get points on the surface of each SQ 74 | V = np.stack([x, y, z], axis=-1) 75 | V = R[i].T.dot(V.T).T + t[i].reshape(1, 3) 76 | vertices.append(V) 77 | 78 | # Finalize the mesh 79 | vertices = np.vstack(vertices) 80 | faces = np.arange(len(vertices)).reshape(-1, 3) 81 | return trimesh.Trimesh(vertices=vertices, faces=faces) 82 | 83 | 84 | def sq_meshes(primitive_params, indices=None): 85 | translations = primitive_params.translations_r 86 | rotations = primitive_params.rotations_r 87 | Rs = quaternions_to_rotation_matrices( 88 | primitive_params.rotations.view(-1, 4) 89 | ).view(1, -1, 3, 3) 90 | alphas = primitive_params.sizes_r 91 | epsilons = primitive_params.shapes_r 92 | probs = primitive_params.probs 93 | 94 | M = primitive_params.n_primitives 95 | if indices is None: 96 | indices = range(M) 97 | 98 | return [ 99 | single_sq_mesh( 100 | alphas[:, i, :].cpu().detach().numpy(), 101 | epsilons[:, i, :].cpu().detach().numpy(), 102 | translations[:, i, :].cpu().detach().numpy(), 103 | Rs[:, i].cpu().detach().numpy() 104 | ) for i in indices 105 | ] 106 | 107 | 108 | 109 | def sq_mesh_from_primitive_params(primitive_params, S=100000, normals=False, 110 | prim_indices=False): 111 | translations = primitive_params.translations_r 112 | rotations = primitive_params.rotations_r 113 | Rs = quaternions_to_rotation_matrices( 114 | primitive_params.rotations.view(-1, 4) 115 | ).view(1, -1, 3, 3) 116 | alphas = primitive_params.sizes_r 117 | epsilons = primitive_params.shapes_r 118 | probs = primitive_params.probs 119 | M = primitive_params.n_primitives 120 | meshes = sq_meshes(primitive_params) 121 | areas = np.array([m.area for m in meshes]) 122 | areas /= areas.sum() 123 | 124 | P = np.empty((0, 3)) 125 | N = np.empty((0, 3)) 126 | I = np.empty((0, 1)) 127 | cnt = 0 128 | while cnt < S: 129 | n_points = np.random.multinomial(S, areas) 130 | for i in range(M): 131 | points, faces = trimesh.sample.sample_surface(meshes[i], n_points[i]) 132 | if len(points) == 0: 133 | continue 134 | # Filter anything that is in an SQ other than i 135 | X_SQ = torch.from_numpy(points) 136 | X_SQ = X_SQ.unsqueeze(0).float().to(alphas.device) 137 | 138 | # Transform the points on the SQs to the other SQs 139 | X_SQ_transformed = transform_to_primitives_centric_system( 140 | X_SQ, translations, rotations 141 | ) 142 | # Compute the inside outside function for every point on every 143 | # primitive to every other primitive 144 | F = inside_outside_function( 145 | X_SQ_transformed, alphas.detach(), epsilons.detach() 146 | ) 147 | F[:, :, i] = 2.0 148 | mask = (F>1).all(dim=-1)[0].cpu().numpy().astype(bool) 149 | points = points[mask] 150 | P = np.vstack([P, points]) 151 | N = np.vstack([N, meshes[i].face_normals[faces[mask]]]) 152 | I = np.vstack([I, np.ones((len(points), 1))*i]) 153 | cnt += len(points) 154 | idxs = np.random.choice(len(P), S, replace=False) 155 | retval = (P[idxs],) 156 | if normals: 157 | retval += (N[idxs],) 158 | if prim_indices: 159 | retval += (I[idxs].astype(int),) 160 | 161 | if len(retval) == 1: 162 | retval = retval[0] 163 | 164 | return retval 165 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/stats_logger.py: -------------------------------------------------------------------------------- 1 | """Stats logger provides a value registry for logging training stats. It is a 2 | separate object for backwards compatibility.""" 3 | 4 | from .value_registry import ValueRegistry 5 | 6 | 7 | class StatsLogger(object): 8 | @staticmethod 9 | def instance(): 10 | return ValueRegistry.get_instance("stats_logger") 11 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/value_registry.py: -------------------------------------------------------------------------------- 1 | """Value registry provides instances of global dictionary like variables. 2 | Although an antipattern it allows for fast prototyping and experimenting with 3 | rapidly changing interfaces.""" 4 | 5 | 6 | class ValueRegistry(object): 7 | _instances = {} 8 | 9 | @staticmethod 10 | def get_instance(key): 11 | if key not in ValueRegistry._instances: 12 | ValueRegistry._instances[key] = ValueRegistry(key) 13 | return ValueRegistry._instances[key] 14 | 15 | def __init__(self, key): 16 | if key in self._instances: 17 | raise RuntimeError("There can be only one! (imdb:0091203)") 18 | self._data = {} 19 | 20 | def clear(self): 21 | self._data.clear() 22 | 23 | def __contains__(self, key): 24 | return key in self._data 25 | 26 | def __getitem__(self, key): 27 | return self._data[key] 28 | 29 | def __setitem__(self, key, value): 30 | return self.update(key, value) 31 | 32 | def update(self, key, value): 33 | self._data[key] = value 34 | 35 | def increment(self, key, value): 36 | self._data[key] = self.get(key, 0) + value 37 | 38 | def get(self, key, value): 39 | return self._data.get(key, value) 40 | -------------------------------------------------------------------------------- /hierarchical_primitives/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fexp(x, p): 5 | return np.sign(x)*(np.abs(x)**p) 6 | 7 | 8 | def sq_surface(a1, a2, a3, e1, e2, eta, omega): 9 | x = a1 * fexp(np.cos(eta), e1) * fexp(np.cos(omega), e2) 10 | y = a2 * fexp(np.cos(eta), e1) * fexp(np.sin(omega), e2) 11 | z = a3 * fexp(np.sin(eta), e1) 12 | return x, y, z 13 | 14 | 15 | def points_on_sq_surface(a1, a2, a3, e1, e2, R, t, Kx, Ky, n_samples=100): 16 | """Sample a set of points on the surface of SQ specified using a set of 17 | parameters that denote its shape [a1, a2, a3], size [e1, e2] and 18 | pose [R,t] 19 | """ 20 | assert R.shape == (3, 3) 21 | assert t.shape == (3, 1) 22 | 23 | eta = np.linspace(-np.pi/2, np.pi/2, n_samples, endpoint=True) 24 | omega = np.linspace(-np.pi, np.pi, n_samples, endpoint=True) 25 | eta, omega = np.meshgrid(eta, omega) 26 | x, y, z = sq_surface(a1, a2, a3, e1, e2, eta, omega) 27 | 28 | # Apply the deformations 29 | fx = Kx * z / a3 30 | fx += 1 31 | fy = Ky * z / a3 32 | fy += 1 33 | fz = 1 34 | 35 | x = x * fx 36 | y = y * fy 37 | z = z * fz 38 | 39 | # Get an array of size 3x10000 that contains the points of the SQ 40 | points = np.stack([x, y, z]).reshape(3, -1) 41 | points_transformed = R.T.dot(points) + t 42 | 43 | x_tr = points_transformed[0].reshape(n_samples, n_samples) 44 | y_tr = points_transformed[1].reshape(n_samples, n_samples) 45 | z_tr = points_transformed[2].reshape(n_samples, n_samples) 46 | 47 | return x_tr, y_tr, z_tr, points_transformed 48 | 49 | 50 | def points_on_cuboid(a1, a2, a3, e1, e2, R, t, n_samples=100): 51 | """Sample a set of points on the surface of a cuboid specified using a set 52 | of parameters that denote its shape [a1, a2, a3], size [e1, e2] and pose 53 | [R,t] 54 | """ 55 | assert R.shape == (3, 3) 56 | assert t.shape == (3, 1) 57 | 58 | X = np.array([ 59 | [0, 1, 1, 0, 0, 0, 0, 0, 0], 60 | [0, 1, 1, 0, 0, 1, 1, 1, 1] 61 | ], dtype=np.float32) 62 | X[X == 1.0] = a1 63 | X[X == 0.0] = -a1 64 | 65 | Y = np.array([ 66 | [0, 0, 0, 0, 0, 1, 1, 0, 0], 67 | [1, 1, 1, 1, 1, 1, 1, 0, 0] 68 | ], dtype=np.float32) 69 | Y[Y == 1.0] = a2 70 | Y[Y == 0.0] = -a2 71 | 72 | Z = np.array([ 73 | [1, 1, 0, 0, 1, 1, 0, 0, 1], 74 | [1, 1, 0, 0, 1, 1, 0, 0, 1] 75 | ], dtype=np.float32) 76 | Z[Z == 1.0] = a3 77 | Z[Z == 0.0] = -a3 78 | 79 | points = np.stack([X, Y, Z]).reshape(3, -1) 80 | points_transformed = R.T.dot(points) + t 81 | 82 | assert points.shape == (3, 18) 83 | 84 | x_tr = points_transformed[0].reshape(2, 9) 85 | y_tr = points_transformed[1].reshape(2, 9) 86 | z_tr = points_transformed[2].reshape(2, 9) 87 | return x_tr, y_tr, z_tr, points_transformed 88 | 89 | 90 | def get_shape_configuration(use_cuboid): 91 | if use_cuboid: 92 | return points_on_cuboid 93 | else: 94 | return points_on_sq_surface 95 | -------------------------------------------------------------------------------- /img/chair.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/img/chair.gif -------------------------------------------------------------------------------- /img/human_punching.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/img/human_punching.gif -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paschalidoud/hierarchical_primitives/2fa5409ad29f92bedfcaa4cba5de1fa808e43e9b/img/teaser.png -------------------------------------------------------------------------------- /scripts/arguments.py: -------------------------------------------------------------------------------- 1 | def add_dataset_parameters(parser): 2 | parser.add_argument( 3 | "--model_tags", 4 | type=lambda x: x.split(","), 5 | default=[], 6 | help="Tags to the models to be used" 7 | ) 8 | parser.add_argument( 9 | "--category_tags", 10 | type=lambda x: x.split(","), 11 | default=[], 12 | help="Category tags to the models to be used" 13 | ) 14 | parser.add_argument( 15 | "--random_subset", 16 | type=float, 17 | default=1.0, 18 | help="Percentage of dataset to be used for evaluation" 19 | ) 20 | parser.add_argument( 21 | "--val_random_subset", 22 | type=float, 23 | default=1.0, 24 | help="Percentage of dataset to be used for validation" 25 | ) 26 | -------------------------------------------------------------------------------- /scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from evaluate import MeshEvaluator 4 | 5 | from hierarchical_primitives.common.dataset import PointsAndLabels, PointsOnMesh, \ 6 | DatasetCollection 7 | from hierarchical_primitives.common.model_factory import DatasetBuilder 8 | from hierarchical_primitives.primitives import \ 9 | compute_accuracy_and_recall_from_primitive_params 10 | 11 | 12 | def report_metrics( 13 | prim_params, 14 | config, 15 | dataset_type, 16 | model_tags, 17 | dataset_directory 18 | ): 19 | data_source = (DatasetBuilder(config) 20 | .with_dataset(dataset_type) 21 | .filter_tags(model_tags) 22 | .build(dataset_directory)) 23 | in_bbox = PointsAndLabels(data_source) 24 | on_surface = PointsOnMesh(data_source) 25 | dataset = DatasetCollection(in_bbox, on_surface) 26 | for y_target in DataLoader(dataset, batch_size=1, num_workers=4): 27 | accuracy, positive_accuracy =\ 28 | compute_accuracy_and_recall_from_primitive_params( 29 | y_target[:-1], prim_params) 30 | print(("accuracy:%.7f - recall:%.7f") % (accuracy, positive_accuracy)) 31 | 32 | metrics = MeshEvaluator().eval_mesh_with_primitive_params( 33 | prim_params, 34 | y_target[0], 35 | y_target[1].squeeze(-1), 36 | y_target[2].squeeze(-1), 37 | y_target[3][..., :3], 38 | config 39 | ) 40 | print(("chamfer-l1: %.7f - iou %.7f") % ( 41 | metrics["ch_l1"], metrics["iou"] 42 | )) 43 | -------------------------------------------------------------------------------- /scripts/evaluate_to_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script used to evaluate the predicted mesh and save the results in an sqlite 3 | database.""" 4 | 5 | import argparse 6 | from hashlib import sha256 7 | 8 | import os 9 | import sys 10 | import time 11 | 12 | import mysql.connector 13 | from mysql.connector import errorcode 14 | import numpy as np 15 | import torch 16 | 17 | from arguments import add_dataset_parameters 18 | from evaluate import MeshEvaluator 19 | from training_utils import get_loss_options, load_config 20 | 21 | from hierarchical_primitives.common.base import build_dataset 22 | from hierarchical_primitives.common.dataset import DatasetWithTags 23 | from hierarchical_primitives.networks.base import build_network 24 | from hierarchical_primitives.utils.filter_sqs import filter_primitives, \ 25 | primitive_parameters_from_indices, qos_less, volume_larger 26 | from hierarchical_primitives.utils.progbar import Progbar 27 | 28 | 29 | def hash_file(filepath): 30 | h = sha256() 31 | with open(filepath, "rb") as f: 32 | h.update(f.read()) 33 | return h.hexdigest() 34 | 35 | 36 | def get_db(dbhost): 37 | CREATE_TABLE = """CREATE TABLE results ( 38 | model_tag VARCHAR(255), 39 | weight_file VARCHAR(255), 40 | config VARCHAR(255), 41 | run INT, 42 | epoch INT, 43 | ch_l1 REAL, 44 | iou REAL, 45 | subset REAL, 46 | CONSTRAINT pk PRIMARY KEY (model_tag, weight_file, config, run) 47 | ); 48 | """ 49 | conn = mysql.connector.connect( 50 | user="evalscript", password=os.getenv("DBPASS"), 51 | host=dbhost, 52 | database="shapenet_evaluation" 53 | ) 54 | cursor = conn.cursor() 55 | try: 56 | cursor.execute(CREATE_TABLE) 57 | except mysql.connector.Error as err: 58 | if err.errno != errorcode.ER_TABLE_EXISTS_ERROR: 59 | raise 60 | conn.commit() 61 | 62 | return conn 63 | 64 | 65 | def start_run(conn, model_tag, weight_file, config, run): 66 | tags = None 67 | while True: 68 | cursor = conn.cursor() 69 | cursor.execute( 70 | ("SELECT model_tag FROM results " 71 | "WHERE weight_file=%s AND config=%s AND run=%s"), 72 | (weight_file, config, run) 73 | ) 74 | tags = set(t[0] for t in cursor) 75 | try: 76 | cursor.execute( 77 | ("INSERT INTO results (model_tag, weight_file, config, run) " 78 | "VALUES (%s, %s, %s, %s)"), 79 | (model_tag, weight_file, config, run) 80 | ) 81 | conn.commit() 82 | return True, tags 83 | except mysql.connector.IntegrityError: 84 | return False, tags 85 | except mysql.connector.Error as err: 86 | conn.rollback() 87 | print(err) 88 | time.sleep(1) 89 | finally: 90 | cursor.close() 91 | 92 | 93 | def fill_run(conn, model_tag, weight_file, config, run, epoch, stats, subset): 94 | while True: 95 | try: 96 | cursor = conn.cursor() 97 | cursor.execute( 98 | ("UPDATE results " 99 | "SET epoch=%s, ch_l1=%s, iou=%s, subset=%s " 100 | "WHERE model_tag=%s AND weight_file=%s AND config=%s " 101 | "AND run=%s"), 102 | (epoch, float(stats["chamfer"]), float(stats["iou"]), subset, 103 | model_tag, weight_file, config, run) 104 | ) 105 | conn.commit() 106 | break 107 | except mysql.connector.Error as err: 108 | conn.rollback() 109 | print(err) 110 | time.sleep(1) 111 | finally: 112 | cursor.close() 113 | 114 | 115 | def get_started_tags(conn, weight_file, config, run): 116 | cursor = conn.cursor() 117 | cursor.execute( 118 | ("SELECT model_tag FROM results " 119 | "WHERE weight_file=%s AND config=%s AND run=%s"), 120 | (weight_file, config, run) 121 | ) 122 | return set(t[0] for t in cursor) 123 | 124 | 125 | def main(argv): 126 | parser = argparse.ArgumentParser( 127 | description="Do the forward pass and estimate a set of primitives" 128 | ) 129 | parser.add_argument( 130 | "dataset_directory", 131 | help="Path to the directory containing the dataset" 132 | ) 133 | parser.add_argument( 134 | "train_test_splits_file", 135 | help="Path to the train-test splits file" 136 | ) 137 | parser.add_argument( 138 | "output_db", 139 | help="Save the results in this sqlite database" 140 | ) 141 | parser.add_argument( 142 | "config_file", 143 | help="Path to the file that contains the experiment configuration" 144 | ) 145 | parser.add_argument( 146 | "--weight_file", 147 | default=None, 148 | help="The path to the previously trainined model to be used" 149 | ) 150 | parser.add_argument( 151 | "--eval_on_train", 152 | action="store_true", 153 | help="When true evaluate on training set" 154 | ) 155 | parser.add_argument( 156 | "--run", 157 | type=int, 158 | default=0, 159 | help="Run id to be able to evaluate many times the same model" 160 | ) 161 | 162 | add_dataset_parameters(parser) 163 | args = parser.parse_args(argv) 164 | 165 | # Get the database connection 166 | conn = get_db(args.output_db) 167 | 168 | # Build the network architecture to be used for training 169 | config = load_config(args.config_file) 170 | network = build_network(args.config_file, args.weight_file, device="cpu") 171 | network.eval() 172 | 173 | eval_config = config.get("eval", {}) 174 | config_hash = hash_file(args.config_file) 175 | captured_at_epoch = ( 176 | -1 if args.weight_file is None else 177 | int(args.weight_file.split("/")[-1].split("_")[-1]) 178 | ) 179 | 180 | dataset = build_dataset( 181 | config, 182 | args.dataset_directory, 183 | args.dataset_type, 184 | args.train_test_splits_file, 185 | args.model_tags, 186 | args.category_tags, 187 | config["data"].get("test_split", ["test"]) if not args.eval_on_train else ["train"], 188 | random_subset=args.random_subset 189 | ) 190 | dataset = DatasetWithTags(dataset) 191 | 192 | prog = Progbar(len(dataset)) 193 | tagset = get_started_tags( 194 | conn, 195 | args.weight_file or "", 196 | config_hash, 197 | args.run 198 | ) 199 | i = 0 200 | for sample in dataset: 201 | if sample[-1] in tagset: 202 | continue 203 | start, tagset = start_run( 204 | conn, 205 | sample[-1], 206 | args.weight_file or "", 207 | config_hash, 208 | args.run 209 | ) 210 | if not start: 211 | continue 212 | 213 | X = sample[0].unsqueeze(0) 214 | y_target = [yi.unsqueeze(0) for yi in sample[1:-1]] 215 | 216 | # Do the forward pass and estimate the primitive parameters 217 | y_hat = network(X) 218 | if ( 219 | "qos_threshold" in eval_config or 220 | "vol_threshold" in eval_config 221 | ): 222 | primitive_indices = filter_primitives( 223 | y_hat.fit, 224 | qos_less(float(eval_config.get("qos_threshold", 1))), 225 | volume_larger(float(eval_config.get("vol_threshold", 0))) 226 | ) 227 | if len(primitive_indices) == 0: 228 | continue 229 | y_hat = primitive_parameters_from_indices( 230 | y_hat.fit, 231 | primitive_indices 232 | ) 233 | 234 | metrics = MeshEvaluator().eval_mesh_with_primitive_params( 235 | y_hat, 236 | y_target[0], 237 | y_target[1].squeeze(-1), 238 | y_target[2].squeeze(-1), 239 | get_loss_options(config) 240 | ) 241 | fill_run( 242 | conn, 243 | sample[-1], 244 | args.weight_file or "", 245 | config_hash, 246 | args.run, 247 | captured_at_epoch, 248 | metrics, 249 | args.random_subset 250 | ) 251 | 252 | # Update progress bar 253 | prog.update(i+1) 254 | i += 1 255 | prog.update(len(dataset)) 256 | 257 | 258 | if __name__ == "__main__": 259 | main(sys.argv[1:]) 260 | -------------------------------------------------------------------------------- /scripts/generate_surface_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script used for exporting occupancy pairs for training 3 | """ 4 | import argparse 5 | import os 6 | import subprocess 7 | import sys 8 | 9 | import numpy as np 10 | import yaml 11 | try: 12 | from yaml import CLoader as Loader 13 | except ImportError: 14 | from yaml import Loader 15 | 16 | from hierarchical_primitives.common.model_factory import DatasetBuilder 17 | from hierarchical_primitives.common.base import splits_factory 18 | from hierarchical_primitives.external.libmesh import check_mesh_contains 19 | from hierarchical_primitives.mesh import Trimesh 20 | from hierarchical_primitives.utils.progbar import Progbar 21 | 22 | 23 | def export_surface_pairs(path_to_mesh_file, N, normalize): 24 | m = Trimesh(path_to_mesh_file, normalize=normalize) 25 | P = m.sample_faces(N) 26 | points = P[:, :3] 27 | normals = P[:, 3:] 28 | 29 | s = np.random.randn(N, 1) * 0.01 30 | points_hat = points + s * normals 31 | labels = check_mesh_contains(m.mesh, points_hat) 32 | 33 | return points_hat, labels 34 | 35 | 36 | def export_volume_pairs(path_to_mesh_file, N, normalize): 37 | m = Trimesh(path_to_mesh_file, normalize=normalize) 38 | random_points = np.random.rand(N, 3) - 0.5 39 | labels = check_mesh_contains(m.mesh, random_points.reshape(-1, 3)) 40 | 41 | return random_points, labels 42 | 43 | 44 | def export_occupancy_pairs(occupancy_type): 45 | return { 46 | "surface": export_surface_pairs, 47 | "volume": export_volume_pairs, 48 | }[occupancy_type] 49 | 50 | 51 | def occupancy_pairs_subdir(occupancy_type, normalize_mesh): 52 | subdir = { 53 | "surface": "surface_points_seq", 54 | "volume": "points_seq" 55 | }[occupancy_type] 56 | return subdir 57 | 58 | 59 | def ensure_parent_directory_exists(filepath): 60 | try: 61 | os.mkdir(os.path.dirname(filepath)) 62 | except FileExistsError: 63 | pass 64 | 65 | 66 | class DirLock(object): 67 | def __init__(self, dirpath): 68 | self._dirpath = dirpath 69 | self._acquired = False 70 | 71 | @property 72 | def is_acquired(self): 73 | return self._acquired 74 | 75 | def acquire(self): 76 | if self._acquired: 77 | return 78 | try: 79 | os.mkdir(self._dirpath) 80 | self._acquired = True 81 | except FileExistsError: 82 | pass 83 | 84 | def release(self): 85 | if not self._acquired: 86 | return 87 | try: 88 | os.rmdir(self._dirpath) 89 | self._acquired = False 90 | except FileNotFoundError: 91 | self._acquired = False 92 | except OSError: 93 | pass 94 | 95 | def __enter__(self): 96 | self.acquire() 97 | return self 98 | 99 | def __exit__(self, exc_type, exc_value, traceback): 100 | self.release() 101 | 102 | 103 | def main(argv): 104 | parser = argparse.ArgumentParser( 105 | description="Export occupancy pairs for training" 106 | ) 107 | parser.add_argument( 108 | "dataset_directory", 109 | help="Path to the directory containing the dataset" 110 | ) 111 | parser.add_argument( 112 | "train_test_splits_file", 113 | help="Path to the train-test splits file" 114 | ) 115 | parser.add_argument( 116 | "--dataset_type", 117 | default="shapenet_v1", 118 | choices=[ 119 | "shapenet_quad", 120 | "shapenet_v1", 121 | "shapenet_v2", 122 | "surreal_bodies", 123 | "dynamic_faust", 124 | ], 125 | help="The type of the dataset type to be used" 126 | ) 127 | parser.add_argument( 128 | "--occupancy_type", 129 | default="surface", 130 | choices=[ 131 | "surface", 132 | "volume" 133 | ], 134 | help="Choose whether to export occpairs from surface or volume" 135 | ) 136 | parser.add_argument( 137 | "--model_tags", 138 | type=lambda x: x.split(","), 139 | default=[], 140 | help="Tags to the models to be used" 141 | ) 142 | parser.add_argument( 143 | "--category_tags", 144 | type=lambda x: x.split(","), 145 | default=[], 146 | help="Category tags to the models to be used" 147 | ) 148 | parser.add_argument( 149 | "--n_surface_samples", 150 | type=int, 151 | default=100000, 152 | help="Number of points to be sampled from the surface of the mesh" 153 | ) 154 | parser.add_argument( 155 | "--normalize_mesh", 156 | action="store_true", 157 | help="When set normalize mesh while loading" 158 | ) 159 | 160 | args = parser.parse_args(argv) 161 | 162 | dataset = (DatasetBuilder(dict(data={})) 163 | .with_dataset(args.dataset_type) 164 | .filter_train_test( 165 | splits_factory(args.dataset_type)(args.train_test_splits_file), 166 | ["train", "test", "val"] 167 | ) 168 | .filter_category_tags(args.category_tags) 169 | .filter_tags(args.model_tags) 170 | .build(args.dataset_directory)) 171 | 172 | prog = Progbar(len(dataset)) 173 | i = 0 174 | for sample in dataset: 175 | # Update progress bar 176 | prog.update(i+1) 177 | i += 1 178 | 179 | # Assemble the target path and ensure the parent dir exists 180 | category_tag = sample.tag.split(":")[0] 181 | model_tag = sample.tag.split(":")[-1] 182 | path_to_file = os.path.join( 183 | args.dataset_directory, 184 | category_tag, 185 | occupancy_pairs_subdir(args.occupancy_type, args.normalize_mesh), 186 | "{}.npz".format(model_tag) 187 | ) 188 | ensure_parent_directory_exists(path_to_file) 189 | 190 | # Make sure we are the only ones creating this file 191 | with DirLock(path_to_file + ".lock") as lock: 192 | if not lock.is_acquired: 193 | continue 194 | if os.path.exists(path_to_file): 195 | continue 196 | 197 | points, labels = export_occupancy_pairs(args.occupancy_type)( 198 | sample.path_to_mesh_file, 199 | args.n_surface_samples, 200 | normalize=args.normalize_mesh 201 | ) 202 | 203 | np.savez( 204 | path_to_file, 205 | points=points.reshape(-1, 3), 206 | occupancies=labels 207 | ) 208 | 209 | prog.update(len(dataset)) 210 | 211 | 212 | if __name__ == "__main__": 213 | main(sys.argv[1:]) 214 | -------------------------------------------------------------------------------- /scripts/output_logger.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from hierarchical_primitives.utils.progbar import Progbar 4 | from hierarchical_primitives.utils.stats_logger import StatsLogger 5 | 6 | 7 | _REGULARIZERS = { 8 | "regularizers.sparsity": "sparsity", 9 | "regularizers.entropy_bernoulli": "entropy", 10 | "regularizers.parsimony": "parsimony", 11 | "regularizers.overlapping": "overl", 12 | "regularizers.proximity": "proximity", 13 | "regularizers.siblings_proximity": "siblings_prox", 14 | "regularizers.overlapping_on_depths": "part-overl", 15 | "regularizers.volumes": "vol" 16 | } 17 | 18 | _LOSSES = { 19 | "losses.cvrg": "cvrg", 20 | "losses.cnst": "cnst", 21 | "losses.partition": "part", 22 | "losses.hierarchy": "hier", 23 | "losses.vae": "vae", 24 | "losses.coverage": "cvrg", 25 | "losses.fit": "fit", 26 | "losses.qos": "qos", 27 | "losses.prox": "prox", 28 | "losses.reconstruction": "rec", 29 | "losses.kinematic": "kin", 30 | "losses.structure": "str", 31 | "losses.fit_parts": "fit_parts", 32 | "losses.fit_shape": "fit_shape" 33 | } 34 | 35 | _METRICS = { 36 | "metrics.positive_accuracy": "pos_acc", 37 | "metrics.accuracy": "acc", 38 | "metrics.iou": "iou", 39 | "metrics.chl1": "chl1", 40 | "metrics.exp_n_prims": "exp_n_prims" 41 | } 42 | 43 | 44 | class LossLogger(object): 45 | def __init__(self, epochs, steps_per_epoch, keys, messages, 46 | stats_filepath, prefix="Epoch {}/{}"): 47 | self._prefix = prefix 48 | self._epochs = epochs 49 | self._steps_per_epoch = steps_per_epoch 50 | self._keys = keys 51 | self._messages = messages 52 | 53 | self._stats = StatsLogger.instance() 54 | self._stats_fp = open(stats_filepath, "a") 55 | self._epoch = 0 56 | 57 | def new_epoch(self, e): 58 | print(self._prefix.format(e, self._epochs)) 59 | self._progbar = Progbar(self._steps_per_epoch) 60 | self._stats.clear() 61 | self._epoch = e 62 | 63 | def new_batch(self, batch_index, batch_loss): 64 | stats = [("loss", batch_loss)] + [ 65 | (m, self._stats[k]) for k, m in zip(self._keys, self._messages) 66 | if k in self._stats 67 | ] 68 | self._progbar.update(batch_index+1, stats) 69 | self._save_to_file(self._epoch, batch_index, stats) 70 | self._stats.clear() 71 | 72 | def _save_to_file(self, epoch, batch, stats): 73 | if epoch == 0 and batch == 0: 74 | print( 75 | " ".join(["epoch", "batch"] + [s[0] for s in stats]), 76 | file=self._stats_fp, 77 | ) 78 | print( 79 | " ".join([str(epoch), str(batch)] + [str(s[1]) for s in stats]), 80 | file=self._stats_fp, 81 | ) 82 | self._stats_fp.flush() 83 | 84 | 85 | def get_logger(epochs, steps_per_epoch, stats_filepath, prefix="Epoch {}/{}"): 86 | keys, messages = list(zip(*( 87 | list(_LOSSES.items()) + 88 | list(_REGULARIZERS.items()) + 89 | list(_METRICS.items()) 90 | ))) 91 | 92 | return LossLogger( 93 | epochs, steps_per_epoch, keys, messages, stats_filepath, prefix=prefix 94 | ) 95 | -------------------------------------------------------------------------------- /scripts/render_dfaust.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | from os import path 5 | import sys 6 | 7 | from simple_3dviz import Mesh 8 | from simple_3dviz.scenes import Scene 9 | from simple_3dviz.utils import save_frame 10 | 11 | 12 | def render_dfaust(scene, prev_renderable, seq, target): 13 | new_renderable = Mesh.from_file(seq) 14 | scene.remove(prev_renderable) 15 | scene.add(new_renderable) 16 | scene.render() 17 | save_frame(target, scene.frame) 18 | 19 | return new_renderable 20 | 21 | 22 | def get_scene(): 23 | scene = Scene((256, 256)) 24 | scene.camera_position = (1, 1.5, 3) 25 | scene.camera_target = (0, 0.5, 0) 26 | scene.light = (1, 1.5, 3) 27 | #scene = Scene((512, 512)) 28 | #scene.light = (1.0, 1.0, 3.0) 29 | #scene.camera_position = (0.4, 0.4, 1.4) 30 | #scene.camera_target = (0, 0.0, 0) 31 | scene.up_vector = (0, 1, 0) 32 | 33 | return scene 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser( 38 | description="Render the D-FAUST dataset" 39 | ) 40 | 41 | scene = get_scene() 42 | renderable = None 43 | for recording in sys.stdin: 44 | recording = recording.strip() 45 | mesh_base = path.join(recording, "mesh_seq") 46 | sequences = [ 47 | path.join(mesh_base, seq) 48 | for seq in os.listdir(mesh_base) 49 | if seq.endswith("obj") 50 | ] 51 | print("Rendering {}".format(path.basename(recording))) 52 | renderings = path.join(recording, "renderings-downsampled") 53 | if not path.exists(renderings): 54 | os.mkdir(renderings) 55 | print("0 / {}".format(len(sequences)), end="") 56 | for i, seq in enumerate(sequences): 57 | target = seq.replace("obj", "png") 58 | target = target.replace("mesh_seq", "renderings-downsampled") 59 | renderable = render_dfaust(scene, renderable, seq, target) 60 | print("\r{} / {}".format(i, len(sequences)), end="") 61 | print() 62 | -------------------------------------------------------------------------------- /scripts/single_mesh_from_primitives.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from subprocess import call 5 | import sys 6 | from tempfile import NamedTemporaryFile 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from hierarchical_primitives.utils.filter_sqs import filter_primitives, \ 12 | primitive_parameters_from_indices, qos_less, volume_larger 13 | from hierarchical_primitives.utils.sq_mesh import sq_mesh_from_primitive_params 14 | 15 | from visualization_utils import get_color 16 | 17 | 18 | def save_ply(f, points, normals, colors): 19 | assert len(points.shape) == 2 20 | assert len(points.shape) == len(normals.shape) 21 | assert len(points.shape) == len(colors.shape) 22 | 23 | header = "\n".join([ 24 | "ply", 25 | "format binary_{}_endian 1.0".format(sys.byteorder), 26 | "comment Raynet pointcloud!", 27 | "element vertex {}".format(len(points)), 28 | "property float x", 29 | "property float y", 30 | "property float z", 31 | "property float nx", 32 | "property float ny", 33 | "property float nz", 34 | "property uchar red", 35 | "property uchar green", 36 | "property uchar blue", 37 | "property uchar alpha", 38 | "end_header\n" 39 | ]) 40 | f.write(header.encode("ascii")) 41 | colors = (colors*255).astype(np.uint8) 42 | for p, n, c in zip(points, normals, colors): 43 | p.astype(np.float32).tofile(f) 44 | n.astype(np.float32).tofile(f) 45 | c.tofile(f) 46 | f.flush() 47 | 48 | 49 | def main(argv=None): 50 | parser = argparse.ArgumentParser( 51 | description="Create a single mesh from the predicted SQs" 52 | ) 53 | parser.add_argument( 54 | "primitives_file", 55 | help="Path to the file containing the primitives" 56 | ) 57 | parser.add_argument( 58 | "recon_binary", 59 | help="Poisson reconstruction binary" 60 | ) 61 | parser.add_argument( 62 | "output_file", 63 | help="Save the mesh in this file" 64 | ) 65 | parser.add_argument( 66 | "--qos_threshold", 67 | type=float, 68 | default=1, 69 | help="Stop partitioning based on the predicted quality" 70 | ) 71 | parser.add_argument( 72 | "--vol_threshold", 73 | default=0, 74 | type=float, 75 | help="Discard primitives with volume smaller than this threshold" 76 | ) 77 | 78 | args = parser.parse_args(argv) 79 | 80 | [C, P], F = torch.load(args.primitives_file) 81 | active_primitives = filter_primitives( 82 | F, 83 | qos_less(args.qos_threshold), 84 | volume_larger(args.vol_threshold) 85 | ) 86 | primitives = primitive_parameters_from_indices( 87 | F, 88 | active_primitives 89 | ) 90 | pts, normals, prim_indices = sq_mesh_from_primitive_params( 91 | primitives, 92 | normals=True, 93 | prim_indices=True 94 | ) 95 | colors = np.array([ 96 | get_color(d, i) 97 | for d, i in active_primitives 98 | ])[prim_indices].reshape(-1, 4) 99 | with NamedTemporaryFile(suffix=".ply") as f: 100 | save_ply(f, pts, normals, colors) 101 | call([ 102 | args.recon_binary, 103 | "--in", f.name, 104 | "--out", args.output_file, 105 | "--colors" 106 | ]) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /scripts/training_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | try: 3 | from yaml import CLoader as Loader 4 | except ImportError: 5 | from yaml import Loader 6 | 7 | 8 | def get_loss_options(config): 9 | # Create a dicitionary with the loss options based on the input arguments 10 | loss_weights = {} 11 | for k, v in config.items(): 12 | if "loss_weight" in k: 13 | loss_weights[k] = v 14 | 15 | loss_options = dict(loss_weights=loss_weights) 16 | # Update the loss options based on the config file 17 | for k, v in config["loss"].items(): 18 | loss_options[k] = v 19 | 20 | return loss_options 21 | 22 | 23 | def get_regularizer_options(config, current_epoch): 24 | def get_weight(w, epoch, current_epoch): 25 | if current_epoch < epoch: 26 | return 0.0 27 | else: 28 | return w 29 | 30 | # Parse the regularizer and its corresponding weight from the config 31 | regularizer_terms = [] 32 | regularizer_options = {} 33 | for r in config.get("regularizers", []): 34 | regularizer_weight = 0.0 35 | # Update the regularizer options based on the config file 36 | for k, v in config[r].items(): 37 | regularizer_options[k] = v 38 | if "weight" in k: 39 | regularizer_weight = v 40 | regularizer_terms.append( 41 | (r, 42 | get_weight( 43 | regularizer_weight, 44 | config[r].get("enable_regularizer_after_epoch", 0), 45 | current_epoch 46 | )) 47 | ) 48 | return regularizer_terms, regularizer_options 49 | 50 | 51 | def load_config(config_file): 52 | with open(config_file, "r") as f: 53 | config = yaml.load(f, Loader=Loader) 54 | return config 55 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import os 3 | 4 | import numpy as np 5 | 6 | from hierarchical_primitives.common.base import build_dataloader 7 | from hierarchical_primitives.networks.base import build_network 8 | from hierarchical_primitives.primitives import quaternions_to_rotation_matrices 9 | from hierarchical_primitives.utils.visualization_utils import points_on_sq_surface 10 | 11 | import matplotlib 12 | matplotlib.use("agg") 13 | import matplotlib.pyplot as plt 14 | from mpl_toolkits.mplot3d import Axes3D 15 | 16 | 17 | def is_inside(pcl1, pcl2, threshold): 18 | # Check the percentage of points that lie inside another pointcloud and if 19 | # they exceed a threashold return True, else return False 20 | assert pcl1.shape[0] == 3 21 | assert pcl2.shape[0] == 3 22 | # for every point in pcl2 check whether it lies inside pcl1 23 | minimum = pcl1.min(1) 24 | maximum = pcl1.max(1) 25 | 26 | c1 = np.logical_and( 27 | pcl2[0, :] <= maximum[0], 28 | pcl2[0, :] >= minimum[0], 29 | ) 30 | c2 = np.logical_and( 31 | pcl2[1, :] <= maximum[1], 32 | pcl2[1, :] >= minimum[1], 33 | ) 34 | c3 = np.logical_and( 35 | pcl2[2, :] <= maximum[2], 36 | pcl2[2, :] >= minimum[2], 37 | ) 38 | c4 = np.logical_and(c1, np.logical_and(c2, c3)).sum() 39 | return float(c4) / pcl1.shape[1] > threshold 40 | 41 | 42 | def get_non_overlapping_primitives(y_hat, active_prims, insidness=0.6): 43 | n_primitives = y_hat.n_primitives 44 | points_from_prims = [] 45 | 46 | R = quaternions_to_rotation_matrices( 47 | y_hat.rotations.view(-1, 4) 48 | ).to("cpu").detach() 49 | translations = y_hat.translations.to("cpu").view(-1, 3) 50 | translations = translations.detach().numpy() 51 | 52 | shapes = y_hat.sizes.view(-1, 3).detach().numpy() 53 | epsilons = y_hat.shapes.to("cpu").view(-1, 2).detach().numpy() 54 | taperings = np.zeros((n_primitives, 2)) 55 | 56 | prim_pts = [] 57 | for i in active_prims: 58 | x_tr, y_tr, z_tr, prim_pts =\ 59 | points_on_sq_surface( 60 | shapes[i, 0], 61 | shapes[i, 1], 62 | shapes[i, 2], 63 | epsilons[i, 0], 64 | epsilons[i, 1], 65 | R[i].numpy(), 66 | translations[i].reshape(-1, 1), 67 | taperings[i, 0], 68 | taperings[i, 1] 69 | ) 70 | points_from_prims.append(prim_pts) 71 | 72 | cmp1 = combinations(active_prims, 2) 73 | cmp2 = combinations(points_from_prims, 2) 74 | non_overlapping_prims = active_prims[:] 75 | for (i, j), (pcl1, pcl2) in zip(cmp1, cmp2): 76 | if is_inside(pcl1, pcl2, insidness) and j in non_overlapping_prims: 77 | non_overlapping_prims.remove(j) 78 | return non_overlapping_prims 79 | 80 | 81 | def build_dataloader_and_network_from_args(args, config, device="cpu"): 82 | # Create a dataloader instance to generate the samples for training 83 | dataloader = build_dataloader( 84 | config, 85 | args.model_tags, 86 | args.category_tags, 87 | split=["train", "test", "val"], 88 | batch_size=1, 89 | n_processes=4, 90 | random_subset=args.random_subset, 91 | ) 92 | 93 | # Build the network architecture to be used for training 94 | network = build_network(args.config_file, args.weight_file, device=device) 95 | network.eval() 96 | 97 | return dataloader, network 98 | -------------------------------------------------------------------------------- /scripts/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from pyquaternion import Quaternion 6 | import trimesh 7 | 8 | from simple_3dviz import Mesh 9 | from simple_3dviz.scripts.mesh_viewer import tab20 10 | 11 | from hierarchical_primitives.networks.primitive_parameters import\ 12 | PrimitiveParameters 13 | from hierarchical_primitives.utils.filter_sqs import filter_primitives, \ 14 | qos_less, volume_larger, get_primitives_indices, \ 15 | primitive_parameters_from_indices 16 | 17 | from utils import get_non_overlapping_primitives 18 | 19 | 20 | def scene_init(mesh, up_vector, camera_position, camera_target, background): 21 | def inner(scene): 22 | scene.background = background 23 | scene.up_vector = up_vector 24 | scene.camera_position = camera_position 25 | scene.camera_target = camera_target 26 | if mesh is not None: 27 | scene.add(mesh) 28 | return inner 29 | 30 | 31 | def load_ground_truth(mesh_file): 32 | m = Mesh.from_file(mesh_file, color=(0.8, 0.8, 0.8, 0.3)) 33 | # m.to_unit_cube() 34 | return m 35 | 36 | 37 | def save_renderables_as_ply( 38 | renderables, 39 | args, 40 | filepath="/tmp/prediction.ply" 41 | ): 42 | """ 43 | Arguments: 44 | ---------- 45 | renderables: simple_3dviz.Mesh objects that were rendered 46 | """ 47 | m = None 48 | for r in renderables: 49 | # Build a mesh object using the vertices loaded before and get its 50 | # convex hull 51 | _m = trimesh.Trimesh(vertices=r._vertices).convex_hull 52 | # Apply color 53 | for i in range(len(_m.faces)): 54 | _m.visual.face_colors[i] = (r._colors[i] * 255).astype(np.uint8) 55 | m = trimesh.util.concatenate(_m, m) 56 | 57 | m.export(filepath, file_type="ply") 58 | 59 | print("Saved prediction as ply file in {}".format( 60 | os.path.join(args.output_directory, "prediction.ply") 61 | )) 62 | 63 | 64 | def get_color(d, i, M=0): 65 | colors = np.array(plt.cm.jet(np.linspace(0, 1, 16))) 66 | order = np.array([12, 13, 7, 8, 9, 3, 14, 11, 67 | 10, 2, 1, 0, 5, 4, 6, 15]) 68 | colors = np.array(tab20) 69 | order = np.arange(len(tab20)) 70 | d = np.asarray(d) 71 | i = np.asarray(i) 72 | # order = [ 8, 38, 20, 15, 3, 0, 30, 37, 34, 28, 32, 24, 26, 14, 39, 73 | # 27, 11, 4, 1, 17, 35, 23, 36, 2, 21, 31, 5, 19, 7, 10, 74 | # 9, 22, 16, 25, 33, 18, 13, 29, 6, 12] 75 | return colors[order[(2**d-1+i) % len(colors)]] 76 | 77 | 78 | def get_color_groupped(d, i, max_depth): 79 | if d <= max_depth: 80 | return get_color(d, i, 0) 81 | else: 82 | return get_color_groupped(d-1, i // 2, max_depth) 83 | 84 | 85 | def _unpack(p): 86 | alpha = p.sizes.view(-1, 3).to("cpu").detach().numpy() 87 | epsilon = p.shapes.view(-1, 2).to("cpu").detach().numpy() 88 | t = p.translations.view(-1, 3).to("cpu").detach().numpy() 89 | R = np.stack([ 90 | Quaternion(r.to("cpu").detach().numpy()).rotation_matrix 91 | for r in p.rotations.view(-1, 4) 92 | ], axis=0) 93 | return alpha, epsilon, t, R 94 | 95 | 96 | def _renderables_from_fit(y_hat, args): 97 | F = y_hat.fit 98 | active_prims = filter_primitives( 99 | F, 100 | qos_less(args.qos_threshold), 101 | volume_larger(args.vol_threshold), 102 | ) 103 | active_prims_map = {p: i for (i, p) in enumerate(active_prims)} 104 | return [ 105 | Mesh.from_superquadrics(*_unpack( 106 | PrimitiveParameters.with_keys( 107 | translations=F[depth].translations_r[:, idx], 108 | rotations=F[depth].rotations_r[:, idx], 109 | sizes=F[depth].sizes_r[:, idx], 110 | shapes=F[depth].shapes_r[:, idx] 111 | )), 112 | get_color_groupped(depth, idx, args.max_depth) if args.group_color 113 | else get_color(depth, idx) 114 | ) 115 | for depth, idx in active_prims_map 116 | ], active_prims 117 | 118 | 119 | def _renderables_from_partition(y_hat, args): 120 | [C, P] = y_hat.space_partition 121 | active_prims_map = { 122 | p: i for (i, p) in enumerate(get_primitives_indices(P)) 123 | } 124 | return [ 125 | Mesh.from_superquadrics(*_unpack( 126 | PrimitiveParameters.with_keys( 127 | translations=P[depth].translations_r[:, idx], 128 | rotations=P[depth].rotations_r[:, idx], 129 | sizes=P[depth].sizes_r[:, idx], 130 | shapes=P[depth].shapes_r[:, idx] 131 | )), 132 | get_color_groupped(depth, idx, args.max_depth) if args.group_color 133 | else get_color(depth, idx) 134 | ) 135 | for depth, idx in active_prims_map 136 | ], get_primitives_indices(P) 137 | 138 | 139 | def _renderables_from_flat_partition(y_hat, args): 140 | _, P = y_hat.space_partition 141 | # Collect the sqs that have prob larger than threshold 142 | indices = [ 143 | i for i in range(y_hat.n_primitives) 144 | if y_hat.probs_r[0, i] >= args.prob_threshold 145 | ] 146 | active_prims_map = { 147 | (0, j): i for i, j in enumerate(indices) 148 | } 149 | return [ 150 | Mesh.from_superquadrics(*_unpack( 151 | PrimitiveParameters.with_keys( 152 | translations=P[-1].translations_r[0, indices], 153 | rotations=P[-1].rotations_r[0, indices], 154 | sizes=P[-1].sizes_r[0, indices] / 2.0, 155 | shapes=P[-1].shapes_r[0, indices] 156 | )), 157 | get_color(0, indices) 158 | ) 159 | ], indices 160 | 161 | 162 | def _renderables_from_flat_primitives(y_hat, args): 163 | # Collect the sqs that have prob larger than threshold 164 | indices = [ 165 | i for i in range(y_hat.n_primitives) 166 | if y_hat.probs_r[0, i] >= args.prob_threshold 167 | ] 168 | active_prims_map = { 169 | (0, j): i for i, j in enumerate(indices) 170 | } 171 | 172 | if args.with_post_processing: 173 | indices = get_non_overlapping_primitives(y_hat, indices) 174 | 175 | return [Mesh.from_superquadrics( 176 | *_unpack( 177 | PrimitiveParameters.with_keys( 178 | translations=y_hat.translations_r[0, indices], 179 | rotations=y_hat.rotations_r[0, indices], 180 | sizes=y_hat.sizes_r[0, indices], 181 | shapes=y_hat.shapes_r[0, indices] 182 | )), 183 | get_color(0, indices) 184 | )], indices 185 | 186 | 187 | def get_renderables(y_hat, args): 188 | """Depending on the arguments compute which primitives should be rendered 189 | """ 190 | # len(y_hat.fit) == 1 means that we do not have hierarhcy 191 | if len(y_hat.fit) == 1: 192 | if args.from_flat_partition: 193 | return _renderables_from_flat_partition(y_hat, args) 194 | else: 195 | return _renderables_from_flat_primitives(y_hat, args) 196 | 197 | if args.from_fit: 198 | return _renderables_from_fit(y_hat, args) 199 | else: 200 | return _renderables_from_partition(y_hat, args) 201 | 202 | 203 | def get_primitive_parameters_from_indices(y_hat, active_prims, args): 204 | # len(y_hat.fit) == 1 means that we do not have hierarhcy 205 | if len(y_hat.fit) == 1: 206 | return primitive_parameters_from_indices( 207 | y_hat.fit, 208 | [(-1, ap) for ap in active_prims] 209 | ) 210 | if args.from_fit: 211 | return primitive_parameters_from_indices(y_hat.fit, active_prims) 212 | else: 213 | [C, P] = y_hat.space_partition 214 | return primitive_parameters_from_indices(P, active_prims) 215 | 216 | def visualize_sharpness(sharpness, epoch): 217 | import seaborn as sns 218 | f = plt.figure(figsize=(8, 6)) 219 | sns.barplot( 220 | np.arange(sharpness.shape[0]), sharpness 221 | ) 222 | plt.title("Epoch {}".format(epoch)) 223 | plt.ylim([0, 10.5]) 224 | plt.ylabel("Sharpness") 225 | plt.xlabel("Primitive id") 226 | plt.savefig("/tmp/sharpness_{:03d}.png".format(epoch)) 227 | plt.close() 228 | -------------------------------------------------------------------------------- /scripts/visualize_partition.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script used for visualizing the partitioning process 3 | """ 4 | import argparse 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from pyquaternion import Quaternion 9 | import torch 10 | 11 | from arguments import add_dataset_parameters 12 | from training_utils import load_config 13 | from utils import build_dataloader_and_network_from_args 14 | from visualization_utils import scene_init, load_ground_truth, get_color 15 | 16 | from simple_3dviz import Mesh 17 | from simple_3dviz.behaviours.misc import CycleThroughObjects, LightToCamera 18 | from simple_3dviz.behaviours.movements import CameraTrajectory 19 | from simple_3dviz.behaviours.trajectory import Circle 20 | from simple_3dviz.behaviours.io import SaveFrames 21 | from simple_3dviz.window import simple_window 22 | 23 | 24 | def _unpack(p, i, max_depth, colors, color_siblings=False): 25 | alpha = p.sizes.view(-1, 3).to("cpu").detach().numpy() 26 | epsilon = p.shapes.view(-1, 2).to("cpu").detach().numpy() 27 | t = p.translations.view(-1, 3).to("cpu").detach().numpy() 28 | R = np.stack([ 29 | Quaternion(r.to("cpu").detach().numpy()).rotation_matrix 30 | for r in p.rotations.view(-1, 4) 31 | ], axis=0) 32 | M = alpha.shape[0] 33 | if max_depth > 0: 34 | colors = [] 35 | for idx in range(M): 36 | colors.append(get_color(i, idx, max_depth)) 37 | else: 38 | if color_siblings and i != 0: 39 | c = (((np.arange(0, M, 2)+i) % len(colors)).tolist()) 40 | c_tiled = [] 41 | for ci, m in zip(c, [2]*len(c)): 42 | c_tiled.extend([ci]*m) 43 | colors = colors[c_tiled] 44 | else: 45 | colors = colors[((np.arange(M)+i) % len(colors)).tolist()] 46 | return alpha, epsilon, t, R, colors 47 | 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser( 51 | description="Visualize the space partitioning" 52 | ) 53 | parser.add_argument( 54 | "dataset_directory", 55 | help="Path to the directory containing the dataset" 56 | ) 57 | parser.add_argument( 58 | "train_test_splits_file", 59 | default=None, 60 | help="Path to the train-test splits file" 61 | ) 62 | parser.add_argument( 63 | "--weight_file", 64 | default=None, 65 | help="The path to the previously trainined model to be used" 66 | ) 67 | parser.add_argument( 68 | "--config_file", 69 | default="../config/default.yaml", 70 | help="Path to the file that contains the experiment configuration" 71 | ) 72 | parser.add_argument( 73 | "--run_on_gpu", 74 | action="store_true", 75 | help="Use GPU" 76 | ) 77 | parser.add_argument( 78 | "--mesh", 79 | type=load_ground_truth, 80 | help="File of ground truth mesh" 81 | ) 82 | parser.add_argument( 83 | "--interval", 84 | type=lambda x: int(float(x)*60), 85 | default=30, 86 | help="Set the interval to update the partition in seconds" 87 | ) 88 | parser.add_argument( 89 | "--save_frames", 90 | help="Path to save the visualization frames to" 91 | ) 92 | parser.add_argument( 93 | "--save_frequency", 94 | type=int, 95 | default=5, 96 | help="Save every that many frames" 97 | ) 98 | parser.add_argument( 99 | "--with_rotating_camera", 100 | action="store_true", 101 | help="Use a camera rotating around the object" 102 | ) 103 | parser.add_argument( 104 | "--background", 105 | type=lambda x: list(map(float, x.split(","))), 106 | default="0,0,0,1", 107 | help="Set the background of the scene" 108 | ) 109 | parser.add_argument( 110 | "--up_vector", 111 | type=lambda x: tuple(map(float, x.split(","))), 112 | default="0,0,1", 113 | help="Up vector of the scene" 114 | ) 115 | parser.add_argument( 116 | "--camera_position", 117 | type=lambda x: tuple(map(float, x.split(","))), 118 | default="-2.0,-2.0,-2.0", 119 | help="Camer position in the scene" 120 | ) 121 | parser.add_argument( 122 | "--camera_target", 123 | type=lambda x: tuple(map(float, x.split(","))), 124 | default="0,0,0", 125 | help="Set the target for the camera" 126 | ) 127 | parser.add_argument( 128 | "--color_siblings", 129 | action="store_true", 130 | help="Use the same color to depict siblings" 131 | ) 132 | parser.add_argument( 133 | "--from_fit", 134 | action="store_true", 135 | help="Visulize everything based on primitive_params.fit" 136 | ) 137 | 138 | add_dataset_parameters(parser) 139 | args = parser.parse_args() 140 | 141 | if args.run_on_gpu and torch.cuda.is_available(): 142 | device = torch.device("cuda:0") 143 | else: 144 | device = torch.device("cpu") 145 | print("Running code on", device) 146 | 147 | config = load_config(args.config_file) 148 | dataloader, network = build_dataloader_and_network_from_args( 149 | args, config, device=device 150 | ) 151 | 152 | for sample in dataloader: 153 | # Do the forward pass and estimate the primitive parameters 154 | X = sample[0].to(device) 155 | y_hat = network(X) 156 | F = y_hat.fit 157 | [C, P] = y_hat.space_partition 158 | n_primitives = y_hat.n_primitives 159 | colors = torch.tensor(np.array( 160 | plt.cm.jet(np.linspace(0, 1, 16)) 161 | )) 162 | 163 | if args.from_fit: 164 | max_depth = 2**(len(F) - 1) 165 | else: 166 | max_depth = -1 167 | meshes = [ 168 | [Mesh.from_superquadrics( 169 | *_unpack(p, i, max_depth, colors, args.color_siblings) 170 | )] 171 | for i, p in enumerate(F if args.from_fit else P) 172 | ] 173 | behaviours = [ 174 | CycleThroughObjects(meshes, interval=args.interval), 175 | LightToCamera() 176 | ] 177 | if args.save_frames: 178 | behaviours += [ 179 | SaveFrames(args.save_frames, args.save_frequency) 180 | ] 181 | if args.with_rotating_camera: 182 | behaviours += [ 183 | CameraTrajectory( 184 | Circle([0, 0, 1], [4, 0, 1], [0, 0, 1]), 0.001 185 | ), 186 | ] 187 | simple_window( 188 | scene_init(args.mesh, args.up_vector, args.camera_position, 189 | args.camera_target, args.background) 190 | ).add_behaviours(behaviours).show() 191 | -------------------------------------------------------------------------------- /scripts/visualize_predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script used for perfoming a forward pass on a previously trained model and 3 | visualizing the predicted primitives. 4 | """ 5 | import argparse 6 | import os 7 | import pickle 8 | import sys 9 | 10 | import torch 11 | 12 | from arguments import add_dataset_parameters 13 | from compute_metrics import report_metrics 14 | from training_utils import load_config 15 | from visualization_utils import scene_init, load_ground_truth, \ 16 | get_renderables, get_primitive_parameters_from_indices, \ 17 | save_renderables_as_ply, visualize_sharpness 18 | from utils import build_dataloader_and_network_from_args, \ 19 | get_non_overlapping_primitives 20 | from hierarchical_primitives.utils.sq_mesh import sq_meshes 21 | 22 | from simple_3dviz.behaviours import SceneInit 23 | from simple_3dviz.behaviours.misc import LightToCamera 24 | from simple_3dviz.behaviours.keyboard import SnapshotOnKey 25 | from simple_3dviz.behaviours.movements import CameraTrajectory 26 | from simple_3dviz.behaviours.trajectory import Circle 27 | from simple_3dviz.behaviours.io import SaveFrames, SaveGif 28 | from simple_3dviz.utils import render 29 | from simple_3dviz.window import show 30 | 31 | 32 | def main(argv): 33 | parser = argparse.ArgumentParser( 34 | description="Do the forward pass and estimate a set of primitives" 35 | ) 36 | parser.add_argument( 37 | "config_file", 38 | help="Path to the file that contains the experiment configuration" 39 | ) 40 | parser.add_argument( 41 | "output_directory", 42 | help="Save the output files in that directory" 43 | ) 44 | parser.add_argument( 45 | "--weight_file", 46 | default=None, 47 | help="The path to the previously trainined model to be used" 48 | ) 49 | parser.add_argument( 50 | "--run_on_gpu", 51 | action="store_true", 52 | help="Use GPU" 53 | ) 54 | parser.add_argument( 55 | "--qos_threshold", 56 | default=1.0, 57 | type=float, 58 | help="Split primitives if predicted qos less than this threshold" 59 | ) 60 | parser.add_argument( 61 | "--vol_threshold", 62 | default=0.0, 63 | type=float, 64 | help="Discard primitives with volume smaller than this threshold" 65 | ) 66 | parser.add_argument( 67 | "--prob_threshold", 68 | default=0.0, 69 | type=float, 70 | help="Discard primitives with probability smaller than this threshold" 71 | ) 72 | parser.add_argument( 73 | "--with_post_processing", 74 | action="store_true", 75 | help="Remove overlapping primitives" 76 | ) 77 | parser.add_argument( 78 | "--mesh", 79 | type=load_ground_truth, 80 | help="File of ground truth mesh" 81 | ) 82 | parser.add_argument( 83 | "--save_frames", 84 | help="Path to save the visualization frames to" 85 | ) 86 | parser.add_argument( 87 | "--without_screen", 88 | action="store_true", 89 | help="Perform no screen rendering" 90 | ) 91 | parser.add_argument( 92 | "--n_frames", 93 | type=int, 94 | default=200, 95 | help="Number of frames to be rendered" 96 | ) 97 | parser.add_argument( 98 | "--background", 99 | type=lambda x: list(map(float, x.split(","))), 100 | default="0,0,0,1", 101 | help="Set the background of the scene" 102 | ) 103 | parser.add_argument( 104 | "--up_vector", 105 | type=lambda x: tuple(map(float, x.split(","))), 106 | default="0,0,1", 107 | help="Up vector of the scene" 108 | ) 109 | parser.add_argument( 110 | "--camera_target", 111 | type=lambda x: tuple(map(float, x.split(","))), 112 | default="0,0,0", 113 | help="Set the target for the camera" 114 | ) 115 | parser.add_argument( 116 | "--camera_position", 117 | type=lambda x: tuple(map(float, x.split(","))), 118 | default="-2.0,-2.0,-2.0", 119 | help="Camer position in the scene" 120 | ) 121 | parser.add_argument( 122 | "--max_depth", 123 | type=int, 124 | default=3, 125 | help="Maximum depth to visualize" 126 | ) 127 | parser.add_argument( 128 | "--window_size", 129 | type=lambda x: tuple(map(int, x.split(","))), 130 | default="512,512", 131 | help="Define the size of the scene and the window" 132 | ) 133 | parser.add_argument( 134 | "--from_fit", 135 | action="store_true", 136 | help="Visulize everything based on primitive_params.fit" 137 | ) 138 | parser.add_argument( 139 | "--from_flat_partition", 140 | action="store_true", 141 | help=("Visulize everything based on primitive_params.space_partition" 142 | " with a single depth") 143 | ) 144 | parser.add_argument( 145 | "--group_color", 146 | action="store_true", 147 | help="Color the active prims based on the group" 148 | ) 149 | parser.add_argument( 150 | "--with_rotating_camera", 151 | action="store_true", 152 | help="Use a camera rotating around the object" 153 | ) 154 | parser.add_argument( 155 | "--visualize_sharpness", 156 | action="store_true", 157 | help="When set visualize the sharpness together with the prediction" 158 | ) 159 | 160 | add_dataset_parameters(parser) 161 | args = parser.parse_args(argv) 162 | 163 | # Check if output directory exists and if it doesn't create it 164 | if not os.path.exists(args.output_directory): 165 | os.makedirs(args.output_directory) 166 | 167 | if args.run_on_gpu and torch.cuda.is_available(): 168 | device = torch.device("cuda:0") 169 | else: 170 | device = torch.device("cpu") 171 | print("Running code on", device) 172 | 173 | config = load_config(args.config_file) 174 | dataloader, network = build_dataloader_and_network_from_args( 175 | args, config, device=device 176 | ) 177 | 178 | for sample in dataloader: 179 | # Do the forward pass and estimate the primitive parameters 180 | X = sample[0].to(device) 181 | y_hat = network(X) 182 | #import matplotlib.pyplot as plt 183 | #import seaborn as sns 184 | #import numpy as np 185 | #f = plt.figure(figsize=(8, 6)) 186 | #sns.barplot( 187 | # np.arange(y_hat.n_primitives), 188 | # y_hat.sharpness_r.squeeze(0).detach().numpy()[:, 0] 189 | #) 190 | #plt.title("Epoch {}".format(args.weight_file.split("/")[-1].split("_")[-1])) 191 | #plt.ylim([0, 10.5]) 192 | #plt.ylabel("Sharpness") 193 | #plt.xlabel("Primitive id") 194 | #plt.savefig("/tmp/sharpness_{:03d}.png".format( 195 | # int(args.weight_file.split("/")[-1].split("_")[-1])) 196 | #) 197 | #plt.close() 198 | 199 | renderables, active_prims = get_renderables(y_hat, args) 200 | with open(os.path.join(args.output_directory, "renderables.pkl"), "wb") as f: 201 | pickle.dump(renderables, f) 202 | print(active_prims) 203 | 204 | behaviours = [ 205 | SceneInit( 206 | scene_init( 207 | args.mesh, 208 | args.up_vector, 209 | args.camera_position, 210 | args.camera_target, 211 | args.background 212 | ) 213 | ), 214 | LightToCamera(), 215 | ] 216 | if args.with_rotating_camera: 217 | behaviours += [ 218 | CameraTrajectory( 219 | Circle( 220 | args.camera_target, 221 | args.camera_position, 222 | args.up_vector 223 | ), 224 | speed=1/180 225 | ) 226 | ] 227 | if args.without_screen: 228 | behaviours += [ 229 | SaveFrames(args.save_frames, 2), 230 | SaveGif("/tmp/out.gif", 2) 231 | ] 232 | render(renderables, size=args.window_size, behaviours=behaviours, 233 | n_frames=args.n_frames) 234 | else: 235 | behaviours += [ 236 | SnapshotOnKey(path=args.save_frames, keys={"", "S"}) 237 | ] 238 | show(renderables, size=args.window_size, behaviours=behaviours) 239 | 240 | # Based on the active primitives report the metrics 241 | active_primitive_params = \ 242 | get_primitive_parameters_from_indices(y_hat, active_prims, args) 243 | report_metrics( 244 | active_primitive_params, 245 | config, 246 | config["data"]["dataset_type"], 247 | args.model_tags, 248 | config["data"]["dataset_directory"] 249 | ) 250 | if args.with_post_processing: 251 | indices = get_non_overlapping_primitives(y_hat, active_prims) 252 | else: 253 | indices = None 254 | for i, m in enumerate(sq_meshes(y_hat, indices)): 255 | m.export( 256 | os.path.join(args.output_directory, "predictions-{}.ply").format(i), 257 | file_type="ply" 258 | ) 259 | 260 | if y_hat.space_partition is not None: 261 | torch.save( 262 | [y_hat.space_partition, y_hat.fit], 263 | os.path.join(args.output_directory, "space_partition.pkl") 264 | ) 265 | if args.visualize_sharpness: 266 | visualize_sharpness( 267 | y_hat.sharpness_r.squeeze(0).detach().numpy()[:, 0], 268 | int(args.weight_file.split("/")[-1].split("_")[-1]) 269 | ) 270 | 271 | 272 | if __name__ == "__main__": 273 | main(sys.argv[1:]) 274 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup hierarchical primitives.""" 2 | 3 | from setuptools import setup 4 | from Cython.Build import cythonize 5 | from distutils.extension import Extension 6 | 7 | from itertools import dropwhile 8 | import numpy as np 9 | from os import path 10 | 11 | 12 | def collect_docstring(lines): 13 | """Return document docstring if it exists""" 14 | lines = dropwhile(lambda x: not x.startswith('"""'), lines) 15 | doc = "" 16 | for line in lines: 17 | doc += line 18 | if doc.endswith('"""\n'): 19 | break 20 | 21 | return doc[3:-4].replace("\r", "").replace("\n", " ") 22 | 23 | 24 | def collect_metadata(): 25 | meta = {} 26 | with open(path.join("hierarchical_primitives", "__init__.py")) as f: 27 | lines = iter(f) 28 | meta["description"] = collect_docstring(lines) 29 | for line in lines: 30 | if line.startswith("__"): 31 | key, value = map(lambda x: x.strip(), line.split("=")) 32 | meta[key[2:-2]] = value[1:-1] 33 | 34 | return meta 35 | 36 | 37 | def get_extensions(): 38 | return cythonize([ 39 | Extension( 40 | "hierarchical_primitives.fast_sampler._sampler", 41 | [ 42 | "hierarchical_primitives/fast_sampler/_sampler.pyx", 43 | "hierarchical_primitives/fast_sampler/sampling.cpp" 44 | ], 45 | language="c++11", 46 | libraries=["stdc++"], 47 | include_dirs=[np.get_include()], 48 | extra_compile_args=["-std=c++11", "-O3"] 49 | ), 50 | Extension( 51 | "hierarchical_primitives.external.libmesh.triangle_hash", 52 | sources=["hierarchical_primitives/external/libmesh/triangle_hash.pyx"], 53 | include_dirs=[np.get_include()], 54 | libraries=["m"] # Unix-like specific 55 | ) 56 | ]) 57 | 58 | 59 | def get_install_requirements(): 60 | return [ 61 | "numpy", 62 | "trimesh", 63 | "torch", 64 | "torchvision", 65 | "cython", 66 | "Pillow", 67 | "pyquaternion", 68 | "pykdtree", 69 | "matplotlib", 70 | "simple-3dviz" 71 | ] 72 | 73 | 74 | def setup_package(): 75 | with open("README.md") as f: 76 | long_description = f.read() 77 | meta = collect_metadata() 78 | setup( 79 | name="hierarchical_primitives", 80 | version=meta["version"], 81 | long_description=long_description, 82 | long_description_content_type="text/markdown", 83 | maintainer=meta["maintainer"], 84 | maintainer_email=meta["email"], 85 | url=meta["url"], 86 | license=meta["license"], 87 | classifiers=[ 88 | "Intended Audience :: Science/Research", 89 | "Intended Audience :: Developers", 90 | "License :: OSI Approved :: MIT License", 91 | "Topic :: Scientific/Engineering", 92 | "Programming Language :: Python", 93 | "Programming Language :: Python :: 3", 94 | "Programming Language :: Python :: 3.6", 95 | ], 96 | install_requires=get_install_requirements(), 97 | ext_modules=get_extensions() 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | setup_package() 103 | --------------------------------------------------------------------------------