├── .github └── workflows │ └── build.yaml ├── .gitignore ├── LICENSE ├── README.md ├── models ├── README.md ├── do_model_updates.py ├── draw_collision_geometry.py └── unzip_all.sh ├── packages └── spatial_grammar_models │ ├── kitchen_models │ ├── cupboard.sdf │ └── extra_heavy_duty_table_surface_only_collision.sdf │ └── package.xml ├── requirements.txt ├── sandbox ├── .gitignore ├── bingham_distribution_prototype.ipynb ├── bingham_quaternion_constraints.ipynb ├── cip_parsing_figure_plotting.ipynb ├── debug_parsing_issue.py ├── density_recovery_from_projections.ipynb ├── dist_fitting.ipynb ├── drake_static_equilibrium_problem.py ├── draw_model_in_meshcat_cpp.py ├── gen_tree_inference.ipynb ├── geometric_only_grammar.py ├── greedy_parsing_prototype.ipynb ├── greedy_proposal_parsing_prototype.ipynb ├── gurobi_solution_pool.ipynb ├── latent_node_param_collapse.ipynb ├── mi_scene_parsing.ipynb ├── mi_scene_parsing_enumerated_max_entropy.ipynb ├── mi_scene_parsing_max_entropy.ipynb ├── mip_parsable_ssg.ipynb ├── mip_rotation_constraint_solution_uniqueness.ipynb ├── mmd_sandbox.ipynb ├── mug_visual.obj ├── plate_11in.obj ├── pyro_uniform_and_gaussian_hmc.ipynb ├── sdp_sandbox.ipynb ├── simple_3d_ssg.py ├── smc_with_minimal_grammar.ipynb ├── ssg_test_grammar_oriented_clusters.py ├── ssg_test_grammar_rotation.py ├── ssg_test_grammar_sink.py └── test_mip_rotation_constraint.ipynb ├── setup.py ├── setup.sh ├── spatial_scene_grammars ├── __init__.py ├── constraints.py ├── dataset.py ├── distributions.py ├── drake_interop.py ├── nodes.py ├── parameter_estimation.py ├── parsing.py ├── random_walk_kernel.py ├── rules.py ├── sampling.py ├── scene_grammar.py ├── test │ ├── __init__.py │ ├── gaussian_grammar.py │ ├── grammar.py │ ├── test_distributions.py │ ├── test_grammar.py │ ├── test_nodes.py │ ├── test_parameter_estimation.py │ ├── test_parsing.py │ ├── test_rules.py │ ├── test_sampling.py │ └── test_torch_util.py ├── torch_utils.py └── visualization.py └── spatial_scene_grammars_examples ├── __init__.py ├── dish_bin ├── .gitattributes ├── .gitignore ├── baseline_em_mmds_precomputed.pickle ├── baseline_em_mmds_total_precomputed.pickle ├── collect_parsing_statistics.ipynb ├── demo_baseline.ipynb ├── demo_grammar_fitting.ipynb ├── demo_sampling.ipynb ├── demo_simple_grammar_fitting.ipynb ├── draw_every_training_example.py ├── draw_samples_from_fit_grammar.py ├── em_mmds_precomputed.pickle ├── em_mmds_total_precomputed.pickle ├── fit_em.pickle ├── fit_em_baseline.pickle ├── fit_grammar.torch ├── fit_grammar_baseline.torch ├── fit_simple_grammar.torch ├── fitting_history.pickle ├── grammar.py ├── grammar_baseline.py ├── parse_single_scene.ipynb ├── parsing_statistics_runs.pickle ├── parsing_statistics_runs_pre_update.pickle ├── prototype_dependency_grammar.ipynb ├── rnn_baseline.ipynb ├── simple_grammar.py ├── sink │ ├── saved_mixed_scenes.yaml │ ├── saved_outlier_scenes.yaml │ └── saved_scenes.yaml ├── sink_dataset.zip ├── sink_dataset_v2.zip ├── test │ └── test_sink_grammar.py └── utils.py ├── gmm ├── .gitattributes ├── create_comparison_figure.ipynb ├── demo_gmm_fitting_vs_em.ipynb ├── grammar.py ├── test_runs.pickle └── test_runs_with_restarts.pickle ├── oriented_clusters ├── .gitignore ├── demo_grammar_and_parsing.ipynb ├── demo_parsing_with_outliers.py.ipynb ├── grammar.py ├── grammar_with_extra_pathway.py └── test │ └── test_oriented_clusters_example.py ├── packages ├── .gitattributes ├── .gitignore ├── all_observable_grammar.py ├── boxes.zip ├── demo_sampling.ipynb └── grammar.py ├── planar_clusters ├── .gitignore ├── demo_grammar.ipynb ├── demo_grammar_fitting.ipynb ├── grammar.py └── test │ └── test_planar_clusters_example.py ├── planar_clusters_gaussians ├── .gitignore ├── demo_grammar.ipynb ├── demo_grammar_fitting.ipynb ├── do_profile_vi.sh ├── grammar.py ├── print_profile_results.py ├── profile_vi.py └── test │ └── test_planar_clusters_gaussians_example.py ├── planar_clusters_no_rotation ├── grammar.py └── prototype_mip_parameter_fitting.ipynb ├── planar_fixed_structure ├── grammar.py ├── parameter_fitting_vi.ipynb └── test_mcmc_diversity.ipynb ├── restaurant ├── demo_sampling.ipynb └── grammar.py ├── singles_pairs ├── .gitattributes ├── collect_scaling_performance.ipynb ├── constituency_scaling_random_runs.pickle ├── dependency_scaling_random_runs.pickle ├── grammar_constituency.py ├── grammar_dependency.py ├── models │ ├── deathstar.obj │ ├── deathstar.png │ ├── deathstar.sdf │ ├── model.dae │ ├── xwing.blend │ ├── xwing.blend1 │ ├── xwing.dae │ ├── xwing.mtl │ ├── xwing.obj │ ├── xwing.png │ └── xwing.sdf ├── parse_single_scene.ipynb ├── single_scene.pickle └── test │ └── test_grammar.py └── table ├── .gitattributes ├── .gitignore ├── baseline_em_mmds_precomputed.pickle ├── baseline_post_fit_grammar_draws.pickle ├── baseline_post_fit_grammar_state_dict.torch ├── baseline_pre_fit_grammar_draws.pickle ├── baseline_pre_fit_grammar_state_dict.torch ├── demo_baseline.ipynb ├── demo_parameter_estimation.ipynb ├── demo_sampling.ipynb ├── em_mmds_precomputed.pickle ├── example_feasible_sampled_scene.pickle ├── fit_em.pickle ├── fit_em_baseline.pickle ├── fit_grammar.torch ├── fit_grammar_baseline.torch ├── generate_decorated_scenes.ipynb ├── generate_target_dataset.ipynb ├── grammar.py ├── grammar_baseline.py ├── grammar_decoration.py ├── models.zip ├── post_fit_grammar_draws.pickle ├── post_fit_grammar_draws_decorated.pickle ├── post_fit_grammar_state_dict.torch ├── pre_fit_grammar_draws.pickle ├── pre_fit_grammar_state_dict.torch ├── render_scenes_in_blender.ipynb ├── renders ├── example_raw_blender_output.jpg └── tweaked_in_blender.jpg ├── structure_constraint_dataset_grammar_state_dict.torch ├── structure_constraint_examples.pickle ├── target_dataset_examples.pickle ├── target_dataset_grammar_state_dict.torch └── utils.py /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | pull_request: 5 | schedule: 6 | - cron: '0 8 * * 2' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-18.04 11 | strategy: 12 | matrix: 13 | DRAKE_URL: ["https://drake-packages.csail.mit.edu/drake/nightly/drake-latest-bionic.tar.gz", 14 | "https://drake-packages.csail.mit.edu/drake/nightly/drake-20211209-bionic.tar.gz"] 15 | env: 16 | PYTHONPATH: "/opt/drake/lib/python3.6/site-packages" 17 | ROS_PACKAGE_PATH: "/opt/drake/share/drake/examples" 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.6' 25 | architecture: 'x64' # optional x64 or x86. Defaults to x64 if not specified 26 | 27 | - name: "Install apt dependencies" 28 | run: "sudo apt-get update && sudo apt install libxml2-utils graphviz libgraphviz-dev" 29 | 30 | - name: "Download and install Drake." 31 | run: | 32 | curl -o drake.tar.gz ${{ matrix.DRAKE_URL }} && sudo tar -xzf drake.tar.gz -C /opt 33 | yes | sudo /opt/drake/share/drake/setup/install_prereqs 34 | 35 | - name: "Install python dependencies" 36 | run: | 37 | pip install wheel 38 | pip install -r requirements.txt 39 | pip install torch-sparse torch-scatter torch-cluster torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.1%2Bcu102.html 40 | pip install -e . 41 | 42 | - name: "Run tests" 43 | run: | 44 | source ./setup.sh 45 | pytest --pyargs spatial_scene_grammars --cov=./ --cov-report=xml --continue-on-collection-errors 46 | 47 | - name: "Codecov upload" 48 | uses: codecov/codecov-action@v2 49 | with: 50 | fail_ci_if_error: false # optional (default = false) 51 | verbose: true # optional (default = false) 52 | env_vars: OS,PYTHON,DRAKE_URL 53 | files: ./coverage.xml 54 | 55 | - run: echo "🍏 This job's status is ${{ job.status }}." 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .coverage 3 | *.png 4 | *.dat 5 | *.ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Greg Izatt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Assets for Scenes 2 | ----------------- 3 | 4 | As a tool for collecting assets for building scenes, I've written a small 5 | pipeline script for converting SDFs to make them simulate-able by Drake, and 6 | visualize-able by Meshcat / Drake visualizer. It assumes assets were downloaded 7 | from the [IgnitionRobotics collection](https://app.ignitionrobotics.org/dashboard) -- 8 | even more specifically, I [boldly] assume that the model comes in with this file 9 | structure exactly: 10 | 11 | ``` 12 | 13 | ├── model.sdf 14 | ├── meshes 15 | │ ├── model.obj 16 | ├── materials 17 | │ ├── textures 18 | │ ├── texture.png 19 | ``` 20 | 21 | If the model has an mtl, it's ignored; it's assumed that the model UVs are set up 22 | such that `texture.png` is a good baked texture for that object. 23 | 24 | Running the `do_model_updates.py` script (you can change its target model folders 25 | by editing the script at the bottom) will open each `model.sdf` and create a new 26 | `model_simplified.sdf` next to it that references automatically-created altered 27 | models and objects. No original files are altered -- new ones are made next to them. 28 | 29 | This whole pipeline is subject to significant change very soon -- probably to make 30 | it create ROS-package-wrapped model data folders rather than mixing in the newly 31 | created assets with the originals. 32 | 33 | # Software prereqs 34 | 35 | Install deps with `pip install trimesh open3d opencv-python imutils`. 36 | 37 | You'll need a backend that can do convex decomposition, too -- IIRC, having 38 | `blender` on your path *might* be enough. But I usually copy the code out of 39 | [this setup file from trimesh](https://github.com/mikedh/trimesh/blob/master/docker/builds/vhacd.bash) 40 | and download a `VHACD` binary and just install it to my system path. 41 | -------------------------------------------------------------------------------- /models/draw_collision_geometry.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import pydrake 6 | from pydrake.all import ( 7 | AddMultibodyPlantSceneGraph, 8 | DiagramBuilder, 9 | Meshcat, 10 | MeshcatVisualizerCpp, 11 | MeshcatVisualizerParams, 12 | MultibodyPlant, 13 | Parser, 14 | Role, 15 | RigidTransform, 16 | Simulator 17 | ) 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser() 21 | 22 | # Required positional argument 23 | parser.add_argument("model_path", help="File to vis") 24 | parser.add_argument("--port", help="Meshcat port", type=int, default=None) 25 | 26 | args = parser.parse_args() 27 | 28 | builder = DiagramBuilder() 29 | mbp, scene_graph = AddMultibodyPlantSceneGraph( 30 | builder, MultibodyPlant(time_step=0.001)) 31 | parser = Parser(mbp) 32 | model_id = parser.AddModelFromFile(args.model_path) 33 | mbp.Finalize() 34 | 35 | print(args.port) 36 | meshcat = Meshcat(port=args.port) 37 | params = MeshcatVisualizerParams() 38 | params.role = Role.kProximity 39 | params.prefix = "geometry_draw" 40 | params.delete_on_initialization_event = False 41 | vis = MeshcatVisualizerCpp.AddToBuilder( 42 | builder, scene_graph, meshcat, params 43 | ) 44 | 45 | diagram = builder.Build() 46 | diagram.Publish(diagram.CreateDefaultContext()) 47 | 48 | input() -------------------------------------------------------------------------------- /models/unzip_all.sh: -------------------------------------------------------------------------------- 1 | # Utility for unzipping all zip files in directories under the first 2 | # argument into folders with the same name as the zip file, alongside 3 | # wherever the zip file is. 4 | find $1 -name '*.zip' -exec sh -c 'unzip -d "${1%.*}" "$1"' _ {} \; 5 | 6 | -------------------------------------------------------------------------------- /packages/spatial_grammar_models/kitchen_models/cupboard.sdf: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 8 | 9 | 10 | 0 0.292 0 0 0 0 11 | 12 | 13 | 0.3 0.016 0.783 14 | 15 | 16 | 17 | 18 | 0 -0.292 0 0 0 0 19 | 20 | 21 | 0.3 0.016 0.783 22 | 23 | 24 | 25 | 26 | 0 0.292 0 0 0 0 27 | 28 | 29 | 0.3 0.016 0.783 30 | 31 | 32 | 33 | 34 | 0 -0.292 0 0 0 0 35 | 36 | 37 | 0.3 0.016 0.783 38 | 39 | 40 | 41 | 42 | 43 | 44 | 0 0 -0.3995 0 0 0 45 | 46 | 47 | 0.3 0.6 0.016 48 | 49 | 50 | 51 | 52 | 0 0 0.3995 0 0 0 53 | 54 | 55 | 0.3 0.6 0.016 56 | 57 | 58 | 59 | 60 | 0 0 -0.13115 0 0 0 61 | 62 | 63 | 0.3 0.6 0.016 64 | 65 | 66 | 67 | 68 | 0 0 0.13115 0 0 0 69 | 70 | 71 | 0.3 0.6 0.016 72 | 73 | 74 | 75 | 76 | 0 0 -0.3995 0 0 0 77 | 78 | 79 | 0.3 0.6 0.016 80 | 81 | 82 | 83 | 84 | 0 0 0.3995 0 0 0 85 | 86 | 87 | 0.3 0.6 0.016 88 | 89 | 90 | 91 | 92 | 0 0 -0.13115 0 0 0 93 | 94 | 95 | 0.3 0.6 0.016 96 | 97 | 98 | 99 | 100 | 0 0 0.13115 0 0 0 101 | 102 | 103 | 0.3 0.6 0.016 104 | 105 | 106 | 107 | 108 | 109 | 110 | top_and_bottom 111 | cupboard_body 112 | 113 | 114 | 115 | left_door 116 | cupboard_body 117 | -0.008 -0.1395 0 0 0 0 118 | 119 | 0 0 1 120 | 121 | 0 122 | 123 | 124 | 1.0 125 | 126 | 127 | 128 | 129 | 0.158 -0.1445 0 0 0 0 130 | 131 | 1 132 | 133 | 0.01042 134 | 0 135 | 0 136 | 0.00542 137 | 0 138 | 0.00542 139 | 140 | 141 | 142 | 0.033 0.1245 0 0 0 0 143 | 144 | 145 | 0.14 146 | 0.005 147 | 148 | 149 | 150 | 151 | 152 | 153 | 0.016 0.279 0.815 154 | 155 | 156 | 157 | 1 0 0 1.0 158 | 159 | 160 | 161 | 0.033 0.1245 0 0 0 0 162 | 163 | 164 | 0.14 165 | 0.005 166 | 167 | 168 | 169 | 170 | 171 | 172 | 0.016 0.279 0.815 173 | 174 | 175 | 176 | 177 | 178 | 179 | right_door 180 | cupboard_body 181 | -0.008 0.1395 0 0 0 0 182 | 183 | 0 0 1 184 | 185 | 0 186 | 187 | 188 | 1.0 189 | 190 | 191 | 192 | 193 | 0.158 0.1445 0 0 0 0 194 | 195 | 1 196 | 197 | 0.01042 198 | 0 199 | 0 200 | 0.00542 201 | 0 202 | 0.00542 203 | 204 | 205 | 206 | 0.033 -0.1245 0 0 0 0 207 | 208 | 209 | 0.14 210 | 0.005 211 | 212 | 213 | 214 | 215 | 216 | 217 | 0.016 0.279 0.815 218 | 219 | 220 | 221 | 1 0 0 1.0 222 | 223 | 224 | 225 | 0.033 -0.1245 0 0 0 0 226 | 227 | 228 | 0.14 229 | 0.005 230 | 231 | 232 | 233 | 234 | 235 | 236 | 0.016 0.279 0.815 237 | 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /packages/spatial_grammar_models/package.xml: -------------------------------------------------------------------------------- 1 | 2 | spatial_grammar_models 3 | 0.0.1 4 | 5 | Some of the models used in the kitchen grammar example. 6 | 7 | Greg Izatt 8 | MIT 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | lxml 4 | matplotlib 5 | networkx 6 | codecov 7 | pygraphviz 8 | pytest-cov 9 | pytest-subtests 10 | pytest-dependency 11 | pyro-ppl==1.7.0 12 | torchviz 13 | torch==1.9.1 14 | torch_geometric==1.7.0 15 | pytorch3d==0.3 16 | meshcat 17 | -------------------------------------------------------------------------------- /sandbox/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /sandbox/cip_parsing_figure_plotting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "import networkx as nx\n", 22 | "import numpy as np" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Define a GMM grammar for illustrative purposes.\n", 32 | "\n" 33 | ] 34 | } 35 | ], 36 | "metadata": { 37 | "kernelspec": { 38 | "display_name": "py36_pyro", 39 | "language": "python", 40 | "name": "py36_pyro" 41 | }, 42 | "language_info": { 43 | "codemirror_mode": { 44 | "name": "ipython", 45 | "version": 3 46 | }, 47 | "file_extension": ".py", 48 | "mimetype": "text/x-python", 49 | "name": "python", 50 | "nbconvert_exporter": "python", 51 | "pygments_lexer": "ipython3", 52 | "version": "3.6.13" 53 | } 54 | }, 55 | "nbformat": 4, 56 | "nbformat_minor": 4 57 | } 58 | -------------------------------------------------------------------------------- /sandbox/debug_parsing_issue.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import pickle 6 | import time 7 | 8 | import torch 9 | torch.set_default_tensor_type(torch.DoubleTensor) 10 | 11 | from spatial_scene_grammars.nodes import * 12 | from spatial_scene_grammars.rules import * 13 | from spatial_scene_grammars.scene_grammar import * 14 | from spatial_scene_grammars.visualization import * 15 | from spatial_scene_grammars.parsing import * 16 | from spatial_scene_grammars.sampling import * 17 | from spatial_scene_grammars.parameter_estimation import * 18 | from spatial_scene_grammars.dataset import * 19 | 20 | import meshcat 21 | import meshcat.geometry as meshcat_geom 22 | 23 | import glob 24 | import os 25 | from functools import lru_cache 26 | 27 | import torch 28 | from spatial_scene_grammars.nodes import * 29 | from spatial_scene_grammars.rules import * 30 | from spatial_scene_grammars.scene_grammar import * 31 | from spatial_scene_grammars.drake_interop import * 32 | 33 | import pydrake 34 | import pydrake.geometry as pydrake_geom 35 | from pydrake.all import ( 36 | RollPitchYaw, 37 | RigidTransform 38 | ) 39 | 40 | ''' 41 | 42 | Simple grammar that forces parsing to make a choice between two explanations for a scene: 43 | 44 | Root -> Mode 1 -> Observed 45 | Root -> Mode 2 -> Observed 46 | 47 | ''' 48 | 49 | class Observed(TerminalNode): 50 | def __init__(self, tf): 51 | super().__init__( 52 | tf=tf, 53 | physics_geometry_info=None, 54 | observed=True 55 | ) 56 | 57 | class Mode1(AndNode): 58 | def __init__(self, tf): 59 | super().__init__( 60 | tf=tf, 61 | physics_geometry_info=None, 62 | observed=False 63 | ) 64 | @classmethod 65 | def generate_rules(cls): 66 | return [ 67 | ProductionRule( 68 | child_type=Observed, 69 | xyz_rule=WorldFrameGaussianOffsetRule( 70 | mean=torch.tensor([0.0, 0.1, 0.2]), 71 | variance=torch.tensor([1., 0.005, 3.]), 72 | ), 73 | # Assume world-frame vertically-oriented plate stacks 74 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 75 | RotationMatrix(), [1000., 10., 0.1] 76 | ) 77 | ) 78 | ] 79 | 80 | class Mode2(AndNode): 81 | def __init__(self, tf): 82 | super().__init__( 83 | tf=tf, 84 | physics_geometry_info=None, 85 | observed=False 86 | ) 87 | @classmethod 88 | def generate_rules(cls): 89 | return [ 90 | ProductionRule( 91 | child_type=Observed, 92 | xyz_rule=WorldFrameGaussianOffsetRule( 93 | mean=torch.tensor([0.0, 0.1, 0.2]), 94 | variance=torch.tensor([1., 0.005, 0.1]), 95 | ), 96 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 97 | RotationMatrix(), [1000., 10., 0.1] 98 | ) 99 | ) 100 | ] 101 | 102 | class Root(OrNode): 103 | def __init__(self, tf): 104 | super().__init__( 105 | tf=tf, 106 | rule_probs=torch.tensor([0.9, 0.1]), 107 | physics_geometry_info=None, 108 | observed=False 109 | ) 110 | @classmethod 111 | def generate_rules(cls): 112 | return [ 113 | ProductionRule( 114 | child_type=Mode1, 115 | xyz_rule=SamePositionRule(), 116 | rotation_rule=SameRotationRule() 117 | ), 118 | ProductionRule( 119 | child_type=Mode2, 120 | xyz_rule=SamePositionRule(), 121 | rotation_rule=SameRotationRule() 122 | ), 123 | ] 124 | 125 | 126 | # Set up grammar 127 | grammar = SpatialSceneGrammar( 128 | root_node_type = Root, 129 | root_node_tf = drake_tf_to_torch_tf(RigidTransform()) 130 | ) 131 | 132 | observed_nodes = [Observed(tf=torch.eye(4))] 133 | results = infer_mle_tree_with_mip(grammar, observed_nodes, N_solutions=5, verbose=False, use_random_rotation_offset=False) 134 | trees = get_optimized_trees_from_mip_results(results) 135 | for k, tree in enumerate(trees): 136 | print("Computed score %f, optimization score %f" % (tree.score(verbose=0), results.optim_result.get_suboptimal_objective(k))) -------------------------------------------------------------------------------- /sandbox/drake_static_equilibrium_problem.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pydrake 5 | 6 | from pydrake.common.cpp_param import List as DrakeBindingList 7 | from pydrake.all import ( 8 | AngleAxis, 9 | AddMultibodyPlantSceneGraph, 10 | BasicVector, 11 | BodyIndex, 12 | ConnectMeshcatVisualizer, 13 | CoulombFriction, 14 | DiagramBuilder, 15 | LeafSystem, 16 | MultibodyPlant, 17 | RigidTransform, 18 | RotationMatrix, 19 | Simulator, 20 | SpatialForce, 21 | SpatialInertia, 22 | UnitInertia, 23 | StaticEquilibriumProblem, 24 | SnoptSolver, 25 | AddUnitQuaternionConstraintOnPlant, 26 | ExternallyAppliedSpatialForce, 27 | Value 28 | ) 29 | import pydrake.geometry as pydrake_geom 30 | 31 | from spatial_scene_grammars.drake_interop import DecayingForceToDesiredConfigSystem 32 | 33 | 34 | def build_mbp(seed=0, verts_geom=False, convex_collision_geom=True): 35 | # Make some random lumpy objects 36 | trimeshes = [] 37 | np.random.seed(42) 38 | for k in range(3): 39 | # Make a small random number of triangles and chull it 40 | # to get a lumpy object 41 | mesh = trimesh.creation.random_soup(5) 42 | mesh = trimesh.convex.convex_hull(mesh) 43 | trimeshes.append(mesh) 44 | 45 | # Create Drake geometry from those objects by adding a small 46 | # sphere at each vertex 47 | sphere_rad = 0.05 48 | cmap = plt.cm.get_cmap('jet') 49 | 50 | builder = DiagramBuilder() 51 | mbp, scene_graph = AddMultibodyPlantSceneGraph( 52 | builder, MultibodyPlant(time_step=0.001)) 53 | 54 | # Add ground 55 | friction = CoulombFriction(0.9, 0.8) 56 | g = pydrake_geom.Box(100., 100., 0.5) 57 | tf = RigidTransform(p=[0., 0., -0.25]) 58 | mbp.RegisterVisualGeometry( 59 | body=mbp.world_body(), 60 | X_BG=tf, 61 | shape=g, 62 | name="ground", 63 | diffuse_color=[1.0, 1.0, 1.0, 1.0] 64 | ) 65 | mbp.RegisterCollisionGeometry( 66 | body=mbp.world_body(), 67 | X_BG=tf, 68 | shape=g, 69 | name="ground", 70 | coulomb_friction=friction 71 | ) 72 | 73 | for i, mesh in enumerate(trimeshes): 74 | inertia = SpatialInertia( 75 | mass=1.0, 76 | p_PScm_E=np.zeros(3), 77 | G_SP_E=UnitInertia(0.01, 0.01, 0.01) 78 | ) 79 | body = mbp.AddRigidBody(name="body_%d" % i, 80 | M_BBo_B=inertia) 81 | color = cmap(np.random.random()) 82 | if verts_geom: 83 | for j, vert in enumerate(mesh.vertices): 84 | g = pydrake_geom.Sphere(radius=sphere_rad) 85 | tf = RigidTransform(p=vert) 86 | mbp.RegisterVisualGeometry( 87 | body=body, 88 | X_BG=tf, 89 | shape=g, 90 | name="body_%d_color_%d" % (i, j), 91 | diffuse_color=color) 92 | mbp.RegisterCollisionGeometry( 93 | body=body, 94 | X_BG=tf, 95 | shape=g, 96 | name="body_%d_collision_%d" % (i, j), 97 | coulomb_friction=friction) 98 | # And add mesh itself for vis 99 | path = "/tmp/part_%d.obj" % i 100 | trimesh.exchange.export.export_mesh(mesh, path) 101 | g = pydrake_geom.Convex(path) 102 | mbp.RegisterVisualGeometry( 103 | body=body, 104 | X_BG=RigidTransform(), 105 | shape=g, 106 | name="body_%d_base" % i, 107 | diffuse_color=color 108 | ) 109 | if convex_collision_geom: 110 | mbp.RegisterCollisionGeometry( 111 | body=body, 112 | X_BG=RigidTransform(), 113 | shape=g, 114 | name="body_%d_base_col" % i, 115 | coulomb_friction=friction 116 | ) 117 | mbp.SetDefaultFreeBodyPose(body, RigidTransform(p=[i % 3, i / 3., 1.])) 118 | mbp.Finalize() 119 | return builder, mbp, scene_graph 120 | 121 | if __name__ == "__main__": 122 | seed = 42 123 | 124 | # This code would try to run a StaticEquilibriumProblem -- 125 | # but it relies on Autodiff-converted MBP/SG, which doesn't support 126 | # the full matrix of geometry collisions. (Mostly does sphere/* 127 | # collisions.) 128 | #builder, mbp, scene_graph = build_mbp(seed=seed) 129 | #diagram = builder.Build() 130 | #diagram_ad = diagram.ToAutoDiffXd() 131 | #mbp_ad = diagram_ad.GetSubsystemByName("plant") # Default name for MBP 132 | #diagram_ad_context = diagram_ad.CreateDefaultContext() 133 | #mbp_ad_context = diagram_ad.GetMutableSubsystemContext(mbp_ad, diagram_ad_context) 134 | #opt = StaticEquilibriumProblem(mbp_ad, mbp_ad_context, ignored_collision_pairs=set()) 135 | #q_vars = opt.q_vars() 136 | #prog = opt.get_mutable_prog() 137 | #AddUnitQuaternionConstraintOnPlant( 138 | # mbp_ad, q_vars, prog) 139 | #q_targ = mbp.GetPositions(mbp.CreateDefaultContext()) 140 | ## Penalize deviation from target configuration 141 | #prog.AddQuadraticErrorCost(np.eye(q_targ.shape[0]), q_targ, q_vars) 142 | #prog.SetInitialGuess(q_vars, q_targ) 143 | #solver = SnoptSolver() 144 | #result = solver.Solve(opt.prog()) 145 | #print(result, result.is_success()) 146 | #q_0 = result.GetSolution(q_vars) 147 | #print("Q found: ", q_0) 148 | 149 | builder, mbp, scene_graph = build_mbp(seed=seed) 150 | q_des = mbp.GetPositions(mbp.CreateDefaultContext()) 151 | forcer = builder.AddSystem(DecayingForceToDesiredConfigSystem(mbp, q_des)) 152 | builder.Connect(mbp.get_state_output_port(), 153 | forcer.get_input_port(0)) 154 | builder.Connect(forcer.get_output_port(0), 155 | mbp.get_applied_spatial_force_input_port()) 156 | 157 | visualizer = ConnectMeshcatVisualizer(builder, scene_graph, 158 | zmq_url="default") 159 | diagram = builder.Build() 160 | diag_context = diagram.CreateDefaultContext() 161 | mbp_context = diagram.GetMutableSubsystemContext(mbp, diag_context) 162 | mbp.SetPositions(mbp_context, np.random.random(q_des.shape)*10.0) 163 | sim = Simulator(diagram, diag_context) 164 | #sim.set_target_realtime_rate(1.0) 165 | sim.AdvanceTo(10.0) 166 | 167 | -------------------------------------------------------------------------------- /sandbox/draw_model_in_meshcat_cpp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pydrake 4 | from pydrake.all import ( 5 | AddMultibodyPlantSceneGraph, 6 | DiagramBuilder, 7 | Meshcat, 8 | MeshcatVisualizerParams, 9 | MeshcatVisualizerCpp, 10 | Parser, 11 | Role, 12 | Simulator, 13 | MultibodyPlant 14 | ) 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description='Do interactive placement of objects.') 18 | parser.add_argument('model_path', 19 | help='Path to model SDF/URDF.') 20 | args = parser.parse_args() 21 | 22 | # Build MBP 23 | builder = DiagramBuilder() 24 | mbp, scene_graph = AddMultibodyPlantSceneGraph( 25 | builder, MultibodyPlant(time_step=1E-3) 26 | ) 27 | # Parse requested file 28 | parser = Parser(mbp, scene_graph) 29 | model_id = parser.AddModelFromFile(args.model_path) 30 | 31 | mbp.Finalize() 32 | 33 | # Visualizer 34 | meshcat = Meshcat() 35 | vis = MeshcatVisualizerCpp.AddToBuilder(builder, scene_graph, meshcat=meshcat) 36 | 37 | diagram = builder.Build() 38 | diagram_context = diagram.CreateDefaultContext() 39 | mbp_context = diagram.GetSubsystemContext(mbp, diagram_context) 40 | simulator = Simulator(diagram, diagram_context) 41 | simulator.Initialize() 42 | simulator.set_target_realtime_rate(1.0) 43 | simulator.AdvanceTo(1000.) 44 | 45 | -------------------------------------------------------------------------------- /sandbox/geometric_only_grammar.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import time 4 | import pydrake 5 | from pydrake.all import ( 6 | CommonSolverOption, 7 | MathematicalProgram, 8 | MakeSolver, 9 | GurobiSolver, 10 | Solve, 11 | SolverOptions, 12 | VPolytope 13 | ) 14 | import matplotlib.pyplot as plt 15 | import networkx as nx 16 | import numpy as np 17 | 18 | ''' 19 | Minimal simple geometric-production-only grammar definition. 20 | 21 | Minimal grammar definition: each node has: 22 | - a pose x 23 | - a type name 24 | 25 | A node class defines: 26 | - its child type (by name or None) 27 | - max # of children 28 | - the geometric stop prob p (1. = always 1 child, 0. = infinite children) 29 | - the region in which children will be produced (uniformly at random), in the frame of the node, in the form of an axis-aligned bounding box 30 | 31 | E.g. object groups in plane: 32 | - Root node produces object clusters and uniform random locations inside [0, 1]^2. 33 | - Each cluster produces up points uniformly in a 0.1-length box centered at the cluster center. 34 | ''' 35 | 36 | from collections import namedtuple 37 | NodeDefinition = namedtuple("NodeDefinition", 38 | ["child_type", # string name of child type (or None if no productions), 39 | "max_children", # max # of children 40 | "p", # geometric keep-going probability 41 | "bounds" # Generation bounds for children relative to node x 42 | ]) 43 | class Node(): 44 | ''' 45 | x: np.array position of node. 2D in this notebook. 46 | type: string, type name of node. 47 | ''' 48 | def __init__(self, type, x): 49 | self.x = x 50 | self.type = type 51 | 52 | def sample_tree(grammar): 53 | # Given a grammar description as a dict of {node name: NodeDefinition} pairs, 54 | # sample a scene tree as a networkx DiGraph. 55 | tree = nx.DiGraph() 56 | root = Node("root", x=np.array([0., 0.])) 57 | tree.add_node(root) 58 | node_queue = [root] 59 | 60 | while len(node_queue) > 0: 61 | parent = node_queue.pop(0) 62 | assert parent.type in grammar.keys() 63 | parent_def = grammar[parent.type] 64 | if parent_def.child_type is None: 65 | continue 66 | n_children = min(np.random.geometric(parent_def.p), parent_def.max_children) 67 | 68 | for k in range(n_children): 69 | child_x = parent.x + np.random.uniform(*parent_def.bounds) 70 | child = Node(parent_def.child_type, child_x) 71 | tree.add_node(child) 72 | tree.add_edge(parent, child) 73 | node_queue.append(child) 74 | return tree 75 | 76 | def get_observed_nodes(tree, observed_types): 77 | # Given a scene tree (nx.DiGraph) and a list of observed 78 | # node type names (list of strings), pulls out only nodes 79 | # in the tree with matching type into a list of Nodes. 80 | return [n for n in tree if n.type in observed_types] 81 | 82 | # Drawing utilities for trees. 83 | def draw_tree(tree, draw_pos=True, with_labels=False, node_color_dict=None, alpha=0.5, node_size=200, **kwargs): 84 | # Decide a coloring for the node tpyes. 85 | unique_types = sorted(list(set([n.type for n in tree]))) 86 | n_types = len(unique_types) 87 | cm = plt.get_cmap("viridis") 88 | if node_color_dict is not None: 89 | node_color = [node_color_dict[node] for node in tree] 90 | else: 91 | color_mapping = {unique_type: cm(float(k)/n_types) for k, unique_type in enumerate(unique_types)} 92 | node_color = [color_mapping[node.type] for node in tree] 93 | if draw_pos: 94 | pos={node: node.x for node in tree} 95 | else: 96 | pos=None 97 | nx.draw_networkx( 98 | tree, 99 | labels={node: node.type for node in tree}, 100 | with_labels=with_labels, 101 | pos=pos, 102 | node_size=node_size, 103 | node_color=node_color, 104 | alpha=alpha, 105 | **kwargs 106 | ) 107 | plt.gca().set_xlim([-0.1, 1.1]) 108 | plt.gca().set_ylim([-0.1, 1.1]) 109 | 110 | def draw_observed_nodes(nodes): 111 | tree = nx.Graph() 112 | tree.add_nodes_from(nodes) 113 | draw_tree(tree) -------------------------------------------------------------------------------- /sandbox/greedy_parsing_prototype.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import time\n", 10 | "import networkx as nx\n", 11 | "import numpy as np\n", 12 | "from copy import deepcopy\n", 13 | "from collections import namedtuple\n", 14 | "import torch\n", 15 | "import pyro\n", 16 | "\n", 17 | "from spatial_scene_grammars.nodes import *\n", 18 | "from spatial_scene_grammars.rules import *\n", 19 | "from spatial_scene_grammars.scene_grammar import *\n", 20 | "from spatial_scene_grammars.parsing import *" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "WARNING:root:Detaching BinghamDistribution parameters.\n", 33 | "WARNING:root:Prior over parameters of WorldFrameBinghamRotationRule are Deltas.\n" 34 | ] 35 | }, 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Observed 8 objects\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "from spatial_scene_grammars_examples.singles_pairs.grammar_constituency import *\n", 46 | "pyro.set_rng_seed(42)\n", 47 | "\n", 48 | "grammar = SpatialSceneGrammar(\n", 49 | " root_node_type = Root,\n", 50 | " root_node_tf = torch.eye(4)\n", 51 | ")\n", 52 | "ground_truth_tree = grammar.sample_tree(detach=True)\n", 53 | "observed_nodes = ground_truth_tree.get_observed_nodes()\n", 54 | "print(\"Observed %d objects\" % len(observed_nodes))" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Found tree with score -1372.632082 in 0.124831s\n", 67 | "FINAL SCORE: tensor([-1372.6321])\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "pyro.set_rng_seed(42)\n", 73 | "parse_tree, score = sample_likely_tree_with_greedy_parsing(grammar, observed_nodes, verbose=True)\n", 74 | "print(\"FINAL SCORE: \", score)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "py36_pyro", 88 | "language": "python", 89 | "name": "py36_pyro" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 3 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython3", 101 | "version": "3.6.13" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 4 106 | } 107 | -------------------------------------------------------------------------------- /sandbox/mip_rotation_constraint_solution_uniqueness.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 111, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "import numpy as np\n", 11 | "from sklearn import mixture\n", 12 | "import scipy as sp\n", 13 | "import pydrake\n", 14 | "from pydrake.all import (\n", 15 | " RandomGenerator,\n", 16 | " MixedIntegerRotationConstraintGenerator,\n", 17 | " UniformlyRandomRotationMatrix,\n", 18 | " IntervalBinning,\n", 19 | " MathematicalProgram,\n", 20 | " GurobiSolver,\n", 21 | " SolverOptions,\n", 22 | " RotationMatrix,\n", 23 | " RollPitchYaw\n", 24 | ")\n", 25 | "from tqdm.notebook import tqdm" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 112, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "R: [[1. 0. 0.]\n", 38 | " [0. 1. 0.]\n", 39 | " [0. 0. 1.]] -> # sols: 20\n", 40 | "R: [[ 1.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 41 | " [ 0.00000000e+00 9.99999500e-01 -9.99999833e-04]\n", 42 | " [ 0.00000000e+00 9.99999833e-04 9.99999500e-01]] -> # sols: 8\n", 43 | "R: [[ 9.99999500e-01 -9.99999333e-04 9.99999667e-07]\n", 44 | " [ 9.99999833e-04 9.99999000e-01 -9.99999333e-04]\n", 45 | " [ 0.00000000e+00 9.99999833e-04 9.99999500e-01]] -> # sols: 2\n", 46 | "R: [[ 9.99999000e-01 -9.98999334e-04 1.00099883e-03]\n", 47 | " [ 9.99999333e-04 9.99999001e-01 -9.98999334e-04]\n", 48 | " [-9.99999833e-04 9.99999333e-04 9.99999000e-01]] -> # sols: 1\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "def get_sols(R_gt, verbose=False):\n", 54 | " prog = MathematicalProgram()\n", 55 | "\n", 56 | " #R_gt = RotationMatrix(np.eye(3))\n", 57 | "\n", 58 | " R_dec = prog.NewContinuousVariables(3, 3, \"R\")\n", 59 | " mip_rot_gen = MixedIntegerRotationConstraintGenerator(\n", 60 | " approach = MixedIntegerRotationConstraintGenerator.Approach.kBilinearMcCormick,\n", 61 | " num_intervals_per_half_axis=2,\n", 62 | " interval_binning = IntervalBinning.kLogarithmic\n", 63 | " )\n", 64 | " mip_rot_gen.AddToProgram(R_dec, prog)\n", 65 | " prog.AddBoundingBoxConstraint(R_gt.matrix().flatten(), R_gt.matrix().flatten(), R_dec.flatten())\n", 66 | " \n", 67 | " solver = GurobiSolver()\n", 68 | " options = SolverOptions()\n", 69 | " options.SetOption(solver.id(), \"LogFile\", \"gurobi.log\")\n", 70 | " options.SetOption(solver.id(), \"PoolSolutions\", 100)\n", 71 | " options.SetOption(solver.id(), \"PoolSearchMode\", 2)\n", 72 | "\n", 73 | " result = solver.Solve(prog, None, options)\n", 74 | "\n", 75 | " N_sols = getattr(result, \"num_suboptimal_solution()\")()\n", 76 | " if verbose or N_sols > 1:\n", 77 | " print(\"R: \", result.GetSolution(R_dec), \" -> # sols: \", N_sols)\n", 78 | " #with open(\"gurobi.log\") as f:\n", 79 | " # print(f.read())\n", 80 | "\n", 81 | " #for solution_k in range(getattr(result, \"num_suboptimal_solution()\")()):\n", 82 | " # sol = result.GetSuboptimalSolution(R_dec, solution_k)\n", 83 | " # print(\"Sol %d: %s\" % (solution_k, sol))\n", 84 | "\n", 85 | "# RPY=0 gets multiple sols\n", 86 | "perturb = np.zeros(3)\n", 87 | "get_sols(RotationMatrix(RollPitchYaw(perturb)), verbose=True)\n", 88 | "\n", 89 | "# No rotation around any given axis gets multiple sols\n", 90 | "perturb = np.array([1E-3, 0., 0.])\n", 91 | "get_sols(RotationMatrix(RollPitchYaw(perturb)), verbose=True)\n", 92 | "\n", 93 | "perturb = np.array([1E-3, 0., 1E-3])\n", 94 | "get_sols(RotationMatrix(RollPitchYaw(perturb)), verbose=True)\n", 95 | "\n", 96 | "# Rotating around all axes (no zeros in rotation matrix) gets unique sol\n", 97 | "perturb = np.array([1E-3, 1E-3, 1E-3])\n", 98 | "get_sols(RotationMatrix(RollPitchYaw(perturb)), verbose=True)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 113, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "application/vnd.jupyter.widget-view+json": { 109 | "model_id": "e31e1d9ff49245378627bbd3209916a5", 110 | "version_major": 2, 111 | "version_minor": 0 112 | }, 113 | "text/plain": [ 114 | " 0%| | 0/100 [00:00=3.6', 22 | ) -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | export ROS_PACKAGE_PATH=$ROS_PACKAGE_PATH:$DIR/packages 5 | -------------------------------------------------------------------------------- /spatial_scene_grammars/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars/__init__.py -------------------------------------------------------------------------------- /spatial_scene_grammars/constraints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import torch 4 | from .nodes import TerminalNode, Node 5 | 6 | 7 | ''' 8 | 9 | Clearly, this is gonna be an important and hard part of this system 10 | to get right. 11 | 12 | Unorganized thoughts: 13 | - HMC should work for factors on continuous properties, I think? 14 | - How do I tell which registered constraints can be HMC-d? What if 15 | a node's continuous properties affect the way it'll produce children? 16 | - Simple "existence of node" or "existence of subtree" constraints 17 | may be an easier class than general topology constraints? 18 | 19 | ''' 20 | 21 | class Constraint(): 22 | ''' 23 | To be used in combination with constraint-wrapping 24 | Gibbs distributions to produce factors? 25 | ''' 26 | def __init__(self, lower_bound, upper_bound): 27 | assert isinstance(lower_bound, torch.Tensor) 28 | assert isinstance(upper_bound, torch.Tensor) 29 | assert lower_bound.shape == upper_bound.shape 30 | self.lower_bound = lower_bound 31 | self.upper_bound = upper_bound 32 | super().__init__() 33 | 34 | def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map): 35 | # Add this constraint to a Drake InverseKinematics object ik. 36 | raise NotImplementedError() 37 | 38 | def eval(self, scene_tree): 39 | # Output should be torch-autodiffable from scene_tree 40 | # params and variables. It should either match the shape 41 | # of lower_bound / upper_bound, or be batched with leftmost 42 | # dims matching the shape of lower/upper bound. 43 | raise NotImplementedError() 44 | 45 | def eval_violation(self, scene_tree): 46 | # Return (max violation, lower_violation, upper_violation) vectors 47 | val = self.eval(scene_tree) 48 | if self.lower_bound.shape != torch.Size() and val.shape != self.lower_bound.shape: 49 | lb_n_elems = len(self.lower_bound.shape) 50 | assert val.shape[-lb_n_elems:] == self.lower_bound.shape, (val.shape, self.lower_bound.shape) 51 | lower_violation = self.lower_bound - val 52 | lower_violation[..., torch.isinf(self.lower_bound)] = 0. 53 | upper_violation = val - self.upper_bound 54 | upper_violation[..., torch.isinf(self.upper_bound)] = 0. 55 | max_violation = torch.max(lower_violation, upper_violation) 56 | return max_violation, lower_violation, upper_violation 57 | 58 | 59 | class PoseConstraint(Constraint): 60 | pass 61 | 62 | class StructureConstraint(Constraint): 63 | pass 64 | 65 | class ObjectCountConstraint(StructureConstraint): 66 | def __init__(self, object_type, min_count, max_count): 67 | self.object_type = object_type 68 | super().__init__(lower_bound=torch.tensor(min_count), upper_bound=torch.tensor(max_count)) 69 | def eval(self, scene_tree): 70 | num = len(list(scene_tree.find_nodes_by_type(self.object_type))) 71 | return torch.tensor(num) 72 | 73 | class ChildCountConstraint(StructureConstraint): 74 | def __init__(self, parent_type, min_count, max_count): 75 | self.parent_type = parent_type 76 | super().__init__(lower_bound=torch.tensor(min_count), upper_bound=torch.tensor(max_count)) 77 | def eval(self, scene_tree): 78 | child_counts = [] 79 | for parent_node in scene_tree.find_nodes_by_type(self.parent_type): 80 | child_counts.append(len(scene_tree.successors(parent_node))) 81 | return torch.tensor(child_counts) -------------------------------------------------------------------------------- /spatial_scene_grammars/dataset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from collections import namedtuple 3 | import numpy as np 4 | import logging 5 | 6 | import yaml 7 | try: 8 | from yaml import CLoader as Loader, CDumper as Dumper 9 | except ImportError: 10 | from yaml import Loader, Dumper 11 | 12 | import pydrake 13 | from pydrake.all import RigidTransform, RollPitchYaw 14 | 15 | from .drake_interop import drake_tf_to_torch_tf 16 | 17 | 18 | def convert_scenes_yaml_to_observed_nodes(dataset_file, type_map={}, model_map={}): 19 | ''' 20 | Converts a scene list YAML file of the format described at 21 | https://github.com/gizatt/drake_hydra_interact/blob/master/environments/README.md 22 | into a list of nodes, using the dictionary type_map to map from metadata "type" 23 | string to Node types. 24 | ''' 25 | 26 | with open(dataset_file, "r") as f: 27 | scenes_dict = yaml.load(f, Loader=Loader) 28 | observed_node_sets = [] 29 | for scene_name, scene_info in scenes_dict.items(): 30 | observed_nodes = [] 31 | for object_info in scene_info["objects"]: 32 | class_name = object_info["metadata"]["class"] 33 | model_name = object_info["model_file"] 34 | if model_name in model_map.keys(): 35 | this_type = model_map[model_name] 36 | elif class_name in type_map.keys(): 37 | this_type = type_map[class_name] 38 | else: 39 | logging.warn("Environment had unknown model name / class name: %s, %s", model_name, class_name) 40 | continue 41 | tf = RigidTransform( 42 | p = object_info["pose"]["xyz"], 43 | rpy = RollPitchYaw(object_info["pose"]["rpy"]) 44 | ) 45 | observed_nodes.append( 46 | this_type(drake_tf_to_torch_tf(tf)) 47 | ) 48 | for object_info in scene_info["world_description"]["models"]: 49 | if "metadata" in object_info.keys(): 50 | class_name = object_info["metadata"]["class"] 51 | model_name = object_info["model_file"] 52 | if model_name in model_map.keys(): 53 | this_type = model_map[model_name] 54 | elif class_name in type_map.keys(): 55 | this_type = type_map[class_name] 56 | else: 57 | logging.warn("Environment had unknown model name / class name: %s, %s", model_name, class_name) 58 | continue 59 | tf = RigidTransform( 60 | p = object_info["pose"]["xyz"], 61 | rpy = RollPitchYaw(object_info["pose"]["rpy"]) 62 | ) 63 | observed_nodes.append( 64 | this_type(drake_tf_to_torch_tf(tf)) 65 | ) 66 | observed_node_sets.append(observed_nodes) 67 | return observed_node_sets 68 | -------------------------------------------------------------------------------- /spatial_scene_grammars/random_walk_kernel.py: -------------------------------------------------------------------------------- 1 | # Based on Pyro PPL MCMC kernel.s 2 | # See them for licenses. 3 | 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | 9 | import pyro 10 | import pyro.distributions as dist 11 | from pyro.distributions.util import scalar_like 12 | from pyro.distributions import Normal 13 | 14 | from pyro.infer.autoguide import init_to_uniform 15 | from pyro.infer.mcmc.mcmc_kernel import MCMCKernel 16 | from pyro.infer.mcmc.util import initialize_model 17 | from pyro.util import optional, torch_isnan 18 | 19 | 20 | class RandomWalkKernel(MCMCKernel): 21 | r""" 22 | Simple random-walk MCMC kernel. 23 | 24 | :param model: Python callable containing Pyro primitives. 25 | :param potential_fn: Python callable calculating potential energy with input 26 | is a dict of real support parameters. 27 | :param float variance: Variance of the normal distribution used to select 28 | random walk steps. 29 | :param dict transforms: Optional dictionary that specifies a transform 30 | for a sample site with constrained support to unconstrained space. The 31 | transform should be invertible, and implement `log_abs_det_jacobian`. 32 | If not specified and the model has sites with constrained support, 33 | automatic transformations will be applied, as specified in 34 | :mod:`torch.distributions.constraint_registry`. 35 | :param int max_plate_nesting: Optional bound on max number of nested 36 | :func:`pyro.plate` contexts. This is required if model contains 37 | discrete sample sites that can be enumerated over in parallel. 38 | :param bool jit_compile: Optional parameter denoting whether to use 39 | the PyTorch JIT to trace the log density computation, and use this 40 | optimized executable trace in the integrator. 41 | :param dict jit_options: A dictionary contains optional arguments for 42 | :func:`torch.jit.trace` function. 43 | :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT 44 | tracer when ``jit_compile=True``. Default is False. 45 | :param callable init_strategy: A per-site initialization function. 46 | See :ref:`autoguide-initialization` section for available functions. 47 | """ 48 | 49 | def __init__(self, 50 | model=None, 51 | potential_fn=None, 52 | variance=1.0, 53 | transforms=None, 54 | max_plate_nesting=None, 55 | jit_compile=False, 56 | jit_options=None, 57 | ignore_jit_warnings=False, 58 | init_strategy=init_to_uniform): 59 | if not ((model is None) ^ (potential_fn is None)): 60 | raise ValueError("Only one of `model` or `potential_fn` must be specified.") 61 | # NB: deprecating args - model, transforms 62 | self.model = model 63 | self.transforms = transforms 64 | self._max_plate_nesting = max_plate_nesting 65 | self._jit_compile = jit_compile 66 | self._jit_options = jit_options 67 | self._ignore_jit_warnings = ignore_jit_warnings 68 | self._init_strategy = init_strategy 69 | self.potential_fn = potential_fn 70 | self.variance = variance 71 | self._reset() 72 | super().__init__() 73 | 74 | def _initialize_model_properties(self, model_args, model_kwargs): 75 | init_params, potential_fn, transforms, trace = initialize_model( 76 | self.model, 77 | model_args, 78 | model_kwargs, 79 | transforms=self.transforms, 80 | max_plate_nesting=self._max_plate_nesting, 81 | jit_compile=self._jit_compile, 82 | jit_options=self._jit_options, 83 | skip_jit_warnings=self._ignore_jit_warnings, 84 | init_strategy=self._init_strategy, 85 | initial_params=None, 86 | ) 87 | self.potential_fn = potential_fn 88 | self.transforms = transforms 89 | self._initial_params = init_params 90 | self._prototype_trace = trace 91 | 92 | @property 93 | def initial_params(self): 94 | return self._initial_params 95 | 96 | @initial_params.setter 97 | def initial_params(self, params): 98 | self._initial_params = params 99 | 100 | def _reset(self): 101 | self._warmup_steps = None 102 | self._t = 0 103 | self._accept_cnt = 0 104 | self._mean_accept_prob = 0. 105 | 106 | def setup(self, warmup_steps, *args, **kwargs): 107 | self._warmup_steps = warmup_steps 108 | if self.model is not None: 109 | self._initialize_model_properties(args, kwargs) 110 | 111 | def cleanup(self): 112 | self._reset() 113 | 114 | def sample(self, params): 115 | lp_old = -self.potential_fn(params) 116 | new_params = {} 117 | for site_name, v in params.items(): 118 | size = v.shape 119 | step_dist = Normal(torch.zeros(v.shape), torch.ones(v.shape)*self.variance) 120 | new_params[site_name] = v + pyro.sample( 121 | "{}_step".format(site_name), step_dist) 122 | lp_new = -self.potential_fn(new_params) 123 | accept_prob = (lp_new - lp_old).exp().clamp(max=1.) 124 | u = pyro.sample("accept_thresh", dist.Uniform(0., 1.)) 125 | 126 | # Do MH check for acceptance based on score of new params. 127 | if u <= accept_prob: 128 | self._accept_cnt += 1 129 | accepted = True 130 | params = new_params 131 | else: 132 | accepted = False 133 | 134 | # Compute diagnostics. 135 | self._t += 1 136 | if self._t > self._warmup_steps: 137 | n = self._t - self._warmup_steps 138 | if accepted: 139 | self._accept_cnt += 1 140 | else: 141 | n = self._t 142 | 143 | self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n 144 | 145 | return params 146 | 147 | def logging(self): 148 | return OrderedDict([ 149 | ("acc. prob", "{:.3f}".format(self._mean_accept_prob)) 150 | ]) 151 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars/test/__init__.py -------------------------------------------------------------------------------- /spatial_scene_grammars/test/gaussian_grammar.py: -------------------------------------------------------------------------------- 1 | import pyro 2 | import pyro.distributions as dist 3 | import torch 4 | import torch.distributions.constraints as constraints 5 | from torch.nn.parameter import Parameter 6 | 7 | from spatial_scene_grammars.torch_utils import ConstrainedParameter 8 | from spatial_scene_grammars.nodes import * 9 | from spatial_scene_grammars.rules import * 10 | 11 | torch.set_default_tensor_type(torch.DoubleTensor) 12 | pyro.enable_validation(True) 13 | 14 | ''' 15 | 16 | Simple grammar for testing parameter estimation / parsing 17 | that has infinite support (i.e. you can't back into a corner where 18 | a set of parameters will make parsing an observation impossible). 19 | 20 | A -> B and C 21 | B -> D or E 22 | D -> G with 50% prob 23 | C -> Geometric repetitions of F 24 | ''' 25 | 26 | class NodeG(TerminalNode): 27 | def __init__(self, tf): 28 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 29 | 30 | class NodeF(TerminalNode): 31 | def __init__(self, tf): 32 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 33 | 34 | class NodeE(TerminalNode): 35 | def __init__(self, tf): 36 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 37 | 38 | class NodeD(IndependentSetNode): 39 | def __init__(self, tf): 40 | super().__init__( 41 | rule_probs=torch.tensor([0.5]), 42 | observed=True, 43 | physics_geometry_info=None, 44 | tf=tf 45 | ) 46 | @classmethod 47 | def generate_rules(cls): 48 | return [ 49 | ProductionRule( 50 | child_type=NodeG, 51 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 52 | rotation_rule=UnconstrainedRotationRule() 53 | ) 54 | ] 55 | 56 | class NodeC(RepeatingSetNode): 57 | def __init__(self, tf): 58 | super().__init__( 59 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=torch.tensor(0.2), max_children=5), 60 | observed=False, 61 | physics_geometry_info=None, 62 | tf=tf 63 | ) 64 | @classmethod 65 | def generate_rules(cls): 66 | return [ 67 | ProductionRule( 68 | child_type=NodeF, 69 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 70 | rotation_rule=UnconstrainedRotationRule() 71 | ) 72 | ] 73 | 74 | class NodeB(OrNode): 75 | def __init__(self, tf): 76 | super().__init__( 77 | rule_probs=torch.tensor([0.75, 0.25]), 78 | observed=False, 79 | physics_geometry_info=None, 80 | tf=tf 81 | ) 82 | @classmethod 83 | def generate_rules(cls): 84 | return [ 85 | ProductionRule( 86 | child_type=NodeD, 87 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 88 | rotation_rule=UniformBoundedRevoluteJointRule.from_bounds(axis=torch.tensor([0., 0., 1.]), lb=-1., ub=1.) 89 | ), 90 | ProductionRule( 91 | child_type=NodeE, 92 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 93 | rotation_rule=UnconstrainedRotationRule() 94 | ) 95 | ] 96 | 97 | class NodeA(AndNode): 98 | def __init__(self, tf): 99 | super().__init__( 100 | observed=False, physics_geometry_info=None, tf=tf 101 | ) 102 | @classmethod 103 | def generate_rules(cls): 104 | return [ 105 | ProductionRule( 106 | child_type=NodeB, 107 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 108 | rotation_rule=UnconstrainedRotationRule() 109 | ), 110 | ProductionRule( 111 | child_type=NodeC, 112 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 113 | rotation_rule=UnconstrainedRotationRule() 114 | ) 115 | ] 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/grammar.py: -------------------------------------------------------------------------------- 1 | import pyro 2 | import pyro.distributions as dist 3 | import torch 4 | import torch.distributions.constraints as constraints 5 | from torch.nn.parameter import Parameter 6 | 7 | from spatial_scene_grammars.torch_utils import ConstrainedParameter 8 | from spatial_scene_grammars.nodes import * 9 | from spatial_scene_grammars.rules import * 10 | 11 | torch.set_default_tensor_type(torch.DoubleTensor) 12 | pyro.enable_validation(True) 13 | 14 | ''' 15 | 16 | Simple grammar for testing that gets basic coverage of 17 | node types. 18 | 19 | A -> B and C 20 | B -> D or E 21 | D -> G with 50% prob 22 | C -> Geometric repetitions of F 23 | ''' 24 | 25 | class NodeG(TerminalNode): 26 | def __init__(self, tf): 27 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 28 | 29 | class NodeF(TerminalNode): 30 | def __init__(self, tf): 31 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 32 | 33 | class NodeE(TerminalNode): 34 | def __init__(self, tf): 35 | super().__init__(observed=True, physics_geometry_info=None, tf=tf) 36 | 37 | class NodeD(IndependentSetNode): 38 | def __init__(self, tf): 39 | super().__init__( 40 | rule_probs=torch.tensor([0.5]), 41 | observed=True, 42 | physics_geometry_info=None, 43 | tf=tf 44 | ) 45 | @classmethod 46 | def generate_rules(cls): 47 | return [ 48 | ProductionRule( 49 | child_type=NodeG, 50 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 51 | rotation_rule=SameRotationRule() 52 | ) 53 | ] 54 | 55 | class NodeC(RepeatingSetNode): 56 | def __init__(self, tf): 57 | super().__init__( 58 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=torch.tensor(0.2), max_children=5), 59 | observed=False, 60 | physics_geometry_info=None, 61 | tf=tf 62 | ) 63 | @classmethod 64 | def generate_rules(cls): 65 | return [ 66 | ProductionRule( 67 | child_type=NodeF, 68 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 69 | rotation_rule=UnconstrainedRotationRule() 70 | ) 71 | ] 72 | 73 | class NodeB(OrNode): 74 | def __init__(self, tf): 75 | super().__init__( 76 | rule_probs=torch.tensor([0.75, 0.25]), 77 | observed=False, 78 | physics_geometry_info=None, 79 | tf=tf 80 | ) 81 | @classmethod 82 | def generate_rules(cls): 83 | return [ 84 | ProductionRule( 85 | child_type=NodeD, 86 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=-torch.ones(3), ub=torch.ones(3)), 87 | rotation_rule=UniformBoundedRevoluteJointRule.from_bounds(axis=torch.tensor([0., 0., 1.]), lb=-1., ub=1.) 88 | ), 89 | ProductionRule( 90 | child_type=NodeE, 91 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds(lb=-torch.ones(3), ub=torch.ones(3)), 92 | rotation_rule=UnconstrainedRotationRule() 93 | ) 94 | ] 95 | 96 | class NodeA(AndNode): 97 | def __init__(self, tf): 98 | super().__init__( 99 | observed=False, physics_geometry_info=None, tf=tf 100 | ) 101 | @classmethod 102 | def generate_rules(cls): 103 | return [ 104 | ProductionRule( 105 | child_type=NodeB, 106 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 107 | rotation_rule=UnconstrainedRotationRule() 108 | ), 109 | ProductionRule( 110 | child_type=NodeC, 111 | xyz_rule=WorldFrameBBoxRule.from_bounds(lb=torch.zeros(3), ub=torch.ones(3)), 112 | rotation_rule=UnconstrainedRotationRule() 113 | ) 114 | ] 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_distributions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import torch 5 | import pyro 6 | 7 | from spatial_scene_grammars.distributions import * 8 | 9 | torch.set_default_tensor_type(torch.DoubleTensor) 10 | 11 | @pytest.fixture(params=range(10)) 12 | def set_seed(request): 13 | pyro.set_rng_seed(request.param) 14 | 15 | def test_UniformWithEqualityHandling(set_seed): 16 | lb = torch.tensor([0., 0., 0.]) 17 | ub = torch.tensor([0., 1., 2.]) 18 | expected_ll = torch.log(torch.tensor([1., 1., 0.5])) 19 | dist = UniformWithEqualityHandling(lb, ub, validate_args=True) 20 | val = dist.sample() 21 | assert all(val >= lb) 22 | assert all(val <= ub) 23 | val = dist.rsample() 24 | assert all(val >= lb) 25 | assert all(val <= ub) 26 | ll = dist.log_prob(val) 27 | assert all(torch.isclose(ll, expected_ll)), ll 28 | 29 | cdf = dist.cdf(torch.tensor([0., 0., 0.])) 30 | assert all(torch.isclose(cdf, torch.tensor([1., 0., 0.]))), cdf 31 | cdf = dist.cdf(torch.tensor([0., 1., 1.])) 32 | assert all(torch.isclose(cdf, torch.tensor([1., 1., 0.5]))), cdf 33 | cdf = dist.cdf(torch.tensor([0., 1., 2.])) 34 | assert all(torch.isclose(cdf, torch.tensor([1., 1., 1.]))), cdf 35 | 36 | entropy = dist.entropy() 37 | expected_entropy = torch.zeros(3) 38 | expected_entropy[1:] = torch.distributions.Uniform(lb[1:], ub[1:]).entropy() 39 | assert all(torch.isclose(entropy, expected_entropy)), entropy 40 | 41 | # Make sure it works with different input shapes 42 | lb = 0. 43 | ub = 1. 44 | dist = UniformWithEqualityHandling(lb, ub, validate_args=True) 45 | val = dist.rsample() 46 | dist.log_prob(val) 47 | dist.cdf(val) 48 | 49 | lb = 1. 50 | ub = 1. 51 | dist = UniformWithEqualityHandling(lb, ub, validate_args=True) 52 | val = dist.rsample() 53 | dist.log_prob(val) 54 | dist.cdf(val) 55 | 56 | lb = torch.zeros(3, 3) 57 | ub = torch.ones(3, 3) 58 | ub[1, 1] = 0. 59 | dist = UniformWithEqualityHandling(lb, ub, validate_args=True) 60 | val = dist.rsample() 61 | dist.log_prob(val) 62 | dist.cdf(val) 63 | 64 | 65 | 66 | 67 | def test_left_sided_constraint(): 68 | constraint = LeftSidedConstraint() 69 | assert constraint.check(torch.zeros(0)) 70 | 71 | assert constraint.check(torch.zeros(1)) 72 | assert constraint.check(torch.zeros(2)) 73 | assert constraint.check(torch.zeros(10)) 74 | 75 | assert constraint.check(torch.ones(1)) 76 | assert constraint.check(torch.ones(2)) 77 | assert constraint.check(torch.ones(10)) 78 | 79 | 80 | assert constraint.check(torch.cat([torch.ones(1), torch.zeros(1)])) 81 | assert constraint.check(torch.cat([torch.ones(5), torch.zeros(5)])) 82 | assert constraint.check(torch.cat([torch.ones(9), torch.zeros(1)])) 83 | 84 | assert not constraint.check(torch.cat([torch.zeros(1), torch.ones(9)])) 85 | assert not constraint.check(torch.cat([torch.zeros(5), torch.ones(5)])) 86 | assert not constraint.check(torch.cat([torch.zeros(9), torch.ones(1)])) 87 | 88 | def test_VectorCappedGeometricDist(): 89 | p = 0.5 90 | k = 15 91 | dist = VectorCappedGeometricDist(geometric_prob=p, max_repeats=k, validate_args=True) 92 | vec = dist.sample() 93 | assert vec.shape == (k,) 94 | 95 | test_vec = torch.zeros(k).int() 96 | ll = dist.log_prob(test_vec).item() 97 | target_ll = np.log(p) 98 | assert np.allclose(ll, target_ll), "ll %f vs %f" % (ll, target_ll) 99 | 100 | for stop_k in range(k+1): 101 | test_vec = torch.zeros(k).int() 102 | test_vec[:stop_k] = 1 103 | ll = dist.log_prob(test_vec).item() 104 | target_p = (1. - p)**stop_k * p 105 | if stop_k == k: 106 | print(np.log(target_p)) 107 | target_p += (1. - p)**(k + 1) 108 | print(" changed to ", np.log(target_p)) 109 | target_ll = np.log(target_p) 110 | assert np.allclose(ll, target_ll), "ll %f vs %f at %d" % (ll, target_ll, stop_k) 111 | 112 | def test_LeftSidedRepeatingOnesDist(): 113 | N = 5 114 | weights = torch.arange(0, N+1) + 1. 115 | weights /= weights.sum() 116 | print(weights) 117 | dist = LeftSidedRepeatingOnesDist(categorical_probs=weights, validate_args=True) 118 | vec = dist.sample() 119 | assert vec.shape == (N,) 120 | 121 | for k in range(N+1): 122 | test_vec = torch.zeros(N).int() 123 | test_vec[:k] = 1 124 | ll = dist.log_prob(test_vec).item() 125 | target_ll = torch.log(weights[k]) 126 | assert np.allclose(ll, target_ll), "ll %f vs %f" % (ll, target_ll) 127 | 128 | def test_Bingham(set_seed): 129 | param_m = torch.eye(4) 130 | param_z = torch.tensor([-100., -10., -1., 0.]) 131 | dist = BinghamDistribution(param_m, param_z) 132 | samples = dist.sample(sample_shape=(100,)) 133 | assert samples.shape == (100, 4) 134 | sample_prob = dist.log_prob(samples) 135 | assert all(torch.isfinite(sample_prob)) and sample_prob.shape == (100,) 136 | 137 | if __name__ == "__main__": 138 | pytest.main() 139 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_grammar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pyro 5 | import pyro.poutine 6 | 7 | from spatial_scene_grammars.nodes import * 8 | from spatial_scene_grammars.rules import * 9 | from spatial_scene_grammars.scene_grammar import * 10 | 11 | from .grammar import * 12 | 13 | from torch.distributions import constraints 14 | 15 | from pytorch3d.transforms.rotation_conversions import ( 16 | axis_angle_to_matrix 17 | ) 18 | 19 | torch.set_default_tensor_type(torch.DoubleTensor) 20 | 21 | @pytest.fixture(params=range(10)) 22 | def set_seed(request): 23 | pyro.set_rng_seed(request.param) 24 | 25 | 26 | ## Basic grammar and tree functionality 27 | def test_grammar(set_seed): 28 | grammar = SpatialSceneGrammar( 29 | root_node_type = NodeA, 30 | root_node_tf = torch.eye(4) 31 | ) 32 | 33 | assert grammar.all_types == set([NodeA, NodeB, NodeC, NodeD, NodeE, NodeF, NodeG]) 34 | 35 | tree = grammar.sample_tree() 36 | 37 | assert(torch.isfinite(tree.score())) 38 | 39 | assert isinstance(tree, SceneTree) 40 | 41 | root = tree.get_root() 42 | assert isinstance(root, NodeA) 43 | assert len(list(tree.predecessors(root))) == 0 44 | assert len(list(tree.successors(root))) == 2 # AND rule with 2 children 45 | 46 | obs = tree.get_observed_nodes() 47 | if len(obs) > 0: 48 | assert all([isinstance(c, (NodeD, NodeE, NodeF, NodeG)) for c in obs]) 49 | 50 | assert len(tree.make_from_observed_nodes(obs).nodes) == len(obs) 51 | 52 | assert len(tree.find_nodes_by_type(NodeA)) == 1 53 | 54 | A = tree.find_nodes_by_type(NodeA)[0] 55 | B = tree.find_nodes_by_type(NodeB)[0] 56 | assert tree.get_parent(A) is None 57 | assert tree.get_parent(B) is A 58 | assert tree.get_rule_for_child(A, B) is A.rules[0] 59 | 60 | def test_tree_score(set_seed): 61 | grammar = SpatialSceneGrammar( 62 | root_node_type = NodeA, 63 | root_node_tf = torch.eye(4) 64 | ) 65 | 66 | trace = pyro.poutine.trace(grammar.sample_tree).get_trace() 67 | tree = trace.nodes["_RETURN"]["value"] 68 | expected_score = trace.log_prob_sum() 69 | score = tree.score() 70 | assert torch.isclose(expected_score, score), "%f vs %f" % (score, expected_score) 71 | 72 | def test_supertree(): 73 | grammar = SpatialSceneGrammar( 74 | root_node_type = NodeA, 75 | root_node_tf = torch.eye(4) 76 | ) 77 | # excessive recursion depth for this grammar 78 | super_tree = grammar.make_super_tree(max_recursion_depth=20) 79 | # Some basic, necessary-but-not-sufficient checks that the 80 | # tree isn't obviously wrong. 81 | assert isinstance(super_tree.get_root(), NodeA) 82 | C = super_tree.find_nodes_by_type(NodeC)[0] 83 | assert len(super_tree.find_nodes_by_type(NodeF)) == C.max_children 84 | Bs = super_tree.find_nodes_by_type(NodeB) 85 | assert len(Bs) == 1 86 | B = Bs[0] 87 | assert len(list(super_tree.successors(B))) == 2 # Takes both options 88 | 89 | # Test that recursion depth takes effect by making it too small 90 | super_tree = grammar.make_super_tree(max_recursion_depth=1) 91 | assert len(super_tree.find_nodes_by_type(NodeA)) == 1 92 | assert len(super_tree.find_nodes_by_type(NodeC)) == 1 93 | assert len(super_tree.find_nodes_by_type(NodeF)) == 0 94 | 95 | # No depends, but this test is depended on. 96 | @pytest.mark.dependency(name="test_param_prior") 97 | def test_param_prior(set_seed): 98 | # Generate a grammar with random parameters from 99 | # their priors. 100 | grammar = SpatialSceneGrammar( 101 | root_node_type = NodeA, 102 | root_node_tf = torch.eye(4), 103 | sample_params_from_prior=True 104 | ) 105 | # Scoring should still work identically. 106 | trace = pyro.poutine.trace(grammar.sample_tree).get_trace() 107 | tree = trace.nodes["_RETURN"]["value"] 108 | expected_score = trace.log_prob_sum() 109 | score = tree.score() 110 | assert torch.isclose(expected_score, score), "%f vs %f" % (score, expected_score) 111 | # But the underlying param values should be different from the 112 | # default values. 113 | default_grammar = SpatialSceneGrammar( 114 | root_node_type = NodeA, 115 | root_node_tf = torch.eye(4), 116 | sample_params_from_prior=False 117 | ) 118 | 119 | all_identical = True 120 | for p1, p2 in zip(grammar.parameters(), default_grammar.parameters()): 121 | if not torch.allclose(p1, p2): 122 | all_identical = False 123 | break 124 | assert not all_identical, "Grammar draw from prior was exactly the same as default grammar." 125 | 126 | @pytest.mark.dependency(depends=["test_param_prior"]) 127 | def test_save_load_grammar(): 128 | # Generate a grammar with random parameters from 129 | # their priors. From `test_param_prior`, we know this'll 130 | # be unique from the default grammar. 131 | grammar = SpatialSceneGrammar( 132 | root_node_type = NodeA, 133 | root_node_tf = torch.eye(4), 134 | sample_params_from_prior=True 135 | ) 136 | # Default grammar, for comparison. 137 | default_grammar = SpatialSceneGrammar( 138 | root_node_type = NodeA, 139 | root_node_tf = torch.eye(4), 140 | sample_params_from_prior=False 141 | ) 142 | torch.save(grammar, "/tmp/test_saved_grammar.torch") 143 | 144 | loaded_grammar = torch.load("/tmp/test_saved_grammar.torch") 145 | 146 | all_identical = True 147 | for p1, p2 in zip(loaded_grammar.parameters(), default_grammar.parameters()): 148 | if not torch.allclose(p1, p2): 149 | all_identical = False 150 | break 151 | assert not all_identical, "Grammar draw from prior and saved/loaded was exactly the same as default grammar; didn't load right." 152 | 153 | all_identical = True 154 | for p1, p2 in zip(loaded_grammar.parameters(), grammar.parameters()): 155 | if not torch.allclose(p1, p2): 156 | all_identical = False 157 | break 158 | assert all_identical, "Grammar loaded from file had different params than one saved." 159 | 160 | # Make sure loaded grammar still works 161 | trace = pyro.poutine.trace(loaded_grammar.sample_tree).get_trace() 162 | tree = trace.nodes["_RETURN"]["value"] 163 | expected_score = trace.log_prob_sum() 164 | score = tree.score() 165 | assert torch.isclose(expected_score, score), "%f vs %f" % (score, expected_score) 166 | 167 | def test_attach_reattach(): 168 | grammar = SpatialSceneGrammar( 169 | root_node_type = NodeA, 170 | root_node_tf = torch.eye(4), 171 | sample_params_from_prior=True 172 | ) 173 | tree = grammar.sample_tree(detach=True) 174 | before_score = tree.score() 175 | assert before_score.grad_fn is None 176 | grammar.update_tree_grammar_parameters(tree) 177 | after_score = tree.score() 178 | assert after_score.grad_fn is not None 179 | after_score.backward() 180 | assert torch.isclose(before_score, after_score) -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_nodes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pyro 5 | 6 | from spatial_scene_grammars.nodes import * 7 | from spatial_scene_grammars.rules import * 8 | from torch.distributions import constraints 9 | 10 | from pytorch3d.transforms.rotation_conversions import ( 11 | axis_angle_to_matrix 12 | ) 13 | 14 | from .grammar import * 15 | 16 | torch.set_default_tensor_type(torch.DoubleTensor) 17 | 18 | @pytest.fixture(params=range(10)) 19 | def set_seed(request): 20 | pyro.set_rng_seed(request.param) 21 | 22 | ## Base Node type 23 | def test_Node(): 24 | node = NodeA(tf=torch.eye(4)) 25 | xyz = node.translation 26 | assert isinstance(xyz, torch.Tensor) and xyz.shape == (3,) 27 | R = node.rotation 28 | assert isinstance(R, torch.Tensor) and R.shape == (3, 3) 29 | new_xyz = torch.tensor([1., 2., 3.]) 30 | new_R = axis_angle_to_matrix(torch.tensor([0.5, 0.5, 0.5]).unsqueeze(0))[0, ...] 31 | node.translation = new_xyz 32 | node.rotation = new_R 33 | xyz = node.translation 34 | R = node.rotation 35 | assert isinstance(xyz, torch.Tensor) and xyz.shape == (3,) and torch.allclose(new_xyz, xyz) 36 | R = node.rotation 37 | assert isinstance(R, torch.Tensor) and R.shape == (3, 3) and torch.allclose(R, new_R) 38 | 39 | ## TerminalNode 40 | def test_TerminalNode(): 41 | node = TerminalNode( 42 | tf=torch.eye(4), 43 | observed=False, 44 | physics_geometry_info=None, 45 | do_sanity_checks=True 46 | ) 47 | assert node.sample_children() == [] 48 | 49 | ## AndNode 50 | def test_AndNode(): 51 | node = NodeA(tf=torch.eye(4)) 52 | trace = pyro.poutine.trace(node.sample_children).get_trace() 53 | children = trace.nodes["_RETURN"]["value"] 54 | score = node.score_child_set(children) 55 | expected_prob_hand = torch.zeros(1) 56 | # No sample nodes to get log prob from in this trace since AndNode 57 | # doesn't sample anything. 58 | assert torch.isclose(score, expected_prob_hand), "%s vs %s" % (expected_prob_hand, score) 59 | assert len(node.parameters) == 0 60 | 61 | ## OrNode 62 | def test_OrNode(set_seed): 63 | node = NodeB(tf=torch.eye(4)) 64 | trace = pyro.poutine.trace(node.sample_children).get_trace() 65 | children = trace.nodes["_RETURN"]["value"] 66 | trace_node = trace.nodes["OrNode_child"] 67 | expected_prob = trace_node["fn"].log_prob(trace_node["value"]) 68 | expected_prob_hand = torch.log(node.rule_probs[children[0].rule_k]) 69 | score = node.score_child_set(children) 70 | assert torch.isclose(score, expected_prob), "%s vs %s" % (expected_prob, score) 71 | assert torch.isclose(score, expected_prob_hand), "%s vs %s" % (expected_prob_hand, score) 72 | assert torch.allclose(node.parameters, node.rule_probs) 73 | 74 | 75 | ## RepeatingSetNode 76 | def test_RepeatingSetNode(set_seed): 77 | node = NodeC(tf=torch.eye(4)) 78 | trace = pyro.poutine.trace(node.sample_children).get_trace() 79 | children = trace.nodes["_RETURN"]["value"] 80 | trace_node = trace.nodes["RepeatingSetNode_n"] 81 | expected_prob = trace_node["fn"].log_prob(trace_node["value"]) 82 | assert len(children) <= node.max_children and len(children) >= 1 # Our default params for this node have no chance of 0 children. 83 | score = node.score_child_set(children) 84 | assert torch.isclose(score, expected_prob), "%s vs %s" % (expected_prob, score) 85 | assert torch.allclose(node.parameters, node.rule_probs) 86 | 87 | ## IndependentSetNode 88 | def test_IndependentSetNode(set_seed): 89 | node = NodeD(tf=torch.eye(4)) 90 | trace = pyro.poutine.trace(node.sample_children).get_trace() 91 | children = trace.nodes["_RETURN"]["value"] 92 | trace_node = trace.nodes["IndependentSetNode_n"] 93 | expected_prob = trace_node["fn"].log_prob(trace_node["value"]).sum() 94 | score = node.score_child_set(children) 95 | assert torch.isclose(score, expected_prob), "%s vs %s" % (expected_prob, score) 96 | assert torch.allclose(node.parameters, node.rule_probs) 97 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_parameter_estimation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pyro 5 | 6 | from spatial_scene_grammars.scene_grammar import * 7 | from spatial_scene_grammars.parameter_estimation import * 8 | 9 | from .grammar import * 10 | 11 | from pydrake.all import ( 12 | SnoptSolver 13 | ) 14 | 15 | torch.set_default_tensor_type(torch.DoubleTensor) 16 | 17 | @pytest.fixture(params=range(1)) 18 | def set_seed(request): 19 | pyro.set_rng_seed(request.param) 20 | 21 | ## Get some quick rough coverage of EM algo by 22 | # running a few steps of parameter fitting on the grammar. 23 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 24 | reason='This test relies on Gurobi and SNOPT.') 25 | def test_em(set_seed): 26 | grammar = SpatialSceneGrammar( 27 | root_node_type = NodeA, 28 | root_node_tf = torch.eye(4) 29 | ) 30 | observed_node_sets = [grammar.sample_tree(detach=True).get_observed_nodes() for k in range(3)] 31 | 32 | em = EMWrapper(grammar, observed_node_sets) 33 | # Can't do more than 1 iter of fitting in case the parameters jump 34 | # to a setting that makes parsing impossible (which is likely, since 35 | # we're only doing a very noisy few steps here.) 36 | em.do_iterated_em_fitting(em_iterations=3) 37 | 38 | # Make sure something happened + was logged 39 | assert len(em.grammar_iters) == 4 # 3 iters + original 40 | assert all(torch.isfinite(torch.tensor(em.log_evidence_iters).view(-1))), em.log_evidence_iters 41 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_sampling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pyro 5 | import pyro.poutine 6 | from pydrake.all import SnoptSolver 7 | 8 | from spatial_scene_grammars.nodes import * 9 | from spatial_scene_grammars.rules import * 10 | from spatial_scene_grammars.scene_grammar import * 11 | from spatial_scene_grammars.sampling import * 12 | 13 | from .grammar import * 14 | 15 | from torch.distributions import constraints 16 | 17 | 18 | torch.set_default_tensor_type(torch.DoubleTensor) 19 | 20 | @pytest.fixture(params=range(3)) 21 | def set_seed(request): 22 | pyro.set_rng_seed(request.param) 23 | 24 | # Proof-of-life of sampling routines 25 | @pytest.mark.skipif(not SnoptSolver().available(), 26 | reason='This test relies on Gurobi and SNOPT.') 27 | @pytest.mark.parametrize("perturb_in_config_space", [True, False]) 28 | @pytest.mark.parametrize("do_hit_and_run_postprocess", [True, False]) 29 | def test_sampling(set_seed, perturb_in_config_space, do_hit_and_run_postprocess): 30 | grammar = SpatialSceneGrammar( 31 | root_node_type = NodeA, 32 | root_node_tf = torch.eye(4) 33 | ) 34 | tree = grammar.sample_tree(detach=True) 35 | 36 | # Hack tree to be totally unobserved to get more code coverage 37 | for node in tree: 38 | node.observed = False 39 | N_samples = 5 40 | sampled_trees = do_fixed_structure_mcmc( 41 | grammar, tree, num_samples=N_samples, 42 | perturb_in_config_space=perturb_in_config_space, verbose=2, 43 | vis_callback=None, 44 | translation_variance=0.05, 45 | rotation_variance=0.05, 46 | do_hit_and_run_postprocess=do_hit_and_run_postprocess) 47 | 48 | assert len(sampled_trees) == N_samples 49 | new_tree = sampled_trees[-1] 50 | # Should be true that the last tree isn't exactly the same as the 51 | # first tree. 52 | has_difference = False 53 | for new_node, old_node in zip(new_tree.nodes, tree.nodes): 54 | if not (torch.allclose(new_node.translation, old_node.translation) and 55 | torch.allclose(new_node.rotation, old_node.rotation)): 56 | has_difference = True 57 | assert has_difference, "Sampling did not produce a unique / different tree." 58 | -------------------------------------------------------------------------------- /spatial_scene_grammars/test/test_torch_util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pyro 5 | 6 | from spatial_scene_grammars.torch_utils import * 7 | from spatial_scene_grammars.drake_interop import drake_tf_to_torch_tf 8 | from torch.distributions import constraints 9 | 10 | from pytorch3d.transforms.rotation_conversions import ( 11 | euler_angles_to_matrix 12 | ) 13 | 14 | from pydrake.all import ( 15 | UniformlyRandomRotationMatrix, 16 | RandomGenerator, 17 | RigidTransform, 18 | RollPitchYaw 19 | ) 20 | 21 | torch.set_default_tensor_type(torch.DoubleTensor) 22 | 23 | @pytest.fixture(params=range(10)) 24 | def set_seed(request): 25 | pyro.set_rng_seed(request.param) 26 | return RandomGenerator(request.param) 27 | 28 | def test_inverse_tf(set_seed): 29 | # Test against drake 30 | t = np.random.uniform(-1, 1, size=3) 31 | R = UniformlyRandomRotationMatrix(set_seed) 32 | drake_tf = RigidTransform(p=t, R=R) 33 | tf = drake_tf_to_torch_tf(drake_tf) 34 | tf_inv = invert_torch_tf(tf) 35 | assert torch.allclose(tf_inv, drake_tf_to_torch_tf(drake_tf.inverse())) 36 | 37 | def test_inv_softplus(set_seed): 38 | x_test = torch.abs(torch.normal(0., 1., size=(100,))) + 0.01 39 | assert torch.allclose(inv_softplus(torch.nn.functional.softplus(x_test)), x_test) 40 | 41 | def test_inv_sigmoid(set_seed): 42 | x_test = torch.clip(torch.normal(0., 1., size=(100,)), -0.99, 0.99) 43 | x_test_pred = inv_sigmoid(torch.sigmoid(x_test)) 44 | assert torch.allclose(x_test_pred, x_test, atol=1E-5), x_test - x_test_pred 45 | 46 | def test_constrained_param(): 47 | # Make sure calling the constraints work for empty and nonempty init_values. 48 | def do_tests(init_value, constraint): 49 | p = ConstrainedParameter(init_value, constraint=constraint) 50 | val = p() 51 | assert(torch.allclose(init_value, val)) 52 | p.set_unconstrained_as_param(init_value) 53 | val = p.get_unconstrained_value() 54 | assert(torch.allclose(init_value, val)) 55 | p.set(init_value) 56 | val = p() 57 | assert(torch.allclose(init_value, val)) 58 | do_tests(torch.tensor([]), constraint=constraints.real) 59 | do_tests(torch.tensor([]), constraint=constraints.simplex) 60 | do_tests(torch.ones(10), constraint=constraints.real) 61 | do_tests(torch.ones(10)/10., constraint=constraints.simplex) 62 | 63 | # Test that gradient propagate correctly after changing param value 64 | p = ConstrainedParameter(torch.ones(3), constraint=constraints.real) 65 | loss = p().square().sum() 66 | # grad = 2. * val 67 | loss.backward() 68 | assert torch.allclose(p.get_unconstrained_value().grad, 2.*torch.ones(3)), p.get_unconstrained_value().grad 69 | 70 | p.get_unconstrained_value().grad = None 71 | p.set(torch.ones(3)*2.) 72 | loss = p().square().sum() 73 | loss.backward() 74 | assert torch.allclose(p.get_unconstrained_value().grad, 2.*2.*torch.ones(3)), p.get_unconstrained_value().grad 75 | 76 | 77 | def test_interpolate_translation(): 78 | t1 = torch.tensor([0., 0., 0.]) 79 | t2 = torch.tensor([1., 1., 1]) 80 | 81 | t = interp_translation(t1, t2, interp_factor=0.) 82 | assert torch.allclose(t, t1), (t, t1) 83 | 84 | t = interp_translation(t1, t2, interp_factor=1.) 85 | assert torch.allclose(t, t2), (t, t2) 86 | 87 | t = interp_translation(t1, t2, interp_factor=0.5) 88 | expected = torch.tensor([0.5, 0.5, 0.5]) 89 | assert torch.allclose(t, expected), (t, expected) 90 | 91 | t = interp_translation(t2, t2, interp_factor=0.75) 92 | assert torch.allclose(t, t2), (t, t2) 93 | 94 | 95 | 96 | def test_interpolate_rotation(): 97 | R1 = euler_angles_to_matrix(torch.tensor([0.0, 0.0, 0.0]), convention="ZYX") 98 | R2 = euler_angles_to_matrix(torch.tensor([np.pi/2., 0., 0.]), convention="ZYX") 99 | 100 | R = interp_rotation(R1, R2, interp_factor=0.) 101 | assert torch.allclose(R, R1), (R, R1) 102 | 103 | R = interp_rotation(R1, R2, interp_factor=1.) 104 | assert torch.allclose(R, R2), (R, R2) 105 | 106 | R = interp_rotation(R2, R2, interp_factor=0.75) 107 | assert torch.allclose(R, R2), (R, R2) 108 | 109 | R = interp_rotation(R1, R2, interp_factor=0.5) 110 | expected = euler_angles_to_matrix(torch.tensor([np.pi/4., 0., 0.]), convention="ZYX") 111 | assert torch.allclose(R, expected), (R, expected) 112 | 113 | R3 = euler_angles_to_matrix(torch.tensor([np.pi, 0., 0.]), convention="ZYX") 114 | R = interp_rotation(R1, R3, interp_factor=0.5) 115 | angle_distance = torch.arccos((torch.trace(torch.matmul(R.transpose(1, 0), R3)) - 1)/2.) 116 | assert torch.isclose(angle_distance, torch.tensor(np.pi/2.)), (angle_distance, np.pi/2.) 117 | 118 | R4 = euler_angles_to_matrix(torch.tensor([3.*np.pi/2., 0., 0.]), convention="ZYX") 119 | R = interp_rotation(R1, R4, interp_factor=0.5) 120 | expected = euler_angles_to_matrix(torch.tensor([-np.pi/4., 0., 0.]), convention="ZYX") 121 | assert torch.allclose(R, expected), (R, expected) 122 | 123 | 124 | def test_se3_dist(set_seed): 125 | # Should be zero distance 126 | population_1 = drake_tf_to_torch_tf(RigidTransform(p=np.zeros(3))).unsqueeze(0) 127 | population_2 = drake_tf_to_torch_tf(RigidTransform(p=np.zeros(3))).unsqueeze(0) 128 | dists = se3_dist(population_1, population_2, beta=1., eps=0) 129 | assert dists.shape == (1, 1) and torch.isclose(dists[0, 0], torch.tensor(0.)) 130 | 131 | # Should be 1 distance 132 | population_1 = drake_tf_to_torch_tf(RigidTransform(p=np.zeros(3))).unsqueeze(0) 133 | population_2 = drake_tf_to_torch_tf(RigidTransform(p=np.array([0, 1, 0]))).unsqueeze(0) 134 | dists = se3_dist(population_1, population_2, beta=1., eps=0) 135 | assert dists.shape == (1, 1) and torch.isclose(dists[0, 0], torch.tensor(1.)) 136 | 137 | # Should be pi distance 138 | population_1 = drake_tf_to_torch_tf(RigidTransform(p=np.zeros(3))).unsqueeze(0) 139 | population_2 = drake_tf_to_torch_tf(RigidTransform(p=np.zeros(3), rpy=RollPitchYaw(np.pi, 0., 0.))).unsqueeze(0) 140 | dists = se3_dist(population_1, population_2, beta=1., eps=0) 141 | assert dists.shape == (1, 1) and torch.isclose(dists[0, 0], torch.tensor(np.pi)) 142 | 143 | # Make sure it works at scale 144 | M = 200 145 | N = 100 146 | population_1 = [] 147 | for k in range(M): 148 | t = np.random.uniform(-1, 1, size=3) 149 | R = UniformlyRandomRotationMatrix(set_seed) 150 | tf = drake_tf_to_torch_tf(RigidTransform(p=t, R=R)) 151 | population_1.append(tf) 152 | population_2 = [] 153 | for k in range(N): 154 | t = np.random.uniform(-1, 1, size=3) 155 | R = UniformlyRandomRotationMatrix(set_seed) 156 | tf = drake_tf_to_torch_tf(RigidTransform(p=t, R=R)) 157 | population_2.append(tf) 158 | 159 | population_1 = torch.stack(population_1) 160 | population_2 = torch.stack(population_2) 161 | dists = se3_dist(population_1, population_2, beta=1., eps=0) 162 | assert dists.shape == (M, N) and torch.all(torch.isfinite(dists)) and torch.all(dists >= 0) 163 | 164 | def test_mmd_se3(set_seed): 165 | # Basic proof-of-life for at-scale samples 166 | M = 200 167 | N = 100 168 | population_1 = [] 169 | for k in range(M): 170 | t = np.random.uniform(-1, 1, size=3) 171 | R = UniformlyRandomRotationMatrix(set_seed) 172 | tf = drake_tf_to_torch_tf(RigidTransform(p=t, R=R)) 173 | population_1.append(tf) 174 | population_2 = [] 175 | for k in range(N): 176 | t = np.random.uniform(-1, 1, size=3) 177 | R = UniformlyRandomRotationMatrix(set_seed) 178 | tf = drake_tf_to_torch_tf(RigidTransform(p=t, R=R)) 179 | population_2.append(tf) 180 | 181 | population_1 = torch.stack(population_1) 182 | population_2 = torch.stack(population_2) 183 | population_1.requires_grad = True 184 | mmd = calculate_mmd(population_1, population_2, alphas=[0.1, 1.0, 10.0], use_se3_metric=True, beta=1.0) 185 | # Note: this MMD estimate can be negative. See page 729 of https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf. 186 | assert torch.isfinite(mmd) 187 | 188 | # Check grad goes through 189 | mmd.backward() 190 | assert population_1.grad is not None 191 | 192 | if __name__ == "__main__": 193 | pytest.main() -------------------------------------------------------------------------------- /spatial_scene_grammars/visualization.py: -------------------------------------------------------------------------------- 1 | import os, contextlib 2 | import meshcat 3 | import meshcat.geometry as meshcat_geom 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import pydrake 8 | from pydrake.all import ( 9 | ConnectMeshcatVisualizer, 10 | ) 11 | from .drake_interop import * 12 | 13 | def rgb_2_hex(rgb): 14 | # Turn a list of R,G,B elements (any indexable list 15 | # of >= 3 elements will work), where each element is 16 | # specified on range [0., 1.], into the equivalent 17 | # 24-bit value 0xRRGGBB. 18 | val = 0 19 | for i in range(3): 20 | val += (256**(2 - i)) * int(255 * rgb[i]) 21 | return val 22 | 23 | def draw_scene_tree_contents_meshcat(scene_tree, prefix="scene", zmq_url=None, alpha=0.25, draw_clearance_geom=False, quiet=True): 24 | ''' Given a scene tree, draws it in meshcat at the requested ZMQ url. 25 | Can be configured to draw the tree geometry or the clearance geometry. ''' 26 | 27 | if draw_clearance_geom: 28 | builder, mbp, scene_graph = compile_scene_tree_clearance_geometry_to_mbp_and_sg(scene_tree) 29 | else: 30 | builder, mbp, scene_graph, _, _, = compile_scene_tree_to_mbp_and_sg(scene_tree) 31 | mbp.Finalize() 32 | 33 | if quiet: 34 | with open(os.devnull, 'w') as devnull: 35 | with contextlib.redirect_stdout(devnull): 36 | vis = ConnectMeshcatVisualizer(builder, scene_graph, 37 | zmq_url=zmq_url or "default", prefix=prefix) 38 | else: 39 | vis = ConnectMeshcatVisualizer(builder, scene_graph, 40 | zmq_url=zmq_url or "default", prefix=prefix) 41 | vis.delete_prefix() 42 | diagram = builder.Build() 43 | context = diagram.CreateDefaultContext() 44 | vis.load(vis.GetMyContextFromRoot(context)) 45 | diagram.Publish(context) 46 | # Necessary to manually remove this meshcat visualizer now that we're 47 | # done with it, as a lot of Drake systems (that are involved with the 48 | # diagram builder) don't get properly garbage collected. See Drake issue #14387. 49 | # Meshcat collects sockets, so deleting this avoids a file descriptor 50 | # leak. 51 | del vis.vis 52 | 53 | def draw_scene_tree_structure_meshcat(scene_tree, prefix="scene_tree", zmq_url=None, 54 | alpha=0.775, node_sphere_size=0.05, linewidth=2, with_triad=True, quiet=True, 55 | color_by_score=None, delete=True): 56 | # Color by score can be a tuple of min, max score. It'll go from red at min score 57 | # to blue at max score. 58 | # Do actual drawing in meshcat. 59 | 60 | if quiet: 61 | with open(os.devnull, 'w') as devnull: 62 | with contextlib.redirect_stdout(devnull): 63 | vis = meshcat.Visualizer(zmq_url=zmq_url or "tcp://127.0.0.1:6000") 64 | else: 65 | vis = meshcat.Visualizer(zmq_url=zmq_url or "tcp://127.0.0.1:6000") 66 | 67 | if delete: 68 | vis[prefix].delete() 69 | 70 | # Assign functionally random colors to each new node 71 | # type we discover, or color my their scores. 72 | node_class_to_color_dict = {} 73 | cmap = plt.cm.get_cmap('jet') 74 | cmap_counter = 0. 75 | 76 | 77 | k = 0 78 | for node in scene_tree.nodes: 79 | children, rules = scene_tree.get_children_and_rules(node) 80 | 81 | # 82 | if color_by_score is not None: 83 | assert len(color_by_score) == 2, "Color by score should be a tuple of (min, max)" 84 | score = node.score_child_set(children) 85 | print("Node score: ", score) 86 | score = (score - color_by_score[0]) / (color_by_score[1] - color_by_score[0]) 87 | score = 1. - np.clip(score.item(), 0., 1.) 88 | color = rgb_2_hex(cmap(score)) 89 | #color = 0x555555 90 | else: 91 | # Draw this node 92 | node_type_string = node.__class__.__name__ 93 | if node_type_string in node_class_to_color_dict.keys(): 94 | color = node_class_to_color_dict[node_type_string] 95 | else: 96 | color = rgb_2_hex(cmap(cmap_counter)) 97 | node_class_to_color_dict[node_type_string] = color 98 | cmap_counter = np.fmod(cmap_counter + np.pi*2., 1.) 99 | 100 | vis[prefix][node.name + "%d/sphere" % k].set_object( 101 | meshcat_geom.Sphere(node_sphere_size), 102 | meshcat_geom.MeshToonMaterial(color=color, opacity=alpha, transparent=(alpha != 1.), depthTest=False)) 103 | if with_triad: 104 | vis[prefix][node.name + "%d/triad" % k].set_object( 105 | meshcat_geom.triad(scale=node_sphere_size*5.) 106 | ) 107 | 108 | tf = node.tf.cpu().detach().numpy() 109 | vis[prefix][node.name + "%d" % k].set_transform(tf) 110 | 111 | # Draw connections to each child 112 | for child, rule in zip(children, rules): 113 | verts = [] 114 | verts.append(node.tf[:3, 3].cpu().detach().numpy()) 115 | verts.append(child.tf[:3, 3].cpu().detach().numpy()) 116 | verts = np.vstack(verts).T 117 | 118 | if color_by_score is not None: 119 | score = rule.score_child(node, child) 120 | print("Rule score: ", score) 121 | score = (score - color_by_score[0]) / (color_by_score[1] - color_by_score[0]) 122 | score = 1. - np.clip(score.item(), 0., 1.) 123 | color = rgb_2_hex(cmap(score)) 124 | 125 | vis[prefix][node.name + "_to_" + child.name].set_object( 126 | meshcat_geom.Line(meshcat_geom.PointsGeometry(verts), 127 | meshcat_geom.LineBasicMaterial(linewidth=linewidth, color=color, depthTest=False))) 128 | k += 1 -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/__init__.py -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/.gitattributes: -------------------------------------------------------------------------------- 1 | parsing_statistics_runs.pickle filter=lfs diff=lfs merge=lfs -text 2 | fit_em.pickle filter=lfs diff=lfs merge=lfs -text 3 | fit_grammar.torch filter=lfs diff=lfs merge=lfs -text 4 | sink_dataset_v2.zip filter=lfs diff=lfs merge=lfs -text 5 | fit_em_baseline.pickle filter=lfs diff=lfs merge=lfs -text 6 | fit_grammar_baseline.torch filter=lfs diff=lfs merge=lfs -text 7 | fit_simple_grammar.torch filter=lfs diff=lfs merge=lfs -text 8 | baseline_em_mmds_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 9 | em_mmds_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 10 | em_mmds_total_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 11 | fitting_history.pickle filter=lfs diff=lfs merge=lfs -text 12 | baseline_em_mmds_total_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 13 | parsing_statistics_runs_pre_update.pickle filter=lfs diff=lfs merge=lfs -text 14 | sink_dataset.zip filter=lfs diff=lfs merge=lfs -text 15 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/.gitignore: -------------------------------------------------------------------------------- 1 | sink 2 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/baseline_em_mmds_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0ba7ab94ee92fa00fdcd2ce4762b7bda339849cb87c2d64ca18778096b28a0e1 3 | size 34602 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/baseline_em_mmds_total_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3cf9f40c39f94dead6752930dab8f52d3bd4605d420b22680ceadb211d0d3f93 3 | size 3343 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/draw_every_training_example.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import pickle 6 | import time 7 | from tqdm.notebook import tqdm 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.constraints import * 13 | from spatial_scene_grammars.nodes import * 14 | from spatial_scene_grammars.rules import * 15 | from spatial_scene_grammars.scene_grammar import * 16 | from spatial_scene_grammars.visualization import * 17 | from spatial_scene_grammars_examples.dish_bin.grammar import * 18 | from spatial_scene_grammars.parsing import * 19 | from spatial_scene_grammars.sampling import * 20 | from spatial_scene_grammars.parameter_estimation import * 21 | from spatial_scene_grammars.dataset import * 22 | 23 | import meshcat 24 | import meshcat.geometry as meshcat_geom 25 | 26 | # Set up grammar 27 | grammar = SpatialSceneGrammar( 28 | root_node_type = DishBin, 29 | root_node_tf = drake_tf_to_torch_tf(RigidTransform(p=[0.5, 0., 0.])) 30 | ) 31 | grammar.load_state_dict(torch.load("fit_grammar.torch")) 32 | 33 | vis = meshcat.Visualizer() 34 | vis.delete() 35 | print("Meshcat port: ", vis.url()) 36 | 37 | # Convert dataset to observed node sets (caching output) and draw a few examples. 38 | RECONVERT_DATASET = True 39 | DATASET_YAML_FILE = "sink/saved_scenes.yaml" 40 | DATASET_SAVE_FILE = "observed_node_sets.dat" 41 | 42 | if RECONVERT_DATASET or not os.path.exists(DATASET_SAVE_FILE): 43 | type_map = { 44 | "bin": DishBin 45 | } 46 | model_map = { 47 | } 48 | for model_type_set in [PlateModels, CupModels, BowlModels]: 49 | for model_type in model_type_set: 50 | # Have to cut off the "sink" folder to match model names; 51 | # dataset management is ugly and should get reorganized... 52 | model_map[os.path.join(*model_type.sdf.split("/")[1:])] = model_type 53 | observed_node_sets = convert_scenes_yaml_to_observed_nodes(DATASET_YAML_FILE, type_map, model_map) 54 | print("Saving...") 55 | with open(DATASET_SAVE_FILE, "wb") as f: 56 | pickle.dump(observed_node_sets, f) 57 | 58 | print("Loading...") 59 | with open(DATASET_SAVE_FILE, "rb") as f: 60 | observed_node_sets = pickle.load(f) 61 | 62 | for observed_node_set in observed_node_sets: 63 | vis.delete() 64 | draw_scene_tree_contents_meshcat( 65 | SceneTree.make_from_observed_nodes(observed_node_set), 66 | zmq_url=vis.window.zmq_url, prefix="scene" 67 | ) 68 | input() 69 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/draw_samples_from_fit_grammar.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import pickle 6 | import time 7 | from tqdm.notebook import tqdm 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.constraints import * 13 | from spatial_scene_grammars.nodes import * 14 | from spatial_scene_grammars.rules import * 15 | from spatial_scene_grammars.scene_grammar import * 16 | from spatial_scene_grammars.visualization import * 17 | from spatial_scene_grammars_examples.dish_bin.grammar import * 18 | from spatial_scene_grammars.parsing import * 19 | from spatial_scene_grammars.sampling import * 20 | from spatial_scene_grammars.parameter_estimation import * 21 | from spatial_scene_grammars.dataset import * 22 | 23 | import meshcat 24 | import meshcat.geometry as meshcat_geom 25 | 26 | # Set up grammar 27 | grammar = SpatialSceneGrammar( 28 | root_node_type = DishBin, 29 | root_node_tf = drake_tf_to_torch_tf(RigidTransform(p=[0.5, 0., 0.])) 30 | ) 31 | grammar.load_state_dict(torch.load("fit_grammar.torch")) 32 | 33 | vis = meshcat.Visualizer() 34 | print("Meshcat port: ", vis.url()) 35 | 36 | class InBinConstraint(Constraint): 37 | # XY coord of each object inside .56 x .83 dish bin 38 | def __init__(self): 39 | lb = torch.tensor([-0.56/2+0.1, -0.83/2+0.1, 0.]) 40 | ub = torch.tensor([0.56/2-0.1, 0.83/2-0.1, 1.]) 41 | super().__init__( 42 | lower_bound=lb, 43 | upper_bound=ub 44 | ) 45 | def eval(self, scene_tree): 46 | xys = [] 47 | bin_pos = scene_tree.find_nodes_by_type(DishBin)[0].translation 48 | for node in scene_tree.find_nodes_by_type(ObjectModel): 49 | xys.append(node.translation - bin_pos) 50 | return torch.stack(xys, axis=0) 51 | def add_to_ik_prog(self, scene_tree, ik, mbp, mbp_context, node_to_free_body_ids_map): 52 | bin_pos = scene_tree.find_nodes_by_type(DishBin)[0].translation.detach().numpy() 53 | for node in scene_tree.find_nodes_by_type(ObjectModel): 54 | for body_id in node_to_free_body_ids_map[node]: 55 | body = mbp.get_body(body_id) 56 | print(bin_pos + self.lower_bound.detach().numpy(), bin_pos + self.upper_bound.detach().numpy()) 57 | ik.AddPositionConstraint( 58 | body.body_frame(), np.zeros(3), 59 | mbp.world_frame(), 60 | bin_pos + self.lower_bound.detach().numpy(), 61 | bin_pos + self.upper_bound.detach().numpy() 62 | ) 63 | constraints = [ 64 | InBinConstraint() 65 | ] 66 | 67 | # Draw a lot of projected environments 68 | torch.random.manual_seed(0) 69 | for k in range(100): 70 | tree = grammar.sample_tree(detach=True) 71 | tree, success = rejection_sample_structure_to_feasibility(tree, constraints=constraints) 72 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="pre_projected_sample/contents") 73 | draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="pre_projected_sample/structure", alpha=0.25, node_sphere_size=0.01) 74 | tree = project_tree_to_feasibility_via_sim(tree, constraints=constraints, zmq_url=vis.window.zmq_url) 75 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="projected_samples/content") 76 | #draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url, prefix="projected_samples/structure") 77 | input() -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/em_mmds_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5242b7d87cee3dc1c12a048fdc2b86e4f9835da31ea60121804db92887662d71 3 | size 34594 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/em_mmds_total_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2761f96472a992f91bcfd17fd37c4e10b2e962a9fcf0ec4005dc7471a8e38dac 3 | size 3343 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fit_em.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:888cd70bf4f543325cbc42be80fba91befb75a605b543ca479312feac2a436f0 3 | size 937418 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fit_em_baseline.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b1a1a8b9b81cba0bc74a7d07f8ce1e21adccde66ee3e6630035e855aae68e7b2 3 | size 659547 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fit_grammar.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eddfb7f9cbc2a9c0cb7b43f04c99863c2267d790c6e82d34f2d727949122932c 3 | size 13609 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fit_grammar_baseline.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:806c6d9681b2f592d52d1147019ff48b2a27c7022ed9cc958d4f979a3ad02476 3 | size 12012 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fit_simple_grammar.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:53770d13e1330c3465bc1ef623ad3ce22d43f3ee8d2567a1e8f89e08be7a0ac4 3 | size 11948 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/fitting_history.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:37186ef70d7c58006e27191f7c5202f1b44b382d74a046c392998b67f1dfd4c7 3 | size 720809 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/grammar_baseline.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import lru_cache 4 | 5 | import torch 6 | from spatial_scene_grammars.nodes import * 7 | from spatial_scene_grammars.rules import * 8 | from spatial_scene_grammars.scene_grammar import * 9 | from spatial_scene_grammars.drake_interop import * 10 | from spatial_scene_grammars.constraints import * 11 | 12 | 13 | import pydrake 14 | import pydrake.geometry as pydrake_geom 15 | from pydrake.all import ( 16 | RollPitchYaw, 17 | RigidTransform 18 | ) 19 | 20 | ''' 21 | Same terminal set as in grammar.py, but 22 | without as much intermediate structure: 23 | TerminalCup, TerminalPlate, TerminalBowl gets 24 | its own distribution from the Sink, and that's it. 25 | ''' 26 | 27 | # Need full import path to match how these types are imported 28 | # in jupyter notebooks and saved out when pickling... seems dumb, 29 | # but don't change it. 30 | from spatial_scene_grammars_examples.dish_bin.grammar import ( 31 | TerminalPlate, TerminalCup, TerminalBowl 32 | ) 33 | 34 | 35 | class TerminalPlates(RepeatingSetNode): 36 | def __init__(self, tf): 37 | super().__init__( 38 | tf=tf, 39 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.6, max_children=6, start_at_one=True), 40 | physics_geometry_info=None, 41 | observed=False 42 | ) 43 | @classmethod 44 | def generate_rules(cls): 45 | return [ 46 | ProductionRule( 47 | child_type=TerminalPlate, 48 | xyz_rule=ParentFrameGaussianOffsetRule( 49 | mean=torch.tensor([0.0, 0.0, 0.00]), 50 | variance=torch.tensor([0.025, 0.025, 0.0001])), 51 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 52 | RotationMatrix(), np.array([1, 1, 1]) 53 | ) 54 | ) 55 | ] 56 | class TerminalCups(RepeatingSetNode): 57 | def __init__(self, tf): 58 | super().__init__( 59 | tf=tf, 60 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.6, max_children=6, start_at_one=True), 61 | physics_geometry_info=None, 62 | observed=False 63 | ) 64 | @classmethod 65 | def generate_rules(cls): 66 | return [ 67 | ProductionRule( 68 | child_type=TerminalCup, 69 | xyz_rule=ParentFrameGaussianOffsetRule( 70 | mean=torch.tensor([0.0, 0.0, 0.00]), 71 | variance=torch.tensor([0.025, 0.025, 0.0001])), 72 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 73 | RotationMatrix(), np.array([1, 1, 1]) 74 | ) 75 | ) 76 | ] 77 | class TerminalBowls(RepeatingSetNode): 78 | def __init__(self, tf): 79 | super().__init__( 80 | tf=tf, 81 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.6, max_children=6, start_at_one=True), 82 | physics_geometry_info=None, 83 | observed=False 84 | ) 85 | @classmethod 86 | def generate_rules(cls): 87 | return [ 88 | ProductionRule( 89 | child_type=TerminalBowl, 90 | xyz_rule=ParentFrameGaussianOffsetRule( 91 | mean=torch.tensor([0.0, 0.0, 0.00]), 92 | variance=torch.tensor([0.025, 0.025, 0.0001])), 93 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 94 | RotationMatrix(), np.array([1, 1, 1]) 95 | ) 96 | ) 97 | ] 98 | 99 | class DishBinBaseline(IndependentSetNode): 100 | def __init__(self, tf): 101 | geom = PhysicsGeometryInfo(fixed=True) 102 | geom.register_model_file(torch.eye(4), "sink/bin.sdf") 103 | super().__init__( 104 | tf=tf, 105 | rule_probs=torch.tensor([0.9, 0.9, 0.9]), 106 | physics_geometry_info=geom, 107 | observed=True 108 | ) 109 | @classmethod 110 | def generate_rules(cls): 111 | return [ 112 | ProductionRule( 113 | child_type=TerminalPlates, 114 | xyz_rule=SamePositionRule(), 115 | rotation_rule=SameRotationRule() 116 | ), 117 | ProductionRule( 118 | child_type=TerminalCups, 119 | xyz_rule=SamePositionRule(), 120 | rotation_rule=SameRotationRule() 121 | ), 122 | ProductionRule( 123 | child_type=TerminalBowls, 124 | xyz_rule=SamePositionRule(), 125 | rotation_rule=SameRotationRule() 126 | ), 127 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/parsing_statistics_runs.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c13b659a3d7f9a5ad839e87d5f5eeb44b189bed1b644920e315a090eb5d1fe59 3 | size 55102539 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/parsing_statistics_runs_pre_update.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:651af9bf4c604e1c2f2f400ddf99e70215cd75213448808f76fc498ef9daca92 3 | size 55122246 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/sink_dataset.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:632ebc3da37465b441a07437b4bdb664e6b26dea320a31d1ec84423fe038be0c 3 | size 72493057 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/sink_dataset_v2.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:024cdf53b4f16b73fda9214ad835ef1fa2e963594494b306821b6709d04c2e41 3 | size 72490727 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/test/test_sink_grammar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import os 6 | import pickle 7 | import time 8 | 9 | import logging 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | import torch 14 | torch.set_default_tensor_type(torch.DoubleTensor) 15 | 16 | from spatial_scene_grammars.nodes import * 17 | from spatial_scene_grammars.rules import * 18 | from spatial_scene_grammars.scene_grammar import * 19 | from spatial_scene_grammars.visualization import * 20 | from spatial_scene_grammars_examples.dish_bin.grammar import * 21 | from spatial_scene_grammars_examples.dish_bin.grammar_baseline import * 22 | from spatial_scene_grammars_examples.dish_bin.utils import get_observed_node_sets 23 | from spatial_scene_grammars.parsing import * 24 | from spatial_scene_grammars.sampling import * 25 | from spatial_scene_grammars.parameter_estimation import * 26 | from spatial_scene_grammars.dataset import * 27 | 28 | import meshcat 29 | import meshcat.geometry as meshcat_geom 30 | 31 | from pydrake.all import SnoptSolver 32 | 33 | @pytest.fixture(params=range(3)) 34 | def set_seed(request): 35 | pyro.set_rng_seed(request.param) 36 | 37 | def test_sampling(set_seed): 38 | vis = meshcat.Visualizer() 39 | 40 | # Draw a random sample from the grammar and visualize it. 41 | grammar = SpatialSceneGrammar( 42 | root_node_type = DishBin, 43 | root_node_tf = torch.eye(4) 44 | ) 45 | torch.random.manual_seed(42) 46 | tree = grammar.sample_tree() 47 | 48 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 49 | 50 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url) 51 | draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url) 52 | 53 | 54 | def test_sampling_baseline(set_seed): 55 | vis = meshcat.Visualizer() 56 | 57 | # Draw a random sample from the grammar and visualize it. 58 | grammar = SpatialSceneGrammar( 59 | root_node_type = DishBinBaseline, 60 | root_node_tf = torch.eye(4) 61 | ) 62 | torch.random.manual_seed(42) 63 | tree = grammar.sample_tree() 64 | 65 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 66 | 67 | 68 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 69 | reason='This test relies on Gurobi and SNOPT.') 70 | def test_parsing_mip(set_seed): 71 | # Try to parse an example of this grammar. 72 | grammar = SpatialSceneGrammar( 73 | root_node_type = DishBin, 74 | root_node_tf = torch.eye(4) 75 | ) 76 | torch.random.manual_seed(42) 77 | observed_tree = grammar.sample_tree(detach=True) 78 | observed_nodes = observed_tree.get_observed_nodes() 79 | 80 | inference_results = infer_mle_tree_with_mip( 81 | grammar, observed_nodes, verbose=True, 82 | max_scene_extent_in_any_dir=10. 83 | ) 84 | mip_optimized_tree = get_optimized_tree_from_mip_results(inference_results) 85 | refinement_results = optimize_scene_tree_with_nlp(grammar, mip_optimized_tree, verbose=True, 86 | max_scene_extent_in_any_dir=10.) 87 | refined_tree = refinement_results.refined_tree 88 | score = refined_tree.score(verbose=True) 89 | assert torch.isfinite(score), "Refined tree was infeasible." 90 | 91 | 92 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 93 | reason='This test relies on Gurobi and SNOPT.') 94 | def test_parsing_ip(set_seed): 95 | # Try to parse an example of this grammar. 96 | grammar = SpatialSceneGrammar( 97 | root_node_type = DishBin, 98 | root_node_tf = torch.eye(4) 99 | ) 100 | torch.random.manual_seed(42) 101 | observed_tree = grammar.sample_tree(detach=True) 102 | observed_nodes = observed_tree.get_observed_nodes() 103 | 104 | N_solutions = 3 105 | parse_trees = infer_mle_tree_with_mip_from_proposals( 106 | grammar, observed_nodes, {}, verbose=False, N_solutions=N_solutions, 107 | min_ll_for_consideration=-1000. 108 | ) 109 | assert len(parse_trees) > 0 and len(parse_trees) <= N_solutions 110 | refinement_results = optimize_scene_tree_with_nlp(grammar, parse_trees[0], verbose=True, 111 | max_scene_extent_in_any_dir=10.) 112 | refined_tree = refinement_results.refined_tree 113 | score = refined_tree.score(verbose=True) 114 | assert torch.isfinite(score), "Refined tree was infeasible." 115 | 116 | 117 | if __name__ == "__main__": 118 | pytest.main() -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/dish_bin/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import pickle 6 | import time 7 | import logging 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.nodes import * 13 | from spatial_scene_grammars.rules import * 14 | from spatial_scene_grammars.scene_grammar import * 15 | from spatial_scene_grammars.visualization import * 16 | from spatial_scene_grammars_examples.dish_bin.grammar import * 17 | from spatial_scene_grammars.parsing import * 18 | from spatial_scene_grammars.sampling import * 19 | from spatial_scene_grammars.parameter_estimation import * 20 | from spatial_scene_grammars.dataset import * 21 | 22 | DATASET_YAML_FILE = "sink/saved_scenes.yaml" 23 | DATASET_SAVE_FILE = "observed_node_sets.dat" 24 | OUTLIER_DATASET_YAML_FILE = "sink/saved_outlier_scenes.yaml" 25 | OUTLIER_DATASET_SAVE_FILE = "observed_outlier_node_sets.dat" 26 | 27 | def _load_dataset(yaml_file, save_file, reconvert=False): 28 | if reconvert or not os.path.exists(save_file): 29 | type_map = { 30 | "bin": DishBin 31 | } 32 | model_map = { 33 | } 34 | for model_type_set in [PlateModels, CupModels, BowlModels]: 35 | for model_type in model_type_set: 36 | # Have to cut off the "sink" folder to match model names; 37 | # dataset management is ugly and should get reorganized... 38 | model_map[os.path.join(*model_type.sdf.split("/")[1:])] = model_type 39 | observed_node_sets = convert_scenes_yaml_to_observed_nodes(yaml_file, type_map, model_map) 40 | logging.debug("Saving %s" % save_file) 41 | with open(save_file, "wb") as f: 42 | pickle.dump(observed_node_sets, f) 43 | 44 | logging.debug("Loading %s" % save_file) 45 | with open(save_file, "rb") as f: 46 | observed_node_sets = pickle.load(f) 47 | return observed_node_sets 48 | 49 | def get_observed_node_sets(reconvert=False): 50 | return _load_dataset(DATASET_YAML_FILE, DATASET_SAVE_FILE, reconvert=reconvert), \ 51 | _load_dataset(OUTLIER_DATASET_YAML_FILE, OUTLIER_DATASET_SAVE_FILE, reconvert=reconvert) 52 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/gmm/.gitattributes: -------------------------------------------------------------------------------- 1 | test_runs_with_restarts.pickle filter=lfs diff=lfs merge=lfs -text 2 | test_runs.pickle filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/gmm/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | from pydrake.all import ( 10 | RollPitchYaw, 11 | RigidTransform 12 | ) 13 | 14 | ''' 15 | Gaussian mixture model. This is a little funky, since it mixes 16 | the parameters (mode means + variances) with variables (the 17 | sampled points), but it should work for inferring GMM parameters. 18 | 19 | root --(rule implementing one of the mixtures)-> observed 20 | ''' 21 | 22 | class Point(TerminalNode): 23 | def __init__(self, tf): 24 | super().__init__( 25 | tf=tf, 26 | physics_geometry_info=None, 27 | observed=True 28 | ) 29 | 30 | class Root(OrNode): 31 | def __init__(self, tf): 32 | super().__init__( 33 | tf=tf, 34 | physics_geometry_info=None, 35 | observed=False, 36 | rule_probs=torch.ones(3) 37 | ) 38 | @classmethod 39 | def generate_rules(cls): 40 | return [ 41 | ProductionRule( 42 | child_type=Point, 43 | xyz_rule=WorldFrameGaussianOffsetRule( 44 | mean=torch.zeros(3), 45 | variance=torch.tensor([1.0, 1.0, 1.0])), 46 | rotation_rule=SameRotationRule() 47 | ) for k in range(3) 48 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/gmm/test_runs.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c891bbc9e47493e4032c807979f16241b4747914c145b3c1f44d80ff0e32b4ea 3 | size 3868174 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/gmm/test_runs_with_restarts.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:89e17dea58fb268372821cad162c4141c611a12b89709b5cdb432ca845a97770 3 | size 3382505 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/oriented_clusters/.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/oriented_clusters/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | 10 | ''' Root creates geometric number of oriented 11 | clusters with random rotations, with xyz distributed 12 | as a random normal. 13 | Each cluster produces a geometric number of long boxes with 14 | minor rotations around their non-long axes. ''' 15 | 16 | class LongBox(TerminalNode): 17 | def __init__(self, tf): 18 | geom = PhysicsGeometryInfo() 19 | geom.register_geometry( 20 | tf=torch.eye(4), 21 | geometry=pydrake_geom.Box(0.025, 0.025, 0.1), 22 | color=np.array([0.8, 0.5, 0.2, 1.0]) 23 | ) 24 | super().__init__( 25 | tf=tf, 26 | physics_geometry_info=geom, 27 | observed=True 28 | ) 29 | 30 | class OrientedCluster(RepeatingSetNode): 31 | def __init__(self, tf): 32 | super().__init__( 33 | tf=tf, 34 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.3, max_children=5), 35 | physics_geometry_info=None, 36 | observed=False 37 | ) 38 | @classmethod 39 | def generate_rules(cls): 40 | return [ProductionRule( 41 | child_type=LongBox, 42 | xyz_rule=ParentFrameGaussianOffsetRule( 43 | mean=torch.zeros(3), 44 | variance=torch.tensor([0.05, 0.05, 0.01]) 45 | ), 46 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 47 | RotationMatrix(RollPitchYaw(0., 0., 1.)), 48 | [100., 100., 10.] 49 | ) 50 | )] 51 | 52 | class OrientedClusterRoot(RepeatingSetNode): 53 | def __init__(self, tf): 54 | super().__init__( 55 | tf=tf, 56 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.3, max_children=3), 57 | physics_geometry_info=None, 58 | observed=True 59 | ) 60 | @classmethod 61 | def generate_rules(cls): 62 | return [ProductionRule( 63 | child_type=OrientedCluster, 64 | xyz_rule=ParentFrameGaussianOffsetRule( 65 | mean=torch.zeros(3), 66 | variance=torch.tensor([1.0, 1.0, 0.01]) 67 | ), 68 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 69 | RotationMatrix(RollPitchYaw(1., 0., 0.)), 70 | [0.1, 0.1, 0.1] 71 | ) 72 | )] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/oriented_clusters/grammar_with_extra_pathway.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | 10 | ''' Root creates geometric number of oriented 11 | clusters with random rotations, with xyz distributed 12 | as a random normal. 13 | Root can also sometimes produce boxes on their own, to account 14 | for outliers. 15 | Each cluster produces a geometric number of long boxes with 16 | minor rotations around their non-long axes. ''' 17 | 18 | class LongBox(TerminalNode): 19 | def __init__(self, tf): 20 | geom = PhysicsGeometryInfo() 21 | geom.register_geometry( 22 | tf=torch.eye(4), 23 | geometry=pydrake_geom.Box(0.025, 0.025, 0.1), 24 | color=np.array([0.8, 0.5, 0.2, 1.0]) 25 | ) 26 | super().__init__( 27 | tf=tf, 28 | physics_geometry_info=geom, 29 | observed=True 30 | ) 31 | 32 | class OrientedCluster(RepeatingSetNode): 33 | def __init__(self, tf): 34 | super().__init__( 35 | tf=tf, 36 | p=0.3, 37 | max_children=5, 38 | physics_geometry_info=None, 39 | observed=False 40 | ) 41 | @classmethod 42 | def generate_rules(cls): 43 | return [ProductionRule( 44 | child_type=LongBox, 45 | xyz_rule=SamePositionRule(),#WorldFrameGaussianOffsetRule( 46 | #mean=torch.zeros(3), 47 | #variance=torch.tensor([0.05, 0.05, 0.001]) 48 | #), 49 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 50 | RotationMatrix(RollPitchYaw(0., 0., 1.)), 51 | [1000., 1000., 1.] 52 | ) 53 | )] 54 | 55 | class AssortedOrientedClusters(RepeatingSetNode): 56 | def __init__(self, tf): 57 | super().__init__( 58 | tf=tf, 59 | p=0.3, 60 | max_children=3, 61 | physics_geometry_info=None, 62 | observed=True 63 | ) 64 | @classmethod 65 | def generate_rules(cls): 66 | return [ProductionRule( 67 | child_type=OrientedCluster, 68 | xyz_rule=SamePositionRule(),#WorldFrameGaussianOffsetRule( 69 | #mean=torch.zeros(3), 70 | #variance=torch.tensor([1.0, 1.0, 1.0]) 71 | #), 72 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 73 | RotationMatrix(RollPitchYaw(1., 0., 0.)), 74 | [0.1, 0.1, 0.1] 75 | ) 76 | )] 77 | 78 | 79 | class AssortedOutliers(RepeatingSetNode): 80 | def __init__(self, tf): 81 | super().__init__( 82 | tf=tf, 83 | p=0.3, 84 | max_children=3, 85 | physics_geometry_info=None, 86 | observed=True 87 | ) 88 | @classmethod 89 | def generate_rules(cls): 90 | return [ProductionRule( 91 | child_type=LongBox, 92 | xyz_rule=SamePositionRule(),#WorldFrameGaussianOffsetRule( 93 | #mean=torch.zeros(3), 94 | #variance=torch.tensor([1.0, 1.0, 1.0]) 95 | #), 96 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 97 | RotationMatrix(RollPitchYaw(1., 0., 0.)), 98 | [0.1, 0.1, 0.1] 99 | ) 100 | )] 101 | 102 | class OrientedClusterRoot(IndependentSetNode): 103 | def __init__(self, tf): 104 | super().__init__( 105 | tf=tf, 106 | rule_probs=torch.tensor([0.8, 0.8]), 107 | physics_geometry_info=None, 108 | observed=True 109 | ) 110 | @classmethod 111 | def generate_rules(cls): 112 | return [ 113 | ProductionRule( 114 | child_type=AssortedOutliers, 115 | xyz_rule=SamePositionRule(), 116 | rotation_rule=SameRotationRule() 117 | ), 118 | ProductionRule( 119 | child_type=AssortedOrientedClusters, 120 | xyz_rule=SamePositionRule(), 121 | rotation_rule=SameRotationRule() 122 | ) 123 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/oriented_clusters/test/test_oriented_clusters_example.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.nodes import * 13 | from spatial_scene_grammars.rules import * 14 | from spatial_scene_grammars.scene_grammar import * 15 | from spatial_scene_grammars.visualization import * 16 | from spatial_scene_grammars_examples.oriented_clusters.grammar import * 17 | from spatial_scene_grammars.parsing import * 18 | 19 | import meshcat 20 | import meshcat.geometry as meshcat_geom 21 | 22 | from pydrake.all import SnoptSolver 23 | 24 | def test_sampling(): 25 | vis = meshcat.Visualizer() 26 | 27 | # Draw a random sample from the grammar and visualize it. 28 | grammar = SpatialSceneGrammar( 29 | root_node_type = OrientedCluster, 30 | root_node_tf = torch.eye(4) 31 | ) 32 | torch.random.manual_seed(42) 33 | tree = grammar.sample_tree() 34 | 35 | assert len(tree.find_nodes_by_type(OrientedCluster)) > 0, "Didn't sample any clusters." 36 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 37 | 38 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url) 39 | draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url) 40 | 41 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 42 | reason='This test relies on Gurobi and SNOPT.') 43 | def test_parsing(): 44 | # Try to parse an example of this grammar. 45 | grammar = SpatialSceneGrammar( 46 | root_node_type = OrientedCluster, 47 | root_node_tf = torch.eye(4) 48 | ) 49 | torch.random.manual_seed(42) 50 | observed_tree = grammar.sample_tree(detach=True) 51 | observed_nodes = observed_tree.get_observed_nodes() 52 | 53 | inference_results = infer_mle_tree_with_mip( 54 | grammar, observed_nodes, verbose=True, 55 | ) 56 | mip_optimized_tree = get_optimized_tree_from_mip_results(inference_results) 57 | refinement_results = optimize_scene_tree_with_nlp(grammar, mip_optimized_tree, verbose=True) 58 | refined_tree = refinement_results.refined_tree 59 | score = refined_tree.score(verbose=True) 60 | assert torch.isfinite(score), "Refined tree was infeasible." 61 | 62 | if __name__ == "__main__": 63 | pytest.main() 64 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/packages/.gitattributes: -------------------------------------------------------------------------------- 1 | boxes.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/packages/.gitignore: -------------------------------------------------------------------------------- 1 | boxes 2 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/packages/all_observable_grammar.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import lru_cache 4 | 5 | import torch 6 | from spatial_scene_grammars.nodes import * 7 | from spatial_scene_grammars.rules import * 8 | from spatial_scene_grammars.scene_grammar import * 9 | from spatial_scene_grammars.drake_interop import * 10 | 11 | import pydrake 12 | import pydrake.geometry as pydrake_geom 13 | from pydrake.all import ( 14 | RollPitchYaw, 15 | RigidTransform 16 | ) 17 | 18 | ''' 19 | Grammar describing small piles and clusters of packages. 20 | 21 | Simple structure -- mostly here to prototype sampling code. 22 | 23 | Root -> [AssortedStacks, AssortedBoxes] 24 | AssortedStacks -> Stacks x N 25 | Stack -> Boxes x N, vertically 26 | AssortedBoxes -> Boxes x N 27 | Box -> Reified Box Geometry 28 | ''' 29 | 30 | eps = 1E-2 31 | 32 | class ObjectModel(TerminalNode): 33 | sdf = None 34 | def __init__(self, tf): 35 | assert self.sdf is not None, "Don't instantiate ObjectModel itself; use a reified version." 36 | geom = PhysicsGeometryInfo(fixed=False) 37 | geom.register_model_file( 38 | drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., 0.])), 39 | self.sdf 40 | ) 41 | super().__init__( 42 | tf=tf, 43 | physics_geometry_info=geom, 44 | observed=True 45 | ) 46 | 47 | # Use caching to try to avoid re-generating objects, which messes with pickling 48 | # and node identity checking. 49 | @lru_cache(maxsize=None) 50 | def reify_models_from_folder_to_object_types(folder): 51 | print("Generating from folder ", folder) 52 | sdfs = glob.glob(os.path.join(folder, "*/box.sdf"))[:5] 53 | new_types = [] 54 | for sdf in sdfs: 55 | new_types.append( 56 | type( 57 | "%s" % os.path.split(sdf)[0].replace("/", ":"), 58 | (ObjectModel,), 59 | {"sdf": sdf} 60 | ) 61 | ) 62 | # Make these dynamically generated types pickle-able 63 | # by registering them globally. Dangerous -- make sure 64 | # their names are unique! 65 | # https://stackoverflow.com/questions/11658511/pickling-dynamically-generated-classes 66 | for new_type in new_types: 67 | print(new_type.__name__, ": ", new_type) 68 | globals()[new_type.__name__] = new_type 69 | 70 | return new_types 71 | 72 | 73 | BoxModels = reify_models_from_folder_to_object_types( 74 | "boxes" 75 | ) 76 | 77 | class BoxGroup(AndNode): 78 | def __init__(self, tf): 79 | geom = PhysicsGeometryInfo(fixed=True) 80 | geom.register_geometry( 81 | drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., -1.0])), 82 | pydrake_geom.Box(20., 20., 2.) 83 | ) 84 | super().__init__( 85 | tf=tf, 86 | physics_geometry_info=geom, 87 | observed=True 88 | ) 89 | @classmethod 90 | def generate_rules(cls): 91 | return [ 92 | ProductionRule( 93 | child_type=box_model, 94 | xyz_rule=WorldFrameGaussianOffsetRule( 95 | mean=torch.tensor([0.0, 0.0, 0.5*k]), 96 | variance=torch.tensor([0.01, 0.01, 0.01])), 97 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 98 | RotationMatrix(), [1., 1., 1] 99 | ) 100 | ) 101 | for k, box_model in enumerate(BoxModels) 102 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/packages/boxes.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:72aa4701a8be4c247e54a1bf4d0db14971e8affe938c395a675c7a92543afe56 3 | size 16876537 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/packages/grammar.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import lru_cache 4 | 5 | import torch 6 | from spatial_scene_grammars.nodes import * 7 | from spatial_scene_grammars.rules import * 8 | from spatial_scene_grammars.scene_grammar import * 9 | from spatial_scene_grammars.drake_interop import * 10 | 11 | import pydrake 12 | import pydrake.geometry as pydrake_geom 13 | from pydrake.all import ( 14 | RollPitchYaw, 15 | RigidTransform 16 | ) 17 | 18 | ''' 19 | Grammar describing small piles and clusters of packages. 20 | 21 | Simple structure -- mostly here to prototype sampling code. 22 | 23 | Root -> [AssortedStacks, AssortedBoxes] 24 | AssortedStacks -> Stacks x N 25 | Stack -> Boxes x N, vertically 26 | AssortedBoxes -> Boxes x N 27 | Box -> Reified Box Geometry 28 | ''' 29 | 30 | eps = 1E-2 31 | 32 | class ObjectModel(TerminalNode): 33 | sdf = None 34 | def __init__(self, tf): 35 | assert self.sdf is not None, "Don't instantiate ObjectModel itself; use a reified version." 36 | geom = PhysicsGeometryInfo(fixed=False) 37 | geom.register_model_file( 38 | drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., 0.])), 39 | self.sdf 40 | ) 41 | super().__init__( 42 | tf=tf, 43 | physics_geometry_info=geom, 44 | observed=True 45 | ) 46 | 47 | # Use caching to try to avoid re-generating objects, which messes with pickling 48 | # and node identity checking. 49 | @lru_cache(maxsize=None) 50 | def reify_models_from_folder_to_object_types(folder): 51 | print("Generating from folder ", folder) 52 | sdfs = glob.glob(os.path.join(folder, "*/box.sdf")) 53 | new_types = [] 54 | for sdf in sdfs: 55 | new_types.append( 56 | type( 57 | "%s" % os.path.split(sdf)[0].replace("/", ":"), 58 | (ObjectModel,), 59 | {"sdf": sdf} 60 | ) 61 | ) 62 | # Make these dynamically generated types pickle-able 63 | # by registering them globally. Dangerous -- make sure 64 | # their names are unique! 65 | # https://stackoverflow.com/questions/11658511/pickling-dynamically-generated-classes 66 | for new_type in new_types: 67 | print(new_type.__name__, ": ", new_type) 68 | globals()[new_type.__name__] = new_type 69 | 70 | return new_types 71 | 72 | 73 | BoxModels = reify_models_from_folder_to_object_types( 74 | "boxes" 75 | ) 76 | class Box(OrNode): 77 | # One of any available box model. 78 | def __init__(self, tf): 79 | super().__init__( 80 | rule_probs=torch.ones(len(BoxModels)), 81 | tf=tf, 82 | physics_geometry_info=None, 83 | observed=False 84 | ) 85 | @classmethod 86 | def generate_rules(cls): 87 | ModelRules = [ 88 | ProductionRule( 89 | child_type=model_type, 90 | xyz_rule=SamePositionRule(), 91 | rotation_rule=SameRotationRule() 92 | ) for model_type in BoxModels 93 | ] 94 | return ModelRules 95 | 96 | 97 | class MaybeStackBox(IndependentSetNode): 98 | def __init__(self, tf): 99 | super().__init__( 100 | rule_probs=torch.tensor([0.5]), 101 | tf=tf, 102 | physics_geometry_info=None, 103 | observed=False 104 | ) 105 | @classmethod 106 | def generate_rules(cls): 107 | ModelRules = [ 108 | ProductionRule( 109 | child_type=BoxAndStackBox, 110 | xyz_rule=WorldFrameGaussianOffsetRule( 111 | mean=torch.tensor([0.0, 0.0, 0.2]), 112 | variance=torch.tensor([0.01, 0.01, 0.05])), 113 | # Assume world-frame vertically-oriented plate stacks 114 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 115 | RotationMatrix(), [100., 100., 0.1] 116 | ) 117 | ) 118 | ] 119 | return ModelRules 120 | 121 | class BoxAndStackBox(AndNode): 122 | def __init__(self, tf): 123 | super().__init__( 124 | tf=tf, 125 | physics_geometry_info=None, 126 | observed=False 127 | ) 128 | @classmethod 129 | def generate_rules(cls): 130 | ModelRules = [ 131 | ProductionRule( 132 | child_type=Box, 133 | xyz_rule=SamePositionRule(), 134 | rotation_rule=SameRotationRule() 135 | ), 136 | ProductionRule( 137 | child_type=MaybeStackBox, 138 | xyz_rule=SamePositionRule(), 139 | rotation_rule=SameRotationRule() 140 | ) 141 | ] 142 | return ModelRules 143 | 144 | class AssortedBoxes(RepeatingSetNode): 145 | def __init__(self, tf): 146 | super().__init__( 147 | tf=tf, 148 | p=0.3, 149 | max_children=6, 150 | physics_geometry_info=None, 151 | observed=False 152 | ) 153 | @classmethod 154 | def generate_rules(cls): 155 | rule = ProductionRule( 156 | child_type=BoxAndStackBox, 157 | xyz_rule=WorldFrameGaussianOffsetRule( 158 | mean=torch.tensor([0.0, 0.0, 0.1]), 159 | variance=torch.tensor([0.25, 0.25, 0.001])), 160 | rotation_rule=WorldFrameBinghamRotationRule(torch.eye(4), torch.tensor([-1, -1, -1, 0.])) 161 | ) 162 | return [rule] 163 | 164 | class BoxGroup(AndNode): 165 | def __init__(self, tf): 166 | geom = PhysicsGeometryInfo(fixed=True) 167 | geom.register_geometry( 168 | drake_tf_to_torch_tf(RigidTransform(p=[0.0, 0., -1.0])), 169 | pydrake_geom.Box(20., 20., 2.) 170 | ) 171 | super().__init__( 172 | tf=tf, 173 | physics_geometry_info=geom, 174 | observed=True 175 | ) 176 | @classmethod 177 | def generate_rules(cls): 178 | return [ 179 | ProductionRule( 180 | child_type=AssortedBoxes, 181 | xyz_rule=SamePositionRule(), 182 | rotation_rule=SameRotationRule() 183 | ) 184 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters/.gitignore: -------------------------------------------------------------------------------- 1 | *.dat 2 | .ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters/test/test_planar_clusters_example.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.nodes import * 13 | from spatial_scene_grammars.rules import * 14 | from spatial_scene_grammars.scene_grammar import * 15 | from spatial_scene_grammars.visualization import * 16 | from spatial_scene_grammars_examples.planar_clusters.grammar import * 17 | from spatial_scene_grammars.parsing import * 18 | 19 | import meshcat 20 | import meshcat.geometry as meshcat_geom 21 | 22 | from pydrake.all import SnoptSolver 23 | 24 | def test_sampling(): 25 | vis = meshcat.Visualizer() 26 | 27 | # Draw a random sample from the grammar and visualize it. 28 | grammar = SpatialSceneGrammar( 29 | root_node_type = Desk, 30 | root_node_tf = torch.eye(4) 31 | ) 32 | torch.random.manual_seed(42) 33 | tree = grammar.sample_tree() 34 | 35 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 36 | 37 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url) 38 | draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url) 39 | 40 | 41 | @pytest.mark.skip(reason="This test times out. Parsing this grammar is hard now. What changed?") 42 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 43 | reason='This test relies on Gurobi and SNOPT.') 44 | def test_parsing(): 45 | # Try to parse an example of this grammar. 46 | grammar = SpatialSceneGrammar( 47 | root_node_type = Desk, 48 | root_node_tf = torch.eye(4) 49 | ) 50 | torch.random.manual_seed(42) 51 | observed_tree = grammar.sample_tree(detach=True) 52 | observed_nodes = observed_tree.get_observed_nodes() 53 | 54 | inference_results = infer_mle_tree_with_mip( 55 | grammar, observed_nodes, verbose=True, 56 | max_scene_extent_in_any_dir=10. 57 | ) 58 | mip_optimized_tree = get_optimized_tree_from_mip_results(inference_results) 59 | refinement_results = optimize_scene_tree_with_nlp(grammar, mip_optimized_tree, verbose=True, 60 | max_scene_extent_in_any_dir=10.) 61 | refined_tree = refinement_results.refined_tree 62 | score = refined_tree.score(verbose=True) 63 | assert torch.isfinite(score), "Refined tree was infeasible." 64 | 65 | if __name__ == "__main__": 66 | pytest.main() 67 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/.gitignore: -------------------------------------------------------------------------------- 1 | *.dat 2 | .ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/do_profile_vi.sh: -------------------------------------------------------------------------------- 1 | python -m cProfile -o profile_results.dat profile_vi.py -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | from pydrake.all import ( 10 | RollPitchYaw, 11 | RigidTransform 12 | ) 13 | 14 | ''' Simple grammar that flexes most of the rule 15 | types while keeping all objects in the XY plane for 16 | easy visualization. 17 | 18 | Themed like clutter on a desk: a desk produces 19 | clusters of stuff on it. Clusters can be clusters 20 | of food waste, papers, or pencils. Each cluster 21 | has a geometric number of stuff distributed 22 | locally. 23 | ''' 24 | 25 | eps = 1E-2 26 | 27 | ## Food waste 28 | class Plate(TerminalNode): 29 | def __init__(self, tf): 30 | geom = PhysicsGeometryInfo() 31 | geom.register_geometry( 32 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .025/2.])), 33 | geometry=pydrake_geom.Cylinder(radius=0.1, length=0.025), 34 | color=np.array([0.8, 0.5, 0.2, 1.0]) 35 | ) 36 | super().__init__( 37 | tf=tf, 38 | physics_geometry_info=geom, 39 | observed=True 40 | ) 41 | 42 | class Drink(TerminalNode): 43 | def __init__(self, tf): 44 | geom = PhysicsGeometryInfo() 45 | geom.register_geometry( 46 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .1/2.])), 47 | geometry=pydrake_geom.Cylinder(radius=0.05, length=0.1), 48 | color=np.array([0.3, 0.8, 0.5, 1.0]) 49 | ) 50 | super().__init__( 51 | tf=tf, 52 | physics_geometry_info=geom, 53 | observed=True 54 | ) 55 | 56 | 57 | class FoodWasteCluster(IndependentSetNode): 58 | # Might make either a plate or soda can, 59 | # the plate might make more stuff on it. 60 | Default_Rule_Probs = torch.tensor([0.5, 0.8]) 61 | def __init__(self, tf): 62 | super().__init__( 63 | rule_probs=self.Default_Rule_Probs, 64 | tf=tf, 65 | physics_geometry_info=None, 66 | observed=False 67 | ) 68 | @classmethod 69 | def generate_rules(cls): 70 | Stuff = [Plate, Drink] 71 | Rules = [ 72 | ProductionRule( 73 | child_type=stuff, 74 | xyz_rule=WorldFramePlanarGaussianOffsetRule( 75 | mean=torch.tensor([0.0, 0.0]), 76 | variance=torch.tensor([0.05, 0.05]), 77 | plane_transform=RigidTransform()), 78 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 79 | RotationMatrix(RollPitchYaw(0., 0., 1.)), [1000., 1000., 1.]) 80 | ) for stuff in Stuff 81 | ] 82 | return Rules 83 | 84 | ## Paper stack 85 | class Paper(TerminalNode): 86 | def __init__(self, tf): 87 | geom = PhysicsGeometryInfo() 88 | rgba = np.random.uniform([0.85, 0.85, 0.85, 1.0], [0.95, 0.95, 0.95, 1.0]) 89 | geom.register_geometry( 90 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .01/2.])), 91 | geometry=pydrake_geom.Box(0.2159, 0.2794, 0.01), # 8.5" x 11" 92 | color=rgba 93 | ) 94 | super().__init__( 95 | tf=tf, 96 | physics_geometry_info=geom, 97 | observed=True 98 | ) 99 | 100 | class PaperCluster(RepeatingSetNode): 101 | # Make a stack of papers 102 | def __init__(self, tf): 103 | super().__init__( 104 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.3, max_children=3), 105 | tf=tf, 106 | physics_geometry_info=None, 107 | observed=False 108 | ) 109 | @classmethod 110 | def generate_rules(cls): 111 | return [ProductionRule( 112 | child_type=Paper, 113 | xyz_rule=WorldFramePlanarGaussianOffsetRule( 114 | mean=torch.tensor([0.0, 0.0]), 115 | variance=torch.tensor([0.05, 0.05]), 116 | plane_transform=RigidTransform()), 117 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 118 | RotationMatrix(RollPitchYaw(0., 0., 1.)), [1000., 1000., 100.]) 119 | )] 120 | 121 | ## Pencils 122 | class Pencil(TerminalNode): 123 | def __init__(self, tf): 124 | geom = PhysicsGeometryInfo() 125 | rgba = np.random.uniform([0.85, 0.75, 0.45, 1.0], [0.95, 0.85, 0.55, 1.0]) 126 | geom.register_geometry( 127 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .01/2.], rpy=RollPitchYaw(0., np.pi/2., 0.))), 128 | geometry=pydrake_geom.Cylinder(radius=0.01, length=0.15), 129 | color=rgba, 130 | ) 131 | super().__init__( 132 | tf=tf, 133 | physics_geometry_info=geom, 134 | observed=True 135 | ) 136 | 137 | class PencilCluster(RepeatingSetNode): 138 | # Make a geometric cluster of roughly-aligned pencils 139 | def __init__(self, tf): 140 | super().__init__( 141 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.5, max_children=3), 142 | tf=tf, 143 | physics_geometry_info=None, 144 | observed=False 145 | ) 146 | @classmethod 147 | def generate_rules(cls): 148 | return [ProductionRule( 149 | child_type=Pencil, 150 | xyz_rule=WorldFramePlanarGaussianOffsetRule( 151 | mean=torch.tensor([0.0, 0.0]), 152 | variance=torch.tensor([0.01, 0.005]), 153 | plane_transform=RigidTransform()), 154 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 155 | RotationMatrix(RollPitchYaw(0., 0., 1.)), [1000., 1000., 1000.]) 156 | )] 157 | 158 | ## Desk and abstract cluster 159 | class ObjectCluster(OrNode): 160 | # Specialize into a type of cluster 161 | DefaultClusterTypeWeights = torch.tensor([1.0, 1.0, 1.0]) 162 | def __init__(self, tf): 163 | super().__init__( 164 | rule_probs=self.DefaultClusterTypeWeights, 165 | tf=tf, 166 | physics_geometry_info=None, 167 | observed=False 168 | ) 169 | 170 | @classmethod 171 | def generate_rules(cls): 172 | ClusterTypes = [FoodWasteCluster, PaperCluster, PencilCluster] 173 | ClusterRules = [ 174 | ProductionRule( 175 | child_type=cluster_type, 176 | xyz_rule=SamePositionRule(), 177 | rotation_rule=SameRotationRule() 178 | ) for cluster_type in ClusterTypes 179 | ] 180 | return ClusterRules 181 | 182 | class Desk(RepeatingSetNode): 183 | # Make geometric # of object clusters 184 | desk_size=[2., 2.] 185 | def __init__(self, tf): 186 | geom = PhysicsGeometryInfo() 187 | geom.register_geometry( 188 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., -0.5])), 189 | geometry=pydrake_geom.Box(self.desk_size[0], self.desk_size[1], 1.0), 190 | color=np.array([0.3, 0.2, 0.2, 1.0]) 191 | ) 192 | super().__init__( 193 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.2, max_children=6), 194 | tf=tf, 195 | physics_geometry_info=geom, 196 | observed=True 197 | ) 198 | @classmethod 199 | def generate_rules(cls): 200 | lb = torch.tensor([0.2, 0.2, 0.0]) 201 | ub = torch.tensor([cls.desk_size[0] - 0.2, cls.desk_size[1] - 0.2, 0.0]) 202 | rule = ProductionRule( 203 | child_type=ObjectCluster, 204 | xyz_rule=WorldFramePlanarGaussianOffsetRule( 205 | mean=torch.tensor([0.0, 0.0]), 206 | variance=torch.tensor([0.2, 0.2]), 207 | plane_transform=RigidTransform()), 208 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 209 | RotationMatrix(RollPitchYaw(0., 0., 1.)), [1000., 1000., 1.]) 210 | ) 211 | return [rule] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/print_profile_results.py: -------------------------------------------------------------------------------- 1 | import pstats 2 | p = pstats.Stats('profile_results.dat') 3 | p.strip_dirs().sort_stats(-1).print_stats() 4 | p.sort_stats('time').print_stats(100) -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/profile_vi.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import os 4 | import pickle 5 | import time 6 | from tqdm import tqdm 7 | import torch 8 | torch.set_default_tensor_type(torch.DoubleTensor) 9 | 10 | from spatial_scene_grammars.nodes import * 11 | from spatial_scene_grammars.rules import * 12 | from spatial_scene_grammars.scene_grammar import * 13 | from spatial_scene_grammars.visualization import * 14 | from spatial_scene_grammars_examples.planar_clusters_gaussians.grammar import * 15 | from spatial_scene_grammars.parsing import * 16 | from spatial_scene_grammars.sampling import * 17 | from spatial_scene_grammars.parameter_estimation import * 18 | 19 | # Sample a dataset of scenes from the default grammar params. 20 | # Draw a random sample from the grammar and visualize it. 21 | # (Cache output.) 22 | torch.random.manual_seed(2) 23 | N_samples = 5 24 | 25 | ground_truth_grammar = SpatialSceneGrammar( 26 | root_node_type = Desk, 27 | root_node_tf = torch.eye(4) 28 | ) 29 | 30 | samples = [] 31 | for k in tqdm(range(N_samples)): 32 | tree = ground_truth_grammar.sample_tree(detach=True) 33 | observed_nodes = tree.get_observed_nodes() 34 | samples.append((tree, observed_nodes)) 35 | 36 | observed_node_sets = [x[1] for x in samples] 37 | 38 | # Randomly reset parameters and try to recover them. 39 | torch.random.manual_seed(42) 40 | random_grammar = SpatialSceneGrammar( 41 | root_node_type = Desk, 42 | root_node_tf = torch.eye(4), 43 | sample_params_from_prior=True 44 | ) 45 | svi = SVIWrapper(random_grammar, [sample[1] for sample in samples]) 46 | svi.do_iterated_vi_fitting(major_iterations=2, minor_iterations=5, tqdm=tqdm, num_elbo_samples=5) -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_gaussians/test/test_planar_clusters_gaussians_example.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.nodes import * 13 | from spatial_scene_grammars.rules import * 14 | from spatial_scene_grammars.scene_grammar import * 15 | from spatial_scene_grammars.visualization import * 16 | from spatial_scene_grammars_examples.planar_clusters_gaussians.grammar import * 17 | from spatial_scene_grammars.parsing import * 18 | 19 | import meshcat 20 | import meshcat.geometry as meshcat_geom 21 | 22 | from pydrake.all import SnoptSolver 23 | 24 | def test_sampling(): 25 | vis = meshcat.Visualizer() 26 | 27 | # Draw a random sample from the grammar and visualize it. 28 | grammar = SpatialSceneGrammar( 29 | root_node_type = Desk, 30 | root_node_tf = torch.eye(4) 31 | ) 32 | torch.random.manual_seed(42) 33 | tree = grammar.sample_tree() 34 | 35 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 36 | 37 | draw_scene_tree_contents_meshcat(tree, zmq_url=vis.window.zmq_url) 38 | draw_scene_tree_structure_meshcat(tree, zmq_url=vis.window.zmq_url) 39 | 40 | 41 | 42 | @pytest.mark.skipif(os.environ.get('GUROBI_PATH') is None or not SnoptSolver().available(), 43 | reason='This test relies on Gurobi and SNOPT.') 44 | def test_parsing(): 45 | # Try to parse an example of this grammar. 46 | grammar = SpatialSceneGrammar( 47 | root_node_type = Desk, 48 | root_node_tf = torch.eye(4) 49 | ) 50 | torch.random.manual_seed(42) 51 | observed_tree = grammar.sample_tree(detach=True) 52 | observed_nodes = observed_tree.get_observed_nodes() 53 | 54 | inference_results = infer_mle_tree_with_mip( 55 | grammar, observed_nodes, verbose=True, 56 | max_scene_extent_in_any_dir=10. 57 | ) 58 | assert inference_results.optim_result.is_success() 59 | mip_optimized_tree = get_optimized_tree_from_mip_results(inference_results) 60 | refinement_results = optimize_scene_tree_with_nlp(grammar, mip_optimized_tree, verbose=True, 61 | max_scene_extent_in_any_dir=10.) 62 | assert refinement_results.optim_result.is_success() 63 | refined_tree = refinement_results.refined_tree 64 | score = refined_tree.score(verbose=True) 65 | assert torch.isfinite(score), "Refined tree was infeasible." 66 | 67 | if __name__ == "__main__": 68 | pytest.main() 69 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_clusters_no_rotation/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | from pydrake.all import ( 10 | RollPitchYaw, 11 | RigidTransform 12 | ) 13 | ''' Simple grammar that flexes most of the rule 14 | types while keeping all objects in the XY plane for 15 | easy visualization. 16 | 17 | Themed like clutter on a desk: a desk produces 18 | clusters of stuff on it. Clusters can be clusters 19 | of food waste, papers, or pencils. Each cluster 20 | has a geometric number of stuff distributed 21 | locally. 22 | ''' 23 | 24 | ## Food waste 25 | class Plate(TerminalNode): 26 | def __init__(self, tf): 27 | geom = PhysicsGeometryInfo() 28 | geom.register_geometry( 29 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .025/2.])), 30 | geometry=pydrake_geom.Cylinder(radius=0.1, length=0.025), 31 | color=np.array([0.8, 0.5, 0.2, 1.0]) 32 | ) 33 | super().__init__( 34 | tf=tf, 35 | physics_geometry_info=geom, 36 | observed=True 37 | ) 38 | 39 | class Drink(TerminalNode): 40 | def __init__(self, tf): 41 | geom = PhysicsGeometryInfo() 42 | geom.register_geometry( 43 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .1/2.])), 44 | geometry=pydrake_geom.Cylinder(radius=0.05, length=0.1), 45 | color=np.array([0.3, 0.8, 0.5, 1.0]) 46 | ) 47 | super().__init__( 48 | tf=tf, 49 | physics_geometry_info=geom, 50 | observed=True 51 | ) 52 | 53 | 54 | class FoodWasteCluster(IndependentSetNode): 55 | # Might make either a plate or soda can, 56 | # the plate might make more stuff on it. 57 | Default_Rule_Probs = torch.tensor([0.5, 0.8]) 58 | def __init__(self, tf): 59 | super().__init__( 60 | rule_probs=self.Default_Rule_Probs, 61 | tf=tf, 62 | physics_geometry_info=None, 63 | observed=False 64 | ) 65 | @classmethod 66 | def generate_rules(cls): 67 | Stuff = [Plate, Drink] 68 | Rules = [ 69 | ProductionRule( 70 | child_type=stuff, 71 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds( 72 | lb=torch.tensor([-0.2, -0.2, 0.0]), 73 | ub=torch.tensor([0.2, 0.2, 0.0]) 74 | ), 75 | rotation_rule=SameRotationRule() 76 | ) for stuff in Stuff 77 | ] 78 | return Rules 79 | 80 | ## Paper stack 81 | class Paper(TerminalNode): 82 | def __init__(self, tf): 83 | geom = PhysicsGeometryInfo() 84 | rgba = np.random.uniform([0.85, 0.85, 0.85, 1.0], [0.95, 0.95, 0.95, 1.0]) 85 | geom.register_geometry( 86 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .01/2.])), 87 | geometry=pydrake_geom.Box(0.2159, 0.2794, 0.01), # 8.5" x 11" 88 | color=rgba 89 | ) 90 | super().__init__( 91 | tf=tf, 92 | physics_geometry_info=geom, 93 | observed=True 94 | ) 95 | 96 | class PaperCluster(RepeatingSetNode): 97 | # Make a stack of papers 98 | def __init__(self, tf): 99 | super().__init__( 100 | p=0.3, 101 | max_children=3, 102 | tf=tf, 103 | physics_geometry_info=None, 104 | observed=False 105 | ) 106 | @classmethod 107 | def generate_rules(cls): 108 | return [ProductionRule( 109 | child_type=Paper, 110 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds( 111 | lb=torch.tensor([-0.05, -0.05, 0.0]), 112 | ub=torch.tensor([0.05, 0.05, 0.0]) 113 | ), 114 | rotation_rule=SameRotationRule() 115 | )] 116 | 117 | ## Pencils 118 | class Pencil(TerminalNode): 119 | def __init__(self, tf): 120 | geom = PhysicsGeometryInfo() 121 | rgba = np.random.uniform([0.85, 0.75, 0.45, 1.0], [0.95, 0.85, 0.55, 1.0]) 122 | geom.register_geometry( 123 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., .01/2.], rpy=RollPitchYaw(0., np.pi/2., 0.))), 124 | geometry=pydrake_geom.Cylinder(radius=0.01, length=0.15), 125 | color=rgba, 126 | ) 127 | super().__init__( 128 | tf=tf, 129 | physics_geometry_info=geom, 130 | observed=True 131 | ) 132 | 133 | class PencilCluster(RepeatingSetNode): 134 | # Make a geometric cluster of roughly-aligned pencils 135 | def __init__(self, tf): 136 | super().__init__( 137 | p=0.5, 138 | max_children=3, 139 | tf=tf, 140 | physics_geometry_info=None, 141 | observed=False 142 | ) 143 | @classmethod 144 | def generate_rules(cls): 145 | return [ProductionRule( 146 | child_type=Pencil, 147 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds( 148 | lb=torch.tensor([-0.05, -0.05, 0.0]), 149 | ub=torch.tensor([0.05, 0.05, 0.0]) 150 | ), 151 | rotation_rule=SameRotationRule() 152 | )] 153 | 154 | ## Desk and abstract cluster 155 | class ObjectCluster(OrNode): 156 | # Specialize into a type of cluster 157 | DefaultClusterTypeWeights = torch.tensor([1.0, 1.0, 1.0]) 158 | def __init__(self, tf): 159 | super().__init__( 160 | rule_probs=self.DefaultClusterTypeWeights, 161 | tf=tf, 162 | physics_geometry_info=None, 163 | observed=False 164 | ) 165 | 166 | @classmethod 167 | def generate_rules(cls): 168 | ClusterTypes = [FoodWasteCluster, PaperCluster, PencilCluster] 169 | ClusterRules = [ 170 | ProductionRule( 171 | child_type=cluster_type, 172 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds( 173 | lb=torch.zeros(3), 174 | ub=torch.zeros(3) 175 | ), 176 | rotation_rule=SameRotationRule() 177 | ) for cluster_type in ClusterTypes 178 | ] 179 | return ClusterRules 180 | 181 | class Desk(RepeatingSetNode): 182 | # Make geometric # of object clusters 183 | desk_size=[1., 1.] 184 | def __init__(self, tf): 185 | geom = PhysicsGeometryInfo() 186 | geom.register_geometry( 187 | tf=drake_tf_to_torch_tf(RigidTransform(p=[self.desk_size[0]/2., self.desk_size[1]/2., -0.5])), 188 | geometry=pydrake_geom.Box(self.desk_size[0], self.desk_size[1], 1.0), 189 | color=np.array([0.3, 0.2, 0.2, 1.0]) 190 | ) 191 | super().__init__( 192 | tf=tf, 193 | p=0.2, 194 | max_children=6, 195 | physics_geometry_info=geom, 196 | observed=True 197 | ) 198 | @classmethod 199 | def generate_rules(cls): 200 | lb = torch.tensor([0.2, 0.2, 0.0]) 201 | ub = torch.tensor([cls.desk_size[0] - 0.2, cls.desk_size[1] - 0.2, 0.0]) 202 | rule = ProductionRule( 203 | child_type=ObjectCluster, 204 | xyz_rule=WorldFrameBBoxOffsetRule.from_bounds(lb=lb, ub=ub), 205 | rotation_rule=SameRotationRule() 206 | ) 207 | return [rule] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/planar_fixed_structure/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | from pydrake.all import ( 10 | RollPitchYaw, 11 | RigidTransform 12 | ) 13 | ''' 14 | Minimal grammar that always comes out with a tree with the same structure, 15 | but with unobservable intermediate node pose. 16 | 17 | ROOT -> HIDDEN NODE -> VISIBLE NODE 18 | 19 | where both hidden node and visible node are offset with a unit normal 20 | from their parent in xyz, but have identical rotation. 21 | ''' 22 | 23 | # Deviation from planar epsilon 24 | eps = 1E-3 25 | 26 | ## Food waste 27 | class VisibleNode(TerminalNode): 28 | def __init__(self, tf): 29 | geom = PhysicsGeometryInfo() 30 | geom.register_geometry( 31 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., 0.05])), 32 | geometry=pydrake_geom.Box(0.1, 0.1, 0.1), 33 | color=np.array([0.8, 0.5, 0.2, 1.0]) 34 | ) 35 | super().__init__( 36 | tf=tf, 37 | physics_geometry_info=geom, 38 | observed=True 39 | ) 40 | 41 | 42 | class HiddenNode(AndNode): 43 | def __init__(self, tf): 44 | super().__init__( 45 | tf=tf, 46 | physics_geometry_info=None, 47 | observed=False 48 | ) 49 | @classmethod 50 | def generate_rules(cls): 51 | return [ 52 | ProductionRule( 53 | child_type=VisibleNode, 54 | xyz_rule=WorldFrameGaussianOffsetRule( 55 | mean=torch.zeros(3), 56 | variance=torch.tensor([1.0, 2.0, eps])), 57 | rotation_rule=SameRotationRule() 58 | ) 59 | ] 60 | 61 | 62 | class RootNode(AndNode): 63 | def __init__(self, tf): 64 | geom = PhysicsGeometryInfo() 65 | geom.register_geometry( 66 | tf=drake_tf_to_torch_tf(RigidTransform(p=[0., 0., -0.05])), 67 | geometry=pydrake_geom.Box(0.1, 0.1, 0.1), 68 | color=np.array([0.2, 0.5, 0.8, 1.0]) 69 | ) 70 | super().__init__( 71 | tf=tf, 72 | physics_geometry_info=geom, 73 | observed=True 74 | ) 75 | @classmethod 76 | def generate_rules(cls): 77 | return [ 78 | ProductionRule( 79 | child_type=HiddenNode, 80 | xyz_rule=WorldFrameGaussianOffsetRule( 81 | mean=torch.zeros(3), 82 | variance=torch.tensor([1.0, 2.0, eps])), 83 | rotation_rule=SameRotationRule() 84 | ) 85 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/.gitattributes: -------------------------------------------------------------------------------- 1 | *.pickle filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/constituency_scaling_random_runs.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:212370ea48f6f83b41273272315a389a0d64f78db2fcbbfdcc7c61c2e4a732c7 3 | size 119363631 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/dependency_scaling_random_runs.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2e3f57092b0dd431f2b9462d9a50ee5f345c5e1c697f881f179d1f047a4e4b5f 3 | size 144021305 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/grammar_constituency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | 10 | ''' Scene contains a number of oriented objects 11 | (represented by airplanes) that can appear either 12 | on their own, or in pairs with tightly coupled position 13 | and orientation. ''' 14 | 15 | class Object(TerminalNode): 16 | def __init__(self, tf): 17 | geom = PhysicsGeometryInfo() 18 | geom.register_model_file(torch.eye(4), "models/xwing.sdf") 19 | super().__init__( 20 | tf=tf, 21 | physics_geometry_info=geom, 22 | observed=True 23 | ) 24 | 25 | class Pair(AndNode): 26 | # TODO(gizatt) This node adds some parsing ambiguity: the children 27 | # are exchangeable, so both orderings are equivalent parses. Changing 28 | # this to a RepeatingSetNode with rule_probs [0, 0, 1] resolves this issue, 29 | # but I'm not making the change now as it makes for a more confusing-looking 30 | # grammar... 31 | PAIR_XYZ_VAR = 0.01 32 | def __init__(self, tf): 33 | super().__init__( 34 | tf=tf, 35 | physics_geometry_info=None, 36 | observed=False 37 | ) 38 | @classmethod 39 | def generate_rules(cls): 40 | return [ 41 | ProductionRule( 42 | child_type=Object, 43 | xyz_rule=ParentFrameGaussianOffsetRule( 44 | mean=torch.zeros(3), 45 | variance=torch.ones(3)*cls.PAIR_XYZ_VAR 46 | ), 47 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 48 | RotationMatrix(RollPitchYaw(0., 0., 1.)), # This 1 is weird, but I'm keeping it until I finish figure gen 49 | [100., 100., 100.] 50 | ) 51 | ), 52 | ProductionRule( 53 | child_type=Object, 54 | xyz_rule=ParentFrameGaussianOffsetRule( 55 | mean=torch.zeros(3), 56 | variance=torch.ones(3)*cls.PAIR_XYZ_VAR 57 | ), 58 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 59 | RotationMatrix(RollPitchYaw(0., 0., 1.)), # Likewise 60 | [100., 100., 100.] 61 | ) 62 | ) 63 | ] 64 | 65 | class Pairs(RepeatingSetNode): 66 | MAX_N_PAIRS = 3 67 | def __init__(self, tf): 68 | super().__init__( 69 | tf=tf, 70 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.3, max_children=self.MAX_N_PAIRS), 71 | physics_geometry_info=None, 72 | observed=False 73 | ) 74 | @classmethod 75 | def generate_rules(cls): 76 | return [ProductionRule( 77 | child_type=Pair, 78 | xyz_rule=WorldFrameGaussianOffsetRule( 79 | mean=torch.zeros(3), 80 | variance=torch.tensor([1.0, 1.0, 1.0]) 81 | ), 82 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 83 | RotationMatrix(RollPitchYaw(1., 0., 0.)), # Likewise 84 | [0.1, 0.1, 0.1] 85 | ) 86 | )] 87 | 88 | class Singles(RepeatingSetNode): 89 | MAX_N_SINGLES = 3 90 | def __init__(self, tf): 91 | super().__init__( 92 | tf=tf, 93 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=0.3, max_children=self.MAX_N_SINGLES), 94 | physics_geometry_info=None, 95 | observed=False 96 | ) 97 | @classmethod 98 | def generate_rules(cls): 99 | return [ProductionRule( 100 | child_type=Object, 101 | xyz_rule=WorldFrameGaussianOffsetRule( 102 | mean=torch.zeros(3), 103 | variance=torch.tensor([1.0, 1.0, 1.0]) 104 | ), 105 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 106 | RotationMatrix(RollPitchYaw(1., 0., 0.)), # Likewise 107 | [0.1, 0.1, 0.1] 108 | ) 109 | )] 110 | 111 | class Root(IndependentSetNode): 112 | def __init__(self, tf): 113 | geom = PhysicsGeometryInfo() 114 | geom.register_model_file(torch.eye(4), "models/deathstar.sdf") 115 | super().__init__( 116 | tf=tf, 117 | rule_probs=torch.tensor([0.8, 0.8]), 118 | physics_geometry_info=geom, 119 | observed=True, 120 | 121 | ) 122 | @classmethod 123 | def generate_rules(cls): 124 | return [ 125 | ProductionRule( 126 | child_type=Singles, 127 | xyz_rule=SamePositionRule(), 128 | rotation_rule=SameRotationRule() 129 | ), 130 | ProductionRule( 131 | child_type=Pairs, 132 | xyz_rule=SamePositionRule(), 133 | rotation_rule=SameRotationRule() 134 | ) 135 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/grammar_dependency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spatial_scene_grammars.nodes import * 3 | from spatial_scene_grammars.rules import * 4 | from spatial_scene_grammars.scene_grammar import * 5 | from spatial_scene_grammars.drake_interop import * 6 | 7 | import pydrake 8 | import pydrake.geometry as pydrake_geom 9 | 10 | ''' Scene contains a number of oriented objects 11 | (represented by airplanes) that can appear either 12 | on their own, or in pairs with tightly coupled position 13 | and orientation. ''' 14 | 15 | class Object(TerminalNode): 16 | def __init__(self, tf): 17 | geom = PhysicsGeometryInfo() 18 | geom.register_model_file(torch.eye(4), "models/xwing.sdf") 19 | super().__init__( 20 | tf=tf, 21 | physics_geometry_info=geom, 22 | observed=True 23 | ) 24 | 25 | class Pair(AndNode): 26 | PAIR_XYZ_VAR = 0.01 27 | def __init__(self, tf): 28 | super().__init__( 29 | tf=tf, 30 | physics_geometry_info=None, 31 | observed=False 32 | ) 33 | @classmethod 34 | def generate_rules(cls): 35 | return [ 36 | ProductionRule( 37 | child_type=Object, 38 | xyz_rule=SamePositionRule(), 39 | rotation_rule=SameRotationRule() 40 | ), 41 | ProductionRule( 42 | child_type=Object, 43 | xyz_rule=ParentFrameGaussianOffsetRule( 44 | mean=torch.zeros(3), 45 | variance=torch.ones(3)*2.*cls.PAIR_XYZ_VAR 46 | ), 47 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 48 | RotationMatrix(), 49 | [100., 100., 100.] 50 | ) 51 | ) 52 | ] 53 | 54 | class Pairs(RepeatingSetNode): 55 | MAX_N_PAIRS = 3 56 | P = 0.3 57 | def __init__(self, tf): 58 | super().__init__( 59 | tf=tf, 60 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=self.P, max_children=self.MAX_N_PAIRS), 61 | physics_geometry_info=None, 62 | observed=False 63 | ) 64 | @classmethod 65 | def generate_rules(cls): 66 | return [ProductionRule( 67 | child_type=Pair, 68 | xyz_rule=WorldFrameGaussianOffsetRule( 69 | mean=torch.zeros(3), 70 | variance=torch.tensor([1.0, 1.0, 1.0]) 71 | ), 72 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 73 | RotationMatrix(), 74 | [0.1, 0.1, 0.1] 75 | ) 76 | )] 77 | 78 | class Singles(RepeatingSetNode): 79 | MAX_N_SINGLES = 3 80 | P = 0.3 81 | def __init__(self, tf): 82 | super().__init__( 83 | tf=tf, 84 | rule_probs=RepeatingSetNode.get_geometric_rule_probs(p=self.P, max_children=self.MAX_N_SINGLES), 85 | physics_geometry_info=None, 86 | observed=False 87 | ) 88 | @classmethod 89 | def generate_rules(cls): 90 | return [ProductionRule( 91 | child_type=Object, 92 | xyz_rule=WorldFrameGaussianOffsetRule( 93 | mean=torch.zeros(3), 94 | variance=torch.tensor([1.0, 1.0, 1.0]) 95 | ), 96 | rotation_rule=WorldFrameBinghamRotationRule.from_rotation_and_rpy_variances( 97 | RotationMatrix(), 98 | [0.1, 0.1, 0.1] 99 | ) 100 | )] 101 | 102 | class Root(IndependentSetNode): 103 | def __init__(self, tf): 104 | geom = PhysicsGeometryInfo() 105 | geom.register_model_file(torch.eye(4), "models/deathstar.sdf") 106 | super().__init__( 107 | tf=tf, 108 | rule_probs=torch.tensor([0.8, 0.8]), 109 | physics_geometry_info=geom, 110 | observed=True, 111 | 112 | ) 113 | @classmethod 114 | def generate_rules(cls): 115 | return [ 116 | ProductionRule( 117 | child_type=Singles, 118 | xyz_rule=SamePositionRule(), 119 | rotation_rule=SameRotationRule() 120 | ), 121 | ProductionRule( 122 | child_type=Pairs, 123 | xyz_rule=SamePositionRule(), 124 | rotation_rule=SameRotationRule() 125 | ) 126 | ] -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/deathstar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/singles_pairs/models/deathstar.png -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/deathstar.sdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 1.0 7 | 8 | 0.001736 9 | 0 10 | 0 11 | 0.001098 12 | 0 13 | 0.002481 14 | 15 | 16 | 17 | 18 | 19 | deathstar.obj 20 | 0.25 0.25 0.25 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/xwing.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/singles_pairs/models/xwing.blend -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/xwing.blend1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/singles_pairs/models/xwing.blend1 -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/xwing.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'xwing.blend' 2 | # Material Count: 5 3 | 4 | newmtl __Chalk_ 5 | Ns 225.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.909804 0.905882 0.925490 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | 14 | newmtl __Charcoal_ 15 | Ns 225.000000 16 | Ka 1.000000 1.000000 1.000000 17 | Kd 0.137255 0.137255 0.137255 18 | Ks 0.500000 0.500000 0.500000 19 | Ke 0.000000 0.000000 0.000000 20 | Ni 1.450000 21 | d 1.000000 22 | illum 2 23 | 24 | newmtl __Concrete-rough_ 25 | Ns 225.000000 26 | Ka 1.000000 1.000000 1.000000 27 | Kd 0.658823 0.662745 0.647059 28 | Ks 0.500000 0.500000 0.500000 29 | Ke 0.000000 0.000000 0.000000 30 | Ni 1.450000 31 | d 1.000000 32 | illum 2 33 | 34 | newmtl __Glass_Blue_Tint_ 35 | Ns 225.000000 36 | Ka 1.000000 1.000000 1.000000 37 | Kd 0.035294 0.686275 0.945098 38 | Ks 0.500000 0.500000 0.500000 39 | Ke 0.000000 0.000000 0.000000 40 | Ni 1.450000 41 | d 0.478431 42 | illum 9 43 | 44 | newmtl __Tomato_ 45 | Ns 225.000000 46 | Ka 1.000000 1.000000 1.000000 47 | Kd 1.000000 0.388235 0.278431 48 | Ks 0.500000 0.500000 0.500000 49 | Ke 0.000000 0.000000 0.000000 50 | Ni 1.450000 51 | d 1.000000 52 | illum 2 53 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/xwing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/singles_pairs/models/xwing.png -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/models/xwing.sdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 1.0 7 | 8 | 0.001736 9 | 0 10 | 0 11 | 0.001098 12 | 0 13 | 0.002481 14 | 15 | 16 | 17 | 18 | 19 | xwing.obj 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/single_scene.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:00c49c03f14b0ba5633730d9487aaf52e3308652ccafaa92bf3b0ef0848a6d06 3 | size 34413 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/singles_pairs/test/test_grammar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import os 7 | import time 8 | 9 | import torch 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | 12 | from spatial_scene_grammars.nodes import * 13 | from spatial_scene_grammars.rules import * 14 | from spatial_scene_grammars.scene_grammar import * 15 | from spatial_scene_grammars.visualization import * 16 | from spatial_scene_grammars.parsing import * 17 | 18 | import meshcat 19 | import meshcat.geometry as meshcat_geom 20 | 21 | from pydrake.all import SnoptSolver 22 | 23 | import spatial_scene_grammars_examples.singles_pairs.grammar_dependency as grammar_dependency 24 | import spatial_scene_grammars_examples.singles_pairs.grammar_constituency as grammar_constituency 25 | 26 | @pytest.fixture(params=range(3)) 27 | def set_seed(request): 28 | torch.manual_seed(request.param) 29 | 30 | @pytest.mark.parametrize("grammar_library", 31 | [grammar_dependency, grammar_constituency] 32 | ) 33 | def test_sampling(set_seed, grammar_library): 34 | # Draw a random sample from the grammar and visualize it. 35 | grammar = SpatialSceneGrammar( 36 | root_node_type = grammar_library.Root, 37 | root_node_tf = torch.eye(4) 38 | ) 39 | torch.random.manual_seed(42) 40 | tree = grammar.sample_tree() 41 | 42 | assert torch.isfinite(tree.score(verbose=True)), "Sampled tree was infeasible." 43 | 44 | if __name__ == "__main__": 45 | pytest.main() 46 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/.gitattributes: -------------------------------------------------------------------------------- 1 | models.zip filter=lfs diff=lfs merge=lfs -text 2 | example_feasible_sampled_scene.pickle filter=lfs diff=lfs merge=lfs -text 3 | render*.blend* filter=lfs diff=lfs merge=lfs -text 4 | target_dataset_grammar_state_dict.torch filter=lfs diff=lfs merge=lfs -text 5 | target_dataset_examples.pickle filter=lfs diff=lfs merge=lfs -text 6 | post_fit_grammar_draws.pickle filter=lfs diff=lfs merge=lfs -text 7 | post_fit_grammar_state_dict.torch filter=lfs diff=lfs merge=lfs -text 8 | pre_fit_grammar_ filter=lfs diff=lfs merge=lfs -text 9 | fit_em_baseline.pickle filter=lfs diff=lfs merge=lfs -text 10 | fit_grammar_baseline.torch filter=lfs diff=lfs merge=lfs -text 11 | structure_constraint_examples.pickle filter=lfs diff=lfs merge=lfs -text 12 | structure_constraint_dataset_grammar_state_dict.torch filter=lfs diff=lfs merge=lfs -text 13 | baseline_pre_fit_grammar_draws.pickle filter=lfs diff=lfs merge=lfs -text 14 | baseline_pre_fit_grammar_state_dict.torch filter=lfs diff=lfs merge=lfs -text 15 | baseline_em_mmds_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 16 | baseline_post_fit_grammar_draws.pickle filter=lfs diff=lfs merge=lfs -text 17 | baseline_post_fit_grammar_state_dict.torch filter=lfs diff=lfs merge=lfs -text 18 | em_mmds_precomputed.pickle filter=lfs diff=lfs merge=lfs -text 19 | post_fit_grammar_draws_decorated.pickle filter=lfs diff=lfs merge=lfs -text 20 | renders.zip filter=lfs diff=lfs merge=lfs -text 21 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/.gitignore: -------------------------------------------------------------------------------- 1 | renders 2 | models 3 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/baseline_em_mmds_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dfb2c56107fe330af6edb87a09e2f605c65d0e75e401c8db2a60140d22bbd5c1 3 | size 23649 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/baseline_post_fit_grammar_draws.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ddc933cb4c519093e2636a72917215a5256b01efcaae36c550249e213bb52b8f 3 | size 3641795 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/baseline_post_fit_grammar_state_dict.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6be45ba74ea765344aae9f41e6be543de8fb1f9f729a7121a5f08780052c4e02 3 | size 16328 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/baseline_pre_fit_grammar_draws.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:391e3e1a11a2df7cad96177b8c1b2eea32247106527ca27c19ae4c8f6839d8eb 3 | size 1111609 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/baseline_pre_fit_grammar_state_dict.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5d83852ebba516a819e371b7ccb06262fd7f564f1ac1b00039610ab0f2b1f2bf 3 | size 16328 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/em_mmds_precomputed.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5c7522ae4898e6ef89334b5f09e17804170784e10d1697e16259226255e31814 3 | size 23633 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/example_feasible_sampled_scene.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4f3fc348a818941a464e415fff6b9d416c1ddc3b1c64743861d328fbc345bcd6 3 | size 4 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/fit_em.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/fit_em.pickle -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/fit_em_baseline.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:572d55d31f336cd81e1c3aad6f3b8662cf03c77b46f694dc15ae00cd2a8a8167 3 | size 1742482 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/fit_grammar.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/fit_grammar.torch -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/fit_grammar_baseline.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a2bff532d06a7d72fd0fcbe406c54558cc82249f15094ab133362508959d9f55 3 | size 16328 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/grammar_decoration.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import lru_cache 4 | 5 | import torch 6 | from spatial_scene_grammars.nodes import * 7 | from spatial_scene_grammars.rules import * 8 | from spatial_scene_grammars.scene_grammar import * 9 | from spatial_scene_grammars.drake_interop import * 10 | from spatial_scene_grammars.constraints import * 11 | 12 | 13 | import pydrake 14 | import pydrake.geometry as pydrake_geom 15 | from pydrake.all import ( 16 | RollPitchYaw, 17 | RigidTransform 18 | ) 19 | 20 | from spatial_scene_grammars_examples.table.grammar import * 21 | 22 | class DumplingDecoration(TerminalNode): 23 | def __init__(self, tf): 24 | geom = PhysicsGeometryInfo(fixed=True) 25 | geom_tf = torch.eye(4) 26 | geom.register_model_file(geom_tf, "models/misc/dumplings/model.sdf") 27 | super().__init__( 28 | tf=tf, 29 | physics_geometry_info=geom, 30 | observed=True 31 | ) 32 | class EggBunDecoration(TerminalNode): 33 | def __init__(self, tf): 34 | geom = PhysicsGeometryInfo(fixed=True) 35 | geom_tf = torch.eye(4) 36 | geom.register_model_file(geom_tf, "models/misc/egg_buns/model.sdf") 37 | super().__init__( 38 | tf=tf, 39 | physics_geometry_info=geom, 40 | observed=True 41 | ) 42 | class RiceWrapDecoration(TerminalNode): 43 | # Lo mai gai 44 | def __init__(self, tf): 45 | geom = PhysicsGeometryInfo(fixed=True) 46 | geom_tf = torch.eye(4) 47 | geom.register_model_file(geom_tf, "models/misc/rice_wrap/model.sdf") 48 | super().__init__( 49 | tf=tf, 50 | physics_geometry_info=geom, 51 | observed=True 52 | ) 53 | class ShrimpDumplingsDecoration(TerminalNode): 54 | # Har gow 55 | def __init__(self, tf): 56 | geom = PhysicsGeometryInfo(fixed=True) 57 | geom_tf = torch.eye(4) 58 | geom.register_model_file(geom_tf, "models/misc/shrimp_dumplings/model.sdf") 59 | super().__init__( 60 | tf=tf, 61 | physics_geometry_info=geom, 62 | observed=True 63 | ) 64 | class SteamerDecoration(OrNode): 65 | # Various foods that could be inside. 66 | def __init__(self, tf): 67 | super().__init__( 68 | tf=tf, 69 | physics_geometry_info=None, 70 | rule_probs=torch.ones(len(self.generate_rules())), 71 | observed=False 72 | ) 73 | @classmethod 74 | def generate_rules(cls): 75 | return [ 76 | ProductionRule( 77 | child_type=DumplingDecoration, 78 | xyz_rule=SamePositionRule(), 79 | rotation_rule=SameRotationRule() 80 | ), 81 | ProductionRule( 82 | child_type=EggBunDecoration, 83 | xyz_rule=SamePositionRule(), 84 | rotation_rule=SameRotationRule() 85 | ), 86 | ProductionRule( 87 | child_type=RiceWrapDecoration, 88 | xyz_rule=SamePositionRule(), 89 | rotation_rule=SameRotationRule() 90 | ), 91 | ProductionRule( 92 | child_type=ShrimpDumplingsDecoration, 93 | xyz_rule=SamePositionRule(), 94 | rotation_rule=SameRotationRule() 95 | ) 96 | ] 97 | 98 | 99 | class EggTartsDecoration(TerminalNode): 100 | def __init__(self, tf): 101 | geom = PhysicsGeometryInfo(fixed=True) 102 | geom_tf = torch.eye(4) 103 | geom.register_model_file(geom_tf, "models/misc/egg_tarts/model.sdf") 104 | super().__init__( 105 | tf=tf, 106 | physics_geometry_info=geom, 107 | observed=True 108 | ) 109 | 110 | class ChairDecoration(TerminalNode): 111 | # Place settings + misc common dishware 112 | def __init__(self, tf): 113 | geom = PhysicsGeometryInfo(fixed=True) 114 | geom_tf = torch.eye(4) 115 | geom.register_model_file(geom_tf, "models/misc/chair/model.sdf") 116 | super().__init__( 117 | tf=tf, 118 | physics_geometry_info=geom, 119 | observed=True 120 | ) 121 | 122 | class PersonalPlateDecoration(IndependentSetNode): 123 | def __init__(self, tf): 124 | super().__init__( 125 | tf=tf, 126 | physics_geometry_info=None, 127 | observed=False, 128 | rule_probs=torch.tensor([0.8]) 129 | ) 130 | @classmethod 131 | def generate_rules(cls): 132 | return [ 133 | ProductionRule( 134 | child_type=FirstChopstick, 135 | xyz_rule=ParentFrameGaussianOffsetRule( 136 | mean=torch.tensor([0.0, 0.0, 0.05]), 137 | variance=torch.tensor([0.0005, 0.0005, 0.0001])), 138 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 139 | RotationMatrix(RollPitchYaw(0., np.pi/2., 0.)), np.array([1000, 1000, 1]) 140 | ) 141 | ) 142 | ] 143 | 144 | class PlaceSettingDecoration(AndNode): 145 | TABLE_HEIGHT = 0.8 146 | def __init__(self, tf): 147 | super().__init__( 148 | tf=tf, 149 | physics_geometry_info=None, 150 | observed=False, 151 | ) 152 | @classmethod 153 | def generate_rules(cls): 154 | return [ 155 | ProductionRule( 156 | child_type=ChairDecoration, 157 | xyz_rule=ParentFrameGaussianOffsetRule( 158 | mean=torch.tensor([-0.4, 0.0, -cls.TABLE_HEIGHT]), 159 | variance=torch.tensor([0.01, 0.01, 0.0001])), 160 | rotation_rule=ParentFrameBinghamRotationRule.from_rotation_and_rpy_variances( 161 | RotationMatrix(RollPitchYaw(0., 0., -np.pi/2.)), np.array([1000, 1000, 100]) 162 | ) 163 | ) 164 | ] 165 | 166 | 167 | decoration_mapping = { 168 | PlaceSetting: PlaceSettingDecoration, 169 | PersonalPlate: PersonalPlateDecoration, 170 | SteamerBottom: SteamerDecoration, 171 | ServingDish: EggTartsDecoration 172 | } -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/models.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:de5859786b42fd1db11916a374be92b2254cd769f11b83846cb408d0b3b62fcd 3 | size 50955048 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/post_fit_grammar_draws.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6ff520cb777b3b208c7bbfb502c974d3dbdbab1c5536253a8b623bd9a1c47c57 3 | size 12475300 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/post_fit_grammar_draws_decorated.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7851a5678b779a3ce1f919f87dbf7b1428bb4eb7194872c0bba5e2cf17f726e6 3 | size 5436620 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/post_fit_grammar_state_dict.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ddec2905b31cd4043f59d9c8dc702939b820061e1ae480c1cb7de12150a3911a 3 | size 23079 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/pre_fit_grammar_draws.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/pre_fit_grammar_draws.pickle -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/pre_fit_grammar_state_dict.torch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/pre_fit_grammar_state_dict.torch -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/renders/example_raw_blender_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/renders/example_raw_blender_output.jpg -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/renders/tweaked_in_blender.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gizatt/spatial_scene_grammars/46d85219b116ce0abbb9d73306046518e124593f/spatial_scene_grammars_examples/table/renders/tweaked_in_blender.jpg -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/structure_constraint_dataset_grammar_state_dict.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:68171d1f8c89a218e6bbcc1d2e1f60291df479a39f314a427831bd9d99e5b771 3 | size 23079 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/structure_constraint_examples.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2dace9cd4ee8672af5d72cc100788ec408e03281d0ca855aeecb3a7138e29a6a 3 | size 3078836 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/target_dataset_examples.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:740d460e479b37df74b4b1b0449eb6f0118dffeb3febf4d390704cc118b4e4b0 3 | size 13222847 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/target_dataset_grammar_state_dict.torch: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fc061f7e6291c581fe34440350b85aee88df8472fa57a4a5e3ac6720e4e99a71 3 | size 23079 4 | -------------------------------------------------------------------------------- /spatial_scene_grammars_examples/table/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def sample_realistic_scene(seed=None): 4 | if seed is not None: 5 | torch.random.manual_seed(seed) 6 | topology_constraints, continuous_constraints = split_constraints(constraints) 7 | if len(topology_constraints) > 0: 8 | tree, success = rejection_sample_under_constraints(grammar, topology_constraints, 1000) 9 | if not success: 10 | logging.error("Couldn't rejection sample a feasible tree config.") 11 | return None 12 | else: 13 | tree = grammar.sample_tree(detach=True) 14 | samples = do_fixed_structure_hmc_with_constraint_penalties( 15 | grammar, tree, num_samples=25, subsample_step=5, 16 | with_nonpenetration=False, zmq_url=vis.window.zmq_url, 17 | constraints=continuous_constraints, 18 | kernel_type="NUTS", max_tree_depth=6, target_accept_prob=0.8, adapt_step_size=True 19 | ) 20 | # Step through samples backwards in HMC process and pick out a tree that satisfies 21 | # the constraints. 22 | good_tree = None 23 | for candidate_tree in samples[::-1]: 24 | if eval_total_constraint_set_violation(candidate_tree, constraints) <= 0.: 25 | good_tree = candidate_tree 26 | break 27 | if good_tree == None: 28 | logging.error("No tree in samples satisfied constraints.") 29 | return None 30 | 31 | return project_tree_to_feasibility(good_tree, do_forward_sim=True, timestep=0.001, T=1.) --------------------------------------------------------------------------------