├── tests ├── __init__.py ├── testdata │ ├── sample_json.tar.xz │ └── sg_t16.stats.pkl.gz ├── conftest.py ├── test_constraint_checks.py ├── test_flat_array.py ├── test_graph_dataset.py ├── test_data.py ├── test_graph_model.py └── test_torch_extensions.py ├── sketchgraphs_models ├── graph │ ├── train │ │ ├── __main__.py │ │ ├── data_loading.py │ │ └── harness.py │ ├── __init__.py │ ├── entropy_rate.py │ ├── dataset │ │ └── benchmark.py │ ├── eval_likelihood.py │ └── model │ │ ├── message_passing.py │ │ └── numerical_features.py ├── autoconstraint │ ├── scripts │ │ ├── __init__.py │ │ ├── eval_statistics_stratified.py │ │ └── eval_statistics_mask.py │ ├── __init__.py │ ├── eval_likelihood.py │ └── dataset.py ├── torch_extensions │ ├── __init__.py │ ├── index.py │ ├── _repeat_interleave.py │ ├── segment_pool.py │ └── segment_ops.py ├── __init__.py ├── nn │ ├── distributed.py │ ├── data_util.py │ ├── functional.py │ ├── summary.py │ └── __init__.py └── distributed_utils.py ├── assets ├── sketchgraphs.gif └── sketch_w_graph.png ├── requirements.txt ├── sketchgraphs ├── onshape │ ├── creds │ │ └── creds.json.example │ ├── __init__.py │ ├── LICENSE │ ├── feature_template.json │ ├── utils.py │ ├── onshape.py │ └── call.py ├── pipeline │ ├── __init__.py │ ├── graph_model │ │ ├── __init__.py │ │ └── target.py │ ├── numerical_parameters.py │ ├── make_quantization_statistics.py │ └── make_sketch_dataset.py ├── __init__.py └── data │ ├── __init__.py │ ├── sketch.py │ ├── _plotting.py │ ├── dof.py │ └── constraint_checks.py ├── setup.py ├── .gitignore ├── docs ├── _templates │ └── autosummary │ │ ├── class.rst │ │ └── module.rst ├── Makefile ├── make.bat ├── onshape_setup.rst ├── models.rst ├── conf.py ├── index.rst └── data.rst ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/train/__main__.py: -------------------------------------------------------------------------------- 1 | from . import main 2 | main() 3 | -------------------------------------------------------------------------------- /assets/sketchgraphs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/assets/sketchgraphs.gif -------------------------------------------------------------------------------- /assets/sketch_w_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/assets/sketch_w_graph.png -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous scripts used to process the result.""" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.17 2 | matplotlib>=3.1 3 | zstandard>=0.14 4 | lz4>=3.1 5 | tqdm>=4.48 6 | requests>=2.25.1 7 | -------------------------------------------------------------------------------- /tests/testdata/sample_json.tar.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/tests/testdata/sample_json.tar.xz -------------------------------------------------------------------------------- /tests/testdata/sg_t16.stats.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/tests/testdata/sg_t16.stats.pkl.gz -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the model, dataset and training procedure for the auto-constraint model.""" 2 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/creds/creds.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "https://cad.onshape.com": { 3 | "access_key": "cA...vT", 4 | "secret_key": "aU...Hw" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='sketchgraphs', 5 | version='0.1', 6 | packages=find_packages(), 7 | license='MIT' 8 | ) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | __pycache__/ 3 | 4 | docs/_autosummary 5 | docs/_build 6 | 7 | creds.json 8 | build/ 9 | *.cpython-37m-x86_64-linux-gnu.so 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the transformations in our data pipeline, taking the raw JSON data obtained 2 | from Onshape into data that is suitable for model training. 3 | """ 4 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :special-members: 8 | :exclude-members: __dict__,__weakref__,__repr__,__str__ 9 | -------------------------------------------------------------------------------- /sketchgraphs/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for manipulating and learning with the SketchGraphs dataset. 2 | 3 | The sub-modules of this module implement the required operations to create datasets 4 | at various levels of abstraction from the source data. In particular, there are three 5 | main scripts used to produce training data for the models. 6 | 7 | """ 8 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains basic code to interact with the Onshape API. 2 | 3 | The Onshape API contains some important functionality to perform more advanced manipulation of sketches. 4 | In particular, the API provides functions for solving constraints in sketches. 5 | 6 | """ 7 | 8 | from .onshape import Onshape 9 | from .client import Client 10 | from .utils import log 11 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/graph_model/__init__.py: -------------------------------------------------------------------------------- 1 | """This module provides the data pipeline to convert from the stored sequence format to 2 | a format suitable for training our graph-based deep learning model. In particular, it features 3 | a number of important helpers for quantizing numerical features, and representing the graph in 4 | terms of tensors. 5 | 6 | """ 7 | 8 | from . import _graph_info 9 | from ._graph_info import * 10 | 11 | __all__ = _graph_info.__all__ 12 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """ This module contains the main implementation for the graph models. 2 | 3 | The graph models represent a family of models which recursively build 4 | the sketch by viewing the sketch as a graph of entities and constraints. 5 | As described in the paper, the baseline generative model does not consider 6 | entity (primitive) coordinates and relies on constraints to determine 7 | the final configuration of the solved sketch. 8 | 9 | """ 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Shared test fixtures in order to easily provide testing data to all tests.""" 2 | 3 | import json 4 | import tarfile 5 | 6 | import pytest 7 | 8 | from sketchgraphs.data.sketch import Sketch 9 | 10 | 11 | @pytest.fixture 12 | def sketches_json(): 13 | sample_file = 'tests/testdata/sample_json.tar.xz' 14 | 15 | result = [] 16 | 17 | with tarfile.open(sample_file, 'r:xz') as tar_archive: 18 | for f in tar_archive: 19 | if not f.isfile(): 20 | continue 21 | result.extend(json.load(tar_archive.extractfile(f))) 22 | 23 | return result 24 | 25 | @pytest.fixture 26 | def sketches(sketches_json): 27 | """Return a list of sample sketches.""" 28 | return [Sketch.from_fs_json(j) for j in sketches_json] 29 | -------------------------------------------------------------------------------- /tests/test_constraint_checks.py: -------------------------------------------------------------------------------- 1 | """Tests for constraint checking""" 2 | 3 | import itertools 4 | 5 | from sketchgraphs.data import constraint_checks, sketch_to_sequence, EdgeOp 6 | 7 | def test_constraint_check(sketches): 8 | for sketch_idx, sketch in enumerate(itertools.islice(sketches, 1000)): 9 | sequence = sketch_to_sequence(sketch) 10 | for op_idx, op in enumerate(sequence): 11 | if not isinstance(op, EdgeOp): 12 | continue 13 | try: 14 | check_value = constraint_checks.check_edge_satisfied(sequence, op) 15 | except ValueError: 16 | check_value = True 17 | 18 | if check_value is None: 19 | check_value = True 20 | 21 | assert check_value, "constraint check failed at sketch {0}, op {1}".format(sketch_idx, op_idx) 22 | -------------------------------------------------------------------------------- /sketchgraphs_models/torch_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | """ Pytorch extensions for our models. 2 | 3 | This module provides some important extensions that work with segmented data, which is 4 | used extensively in our models due to the nature of graph batching. This modules provides 5 | a pure python implementation, although can leverage a C++ extension if present. This 6 | enables speed-up of over 2x (on GPU) when it is present, so it is recommended to compile the C++ 7 | extensions if possible. 8 | 9 | """ 10 | 11 | from ._repeat_interleave import repeat_interleave 12 | from .segment_ops import segment_logsumexp 13 | from .segment_pool import segment_avg_pool1d, segment_max_pool1d 14 | 15 | from .index import segment_triu_indices, segment_cartesian_product 16 | 17 | __all__ = ['repeat_interleave', 'segment_logsumexp', 'segment_avg_pool1d', 'segment_max_pool1d', 18 | 'segment_triu_indices', 'segment_cartesian_product'] 19 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /sketchgraphs_models/torch_extensions/index.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def segment_triu_indices_loop(scopes, offset=0, dtype=torch.long, device='cpu'): 5 | indices = torch.cat([ 6 | torch.triu_indices( 7 | n, n, offset, dtype=dtype, device=device).t() + o for (o, n) in scopes 8 | ], dim=0) 9 | 10 | return indices 11 | 12 | 13 | def segment_cartesian_product_loop(values_a, values_b, scopes_a, scopes_b): 14 | result = [] 15 | 16 | for sa, sb in zip(scopes_a, scopes_b): 17 | va = values_a.narrow(0, sa[0], sa[1]).unsqueeze(1).expand((-1, sb[1], -1)) 18 | vb = values_b.narrow(0, sb[0], sb[1]).unsqueeze(0).expand((sa[1], -1, -1)) 19 | 20 | values_cat = torch.cat((va, vb), dim=-1) 21 | result.append(values_cat.flatten(end_dim=1)) 22 | 23 | return torch.cat(result, dim=0) 24 | 25 | 26 | segment_triu_indices = segment_triu_indices_loop 27 | segment_cartesian_product = segment_cartesian_product_loop 28 | -------------------------------------------------------------------------------- /sketchgraphs/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains the main data representations for the SketchGraphs dataset. 2 | 3 | There are two main projections for the data. 4 | 5 | The first is the `sketch` projection, which is close to the underlying FeatureScript JSON format from Onshape. 6 | This projection is favourable for interactions with Onshape's API, and serves as a first step for parsing 7 | the raw data from Onshape. 8 | 9 | The second projection in the `sequence` projection, and is more 10 | specialized towards applications in machine learning. It is more streamlined, and requires conversion 11 | to interact with Onshape's API. However, it interacts more naturally with machine learning applications. 12 | 13 | """ 14 | 15 | from .sketch import * 16 | from .sequence import * 17 | 18 | from .sketch import Sketch, EntityType, SubnodeType, Entity, GenericEntity, Point, Line, Circle, Arc, Spline, Ellipse, ENTITY_TYPE_TO_CLASS 19 | from .sketch import render_sketch, render_graph 20 | -------------------------------------------------------------------------------- /sketchgraphs_models/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains supporting code to the main sketchgraphs dataset. In particular, it contains two models based 2 | on graph representations of the data, which are geared towards the task of generation (see `sketchgraphs_models.graph`) 3 | and autoconstrain (see `sketchgraphs_models.autoconstraint`). In addition, this module contains a number of submodules 4 | to help the implementation of these models. 5 | 6 | These models are tuned and operate on a subset of the full sketchgraphs dataset. In particular, they only handle 7 | constraints which relate at most two entities (e.g. the mirror constraint is not handled), and are trained on a subset 8 | of sketches which excludes sketches that are too large or too small. In addition, these model use a quantization 9 | strategy to model continuous parameters in the dataset, which must be pre-computed before training. To create 10 | quantization maps, see `sketchgraphs.pipeline.make_quantization_statistics`. 11 | 12 | """ 13 | -------------------------------------------------------------------------------- /tests/test_flat_array.py: -------------------------------------------------------------------------------- 1 | #pylint: disable=missing-module-docstring,missing-function-docstring 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | 7 | from sketchgraphs.data import flat_array 8 | 9 | 10 | def test_flat_array_of_int(): 11 | x = [2, 3, 4, 5] 12 | 13 | x_flat = flat_array.save_list_flat(x) 14 | x_array = flat_array.FlatSerializedArray.from_flat_array(x_flat) 15 | 16 | assert len(x_array) == len(x) 17 | 18 | for i, j in itertools.zip_longest(x, x_array): 19 | assert i == j 20 | 21 | 22 | def test_flat_dictionary(): 23 | x = [2, 3, 4, 5] 24 | y = np.array([3, 5]) 25 | z = ["A", "python", "list"] 26 | 27 | x_flat = flat_array.save_list_flat(x) 28 | 29 | dict_flat = flat_array.pack_dictionary_flat({ 30 | 'x': x_flat, 31 | 'y': y, 32 | 'z': z 33 | }) 34 | 35 | result = flat_array.load_dictionary_flat(dict_flat) 36 | 37 | assert isinstance(result['x'], flat_array.FlatSerializedArray) 38 | assert len(result['x']) == len(x) 39 | 40 | assert result['z'] == z 41 | assert all(result['y'] == y) 42 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 PTC Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/feature_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "feature": { 3 | "type": 151, 4 | "typeName": "BTMSketch", 5 | "message": { 6 | "entities": [], 7 | "constraints": [], 8 | "featureType": "newSketch", 9 | "name": "My Sketch", 10 | "parameters": [{ 11 | "type": 148, 12 | "typeName": "BTMParameterQueryList", 13 | "message": { 14 | "queries": [{ 15 | "type": 138, 16 | "typeName": "BTMIndividualQuery", 17 | "message": { 18 | "geometryIds": ["JEC"], 19 | "hasUserCode": false 20 | } 21 | }], 22 | "parameterId": "sketchPlane", 23 | "hasUserCode": false 24 | } 25 | }], 26 | "suppressed": false, 27 | "namespace": "", 28 | "subFeatures": [], 29 | "returnAfterSubfeatures": false, 30 | "suppressionState": { 31 | "type": 0 32 | }, 33 | "hasUserCode": false 34 | } 35 | }, 36 | "serializationVersion": "1.1.18", 37 | "sourceMicroversion": "57c52f87718e8008cce43b47", 38 | "rejectMicroversionSkew": false, 39 | "microversionSkew": false, 40 | "libraryVersion": 1324 41 | } 42 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block classes %} 6 | {% if classes %} 7 | .. rubric:: {{ _('Classes') }} 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in classes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block attributes %} 18 | {% if attributes %} 19 | .. rubric:: Module Attributes 20 | 21 | {% for item in attributes %} 22 | .. autoattribute:: {{ item }} 23 | {%- endfor %} 24 | {% endif %} 25 | {% endblock %} 26 | 27 | {% block functions %} 28 | {% if functions %} 29 | .. rubric:: {{ _('Functions') }} 30 | 31 | {% for item in functions %} 32 | .. autofunction:: {{ item }} 33 | {%- endfor %} 34 | {% endif %} 35 | {% endblock %} 36 | 37 | 38 | {% block exceptions %} 39 | {% if exceptions %} 40 | .. rubric:: {{ _('Exceptions') }} 41 | 42 | .. autosummary:: 43 | {% for item in exceptions %} 44 | {{ item }} 45 | {%- endfor %} 46 | {% endif %} 47 | {% endblock %} 48 | 49 | {% block modules %} 50 | {% if modules %} 51 | .. rubric:: Modules 52 | 53 | .. autosummary:: 54 | :toctree: 55 | :recursive: 56 | {% for item in modules %} 57 | {{ item }} 58 | {%- endfor %} 59 | {% endif %} 60 | {% endblock %} -------------------------------------------------------------------------------- /sketchgraphs_models/nn/distributed.py: -------------------------------------------------------------------------------- 1 | """Utility modules for distributed and parallel training. """ 2 | 3 | import torch 4 | 5 | class SingleDeviceDistributedParallel(torch.nn.parallel.distributed.DistributedDataParallel): 6 | """This module implements a module similar to `DistributedDataParallel`, but it accepts 7 | inputs of any shape, and only supports a single device per instance. 8 | """ 9 | def __init__(self, module, device_id, find_unused_parameters=False): 10 | super(SingleDeviceDistributedParallel, self).__init__( 11 | module, [device_id], find_unused_parameters=find_unused_parameters) 12 | 13 | def forward(self, *inputs, **kwargs): 14 | if self.require_forward_param_sync: 15 | self._sync_params() 16 | 17 | output = self.module(*inputs, **kwargs) 18 | 19 | if torch.is_grad_enabled() and self.require_backward_grad_sync: 20 | self.require_forward_param_sync = True 21 | 22 | if self.find_unused_parameters: 23 | self.reducer.prepare_for_backward(list(torch.nn.parallel.distributed._find_tensors(output))) 24 | else: 25 | self.reducer.prepare_for_backward([]) 26 | 27 | return output 28 | 29 | def state_dict(self, destination=None, prefix='', keep_vars=False): 30 | return self.module.state_dict(destination, prefix, keep_vars) 31 | 32 | def load_state_dict(self, state_dict, strict=True): 33 | return self.module.load_state_dict(state_dict, strict) 34 | -------------------------------------------------------------------------------- /sketchgraphs_models/nn/data_util.py: -------------------------------------------------------------------------------- 1 | """Utilities and extensions to work with the torch.utils.data package. 2 | """ 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | class MultiEpochSampler(torch.utils.data.Sampler): 8 | """A sampler wrapper which creates multiple epochs of indices 9 | from one sampler. 10 | 11 | Due to the way torch dataloaders function, they reset all workers 12 | when starting a new epoch. This is undesirable due to the high setup 13 | cost of starting workers. 14 | 15 | This sampler thus wraps an existing batch sampler, and simply repeatedly 16 | samples several epochs from the wrapped sampler. 17 | 18 | """ 19 | def __init__(self, batch_sampler, num_epochs): 20 | """Initializes a new instance of `MultiEpochSampler`. 21 | 22 | Parameters 23 | ---------- 24 | batch_sampler : torch.utils.data.Sampler 25 | A batch sampler to wrap 26 | num_epochs : int 27 | The number of epochs for this sampler. 28 | """ 29 | #pylint: disable=super-init-not-called 30 | self._sampler = batch_sampler 31 | self._epochs = num_epochs 32 | 33 | def __len__(self): 34 | return len(self._sampler) * self._epochs 35 | 36 | @property 37 | def batches_per_epoch(self): 38 | """Get the number of batches per one single epoch of the original sampler.""" 39 | return len(self._sampler) 40 | 41 | def __iter__(self): 42 | for _ in range(self._epochs): 43 | for sample in self._sampler: 44 | yield sample 45 | -------------------------------------------------------------------------------- /docs/onshape_setup.rst: -------------------------------------------------------------------------------- 1 | Onshape setup 2 | ============= 3 | 4 | We include a set of functions (`sketchgraphs.onshape.call`) for interacting with the CAD program `Onshape `_ in order to solve geometric constraints. 5 | The SketchGraphs `demo notebook `_ demonstrates their main usage. 6 | 7 | Before calling these for the first time, a small bit of setup is required. 8 | 9 | Account & credentials 10 | --------------------- 11 | 12 | - Visit `Onshape `_ and create an account (it's free). 13 | - Create an API key at https://dev-portal.onshape.com/keys. Be sure to enable read and write permissions so that we may send and retrieve CAD sketches from Onshape. 14 | - Save the file at `sketchgraphs/onshape/creds/creds.json` 15 | 16 | 17 | Create document 18 | --------------- 19 | We'll need a document to serve as the target for our sketches. 20 | 21 | - Go to https://cad.onshape.com/documents and click Create->Document in the upper left. 22 | - The document URL will be needed to perform API calls. For example, in the `demo notebook `_, we manually paste in the target URL. 23 | - Now that we have a new document, we must set the version identifiers of our `feature_template.json` accordingly. Run the following whenever working with a new document (set the variable `url` accordingly): 24 | 25 | >>> url=https://cad.onshape.com/documents/6f6d14f8facf0bba02184e88/w/66a5db71489c81f4893101ed/e/120c56983451157d26a7102d 26 | >>> python -m sketchgraphs.onshape.call --action update_template --url $url --enable_logging 27 | 28 | You should see a "request succeeded" included in the log if the configuration is correct. 29 | 30 | Our Onshape setup is complete! Please reach out with any questions. -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | SketchGraphs models 2 | =================== 3 | 4 | This page describes the models implemented in SketchGraphs, as well as details their usage. 5 | The models are based on a Graph Neural Network architecture, modelling the sketch as a graph 6 | with vertices given by entities and edges given by their constraints. 7 | 8 | Quickstart 9 | ---------- 10 | 11 | For an initial quick-start, we recommend users to start with the provided sequence files and 12 | associated quantization maps, and following the default hyper-parameter values for training. 13 | Additionally, we strongly recommend using a powerful GPU, as training is compute intensive. 14 | 15 | For example, assuming that you have downloaded the training dataset, as well as the accompanying 16 | quantization statistics (available `here `_), 17 | the generative model may be trained by running: 18 | 19 | .. code-block:: bash 20 | 21 | python -m sketchgraphs_models.graph.train --dataset_train sg_t16_train.npy 22 | 23 | You may monitor the training progress on the standard output, or through `tensorboard `_. 24 | 25 | Similarly, the autoconstrain model may be trained by running: 26 | 27 | .. code-block:: bash 28 | 29 | python -m sketchgraphs_models.autoconstraint.train --dataset_train sg_t16_train.npy 30 | 31 | We will also provide pre-trained models (coming soon!). 32 | 33 | 34 | Torch scatter 35 | ------------- 36 | 37 | In order to enjoy the best training performance, we strongly recommend you install the `torch-scatter `_ 38 | package with the correct CUDA version for training on GPU. If you do not have access, 39 | the models will automatically fall back to a plain python / pytorch implementation (however, there will be a 40 | performance penalty due to a substantial amount of looping). 41 | -------------------------------------------------------------------------------- /tests/test_graph_dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import pickle 4 | import pytest 5 | 6 | import numpy as np 7 | 8 | from sketchgraphs_models.graph import dataset as graph_dataset 9 | from sketchgraphs.data import sequence as data_sequence 10 | 11 | @pytest.fixture 12 | def edge_feature_mapping(): 13 | """Dummy quantization functions.""" 14 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f: 15 | mapping = pickle.load(f) 16 | 17 | return graph_dataset.EdgeFeatureMapping( 18 | mapping['edge']['angle'], mapping['edge']['length']) 19 | 20 | @pytest.fixture 21 | def node_feature_mapping(): 22 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f: 23 | mapping = pickle.load(f) 24 | return graph_dataset.EntityFeatureMapping(mapping['node']) 25 | 26 | 27 | def test_dataset(sketches, node_feature_mapping, edge_feature_mapping): 28 | sequences = list(map(data_sequence.sketch_to_sequence, sketches)) 29 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=12) 30 | 31 | graph, target = dataset[0] 32 | assert graph_dataset.TargetType.EdgeDistance in graph.sparse_edge_features 33 | assert graph.node_features is not None 34 | assert graph.sparse_node_features is not None 35 | 36 | 37 | def test_collate_empty(node_feature_mapping, edge_feature_mapping): 38 | result = graph_dataset.collate([], node_feature_mapping, edge_feature_mapping) 39 | assert 'edge_numerical' in result 40 | 41 | 42 | def test_collate_some(sketches, node_feature_mapping, edge_feature_mapping): 43 | sequences = list(map(data_sequence.sketch_to_sequence, sketches)) 44 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=42) 45 | 46 | batch = [dataset[i] for i in range(5)] 47 | 48 | batch_info = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping) 49 | assert 'edge_numerical' in batch_info 50 | assert batch_info['graph'].node_features is not None 51 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | utils 3 | ===== 4 | 5 | Handy functions for API key sample app 6 | ''' 7 | 8 | import logging 9 | from logging.config import dictConfig 10 | 11 | __all__ = [ 12 | 'log' 13 | ] 14 | 15 | 16 | def log(msg, level=0): 17 | ''' 18 | Logs a message to the console, with optional level paramater 19 | 20 | Args: 21 | - msg (str): message to send to console 22 | - level (int): log level; 0 for info, 1 for error (default = 0) 23 | ''' 24 | 25 | red = '\033[91m' 26 | endc = '\033[0m' 27 | 28 | # configure the logging module 29 | cfg = { 30 | 'version': 1, 31 | 'disable_existing_loggers': False, 32 | 'formatters': { 33 | 'stdout': { 34 | 'format': '[%(levelname)s]: %(asctime)s - %(message)s', 35 | 'datefmt': '%x %X' 36 | }, 37 | 'stderr': { 38 | 'format': red + '[%(levelname)s]: %(asctime)s - %(message)s' + endc, 39 | 'datefmt': '%x %X' 40 | } 41 | }, 42 | 'handlers': { 43 | 'stdout': { 44 | 'class': 'logging.StreamHandler', 45 | 'level': 'DEBUG', 46 | 'formatter': 'stdout' 47 | }, 48 | 'stderr': { 49 | 'class': 'logging.StreamHandler', 50 | 'level': 'ERROR', 51 | 'formatter': 'stderr' 52 | } 53 | }, 54 | 'loggers': { 55 | 'info': { 56 | 'handlers': ['stdout'], 57 | 'level': 'INFO', 58 | 'propagate': True 59 | }, 60 | 'error': { 61 | 'handlers': ['stderr'], 62 | 'level': 'ERROR', 63 | 'propagate': False 64 | } 65 | } 66 | } 67 | 68 | dictConfig(cfg) 69 | 70 | lg = 'info' if level == 0 else 'error' 71 | lvl = 20 if level == 0 else 40 72 | 73 | logger = logging.getLogger(lg) 74 | logger.log(lvl, msg) 75 | -------------------------------------------------------------------------------- /sketchgraphs_models/torch_extensions/_repeat_interleave.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _ensure_repeats(repeats_or_scopes: torch.Tensor): 4 | if repeats_or_scopes.dim() == 2: 5 | return repeats_or_scopes.select(1, 1) 6 | else: 7 | return repeats_or_scopes 8 | 9 | 10 | def repeat_interleave_out(values, repeats_or_scope, dim, out): 11 | index = torch.repeat_interleave(_ensure_repeats(repeats_or_scope)) 12 | 13 | if dim is None: 14 | values = values.flatten() 15 | dim = 0 16 | 17 | return torch.index_select(values, dim, index, out=out) 18 | 19 | 20 | def repeat_interleave(values, repeats_or_scope, dim=None, out=None, out_length=None): 21 | """ Extended `repeat_interleave` with the ability to pass in pre-computed information 22 | to reduce CPU-GPU communication when the repeats tensor resides on GPU. 23 | 24 | In particular, this function can make use of a scope parameter instead of simply repeats, 25 | which is a two-dimensional tensor with two columns, one corresponding to the offsets 26 | of each segment, and the second one corresponding to the length (i.e. the number of repeats). 27 | 28 | Parameters 29 | ---------- 30 | values : torch.Tensor 31 | An array of values to repeat. 32 | repeats_or_scope : torch.Tensor 33 | Either a one-dimensional array of repeats, or a two-dimensional array 34 | representing the offset and length (which is referred to as scope). 35 | dim : int, optional 36 | The dimension of values along which to repeat. 37 | out : torch.Tensor, optional 38 | if not None, a tensor into which to produce the output. Note that this is not differentiable. 39 | out_length : int, optional 40 | if not None, the length of the output along the repeated dimension. 41 | Returns 42 | ------- 43 | A new tensor, containing values repeated the appropriate amount of times along the 44 | given dimension. 45 | """ 46 | if out is not None: 47 | return repeat_interleave_out(values, repeats_or_scope, dim, out) 48 | return torch.repeat_interleave(values, repeats_or_scope, dim) 49 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/entropy_rate.py: -------------------------------------------------------------------------------- 1 | """Module to compute the entropy rate of a given representation. """ 2 | 3 | import argparse 4 | import lzma 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from sketchgraphs_models.graph import dataset, sample 11 | from sketchgraphs.data import flat_array 12 | 13 | 14 | 15 | def sequence_to_integers(seq, node_feature_mapping, edge_feature_mapping): 16 | graph = dataset.graph_info_from_sequence(seq, node_feature_mapping, edge_feature_mapping) 17 | all_tensors = ( 18 | [torch.full([1], fill_value=graph.edge_counts[0], dtype=torch.int64)] + 19 | [graph.node_features, 20 | graph.edge_features, 21 | graph.incidence.flatten()] + 22 | [v.value.flatten() for v in graph.sparse_node_features.values()] + 23 | [v.value.flatten() for v in graph.sparse_edge_features.values()]) 24 | 25 | integers = torch.cat(all_tensors) 26 | 27 | return integers.numpy().astype(np.int32) 28 | 29 | 30 | def compress_sequences(seqs, node_feature_mapping, edge_feature_mapping): 31 | compressor = lzma.LZMACompressor(preset=9 | lzma.PRESET_EXTREME) 32 | 33 | total_bytes = 0 34 | 35 | for seq in seqs: 36 | integers = sequence_to_integers(seq, node_feature_mapping, edge_feature_mapping) 37 | total_bytes += len(compressor.compress(integers.tobytes())) 38 | 39 | total_bytes += len(compressor.flush()) 40 | 41 | return total_bytes 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--dataset', required=True) 47 | parser.add_argument('--model_state', help='Path to saved model state_dict.') 48 | parser.add_argument('--limit', type=int, default=None) 49 | 50 | args = parser.parse_args() 51 | 52 | print('Loading trained model') 53 | _, node_feature_mapping, edge_feature_mapping = sample.load_saved_model(args.model_state) 54 | 55 | print('Loading testing data') 56 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences'] 57 | 58 | if args.limit is not None: 59 | seqs = seqs[:args.limit] 60 | 61 | total_bytes = compress_sequences(tqdm.tqdm(seqs, total=len(seqs)), node_feature_mapping, edge_feature_mapping) 62 | print('Total bytes: {0}'.format(total_bytes)) 63 | print('Average Entropy: {0:.3f} bits / graph'.format(total_bytes * 8 / len(seqs))) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/dataset/benchmark.py: -------------------------------------------------------------------------------- 1 | """Utility to benchmark data loading speed.""" 2 | 3 | import argparse 4 | import time 5 | 6 | 7 | from sketchgraphs_models.graph.train.data_loading import initialize_datasets 8 | 9 | 10 | def time_iterator(iterator, batch_size, total_time_seconds=None): 11 | start_time = time.perf_counter() 12 | last_time = start_time 13 | num_batch_processed = 0 14 | 15 | for _ in iterator: 16 | num_batch_processed += 1 17 | current_time = time.perf_counter() 18 | elapsed = current_time - last_time 19 | if elapsed > 5: 20 | print('Processed {0} elements in {1:.2f} seconds. ({2:.2f} / second )'.format( 21 | num_batch_processed * batch_size, elapsed, num_batch_processed * batch_size / elapsed )) 22 | last_time = time.perf_counter() 23 | num_batch_processed = 0 24 | 25 | if total_time_seconds is not None and current_time - start_time > total_time_seconds: 26 | break 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument('--seed', type=int, default=42) 33 | parser.add_argument('--dataset_train', required=True, 34 | help='Pickle dataset for train data.') 35 | parser.add_argument('--dataset_auxiliary', default=None, help='path to auxiliary dataset containing metadata') 36 | parser.add_argument('--num_quantize_length', type=int, default=383, help='number of quantization values for length') 37 | parser.add_argument('--num_quantize_angle', type=int, default=127, help='number of quantization values for angle') 38 | parser.add_argument('--batch_size', type=int, default=32, help='Training batch size.') 39 | parser.add_argument('--num_epochs', type=int, default=100, 40 | help='Number of training epochs.') 41 | parser.add_argument('--num_workers', type=int, default=0, 42 | help='Number of dataloader workers.') 43 | parser.add_argument('--disable_edge_features', action='store_true', 44 | help='Disable using and predicting edge features') 45 | 46 | parser.add_argument('--max_time', type=float, default=300) 47 | 48 | args = vars(parser.parse_args()) 49 | args['dataset_test'] = None 50 | args['disable_entity_features'] = True 51 | 52 | print('Loading dataset') 53 | dataloader, _, batches_per_epoch, _, _ = initialize_datasets(args) 54 | 55 | print('Benchmarking dataloader') 56 | time_iterator(dataloader, args['batch_size'], args['max_time']) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | 62 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'SketchGraphs' 21 | copyright = '2020, Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams' 22 | author = 'Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.autosummary', 33 | 'sphinx.ext.napoleon', 34 | 'sphinx.ext.intersphinx' 35 | ] 36 | 37 | autosummary_generate = True 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 46 | 47 | default_role = 'any' 48 | 49 | # -- Options for intersphinx extension --------------------------------------- 50 | 51 | intersphinx_mapping = { 52 | 'python': ('https://docs.python.org/3', None), 53 | 'torch': ('https://pytorch.org/docs/stable/', None), 54 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 55 | } 56 | 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | html_theme = 'sphinx_rtd_theme' 64 | 65 | # Add any paths that contain custom static files (such as style sheets) here, 66 | # relative to this directory. They are copied after the builtin static files, 67 | # so a file named "default.css" will overwrite the builtin "default.css". 68 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | #pylint: disable=missing-module-docstring,missing-function-docstring 2 | 3 | import numpy as np 4 | 5 | from sketchgraphs.data.sequence import sketch_to_sequence, NodeOp, EdgeOp, sketch_from_sequence 6 | from sketchgraphs.data.sketch import Sketch, render_sketch, EntityType, ConstraintType, SubnodeType 7 | from sketchgraphs.data.dof import get_sequence_dof 8 | 9 | 10 | def test_sketch_from_json(sketches_json): 11 | for sketch_json in sketches_json: 12 | Sketch.from_fs_json(sketch_json) 13 | 14 | 15 | def test_sequence_from_sketch(sketches_json): 16 | 17 | for sketch_json in sketches_json: 18 | sketch = Sketch.from_fs_json(sketch_json) 19 | seq = sketch_to_sequence(sketch) 20 | 21 | def test_plot_sketch(sketches_json): 22 | sketch_json_list = sketches_json[:10] 23 | 24 | for sketch_json in sketch_json_list: 25 | fig = render_sketch(Sketch.from_fs_json(sketch_json)) 26 | assert fig is not None 27 | 28 | 29 | def test_get_sequence_dof(): 30 | seq = [ 31 | NodeOp(label=EntityType.External), 32 | NodeOp(label=EntityType.Line), 33 | NodeOp(label=SubnodeType.SN_Start), 34 | EdgeOp(label=ConstraintType.Subnode, references=(2, 1)), 35 | NodeOp(label=SubnodeType.SN_End), 36 | EdgeOp(label=ConstraintType.Subnode, references=(3, 1)), 37 | NodeOp(label=EntityType.Line), 38 | EdgeOp(label=ConstraintType.Parallel, references=(4, 1)), 39 | EdgeOp(label=ConstraintType.Horizontal, references=(4,)), 40 | EdgeOp(label=ConstraintType.Distance, references=(4, 1)), 41 | NodeOp(label=SubnodeType.SN_Start), 42 | EdgeOp(label=ConstraintType.Subnode, references=(5, 4)), 43 | NodeOp(label=SubnodeType.SN_End), 44 | EdgeOp(label=ConstraintType.Subnode, references=(6, 4)), 45 | NodeOp(label=EntityType.Stop)] 46 | 47 | dof_remaining = np.sum(get_sequence_dof(seq)) 48 | assert dof_remaining == 5 49 | 50 | 51 | _UNSUPPORTED_CONSTRAINTS = ( 52 | ConstraintType.Circular_Pattern, 53 | ConstraintType.Linear_Pattern, 54 | ConstraintType.Midpoint, 55 | ConstraintType.Mirror, 56 | ) 57 | 58 | def test_sketch_from_sequence(sketches_json): 59 | for sketch_json in sketches_json: 60 | sketch = Sketch.from_fs_json(sketch_json, include_external_constraints=False) 61 | seq = sketch_to_sequence(sketch) 62 | 63 | if any(s.label in _UNSUPPORTED_CONSTRAINTS for s in seq): 64 | # Skip not supported constraints for now 65 | continue 66 | 67 | sketch2 = sketch_from_sequence(seq) 68 | seq2 = sketch_to_sequence(sketch2) 69 | 70 | assert len(seq) == len(seq2) 71 | for op1, op2 in zip(seq, seq2): 72 | assert op1 == op2 73 | -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/scripts/eval_statistics_stratified.py: -------------------------------------------------------------------------------- 1 | """Computes edge evaluation statistics stratified by size. """ 2 | 3 | import argparse 4 | import gzip 5 | import pickle 6 | import os 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def compute_stratified_statistics(statistics, sizes): 13 | unique_sizes, unique_counts = np.unique(sizes, return_counts=True) 14 | 15 | precision = statistics['precision'] 16 | recall = statistics['recall'] 17 | 18 | average_precision = np.empty(len(sizes)) 19 | average_recall = np.empty(len(sizes)) 20 | average_f1 = np.empty(len(sizes)) 21 | 22 | sd_precision = np.empty(len(sizes)) 23 | sd_recall = np.empty(len(sizes)) 24 | sd_f1 = np.empty(len(sizes)) 25 | 26 | for i, size in enumerate(unique_sizes): 27 | mask = sizes == size 28 | pm = precision[mask] 29 | rm = recall[mask] 30 | 31 | average_precision[i] = pm.mean() 32 | sd_precision[i] = pm.std() / np.sqrt(len(pm)) 33 | 34 | average_recall[i] = rm.mean() 35 | sd_recall[i] = rm.std() / np.sqrt(len(rm)) 36 | 37 | f1 = np.divide(2 * pm * rm, pm + rm, out=np.zeros_like(rm), where=pm + rm != 0) 38 | average_f1[i] = f1.mean() 39 | sd_f1[i] = f1.std() / np.sqrt(len(f1)) 40 | 41 | 42 | return { 43 | 'precision': average_precision, 44 | 'precision_sd': sd_precision, 45 | 'recall': average_recall, 46 | 'recall_sd': sd_recall, 47 | 'f1': average_f1, 48 | 'f1_sd': sd_f1, 49 | 'sizes': unique_sizes, 50 | 'counts': unique_counts, 51 | } 52 | 53 | 54 | def boxplot_stratified(statistic, sizes): 55 | unique_sizes = np.unique(sizes) 56 | unique_sizes = unique_sizes[(unique_sizes > 3) & (unique_sizes <= 50)] 57 | 58 | return plt.boxplot( 59 | [statistic[s == sizes] for s in unique_sizes], labels=unique_sizes, 60 | showfliers=False, showcaps=False, whis=0.0, medianprops={'linewidth': 2.5}) 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser() 65 | 66 | parser.add_argument('--input', required=True) 67 | parser.add_argument('--output') 68 | 69 | args = parser.parse_args() 70 | 71 | print('Reading input files') 72 | 73 | input_result_path = args.input 74 | input_basename, input_ext = os.path.splitext(input_result_path) 75 | if input_ext == '.gz': 76 | input_basename, _ = os.path.splitext(input_basename) 77 | 78 | input_stat_path = input_basename + '_stat.npz' 79 | 80 | with gzip.open(input_result_path, 'rb') as f: 81 | ops = pickle.load(f) 82 | 83 | input_stat = np.load(input_stat_path) 84 | 85 | print('Computing statistics') 86 | stratified_statistics = compute_stratified_statistics( 87 | input_stat, [len(x['node_ops']) for x in ops]) 88 | 89 | if args.output is not None: 90 | print('Saving result') 91 | np.savez_compressed(args.output, **stratified_statistics) 92 | 93 | print(stratified_statistics) 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /sketchgraphs_models/nn/functional.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for computing specific nn functions. """ 2 | 3 | import torch 4 | from sketchgraphs_models.torch_extensions import segment_logsumexp, segment_avg_pool1d 5 | from sketchgraphs_models.torch_extensions.segment_ops import segment_argmax 6 | 7 | 8 | def segmented_cross_entropy(logits: torch.Tensor, target: torch.Tensor, scopes: torch.Tensor) -> torch.Tensor: 9 | """Segmented cross-entropy loss. 10 | 11 | Computes the cross-entropy loss from unscaled `logits` for a segmented problem. 12 | 13 | Parameters 14 | ---------- 15 | logits : torch.Tensor 16 | unscaled logits by segment 17 | target : torch.Tensor 18 | tensor of length `n_segments`, representing the index of the true label 19 | for each segment. 20 | scopes : tonch.Tensor 21 | tensor of shape `[n_segments, 2]`, representing the segments as `(start, length)`. 22 | 23 | Returns 24 | ------- 25 | torch.Tensor 26 | A tensor of length `n_segments` representing the cross-entropy loss 27 | at each segment. 28 | """ 29 | input_logsumexp = segment_logsumexp(logits, scopes) 30 | return logits.index_select(0, target) - input_logsumexp 31 | 32 | 33 | def segmented_multinomial(logits, scopes, generator=None): 34 | """Segmented multinomial sample. 35 | 36 | Parameters 37 | ---------- 38 | logits : torch.Tensor 39 | unscaled logits by segment 40 | scopes : torch.Tensor 41 | tensor of shape `[n_segments, 2]` representing the segments as `(start, length)`. 42 | generator : torch.Generator, optional 43 | PRNG for sampling 44 | 45 | Returns 46 | ------- 47 | torch.Tensor 48 | A tensor of length `n_segments` representing the sampled values. 49 | """ 50 | output = logits.new_empty([scopes.shape[0]], dtype=torch.int64) 51 | 52 | logits = logits.detach() 53 | 54 | for i in range(scopes.shape[0]): 55 | segment_logits = logits.narrow(0, scopes[i, 0], scopes[i, 1]) 56 | torch.multinomial( 57 | torch.nn.functional.softmax(segment_logits, dim=0), num_samples=1, 58 | generator=generator, 59 | out=output[i:i]) 60 | 61 | return output 62 | 63 | 64 | def segmented_multinomial_extended(logits, scopes, generator=None, also_return_probs=False): 65 | """Segmented multinomial sample with implicit element. 66 | 67 | Parameters 68 | ---------- 69 | logits : torch.Tensor 70 | logits for explicit outcomes by segment 71 | scopes : torch.Tensor 72 | tensor of shape `[n_segments, 2]` representing the segments as `(start, length)`. 73 | generator : torch.Generator, optional 74 | PRNG for sampling 75 | also_return_probs: bool, optional 76 | If true, returns tuple including log-likelihood of the sample. 77 | 78 | Returns 79 | ------- 80 | torch.Tensor 81 | A tensor of length `n_segments` representing the sampled values. 82 | """ 83 | 84 | output = logits.new_empty([scopes.shape[0]], dtype=torch.int64) 85 | 86 | logits = logits.detach() 87 | 88 | for i in range(scopes.shape[0]): 89 | segment_logits = logits.new_zeros([scopes[i, 1] + 1]) 90 | segment_logits[:-1] = logits.narrow(0, scopes[i, 0], scopes[i, 1]) 91 | 92 | dist = torch.nn.functional.softmax(segment_logits, dim=0) 93 | torch.multinomial(dist, num_samples=1, generator=generator, 94 | out=output[i:i]) 95 | if also_return_probs: 96 | return output, dist[output] 97 | return output 98 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/graph_model/target.py: -------------------------------------------------------------------------------- 1 | """This module implements the basic definitions for targets.""" 2 | 3 | import enum 4 | 5 | from sketchgraphs.data import sketch as datalib, sequence as data_sequence 6 | 7 | NODE_TYPES_PREDICTED = list(datalib.EntityType) 8 | EDGE_TYPES_PREDICTED = list(x for x in datalib.ConstraintType if x != datalib.ConstraintType.Subnode) 9 | NODE_TYPES = list(datalib.EntityType) + list(datalib.SubnodeType) 10 | EDGE_TYPES = list(datalib.ConstraintType) 11 | 12 | 13 | EDGE_IDX_MAP = {t: i for i, t in enumerate(datalib.ConstraintType)} 14 | NODE_IDX_MAP = {t: i for i, t in enumerate(list(datalib.EntityType) + list(datalib.SubnodeType))} 15 | NODE_IDX_MAP_REVERSE = {i: t for i, t in enumerate(list(datalib.EntityType) + list(datalib.SubnodeType))} 16 | EDGE_IDX_MAP_REVERSE = {i: t for i, t in enumerate(datalib.ConstraintType)} 17 | 18 | 19 | class TargetType(enum.IntEnum): 20 | EdgeCategorical = 0 21 | EdgeAngle = 1 22 | EdgeLength = 2 23 | EdgeDistance = 3 24 | EdgeDiameter = 4 25 | EdgeRadius = 5 26 | NodeGeneric = 6 27 | NodeArc = 7 28 | NodeCircle = 8 29 | NodeLine = 9 30 | NodePoint = 10 31 | Subnode = 11 32 | 33 | @staticmethod 34 | def from_op(op): 35 | return TargetType.from_label(op.label) 36 | 37 | @staticmethod 38 | def from_label(label): 39 | if isinstance(label, datalib.SubnodeType): 40 | return TargetType.Subnode 41 | elif isinstance(label, datalib.EntityType): 42 | return _entity_to_target_type.get(label, TargetType.NodeGeneric) 43 | elif isinstance(label, datalib.ConstraintType): 44 | return _edge_to_target_type.get(label, TargetType.EdgeCategorical) 45 | else: 46 | raise ValueError('label must be one of SubnodeType, EntityType or ConstraintType.') 47 | 48 | @staticmethod 49 | def from_edge_label_int(x): 50 | fake_edge_op = datalib.EdgeOp(x, 0, 0) 51 | return TargetType.from_op(fake_edge_op) 52 | 53 | @staticmethod 54 | def from_node_label_int(x): 55 | fake_node_op = datalib.NodeOp(x) 56 | return TargetType.from_op(fake_node_op) 57 | 58 | @staticmethod 59 | def edge_types(): 60 | return (TargetType.EdgeCategorical,) + TargetType.numerical_edge_types() 61 | 62 | @staticmethod 63 | def numerical_edge_types(): 64 | return ( 65 | TargetType.EdgeAngle, 66 | TargetType.EdgeLength, 67 | TargetType.EdgeDistance, 68 | TargetType.EdgeDiameter, 69 | TargetType.EdgeRadius 70 | ) 71 | 72 | @staticmethod 73 | def numerical_node_types(): 74 | return ( 75 | TargetType.NodeArc, 76 | TargetType.NodeCircle, 77 | TargetType.NodeLine, 78 | TargetType.NodePoint 79 | ) 80 | 81 | @staticmethod 82 | def node_types(): 83 | return (TargetType.NodeGeneric,) + TargetType.numerical_node_types() 84 | 85 | _str_to_numerical_edge = { 86 | t.name.upper(): t 87 | for t in TargetType.numerical_edge_types() 88 | } 89 | 90 | _str_to_numerical_entity = { 91 | t.name.title(): t 92 | for t in TargetType.numerical_node_types() 93 | } 94 | 95 | _edge_to_target_type = { 96 | s: TargetType['Edge' + s.name] for s in 97 | (datalib.ConstraintType.Angle, datalib.ConstraintType.Length, datalib.ConstraintType.Distance, 98 | datalib.ConstraintType.Diameter, datalib.ConstraintType.Radius) 99 | } 100 | 101 | _entity_to_target_type = { 102 | s: TargetType['Node' + s.name] for s in 103 | (datalib.EntityType.Arc, datalib.EntityType.Circle, datalib.EntityType.Line, datalib.EntityType.Point) 104 | } 105 | -------------------------------------------------------------------------------- /sketchgraphs/data/sketch.py: -------------------------------------------------------------------------------- 1 | """This module implements parsing and representation for sketches. 2 | """ 3 | 4 | from collections import OrderedDict 5 | from typing import Dict 6 | 7 | # pylint: disable=invalid-name, too-many-arguments, too-many-return-statements, too-many-instance-attributes, wildcard-import, unused-wildcard-import 8 | 9 | 10 | from . import _entity 11 | from . import _constraint 12 | from . import _plotting 13 | 14 | from ._entity import EntityType, SubnodeType, Entity, GenericEntity, Point, Line, Circle, Arc, Spline, Ellipse, ENTITY_TYPE_TO_CLASS 15 | 16 | from ._constraint import * 17 | from ._plotting import render_sketch, render_graph 18 | 19 | 20 | class Sketch: 21 | """This class encapsulates a sketch instance. 22 | 23 | A sketch is defined by a list of entities, and a list of constraints between these entities. 24 | The Sketch class is designed to represent the sketches as obtained from Onshape in a structured 25 | and faithful manner. In particular, it can round-trip the relevant parts of the JSON representation 26 | of Onshape's feature-script. 27 | """ 28 | entities: Dict[str, Entity] 29 | constraints: Dict[str, Constraint] 30 | 31 | def __init__(self, entities=None, constraints=None): 32 | if entities is None: 33 | entities = OrderedDict() 34 | if constraints is None: 35 | constraints = OrderedDict() 36 | 37 | self.entities = entities 38 | self.constraints = constraints 39 | 40 | def to_dict(self) -> dict: 41 | """Create a dictionary representing this sketch. 42 | 43 | The created dictionary should be compatible with the json represention of the sketch. 44 | """ 45 | return { 46 | 'entities': [e.to_dict() for e in self.entities.values()], 47 | 'constraints': [c.to_dict() for c in self.constraints.values()] 48 | } 49 | 50 | @staticmethod 51 | def from_fs_json(sketch_dict, include_external_constraints=True): 52 | """Parse primitives and constraints. 53 | 54 | Parameters 55 | ---------- 56 | include_external_constraints : bool, optional 57 | If True, indicates that constraints referencing the first external node (representing) 58 | the origin should be included, otherwise, exclude those from the graph. 59 | """ 60 | entities = (Entity.from_dict(ed) for ed in sketch_dict['entities']) 61 | entities_dict = OrderedDict((e.entityId, e) for e in entities) 62 | 63 | constraints = (Constraint.from_dict(cd) for cd in sketch_dict['constraints']) 64 | 65 | if not include_external_constraints: 66 | constraints = ( 67 | c for c in constraints 68 | if not any(isinstance(p, ExternalReferenceParameter) for p in c.parameters)) 69 | 70 | constraints_dict = OrderedDict((c.identifier, c) for c in constraints) 71 | return Sketch(entities_dict, constraints_dict) 72 | 73 | @staticmethod 74 | def from_info(sketch_info): 75 | """Parse entities given result of `sketch information` call.""" 76 | subnode_suffixes = ('.start', '.end', '.center') # TODO: dry-ify this 77 | entities = [Entity.from_info(ed) for ed in sketch_info if not ed['id'].endswith(subnode_suffixes)] 78 | entities_dict = OrderedDict((e.entityId, e) for e in entities) 79 | return Sketch(entities=entities_dict) 80 | 81 | def __repr__(self): 82 | return 'Sketch(n_entities={0}, n_constraints={1})'.format(len(self.entities), len(self.constraints)) 83 | 84 | 85 | __all__ = [ 86 | 'Sketch', 'EntityType', 'SubnodeType', 'Entity', 'GenericEntity', 'Point', 'Line', 87 | 'Circle', 'Arc', 'Spline', 'Ellipse', 'ENTITY_TYPE_TO_CLASS'] + _constraint.__all__ + _plotting.__all__ 88 | -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/scripts/eval_statistics_mask.py: -------------------------------------------------------------------------------- 1 | """Evaluate statistics from the masks used applied to samples in the data.""" 2 | 3 | import argparse 4 | import functools 5 | 6 | import numpy as np 7 | import torch 8 | from torch import multiprocessing 9 | import tqdm 10 | 11 | from sketchgraphs.data import flat_array 12 | from sketchgraphs.data.sequence import NodeOp 13 | from sketchgraphs_models.autoconstraint.eval import MASK_FUNCTIONS 14 | 15 | 16 | def count_valid_choices_by_step(node_ops, mask_function): 17 | if node_ops[-1].label == 'Stop': 18 | node_ops = node_ops[:-1] 19 | 20 | masks = mask_function(node_ops) 21 | 22 | valid_choices = np.ones(len(node_ops) - 1, dtype=np.int) 23 | 24 | for i in range(1, len(node_ops)): 25 | mask_offset = i * (i + 1) // 2 26 | valid_choices[i - 1] = (torch.narrow(masks, 0, mask_offset, i + 1) == 0).int().sum().item() + 1 27 | 28 | return valid_choices 29 | 30 | 31 | def total_valid_choices(seq, mask_function): 32 | node_ops = [op for op in seq if isinstance(op, NodeOp)] 33 | 34 | valid_choices = count_valid_choices_by_step(node_ops, mask_function) 35 | 36 | if len(valid_choices) > 0: 37 | average_choices = valid_choices.mean() 38 | average_entropy = np.log2(valid_choices).mean() 39 | else: 40 | # These values will be ignored 41 | average_choices = 0 42 | average_entropy = 0 43 | 44 | return average_choices, average_entropy, len(valid_choices) 45 | 46 | 47 | def uniform_valid_perplexity(seqs, mask_function, num_workers=0): 48 | if num_workers is None or num_workers > 0: 49 | pool = multiprocessing.Pool(num_workers) 50 | map_fn = functools.partial(pool.imap, chunksize=8) 51 | else: 52 | map_fn = map 53 | 54 | valid_choices_fn = functools.partial(total_valid_choices, mask_function=mask_function) 55 | 56 | average_choices = np.empty(len(seqs)) 57 | average_entropy = np.empty(len(seqs)) 58 | seq_length = np.empty(len(seqs), dtype=np.int) 59 | 60 | for i, (choices, entropy, length) in enumerate(tqdm.tqdm(map_fn(valid_choices_fn, seqs), total=len(seqs))): 61 | average_choices[i] = choices 62 | average_entropy[i] = entropy 63 | seq_length[i] = length 64 | 65 | return { 66 | 'choices': average_choices, 67 | 'entropy': average_entropy, 68 | 'sequence_length': seq_length 69 | } 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--input', required=True) 75 | parser.add_argument('--mask', choices=list(MASK_FUNCTIONS.keys()), default='node_type') 76 | parser.add_argument('--output') 77 | parser.add_argument('--limit', type=int, default=None) 78 | parser.add_argument('--num_workers', type=int, default=16) 79 | 80 | args = parser.parse_args() 81 | 82 | print('Reading data') 83 | seqs = flat_array.load_dictionary_flat(np.load(args.input, mmap_mode='r'))['sequences'] 84 | seqs.share_memory_() 85 | 86 | if args.limit is not None: 87 | seqs = seqs[:args.limit] 88 | 89 | print('Computing statistics') 90 | result = uniform_valid_perplexity(seqs, MASK_FUNCTIONS[args.mask], args.num_workers) 91 | 92 | if args.output is not None: 93 | print('Saving results') 94 | np.savez_compressed(args.output, **result) 95 | 96 | choices = np.average(result['choices'], weights=result['sequence_length']) 97 | entropy = np.average(result['entropy'], weights=result['sequence_length']) 98 | print('Average choices: {:.3f}'.format(choices)) 99 | print('Average entropy: {:.3f}'.format(entropy)) 100 | 101 | if __name__ == '__main__': 102 | multiprocessing.set_start_method('spawn') 103 | main() 104 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. SketchGraphs documentation master file, created by 2 | sphinx-quickstart on Sun Jul 12 05:33:22 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to SketchGraphs' documentation! 7 | ======================================== 8 | 9 | `SketchGraphs `_ is a dataset of 15 million sketches extracted from real world CAD models intended to facilitate 10 | research in ML-aided design and geometric program induction. 11 | In addition to the raw data, we provide several processed datasets, an accompanying Python package to work with 12 | the data, as well as a couple of starter models which implement GNN-type strategies on a couple of example problems. 13 | 14 | 15 | Data 16 | ---- 17 | 18 | We provide our dataset in a number of forms, some of which may be more appropriate for your desired usage. 19 | The following data files are provided: 20 | 21 | - The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed 22 | using zstandard_. They total about 43GB of data. In general, we only recommend these for advanced usage, as 23 | they are heavy and require extensive processing to manipulate. Users interested in working with the raw data 24 | may wish to inspect `sketchgraphs.pipeline.make_sketch_dataset` and `sketchgraphs.pipeline.make_sequence_dataset` 25 | to view our data processing scripts. The data is available for download `here `_. 26 | 27 | - A dataset of construction sequences. This is provided as a single file, stored in a custom binary format. 28 | This format is much more concise (as it eliminates many of the unique identifiers used in the raw JSON format), 29 | and is better suited for ML applications. It is supported by our python libraries, and forms the baseline 30 | on which our models are trained. The data is available for download `here `_ 31 | (warning: 15GB file!). 32 | 33 | - A filtered dataset of construction sequences. This is provided as a single file, similarly stored in a custom 34 | binary format. This dataset is similar to the sequence dataset, but simplified by filtering out sketches 35 | that are too large or too small, and only includes a simplified set of entities and constraints (while still 36 | capturing a large portion of the data). Additionally, this dataset has been split into training, testing and 37 | validation splits for convenience. We train our models on this subset of the data. You can find download the splits 38 | here: `train `_, 39 | `test `_, 40 | `validation `_. 41 | 42 | 43 | More details concerning the data can be found in the :doc:`data` page. 44 | 45 | 46 | .. _zstandard: https://facebook.github.io/zstd/ 47 | 48 | 49 | Models 50 | ------ 51 | 52 | In addition to the dataset, we also provide some baseline model implementations to tackle the tasks of generative 53 | modelling and autoconstrain. These models are based on Graph Neural Network approaches, and model the sketch as 54 | a graph where vertices are given by the entities in the sketch, and edges by the constraints between those entities. 55 | For more details, please refer to the dedicated :doc:`models` page. 56 | 57 | 58 | .. toctree:: 59 | models 60 | data 61 | onshape_setup 62 | .. autosummary:: 63 | :toctree: _autosummary 64 | :recursive: 65 | 66 | sketchgraphs 67 | sketchgraphs_models 68 | 69 | 70 | 71 | Indices and tables 72 | ================== 73 | 74 | * :ref:`genindex` 75 | * :ref:`modindex` 76 | * :ref:`search` 77 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/eval_likelihood.py: -------------------------------------------------------------------------------- 1 | """Module to evaluate the graph model according to data likelihood.""" 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | 9 | from sketchgraphs_models import training 10 | 11 | from sketchgraphs_models.graph import dataset, sample 12 | from sketchgraphs_models.graph import model as graph_model 13 | 14 | from sketchgraphs.data import flat_array 15 | 16 | 17 | def _total_loss(losses): 18 | result = 0 19 | 20 | for v in losses.values(): 21 | if v is None: 22 | continue 23 | 24 | if isinstance(v, dict): 25 | result += _total_loss(v) 26 | else: 27 | result += v.sum() 28 | 29 | return result 30 | 31 | 32 | def batch_from_example(seq, node_feature_mapping, edge_feature_mapping): 33 | step_indices = [i for i, op in enumerate(seq) if i > 0 and not dataset._is_subnode_edge(op)] 34 | 35 | batch_list = [] 36 | 37 | for step_idx in step_indices: 38 | graph = dataset.graph_info_from_sequence(seq[:step_idx], node_feature_mapping, edge_feature_mapping) 39 | target = seq[step_idx] 40 | batch_list.append((graph, target)) 41 | 42 | return dataset.collate(batch_list, node_feature_mapping, edge_feature_mapping) 43 | 44 | 45 | class GraphLikelihoodEvaluator: 46 | def __init__(self, model, node_feature_mapping, edge_feature_mapping, device=None): 47 | self.model = model 48 | self.node_feature_mapping = node_feature_mapping 49 | self.edge_feature_mapping = edge_feature_mapping 50 | self.device = device 51 | self._feature_dimensions = { 52 | **node_feature_mapping.feature_dimensions, 53 | **edge_feature_mapping.feature_dimensions 54 | } 55 | 56 | def compute_likelihood(self, seqs): 57 | for i, seq in enumerate(seqs): 58 | batch = batch_from_example(seq, self.node_feature_mapping, self.edge_feature_mapping) 59 | batch_device = training.load_cuda_async(batch, device=self.device) 60 | 61 | with torch.no_grad(): 62 | readout = self.model(batch_device) 63 | losses, *_ = graph_model.compute_losses(readout, batch_device, self._feature_dimensions) 64 | loss = _total_loss(losses).cpu().item() 65 | 66 | yield loss, len(seq) 67 | 68 | 69 | def main(): 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--dataset', required=True) 72 | parser.add_argument('--model_state', help='Path to saved model state_dict.') 73 | parser.add_argument('--device', type=str, default='cuda') 74 | parser.add_argument('--limit', type=int, default=None) 75 | 76 | args = parser.parse_args() 77 | 78 | device = torch.device(args.device) 79 | 80 | print('Loading trained model') 81 | model, node_feature_mapping, edge_feature_mapping = sample.load_saved_model(args.model_state) 82 | model = model.eval().to(device) 83 | 84 | print('Loading testing data') 85 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences'] 86 | 87 | if args.limit is not None: 88 | seqs = seqs[:args.limit] 89 | 90 | evaluator = GraphLikelihoodEvaluator(model, node_feature_mapping, edge_feature_mapping, device) 91 | 92 | losses = np.empty(len(seqs)) 93 | length = np.empty(len(seqs), dtype=np.int64) 94 | 95 | for i, result in enumerate(tqdm.tqdm(evaluator.compute_likelihood(seqs), total=len(seqs))): 96 | losses[i], length[i] = result 97 | 98 | print('Average bits per sketch: {:.2f}'.format(losses.mean() / np.log(2))) 99 | print('Average bits per step: {:.2f}'.format(losses.sum() / np.log(2) / length.sum())) 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /sketchgraphs_models/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for distributed (multi-gpu) training. 2 | """ 3 | 4 | import math 5 | import socket 6 | import typing 7 | 8 | import torch 9 | import torch.distributed 10 | 11 | from functools import partial 12 | 13 | 14 | class DistributedTrainingInfo(typing.NamedTuple): 15 | sync_url: str 16 | world_size: int 17 | rank: int 18 | local_rank: int 19 | 20 | def __bool__(self): 21 | return self.world_size > 1 22 | 23 | 24 | 25 | def is_leader(config: DistributedTrainingInfo): 26 | """Tests whether the current process is the leader for distributed training.""" 27 | return config is None or config.rank == 0 28 | 29 | 30 | def train_boostrap_distributed(parameters, train): 31 | world_size = parameters.get('world_size', 1) 32 | 33 | if world_size == 1: 34 | # Single-node training, nothing to do. 35 | parameters['rank'] = 0 36 | return train(parameters) 37 | 38 | parameters['rank'] = -1 39 | from torch.multiprocessing import spawn 40 | 41 | sync_url = parameters.get('sync_url') 42 | if sync_url is None: 43 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 44 | s.bind(('', 0)) 45 | port = s.getsockname()[1] 46 | s.close() 47 | sync_url = f'tcp://127.0.0.1:{port}' 48 | parameters['sync_url'] = sync_url 49 | print(f'Using URL bootstrap at {sync_url}') 50 | 51 | spawn(partial(train_distributed, train=train), nprocs=parameters['world_size'], args=(parameters,)) 52 | 53 | 54 | def initialize_distributed(config: DistributedTrainingInfo): 55 | from torch import distributed as dist 56 | 57 | dist.init_process_group( 58 | 'nccl', init_method=config.sync_url, 59 | world_size=config.world_size, rank=config.rank) 60 | 61 | torch.cuda.set_device(config.local_rank) 62 | 63 | 64 | def get_distributed_config(parameters, local_rank=None): 65 | world_size = parameters.get('world_size', 1) 66 | 67 | if world_size == 1: 68 | return DistributedTrainingInfo('', 1, 0, 0) 69 | 70 | sync_url = parameters['sync_url'] 71 | 72 | rank = local_rank 73 | 74 | return DistributedTrainingInfo(sync_url, world_size, rank, local_rank) 75 | 76 | 77 | def train_distributed(local_rank, parameters, train): 78 | config = get_distributed_config(parameters, local_rank) 79 | initialize_distributed(config) 80 | train(parameters, config) 81 | 82 | 83 | class DistributedSampler(torch.utils.data.Sampler): 84 | """Utility class which adapts a sampler into a distributed sampler which only 85 | samples a subset of the underlying sampler, according to a division by rank. 86 | """ 87 | def __init__(self, sampler, num_replicas=None, rank=None): 88 | if num_replicas is not None: 89 | if not torch.distributed.is_available(): 90 | raise RuntimeError("DistributedSampler requires torch distributed to be available") 91 | num_replicas = torch.distributed.get_world_size() 92 | 93 | if rank is not None: 94 | if not torch.distributed.is_available(): 95 | raise RuntimeError("DistributedSampler requires torch distributed to be available") 96 | 97 | self.sampler = sampler 98 | self.num_replicas = num_replicas 99 | self.rank = rank 100 | self.num_samples = int(math.ceil(len(sampler) / self.num_replicas)) 101 | self.total_size = self.num_replicas * self.num_samples 102 | 103 | def __iter__(self): 104 | indices = list(self.sampler) 105 | indices += indices[:(self.total_size - len(indices))] 106 | assert len(indices) == self.total_size 107 | 108 | indices = indices[self.rank:self.total_size:self.num_replicas] 109 | assert len(indices) == self.num_samples 110 | 111 | return iter(indices) 112 | 113 | def __len__(self): 114 | return self.num_samples 115 | -------------------------------------------------------------------------------- /docs/data.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | We provide three datasets at various levels of pre-processing in order to enable a wide variety of uses 5 | while ensuring that it is possible to quickly get started on this data. We describe each dataset below. 6 | 7 | Tarballs 8 | -------- 9 | 10 | The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed 11 | using `zstandard `_. 12 | They total about 43GB of data. In general, we only recommend these for advanced usage, as 13 | they are heavy and require extensive processing to manipulate. 14 | Users who wish to directly access the json from python may be interested in the utility function 15 | `sketchgraphs.pipeline.make_sketch_dataset.load_json_tarball`, which iterates through a single compressed tarball 16 | and enumerates the sketches in the tarball. 17 | The data is available for download `here `_. 18 | 19 | 20 | Filtered sequence 21 | ----------------- 22 | 23 | The filtered sequences are the most convenient dataset to use, and are provided ready to use with ML models. 24 | They are stored in a custom binary format, which can be read using `sketchgraphs.data.flat_array.load_dictionary_flat`, 25 | by for example executing: 26 | 27 | >>> from sketchgraphs.data import flat_array 28 | >>> data = flat_array.load_dictionary_flat('sg_t16_train.npy') 29 | >>> data['sequences'] 30 | FlatArray(len=9179789, mem=5.3 GB) 31 | 32 | In addition to the sequences, which may be accessed through the ``data['sequences']`` key, the sequence files 33 | also contain an integer array of the same length which record the length of each sequence, accessed through 34 | the ``data['sequence_lengths']`` key, and a structured array which contains an identifier uniquely identifying 35 | the sketch, accessed through the ``data['sketch_ids']`` key. The latter is an array of tuples, the first element 36 | being the document id of the sketch, the second being the part index in that document, and the last being the 37 | sketch index in that part. 38 | 39 | This data is pre-split into a `train `_, 40 | `test `_ and 41 | `validation `_ set (the training set 42 | is formed from the shards 1-120, the validation set from the shards 121-124, and the testing set from the 43 | shards 125-128). Additionally, we also provide pre-computed `quantization statistics `_, 44 | computed on the training set using the `sketchgraphs.pipeline.make_quantization_statistics` script. 45 | These quantization statistics are used by the models in order to handle continuous parameters. 46 | 47 | 48 | Full sequence 49 | ------------- 50 | 51 | As a middle-ground between the filtered sequences and the json tarballs, we also provide the full sequences, 52 | which are directly converted from the json tarballs with no filtering (except the exclusing of empty sketches). 53 | This full sequence file is available `here `_ 54 | (warning: 15GB download), and contains the sequence representation of all the sketches in the dataset. 55 | We note that although it is mostly equivalent to the data contained in the tarballs for ML purposes, it is much smaller 56 | as many of the original identifiers are discarded and renamed to sequential indices. 57 | 58 | The full sequence may be accessed in the same fashion as the filtered sequences (using the `sketchgraphs.data.flat_array` 59 | module). Note that the full sequence file contains all sketches, and does not provide any train / test split. 60 | Additionally, users should be warned that the dataset contains substantial outliers (for example, the largest sketch in the 61 | dataset contains more than 350 thousand operations, whereas the 99th percentile is "only" 640 operations). 62 | 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SketchGraphs: A Large-Scale Dataset for Modeling Relational Geometry in Computer-Aided Design 2 | 3 | SketchGraphs is a dataset of 15 million sketches extracted from real-world CAD models intended to facilitate research in both ML-aided design and geometric program induction. 4 | 5 | ![blah](/assets/sketchgraphs.gif) 6 | 7 | 8 | Each sketch is represented as a geometric constraint graph where edges denote designer-imposed geometric relationships between primitives, the nodes of the graph. 9 | 10 | ![Sketch and graph](/assets/sketch_w_graph.png) 11 | 12 | Video: https://youtu.be/ki784S3wjqw 13 | Paper: https://arxiv.org/abs/2007.08506 14 | 15 | See [demo notebook](demos/sketchgraphs_demo.ipynb) for a quick overview of the data representions in SketchGraphs as well as an example of solving constraints via Onshape's API. 16 | 17 | ## Installation 18 | 19 | SketchGraphs can be installed using pip: 20 | 21 | ```bash 22 | >> pip install -e SketchGraphs 23 | ``` 24 | 25 | This will provide you with the necessary dependencies to load and explore the data. 26 | However, to train the models, you will need to additionally install [pytorch](https://pytorch.org/) 27 | and [torch-scatter](https://github.com/rusty1s/pytorch_scatter). 28 | 29 | ## Data 30 | 31 | We provide our dataset in a number of forms, some of which may be more appropriate for your desired usage. 32 | The following data files are provided: 33 | 34 | - The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed 35 | using [zstandard](https://facebook.github.io/zstd). They total about 43GB of data. In general, we only recommend these for advanced usage, as 36 | they are heavy and require extensive processing to manipulate. Users interested in working with the raw data 37 | may wish to inspect `sketchgraphs.pipeline.make_sketch_dataset` and `sketchgraphs.pipeline.make_sequence_dataset` 38 | to view our data processing scripts. The data is available for download [here](https://sketchgraphs.cs.princeton.edu/shards). 39 | 40 | - A dataset of construction sequences. This is provided as a single file, stored in a custom binary format. 41 | This format is much more concise (as it eliminates many of the unique identifiers used in the raw JSON format), 42 | and is better suited for ML applications. It is supported by our Python libraries and forms the baseline 43 | on which our models are trained. The data is available for download [here](https://sketchgraphs.cs.princeton.edu/sequence/sg_all.npy) (warning: 15GB file!). 44 | 45 | - A filtered dataset of construction sequences. This is provided as a single file, similarly stored in a custom 46 | binary format. This dataset is similar to the sequence dataset, but simplified by filtering out sketches 47 | that are too large or too small, and only includes a simplified set of entities and constraints (while still 48 | capturing a large portion of the data). Additionally, this dataset has been split into training, testing and 49 | validation splits for convenience. We train our models on this subset of the data. You can download the splits 50 | here: [train](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_train.npy) 51 | [validation](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_validation.npy) 52 | [test](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_test.npy) 53 | 54 | For full documentation of the processing pipeline, see https://princetonlips.github.io/SketchGraphs. 55 | 56 | The original creators of the CAD sketches hold the copyright. See [Onshape Terms of Use 1.g.ii](https://www.onshape.com/legal/terms-of-use#your_content) for additional licensing details. 57 | 58 | 59 | ## Models 60 | In addition to the dataset, we also provide some baseline model implementations to tackle the tasks of generative 61 | modeling and autoconstrain. These models are based on Graph Neural Network approaches and model the sketch as 62 | a graph where vertices are given by the entities in the sketch, and edges by the constraints between those entities. 63 | For more details, please refer to https://princetonlips.github.io/SketchGraphs/models. 64 | 65 | 66 | ## Citation 67 | If you use this dataset in your research, please cite: 68 | ``` 69 | @inproceedings{SketchGraphs, 70 | title={Sketch{G}raphs: A Large-Scale Dataset for Modeling Relational Geometry in Computer-Aided Design}, 71 | author={Seff, Ari and Ovadia, Yaniv and Zhou, Wenda and Adams, Ryan P.}, 72 | booktitle={ICML 2020 Workshop on Object-Oriented Learning}, 73 | year={2020} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/numerical_parameters.py: -------------------------------------------------------------------------------- 1 | """This module implements functionality to handle numerical parameters in the sketch. 2 | 3 | """ 4 | 5 | import math 6 | import re 7 | 8 | import numpy as np 9 | from sklearn.cluster import KMeans 10 | 11 | _IMPLICIT_MUL_PATTERN = re.compile(r"(?<=\))\s*(?=\w+) | (?<=\d)\s*(?=[A-Z]+)", re.X) 12 | 13 | 14 | ##### Constraint handling 15 | 16 | METER_CONVERSIONS = { 17 | 'METER': 1, 18 | 'METERS': 1, 19 | 'M': 1, 20 | 'MILLIMETER': 1e-3, 21 | 'MILLIMETERS': 1e-3, 22 | 'MM': 1e-3, 23 | 'CM': 1e-2, 24 | 'CENTIMETER': 1e-2, 25 | 'CENTIMETERS': 1e-2, 26 | 'INCHES': 0.0254, 27 | 'INCH': 0.0254, 28 | 'IN': 0.0254, 29 | 'FOOT': 0.3048, 30 | 'FEET': 0.048, 31 | 'FT': 0.3048, 32 | 'YD': 0.9144 33 | } 34 | 35 | DEGREE_CONVERSIONS = { 36 | 'DEG': 1, 37 | 'DEGREE': 1, 38 | 'RADIAN': 180/np.pi, 39 | 'RAD': 180/np.pi 40 | } 41 | 42 | 43 | def normalize_expression(expression, parameter_id): 44 | """Converts a numerical expression into a normalized form. 45 | 46 | Parameters 47 | ---------- 48 | expression : str 49 | A string representing a quantity parameter value 50 | parameter_id : str 51 | A parameterId string; must be one of 'angle' or 'length' 52 | 53 | Returns 54 | ------- 55 | norm_expression : str or None 56 | Normalized expression string if successful, or None otherwise. 57 | """ 58 | expression = expression.upper().strip() 59 | 60 | if parameter_id == 'angle': 61 | conversions = DEGREE_CONVERSIONS 62 | default_unit = 'DEGREE' 63 | elif parameter_id == 'length': 64 | conversions = METER_CONVERSIONS 65 | default_unit = 'METER' 66 | else: 67 | raise ValueError('parameter_id must be one of angle or length') 68 | 69 | if '#' in expression or 'lookup' in expression: 70 | # variables in feature script are not supported 71 | return None 72 | 73 | # Adds implicit multiplication signs in order to be more easily parsed. 74 | expression = _IMPLICIT_MUL_PATTERN.sub('*', expression) 75 | 76 | try: 77 | value = eval(expression, {'PI': np.pi, 'SQRT': math.sqrt, 'TAN': math.tan}, conversions) 78 | except: 79 | return None 80 | 81 | value_str = np.format_float_positional(value, precision=4, trim='0', fractional=False) 82 | return '{} {}'.format(value_str, default_unit) 83 | 84 | 85 | def make_quantization(values, num_points, scheme): 86 | """Find optimal centers for parameter via either uniform, K-means, or CDF-based K-means. 87 | 88 | Obtains a quantization scheme for the given values, according to a given strategy. 89 | Several schemes are supported, although we prefer the 'cdf' scheme, a hybrid which 90 | avoids some issues with large outliers in the datasets faced by other schemes. 91 | 92 | Parameters 93 | ---------- 94 | values : np.array 95 | An array of values representing a sample of the values to quantize 96 | num_points : int 97 | Number of points to obtain in the dictionary 98 | scheme : str 99 | Indicates the quantization scheme to use, must be 'uniform', 'kmeans' or 'cdf' 100 | 101 | Returns 102 | ------- 103 | np.array 104 | Array of quantization codes 105 | """ 106 | 107 | if scheme == 'uniform': 108 | edges = np.linspace(np.min(values), np.max(values), num_points+1) 109 | return np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)]) 110 | 111 | if scheme == 'kmeans': 112 | km = KMeans(n_clusters=num_points) 113 | km.fit(values.reshape(-1, 1)) 114 | return np.sort(np.squeeze(km.cluster_centers_)) 115 | 116 | if scheme == 'cdf': 117 | values, cdf = make_unique_cdf(values) 118 | cdf_centers = make_quantization(cdf, num_points, 'kmeans') 119 | return np.interp(cdf_centers, cdf, values) 120 | 121 | schemes = ['uniform', 'kmeans', 'cdf'] 122 | raise ValueError("scheme must be one of " + str(schemes)) 123 | 124 | 125 | def make_unique_cdf(arr): 126 | """Return 'collapsed' cdf of arr (identical/close arr vals all have same cdf point). 127 | 128 | Parameters 129 | ---------- 130 | arr: array of parameter values 131 | 132 | Returns 133 | ------- 134 | sorted_arr: sorted copy of arr. 135 | cdf: collapsed cdf of arr. 136 | """ 137 | cdf = np.linspace(0, 1, len(arr)) 138 | sorted_arr = np.sort(arr) 139 | last_unique_idx = 0 140 | for idx, arr_val in enumerate(np.append(sorted_arr, [np.inf])): 141 | if not np.isclose(arr_val, sorted_arr[last_unique_idx]): 142 | cdf[last_unique_idx:idx] = np.mean(cdf[last_unique_idx:idx]) 143 | last_unique_idx = idx 144 | return sorted_arr, cdf 145 | -------------------------------------------------------------------------------- /sketchgraphs_models/torch_extensions/segment_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ._repeat_interleave import repeat_interleave 4 | 5 | 6 | class SegmentAvgPool1DLoop(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, values, scopes): 9 | ctx.save_for_backward(scopes) 10 | ctx.input_length = values.shape[0] 11 | result = values.new_empty((scopes.shape[0],) + values.shape[1:]) 12 | 13 | for i in range(len(scopes)): 14 | x = values.narrow(0, scopes[i, 0], scopes[i, 1]) 15 | result[i] = x.mean(dim=0) 16 | 17 | return result 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | scopes, = ctx.saved_tensors 22 | return segment_avg_pool1d_backward(grad_output, scopes, ctx.input_length), None 23 | 24 | 25 | _avg_pool_docstring = """ 26 | Segmented average pooling. 27 | 28 | This function computes an average pool over a set of values, where the 29 | pool is taken over segments defined by scope. 30 | 31 | Parameters 32 | ---------- 33 | values : torch.Tensor 34 | A 1-dimensional tensor. 35 | scopes : torch.Tensor 36 | a 2-dimensional integer tensor representing segments. 37 | Each row of scopes represents a segment, which starts at ``scopes[i, 0]``, 38 | and has length ``scopes[i, 1]``. 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | A tensor representing the average value for each segment. 44 | """ 45 | 46 | 47 | def segment_avg_pool1d_loop(values, scopes): 48 | return SegmentAvgPool1DLoop.apply(values, scopes) 49 | 50 | 51 | def segment_avg_pool1d_scatter(values, scopes): 52 | import torch_scatter 53 | 54 | lengths = scopes.select(1, 1) 55 | offsets = lengths.new_zeros(len(lengths) + 1) 56 | torch.cumsum(lengths, 0, out=offsets[1:]) 57 | 58 | return torch_scatter.segment_mean_csr(values, offsets) 59 | 60 | 61 | segment_avg_pool1d_loop.__docstring__ = _avg_pool_docstring 62 | segment_avg_pool1d_scatter.__docstring__ = _avg_pool_docstring 63 | 64 | 65 | def segment_avg_pool1d_backward(grad_output, scopes, input_length): 66 | """ Backward pass for segmented average pooling. """ 67 | scopes_length = scopes.select(1, 1) 68 | norm_grad = torch.true_divide(grad_output, scopes_length.type_as(grad_output).unsqueeze(1)) 69 | 70 | result = grad_output.new_empty([input_length] + list(norm_grad.shape[1:])) 71 | return repeat_interleave(norm_grad, scopes, dim=0, out=result) 72 | 73 | 74 | def segment_max_pool1d_backward(grad_output, scopes, max_indices, input_length): 75 | result = grad_output.new_zeros((input_length, max_indices.shape[1])) 76 | 77 | scopes_offsets = scopes.select(1, 0).unsqueeze(1) 78 | linear_max_indices = scopes_offsets + max_indices 79 | 80 | result.scatter_(0, linear_max_indices, grad_output) 81 | return result 82 | 83 | 84 | class SegmentMaxPool1DLoop(torch.autograd.Function): 85 | @staticmethod 86 | def forward(ctx, values, scopes, return_indices=False): 87 | result = values.new_empty([scopes.shape[0], values.shape[1]]) 88 | result_idx = values.new_empty([scopes.shape[0], values.shape[1]], dtype=torch.int64) 89 | 90 | for i in range(len(scopes)): 91 | x = values.narrow(0, scopes[i, 0], scopes[i, 1]).t().unsqueeze(0) 92 | r, ri = torch.nn.functional.adaptive_max_pool1d(x, 1, return_indices=True) 93 | r = r.squeeze(0).squeeze(1) 94 | ri = ri.squeeze(0).squeeze(1) 95 | 96 | result[i] = r 97 | result_idx[i] = ri 98 | 99 | ctx.save_for_backward(scopes, result_idx) 100 | ctx.input_length = values.shape[0] 101 | ctx.mark_non_differentiable(result_idx) 102 | 103 | if return_indices: 104 | return result, result_idx 105 | else: 106 | return result 107 | 108 | @staticmethod 109 | def backward(ctx, grad_output, *args): 110 | scopes, result_idx = ctx.saved_tensors 111 | return segment_max_pool1d_backward(grad_output, scopes, result_idx, ctx.input_length), None, None 112 | 113 | 114 | def segment_max_pool1d(values: torch.Tensor, scopes: torch.Tensor, return_indices=False) -> torch.Tensor: 115 | """ 116 | Computes the maximum value in each segment. 117 | 118 | Parameters 119 | ---------- 120 | values : torch.Tensor 121 | A 1-dimensional tensor. 122 | scopes : torch.Tensor 123 | a 2-dimensional integer tensor representing segments. 124 | Each row of scopes represents a segment, which starts at ``scopes[i, 0]``, 125 | and has length ``scopes[i, 1]``. 126 | 127 | Returns 128 | ------- 129 | torch.Tensor 130 | A tensor representing the maximum value for each segment. 131 | """ 132 | return SegmentMaxPool1DLoop.apply(values, scopes, return_indices) 133 | 134 | 135 | try: 136 | import torch_scatter 137 | 138 | segment_avg_pool1d = segment_avg_pool1d_scatter 139 | except ImportError: 140 | segment_avg_pool1d = segment_avg_pool1d_loop 141 | 142 | 143 | __all__ = ['segment_avg_pool1d'] 144 | -------------------------------------------------------------------------------- /tests/test_graph_model.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | import pytest 8 | 9 | from sketchgraphs_models.graph import dataset as graph_dataset 10 | from sketchgraphs_models.graph import model as graph_model 11 | 12 | from sketchgraphs.data.sketch import Sketch 13 | 14 | @pytest.fixture 15 | def edge_feature_mapping(): 16 | """Dummy quantization functions.""" 17 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f: 18 | mapping = pickle.load(f) 19 | 20 | return graph_dataset.EdgeFeatureMapping( 21 | mapping['edge']['angle'], mapping['edge']['length']) 22 | 23 | @pytest.fixture 24 | def node_feature_mapping(): 25 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f: 26 | mapping = pickle.load(f) 27 | return graph_dataset.EntityFeatureMapping(mapping['node']) 28 | 29 | 30 | def test_compute_model(sketches, node_feature_mapping, edge_feature_mapping): 31 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) 32 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=36) 33 | 34 | batch = [dataset[i] for i in range(10)] 35 | batch_input = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping) 36 | 37 | model = graph_model.make_graph_model(32, {**node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions}) 38 | result = model(batch_input) 39 | assert 'node_embedding' in result 40 | 41 | 42 | def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping): 43 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) 44 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=36) 45 | 46 | batch = [dataset[i] for i in range(10)] 47 | batch_input = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping) 48 | 49 | feature_dimensions = {**node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions} 50 | 51 | model = graph_model.make_graph_model(32, feature_dimensions) 52 | model_output = model(batch_input) 53 | 54 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses( 55 | model_output, batch_input, feature_dimensions) 56 | assert isinstance(losses, dict) 57 | 58 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts']) 59 | assert isinstance(avg_losses, dict) 60 | 61 | for t in edge_metrics: 62 | assert edge_metrics[t][0].shape == edge_metrics[t][1].shape 63 | 64 | for t in edge_metrics: 65 | assert node_metrics[t][0].shape == node_metrics[t][1].shape 66 | 67 | 68 | def test_compute_model_no_entity_features(sketches, edge_feature_mapping): 69 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) 70 | dataset = graph_dataset.GraphDataset(sequences, None, edge_feature_mapping, seed=36) 71 | 72 | batch = [dataset[i] for i in range(10)] 73 | batch_input = graph_dataset.collate(batch, None, edge_feature_mapping) 74 | 75 | model = graph_model.make_graph_model( 76 | 32, {**edge_feature_mapping.feature_dimensions}, readout_entity_features=False) 77 | result = model(batch_input) 78 | assert 'node_embedding' in result 79 | 80 | 81 | def test_compute_losses_no_entity_features(sketches, edge_feature_mapping): 82 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) 83 | dataset = graph_dataset.GraphDataset(sequences, None, edge_feature_mapping, seed=36) 84 | 85 | batch = [dataset[i] for i in range(10)] 86 | batch_input = graph_dataset.collate(batch, None, edge_feature_mapping) 87 | 88 | feature_dimensions = {**edge_feature_mapping.feature_dimensions} 89 | 90 | model = graph_model.make_graph_model(32, feature_dimensions, readout_entity_features=False) 91 | model_output = model(batch_input) 92 | 93 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses( 94 | model_output, batch_input, feature_dimensions) 95 | assert isinstance(losses, dict) 96 | 97 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts']) 98 | assert isinstance(avg_losses, dict) 99 | 100 | for t in edge_metrics: 101 | assert edge_metrics[t][0].shape == edge_metrics[t][1].shape 102 | 103 | for t in edge_metrics: 104 | assert node_metrics[t][0].shape == node_metrics[t][1].shape 105 | 106 | 107 | def test_compute_model_subnode(sketches): 108 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches)) 109 | dataset = graph_dataset.GraphDataset(sequences, seed=36) 110 | 111 | batch = [dataset[i] for i in range(10)] 112 | batch_input = graph_dataset.collate(batch) 113 | 114 | model = graph_model.make_graph_model( 115 | 32, feature_dimensions={}, readout_entity_features=False, readout_edge_features=False) 116 | result = model(batch_input) 117 | assert 'node_embedding' in result 118 | 119 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses( 120 | result, batch_input, {}) 121 | assert isinstance(losses, dict) 122 | 123 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts']) 124 | assert isinstance(avg_losses, dict) 125 | -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/eval_likelihood.py: -------------------------------------------------------------------------------- 1 | """Evaluates an auto-constraint model in terms of the sequential assigned likelihood. """ 2 | 3 | import argparse 4 | import itertools 5 | import gzip 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | import tqdm 11 | 12 | from sketchgraphs import data as datalib 13 | from sketchgraphs.data import flat_array 14 | 15 | from sketchgraphs_models import training 16 | from sketchgraphs_models.autoconstraint import dataset, model as auto_model, eval 17 | 18 | 19 | def _edge_ops(ops): 20 | return [op for op in ops if isinstance(op, datalib.EdgeOp)] 21 | 22 | def _node_ops(ops): 23 | return [op for op in ops if isinstance(op, datalib.NodeOp)] 24 | 25 | 26 | def ops_to_batch(ops, node_feature_mappings): 27 | """Given a sequence of operations representing a sketch, constructs a batch 28 | """ 29 | node_ops = [op for op in ops if isinstance(op, datalib.NodeOp)] 30 | 31 | batch = [] 32 | 33 | for i, op in enumerate(ops): 34 | if i == 0: 35 | # First external node does not induce edge stop token 36 | continue 37 | 38 | if op.label == datalib.ConstraintType.Subnode: 39 | continue 40 | 41 | ops_in_graph = ops[:i] 42 | 43 | features = dataset.process_node_and_edge_ops( 44 | node_ops, _edge_ops(ops_in_graph), 45 | len(_node_ops(ops_in_graph)), node_feature_mappings) 46 | 47 | if isinstance(op, datalib.NodeOp): 48 | # Stop problem 49 | partner_index = -1 50 | target_edge_label = -1 51 | else: 52 | partner_index = op.references[-1] 53 | target_edge_label = dataset.EDGE_IDX_MAP[op.label] 54 | 55 | features['partner_index'] = partner_index 56 | features['target_edge_label'] = target_edge_label 57 | 58 | batch.append(features) 59 | 60 | # Append final edge stop target 61 | batch.append({ 62 | **dataset.process_node_and_edge_ops(node_ops, _edge_ops(ops), len(node_ops), node_feature_mappings), 63 | 'partner_index': -1, 64 | 'target_edge_label': -1 65 | }) 66 | 67 | return batch 68 | 69 | 70 | class EdgeLikelihoodEvaluator: 71 | def __init__(self, model, node_feature_mappings, device=None): 72 | self.model = model 73 | self.node_feature_mappings = node_feature_mappings 74 | self.device = device 75 | 76 | def edge_likelihood(self, ops): 77 | if ops[-1].label == 'Stop': 78 | ops = ops[:-1] 79 | 80 | if len(ops) == 1: 81 | return np.empty([0, 2]) 82 | 83 | batch_list = ops_to_batch(ops, self.node_feature_mappings) 84 | 85 | batch = dataset.collate(batch_list) 86 | batch_device = training.load_cuda_async(batch, self.device) 87 | 88 | with torch.no_grad(): 89 | readout = self.model(batch_device) 90 | losses, _ = auto_model.compute_losses(batch_device, readout, reduction='none') 91 | 92 | losses_flat = torch.zeros((len(batch_list), 2)) 93 | losses_flat[batch['partner_index'].index, 0] = losses['edge_partner'].cpu() 94 | losses_flat[batch['stop_partner_index_index'], 0] = losses['edge_stop'].cpu() 95 | 96 | losses_flat[batch['partner_index'].index, 1] = losses['edge_label'].cpu() 97 | 98 | losses_flat = losses_flat[torch.argsort(batch['sorted_indices'])] 99 | 100 | return losses_flat.numpy() 101 | 102 | 103 | 104 | def main(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--model', type=str, required=True) 107 | parser.add_argument('--dataset', type=str, required=True) 108 | parser.add_argument('--device', type=str, default='cpu') 109 | parser.add_argument('--output', type=str) 110 | parser.add_argument('--max_predictions', type=int, default=None) 111 | 112 | 113 | args = parser.parse_args() 114 | 115 | device = torch.device(args.device) 116 | 117 | print('Loading trained model') 118 | model, node_feature_mapping = eval.load_sampling_model(args.model) 119 | model = model.eval().to(device) 120 | 121 | print('Loading testing data') 122 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences'] 123 | 124 | likelihood_evaluation = EdgeLikelihoodEvaluator(model, node_feature_mapping, device) 125 | 126 | length = len(seqs) 127 | if args.max_predictions is not None: 128 | length = min(length, args.max_predictions) 129 | 130 | results = [] 131 | average_likelihoods = np.empty(length) 132 | sequence_length = np.empty(length) 133 | 134 | for i in tqdm.trange(length): 135 | seq = seqs[i] 136 | result = likelihood_evaluation.edge_likelihood(seq) 137 | results.append({ 138 | 'seq': seq, 139 | 'likelihood': result 140 | }) 141 | 142 | sequence_length[i] = result.shape[0] 143 | 144 | if result.shape[0] == 0: 145 | average_likelihoods[i] = 0.0 146 | else: 147 | average_likelihoods[i] = np.mean(np.sum(result, axis=-1), axis=0) 148 | 149 | print('Average bit per edge {0:.3f}'.format(np.average(average_likelihoods, weights=sequence_length) / np.log(2))) 150 | 151 | if args.output is None: 152 | return 153 | 154 | print('Saving to output {0}'.format(args.output)) 155 | with gzip.open(args.output, 'wb') as f: 156 | pickle.dump(results, f, protocol=4) 157 | 158 | 159 | if __name__ == '__main__': 160 | main() 161 | -------------------------------------------------------------------------------- /sketchgraphs_models/nn/summary.py: -------------------------------------------------------------------------------- 1 | """This module implements utilities to compute summary statistics. 2 | """ 3 | 4 | import posixpath 5 | import torch 6 | import torchmetrics 7 | 8 | 9 | class ClassificationSummary: 10 | """ Simple class to keep track of summaries of a classification problem. """ 11 | def __init__(self, num_outcomes=2, device=None): 12 | """ Initializes a new summary class with the given number of outcomes. 13 | 14 | Parameters 15 | ---------- 16 | num_outcomes : int 17 | the number of possible outcomes of the classification problem. 18 | device : torch.device, optional 19 | device on which to place the recorded statistics. 20 | """ 21 | self.recorded = torch.zeros(num_outcomes * num_outcomes, dtype=torch.int32, device=device) 22 | self.num_outcomes = num_outcomes 23 | 24 | @property 25 | def prediction_matrix(self): 26 | """Returns a `torch.Tensor` representing the prediction matrix.""" 27 | return self.recorded.view((self.num_outcomes, self.num_outcomes)) 28 | 29 | def record_statistics(self, labels, predictions): 30 | """ Records statistics for a batch of predictions. 31 | 32 | Parameters 33 | ---------- 34 | labels : torch.Tensor 35 | an array of true labels in integer format. Each label must correspond to an 36 | integer in 0 to num_outcomes - 1 inclusive. 37 | predictions : torch.Tensor 38 | an array of predicted labels. Must follow the same format as `labels`. 39 | """ 40 | indices = torch.add(labels.int(), predictions.int(), alpha=self.num_outcomes).long().to(device=self.recorded.device) 41 | self.recorded = self.recorded.scatter_add_( 42 | 0, indices, torch.ones_like(indices, dtype=torch.int32)) 43 | 44 | def reset_statistics(self): 45 | """ Resets statistics recorded in this accumulator. """ 46 | self.recorded = torch.zeros_like(self.recorded) 47 | 48 | def accuracy(self): 49 | """ Compute the accuracy of the recorded problem. """ 50 | num_correct = self.prediction_matrix.diag().sum() 51 | num_total = self.recorded.sum() 52 | 53 | return num_correct.float() / num_total.float() 54 | 55 | def confusion_matrix(self): 56 | """Returns a `torch.Tensor` representing the confusion matrix.""" 57 | return self.prediction_matrix.float() / self.prediction_matrix.sum().float() 58 | 59 | def cohen_kappa(self): 60 | """Computes the Cohen kappa measure of agreement. 61 | """ 62 | pm = self.prediction_matrix.float() 63 | N = self.recorded.sum().float() 64 | 65 | p_observed = pm.diag().sum() / N 66 | p_expected = torch.dot(pm.sum(dim=0), pm.sum(dim=1)) / (N * N) 67 | 68 | if p_expected == 1: 69 | return 1 70 | else: 71 | return 1 - (1 - p_observed) / (1 - p_expected) 72 | 73 | def marginal_labels(self): 74 | """Computes the empirical marginal distribution of the true labels.""" 75 | return self.prediction_matrix.sum(dim=0).float() / self.recorded.sum().float() 76 | 77 | def marginal_predicted(self): 78 | """Computes the empirical marginal distribution of the predicted labels.""" 79 | return self.prediction_matrix.sum(dim=1).float() / self.recorded.sum().float() 80 | 81 | def write_tensorboard(self, writer, prefix="", global_step=None, **kwargs): 82 | """Write the accuracy and kappa metrics to a tensorboard writer. 83 | 84 | Parameters 85 | ---------- 86 | writer: torch.utils.tensorboard.SummaryWriter 87 | The writer to which the metrics will be written 88 | prefix: str, optional 89 | Optional prefix for the name under which the metrics will be written 90 | global_step: int, optional 91 | Global step at which the metric is recorded 92 | **kwargs 93 | Further arguments to `torch.utils.tensorboard.SummaryWriter.add_scalar`. 94 | """ 95 | writer.add_scalar(posixpath.join(prefix, "kappa"), self.cohen_kappa(), global_step, **kwargs) 96 | writer.add_scalar(posixpath.join(prefix, "accuracy"), self.accuracy(), global_step, **kwargs) 97 | 98 | 99 | class CohenKappa(torchmetrics.Metric): 100 | """A pytorch-lightning compatible metric which computes Cohen's kappa score of agreement. 101 | """ 102 | recorded: torch.Tensor 103 | 104 | def __init__(self, num_outcomes=2, compute_on_step=True, dist_sync_on_step=False): 105 | super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step) 106 | 107 | self.num_outcomes = num_outcomes 108 | self.add_state("recorded", torch.zeros(num_outcomes * num_outcomes, dtype=torch.int32), dist_reduce_fx='sum') 109 | 110 | def update(self, preds: torch.Tensor, targets: torch.Tensor): 111 | indices = torch.add(targets.int(), preds.int(), alpha=self.num_outcomes).long().to(device=self.recorded.device) 112 | self.recorded.scatter_add_(0, indices, torch.ones_like(indices, dtype=torch.int32)) 113 | 114 | def compute(self): 115 | pm = self.recorded.view(self.num_outcomes, self.num_outcomes).float() 116 | N = self.recorded.sum().float() 117 | 118 | if N == 0: 119 | return pm.new_tensor(0.0) 120 | 121 | p_observed = pm.diag().sum() / N 122 | p_expected = torch.dot(pm.sum(dim=0), pm.sum(dim=1)) / (N * N) 123 | 124 | if p_expected == 1: 125 | return pm.new_tensor(1.0) 126 | else: 127 | return 1 - (1 - p_observed) / (1 - p_expected) 128 | -------------------------------------------------------------------------------- /sketchgraphs/data/_plotting.py: -------------------------------------------------------------------------------- 1 | """Functions for drawing sketches using matplotlib 2 | 3 | This module implements local plotting functionality in order to render Onshape sketches using matplotlib. 4 | 5 | """ 6 | 7 | import math 8 | from contextlib import nullcontext 9 | 10 | import matplotlib as mpl 11 | import matplotlib.patches 12 | import matplotlib.pyplot as plt 13 | 14 | from ._entity import Arc, Circle, Line, Point 15 | 16 | 17 | def _get_linestyle(entity): 18 | return '--' if entity.isConstruction else '-' 19 | 20 | def sketch_point(ax, point: Point, color='black', show_subnodes=False): 21 | ax.scatter(point.x, point.y, c=color, marker='.') 22 | 23 | def sketch_line(ax, line: Line, color='black', show_subnodes=False): 24 | start_x, start_y = line.start_point 25 | end_x, end_y = line.end_point 26 | if show_subnodes: 27 | marker = '.' 28 | else: 29 | marker = None 30 | ax.plot((start_x, end_x), (start_y, end_y), color, linestyle=_get_linestyle(line), linewidth=1, marker=marker) 31 | 32 | def sketch_circle(ax, circle: Circle, color='black', show_subnodes=False): 33 | patch = matplotlib.patches.Circle( 34 | (circle.xCenter, circle.yCenter), circle.radius, 35 | fill=False, linestyle=_get_linestyle(circle), color=color) 36 | if show_subnodes: 37 | ax.scatter(circle.xCenter, circle.yCenter, c=color, marker='.', zorder=20) 38 | ax.add_patch(patch) 39 | 40 | def sketch_arc(ax, arc: Arc, color='black', show_subnodes=False): 41 | angle = math.atan2(arc.yDir, arc.xDir) * 180 / math.pi 42 | startParam = arc.startParam * 180 / math.pi 43 | endParam = arc.endParam * 180 / math.pi 44 | 45 | if arc.clockwise: 46 | startParam, endParam = -endParam, -startParam 47 | 48 | ax.add_patch( 49 | matplotlib.patches.Arc( 50 | (arc.xCenter, arc.yCenter), 2*arc.radius, 2*arc.radius, 51 | angle=angle, theta1=startParam, theta2=endParam, 52 | linestyle=_get_linestyle(arc), color=color)) 53 | 54 | if show_subnodes: 55 | ax.scatter(arc.xCenter, arc.yCenter, c=color, marker='.') 56 | ax.scatter(*arc.start_point, c=color, marker='.', zorder=40) 57 | ax.scatter(*arc.end_point, c=color, marker='.', zorder=40) 58 | 59 | 60 | _PLOT_BY_TYPE = { 61 | Arc: sketch_arc, 62 | Circle: sketch_circle, 63 | Line: sketch_line, 64 | Point: sketch_point 65 | } 66 | 67 | 68 | def render_sketch(sketch, ax=None, show_axes=False, show_origin=False, 69 | hand_drawn=False, show_subnodes=False, show_points=True): 70 | """Renders the given sketch using matplotlib. 71 | 72 | Parameters 73 | ---------- 74 | sketch : Sketch 75 | The sketch instance to render 76 | ax : matplotlib.Axis, optional 77 | Axis object on which to render the sketch. If None, a new figure is created. 78 | show_axes : bool 79 | Indicates whether axis lines should be drawn 80 | show_origin : bool 81 | Indicates whether origin point should be drawn 82 | hand_drawn : bool 83 | Indicates whether to emulate a hand-drawn appearance 84 | show_subnodes : bool 85 | Indicates whether endpoints/centerpoints should be drawn 86 | show_points : bool 87 | Indicates whether Point entities should be drawn 88 | 89 | Returns 90 | ------- 91 | matplotlib.Figure 92 | If `ax` is not provided, the newly created figure. Otherwise, `None`. 93 | """ 94 | with plt.xkcd( 95 | scale=1, length=100, randomness=3) if hand_drawn else nullcontext(): 96 | 97 | if ax is None: 98 | fig = plt.figure() 99 | ax = fig.add_subplot(111, aspect='equal') 100 | else: 101 | fig = None 102 | 103 | # Eliminate upper and right axes 104 | ax.spines['right'].set_color('none') 105 | ax.spines['top'].set_color('none') 106 | 107 | if not show_axes: 108 | ax.set_yticklabels([]) 109 | ax.set_xticklabels([]) 110 | _ = [line.set_marker('None') for line in ax.get_xticklines()] 111 | _ = [line.set_marker('None') for line in ax.get_yticklines()] 112 | 113 | # Eliminate lower and left axes 114 | ax.spines['left'].set_color('none') 115 | ax.spines['bottom'].set_color('none') 116 | 117 | if show_origin: 118 | point_size = mpl.rcParams['lines.markersize'] * 1 119 | ax.scatter(0, 0, s=point_size, c='black') 120 | 121 | for ent in sketch.entities.values(): 122 | sketch_fn = _PLOT_BY_TYPE.get(type(ent)) 123 | if isinstance(ent, Point): 124 | if not show_points: 125 | continue 126 | if sketch_fn is None: 127 | continue 128 | sketch_fn(ax, ent, show_subnodes=show_subnodes) 129 | 130 | # Rescale axis limits 131 | ax.relim() 132 | ax.autoscale_view() 133 | 134 | return fig 135 | 136 | 137 | def render_graph(graph, filename, show_node_idxs=False): 138 | """Renders the given pgv.AGraph to an image file. 139 | 140 | Parameters 141 | ---------- 142 | graph : pgv.AGraph 143 | The graph to render 144 | filename : string 145 | Where to save the image file 146 | show_node_idxs: bool 147 | If true, append node indexes to their labels 148 | 149 | 150 | Returns 151 | ------- 152 | None 153 | """ 154 | if show_node_idxs: 155 | for idx, node in enumerate(graph.nodes()): 156 | node.attr['label'] += ' (' + str(idx) + ')' 157 | graph.layout('dot') 158 | graph.draw(filename) 159 | 160 | 161 | __all__ = ['render_sketch', 'render_graph'] -------------------------------------------------------------------------------- /tests/test_torch_extensions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | 5 | from scipy.special import logsumexp 6 | 7 | from sketchgraphs_models.torch_extensions import _repeat_interleave, segment_ops, segment_pool 8 | 9 | 10 | def test_repeat_python(): 11 | x = np.random.randn(40).reshape(4, 10) 12 | times = [2, 5, 0, 1] 13 | 14 | expected = np.repeat(x, times, axis=0) 15 | result = _repeat_interleave.repeat_interleave(torch.tensor(x), torch.tensor(times), 0) 16 | 17 | assert np.allclose(result.numpy(), expected) 18 | 19 | 20 | def test_segment_logsumexp_python(): 21 | x = np.random.randn(40) 22 | lengths = [5, 10, 6, 4, 15] 23 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 24 | 25 | scopes = np.stack((offsets, lengths), axis=1) 26 | 27 | expected = np.array([logsumexp(x[s[0]:s[0] + s[1]]) for s in scopes]) 28 | result = segment_ops.segment_logsumexp_python(torch.tensor(x), torch.tensor(scopes)) 29 | 30 | assert np.allclose(result, expected) 31 | 32 | 33 | def test_segment_logsumexp_python_grad(): 34 | x = np.random.randn(40) 35 | 36 | lengths = [5, 10, 6, 4, 15] 37 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 38 | 39 | scopes = np.stack((offsets, lengths), axis=1) 40 | 41 | torch.autograd.gradcheck( 42 | segment_ops.segment_logsumexp_python, 43 | (torch.tensor(x, requires_grad=True), torch.tensor(scopes))) 44 | 45 | 46 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 47 | def test_segment_logsumexp_scatter(device): 48 | x = np.random.randn(40) 49 | lengths = [0, 5, 10, 6, 4, 15, 0] 50 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 51 | 52 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 53 | 54 | expected = np.array([logsumexp(x[s[0]:s[0] + s[1]]) if s[1] != 0 else -np.inf for s in scopes]) 55 | 56 | result = segment_ops.segment_logsumexp_scatter(torch.tensor(x, device=device), torch.tensor(scopes, device=device)) 57 | 58 | assert np.allclose(result.cpu().numpy(), expected) 59 | 60 | 61 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 62 | def test_segment_logsumexp_scatter_grad(device): 63 | x = np.random.randn(40) 64 | 65 | lengths = [5, 10, 6, 4, 15] 66 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 67 | 68 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 69 | 70 | torch.autograd.gradcheck( 71 | segment_ops.segment_logsumexp_scatter, 72 | (torch.tensor(x, requires_grad=True, device=device), torch.tensor(scopes, device=device))) 73 | 74 | 75 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 76 | def test_segment_logsumexp_scatter_grad_full(device): 77 | x = np.random.randn(20) 78 | 79 | scopes = torch.tensor([[0, 20]], dtype=torch.int64, device=device) 80 | 81 | torch.autograd.gradcheck( 82 | segment_ops.segment_logsumexp_scatter, 83 | (torch.tensor(x, requires_grad=True, device=device), scopes)) 84 | 85 | 86 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 87 | def test_segment_argmax(device): 88 | x = np.random.randn(40) 89 | 90 | lengths = np.array([0, 5, 10, 6, 4, 15, 0]) 91 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 92 | 93 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 94 | 95 | x = torch.tensor(x, device=device) 96 | scopes = torch.tensor(scopes, device=device) 97 | 98 | expected_values, expected_index = segment_ops.segment_argmax_python(x, scopes) 99 | result_values, result_index = segment_ops.segment_argmax_scatter(x, scopes) 100 | 101 | result_values = result_values.cpu().numpy() 102 | expected_values = expected_values.cpu().numpy() 103 | result_index = result_index.cpu().numpy() 104 | expected_index = expected_index.cpu().numpy() 105 | 106 | assert np.allclose(result_values, expected_values) 107 | assert np.allclose(result_index[lengths > 0], expected_index[lengths > 0]) 108 | 109 | 110 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 111 | def test_segment_argmax_backward(device): 112 | x = np.random.randn(40) 113 | 114 | lengths = [5, 10, 6, 4, 15] 115 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 116 | 117 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 118 | 119 | torch.autograd.gradcheck( 120 | segment_ops.segment_argmax_scatter, 121 | (torch.tensor(x, requires_grad=True, device=device), 122 | torch.tensor(scopes, device=device), 123 | False)) 124 | 125 | 126 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 127 | def test_segment_pool(device): 128 | x = np.random.randn(40) 129 | 130 | lengths = [5, 10, 6, 4, 15] 131 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 132 | 133 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 134 | 135 | x = torch.tensor(x, device=device) 136 | scopes = torch.tensor(scopes, device=device) 137 | 138 | expected_values = segment_pool.segment_avg_pool1d_loop(x, scopes) 139 | result_values = segment_pool.segment_avg_pool1d_scatter(x, scopes) 140 | 141 | assert torch.allclose(expected_values, result_values) 142 | 143 | 144 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 145 | def test_segment_pool_2d(device): 146 | x = np.random.randn(40, 5) 147 | 148 | lengths = [5, 10, 6, 4, 15] 149 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1]))) 150 | 151 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64) 152 | 153 | x = torch.tensor(x, device=device) 154 | scopes = torch.tensor(scopes, device=device) 155 | 156 | expected_values = segment_pool.segment_avg_pool1d_loop(x, scopes) 157 | result_values = segment_pool.segment_avg_pool1d_scatter(x, scopes) 158 | 159 | assert torch.allclose(expected_values, result_values) 160 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/model/message_passing.py: -------------------------------------------------------------------------------- 1 | """This module contains the components of the graph model associated with handling 2 | the message passing. 3 | """ 4 | 5 | import torch 6 | 7 | from sketchgraphs_models import nn as sg_nn 8 | import sketchgraphs_models.nn.functional 9 | from sketchgraphs.pipeline import graph_model 10 | 11 | 12 | class DenseSparsePreEmbedding(torch.nn.Module): 13 | """This is a generic pre-embedding module which combines dense and sparse pre-embeddings.""" 14 | def __init__(self, target_type, feature_embeddings, fixed_embedding_cardinality, fixed_embedding_dim, 15 | sparse_embedding_dim=None, embedding_dim=None): 16 | """Initializes a new DenseSparsePreEmbedding module. 17 | 18 | Parameters 19 | ---------- 20 | target_type : enum 21 | The underlying enumeration indicating target types 22 | feature_embeddings : dict of modules 23 | A dictionary of embeddings for each of the sparse feature types. 24 | fixed_embedding_cardinality : int 25 | The number of classes in the fixed (dense) embedding layer. 26 | fixed_embedding_dim : int 27 | The dimension of the embedding for the fixed layer. 28 | sparse_embedding_dim : int, optional 29 | The dimension of the sparse embeddings. If None, assumed to be the same as the dense embedding. 30 | embedding_dim : int, optional 31 | The outpu dimension of the embedding. If None, assumed to be the same as the dense embedding. 32 | """ 33 | super(DenseSparsePreEmbedding, self).__init__() 34 | sparse_embedding_dim = sparse_embedding_dim or fixed_embedding_dim 35 | embedding_dim = embedding_dim or fixed_embedding_dim 36 | 37 | self.target_type = target_type 38 | self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings) 39 | self.sparse_embedding_dim = sparse_embedding_dim 40 | self.fixed_embedding_dim = fixed_embedding_dim 41 | self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality, fixed_embedding_dim) 42 | self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim, sparse_embedding_dim, embedding_dim) 43 | 44 | def forward(self, fixed_features, sparse_features): 45 | fixed_embeddings = self.fixed_embedding(fixed_features) 46 | sparse_embeddings = fixed_embeddings.new_zeros((fixed_embeddings.shape[0], self.sparse_embedding_dim)) 47 | 48 | for k, embedding_network in self.feature_embeddings.items(): 49 | sf = sparse_features[self.target_type[k]] 50 | if sf is None or len(sf.index) == 0: 51 | continue 52 | 53 | assert (sf.index < fixed_embeddings.shape[0]).all() 54 | sparse_embeddings[sf.index] = embedding_network(sf.value) 55 | 56 | return self.dense_merge(fixed_embeddings, sparse_embeddings) 57 | 58 | 59 | class DenseOnlyEmbedding(torch.nn.Module): 60 | """Generic pre-embedding module which encapsulates a pytorch embedding layer. 61 | 62 | This class is simply provided for compatibility with `DenseSparsePreEmbedding`, to construct 63 | models where sparse embeddings are not present. 64 | """ 65 | def __init__(self, cardinality, dimension): 66 | super(DenseOnlyEmbedding, self).__init__() 67 | self.fixed_embedding = torch.nn.Embedding(cardinality, dimension) 68 | 69 | def forward(self, features, *_): 70 | return self.fixed_embedding(features) 71 | 72 | 73 | class GraphModelCore(torch.nn.Module): 74 | """Component of the entity model used to compute global features, i.e. graph and node embeddings. 75 | 76 | This component is responsible for the computation that is independent of any 77 | specific target (edge / node). It is split off to ease sharing with sampling models. 78 | """ 79 | def __init__(self, message_passing, node_embedding, edge_embedding, graph_embedding): 80 | super(GraphModelCore, self).__init__() 81 | self.message_passing = message_passing 82 | self.node_embedding = node_embedding 83 | self.edge_embedding = edge_embedding 84 | self.graph_embedding = graph_embedding 85 | 86 | def forward(self, data): 87 | graph = data['graph'] 88 | 89 | with torch.autograd.profiler.record_function('feature_embedding'): 90 | node_pre_embedding = self.node_embedding(graph.node_features, graph.sparse_node_features) 91 | edge_pre_embedding = self.edge_embedding(graph.edge_features, graph.sparse_edge_features) 92 | 93 | node_post_embedding = self.message_passing(node_pre_embedding, graph.incidence, (edge_pre_embedding,)) 94 | graph_embedding = self.graph_embedding(node_post_embedding, graph) 95 | return node_post_embedding, graph_embedding 96 | 97 | 98 | class GraphPostEmbedding(torch.nn.Module): 99 | """Component of the graph model which computes graph-wide representation by aggregating node representations. 100 | """ 101 | def __init__(self, hidden_size, graph_embedding_size=None): 102 | super(GraphPostEmbedding, self).__init__() 103 | 104 | if graph_embedding_size is None: 105 | graph_embedding_size = 2 * hidden_size 106 | 107 | self.node_gating_net = torch.nn.Sequential( 108 | torch.nn.Linear(hidden_size, 1), 109 | torch.nn.Sigmoid() 110 | ) 111 | self.node_to_graph_net = torch.nn.Linear(hidden_size, graph_embedding_size) 112 | 113 | def forward(self, node_embedding, graph): 114 | scopes = graph_model.scopes_from_offsets(graph.node_offsets) 115 | 116 | transformed_embedding = self.node_gating_net(node_embedding) * self.node_to_graph_net(node_embedding) 117 | 118 | graph_embedding = sg_nn.functional.segment_avg_pool1d( 119 | transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1) 120 | 121 | return graph_embedding 122 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/make_quantization_statistics.py: -------------------------------------------------------------------------------- 1 | """This script computes statistics required for quantization of continuous parameters. 2 | 3 | Many models we use require quantization to process the continuous parameters in the sketch. 4 | This scripts computes the required statistics for quantization of the dataset. 5 | 6 | """ 7 | 8 | import argparse 9 | import collections 10 | import functools 11 | import gzip 12 | import itertools 13 | import multiprocessing 14 | import pickle 15 | import os 16 | 17 | import numpy as np 18 | import tqdm 19 | 20 | from sketchgraphs.data import flat_array 21 | from sketchgraphs.data.sequence import EdgeOp 22 | from sketchgraphs.data.sketch import EntityType, ENTITY_TYPE_TO_CLASS 23 | from . import numerical_parameters 24 | 25 | 26 | _EDGE_PARAMETER_IDS = ('angle', 'length') 27 | 28 | 29 | def _worker_edges(dataset_path, worker_idx, num_workers, result_queue): 30 | # Load data 31 | data = flat_array.load_dictionary_flat(np.load(dataset_path, mmap_mode='r')) 32 | sequences = data['sequences'] 33 | 34 | # Extract sub-sequence for worker 35 | length_for_worker, num_additional = divmod(len(sequences), num_workers) 36 | offset = worker_idx * length_for_worker + max(worker_idx, num_additional) 37 | if worker_idx < num_additional: 38 | length_for_worker += 1 39 | 40 | seq_indices = range(offset, min((offset+length_for_worker, len(sequences)))) 41 | 42 | # Process data 43 | expression_counters = { 44 | k: collections.Counter() for k in _EDGE_PARAMETER_IDS 45 | } 46 | 47 | num_processed = 0 48 | 49 | for seq_idx in seq_indices: 50 | seq = sequences[seq_idx] 51 | 52 | try: 53 | for op in seq: 54 | if not isinstance(op, EdgeOp): 55 | continue 56 | 57 | for k in _EDGE_PARAMETER_IDS: 58 | if k in op.parameters: 59 | value = op.parameters[k] 60 | value = numerical_parameters.normalize_expression(value, k) 61 | expression_counters[k][value] += 1 62 | except Exception: 63 | print('Error processing sequence at index {0}'.format(seq_idx)) 64 | 65 | num_processed += 1 66 | if num_processed > 1000: 67 | result_queue.put(num_processed) 68 | num_processed = 0 69 | result_queue.put(num_processed) 70 | 71 | result_queue.put(expression_counters) 72 | 73 | 74 | def _worker_node(param_combination, filepath, num_centers, max_values=None): 75 | label, param_name = param_combination 76 | sequences = flat_array.load_dictionary_flat(np.load(filepath, mmap_mode='r'))['sequences'] 77 | 78 | values = (op.parameters[param_name] for op in itertools.chain.from_iterable(sequences) 79 | if op.label == label and param_name in op.parameters) 80 | 81 | if max_values is not None: 82 | values = itertools.islice(values, max_values) 83 | 84 | values = np.array(list(values)) 85 | centers = numerical_parameters.make_quantization(values, num_centers, 'cdf') 86 | return centers 87 | 88 | 89 | def process_edges(dataset_path, num_threads): 90 | print('Checking total sketch dataset size.') 91 | total_sequences = len(flat_array.load_dictionary_flat(np.load(dataset_path, mmap_mode='r'))['sequences']) 92 | 93 | result_queue = multiprocessing.Queue() 94 | 95 | workers = [] 96 | 97 | for worker_idx in range(num_threads): 98 | workers.append( 99 | multiprocessing.Process( 100 | target=_worker_edges, 101 | args=(dataset_path, worker_idx, num_threads, result_queue))) 102 | 103 | for worker in workers: 104 | worker.start() 105 | 106 | active_workers = len(workers) 107 | 108 | total_result = {} 109 | 110 | print('Processing sequences for edge statistics') 111 | with tqdm.tqdm(total=total_sequences) as pbar: 112 | while active_workers > 0: 113 | result = result_queue.get() 114 | 115 | if isinstance(result, int): 116 | pbar.update(result) 117 | continue 118 | 119 | for k, v in result.items(): 120 | total_result.setdefault(k, collections.Counter()).update(v) 121 | active_workers -= 1 122 | 123 | for worker in workers: 124 | worker.join() 125 | 126 | return total_result 127 | 128 | 129 | def process_nodes(dataset_path, num_centers, num_threads): 130 | print('Processing sequences for node statistics') 131 | label_parameter_combinations = [ 132 | (t, parameter_name) 133 | for t in (EntityType.Arc, EntityType.Circle, EntityType.Line, EntityType.Point) 134 | for parameter_name in ENTITY_TYPE_TO_CLASS[t].float_ids 135 | ] 136 | 137 | pool = multiprocessing.Pool(num_threads) 138 | 139 | all_centers = pool.map( 140 | functools.partial( 141 | _worker_node, filepath=dataset_path, num_centers=num_centers, max_values=50000), 142 | label_parameter_combinations) 143 | 144 | result = {} 145 | for (t, parameter_name), centers in zip(label_parameter_combinations, all_centers): 146 | result.setdefault(t, {})[parameter_name] = centers 147 | 148 | return result 149 | 150 | 151 | def main(): 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--input', type=str, help='Input sequence dataset', required=True) 154 | parser.add_argument('--output', type=str, help='Output dataset path', default='meta.pkl.gz') 155 | parser.add_argument('--num_threads', type=int, default=0) 156 | parser.add_argument('--node_num_centers', type=int, default=256) 157 | 158 | args = parser.parse_args() 159 | 160 | num_threads = args.num_threads 161 | if num_threads is None: 162 | num_threads = len(os.sched_getaffinity(0)) 163 | 164 | edge_results = process_edges(args.input, num_threads) 165 | node_results = process_nodes(args.input, args.node_num_centers, num_threads) 166 | 167 | print('Saving results in {0}'.format(args.output)) 168 | with gzip.open(args.output, 'wb', compresslevel=9) as f: 169 | pickle.dump({ 170 | 'edge': edge_results, 171 | 'node': node_results 172 | }, f, protocol=pickle.HIGHEST_PROTOCOL) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /sketchgraphs/data/dof.py: -------------------------------------------------------------------------------- 1 | """This module contains the implementation and data for heuristic degrees of freedom computation.""" 2 | 3 | import numpy as np 4 | 5 | from ._entity import EntityType, SubnodeType 6 | from ._constraint import ConstraintType 7 | from .sequence import EdgeOp, NodeOp 8 | 9 | 10 | NODE_DOF = { 11 | EntityType.Point: 2, 12 | EntityType.Line: 4, 13 | EntityType.Circle: 3, 14 | EntityType.Arc: 5, 15 | } 16 | 17 | EDGE_DOF_REMOVED = { 18 | ConstraintType.Angle: { 19 | (EntityType.Line, EntityType.Line): 1}, 20 | ConstraintType.Centerline_Dimension: { 21 | (EntityType.Arc, EntityType.Line): 0, 22 | (EntityType.Circle, EntityType.Line): 0, 23 | (EntityType.Line, EntityType.Line): 0, 24 | (EntityType.Line, EntityType.Point): 0}, 25 | ConstraintType.Coincident: { 26 | (EntityType.Arc, EntityType.Arc): 3, 27 | (EntityType.Arc, EntityType.Circle): 3, 28 | (EntityType.Arc, EntityType.Point): 1, 29 | (EntityType.Circle, EntityType.Circle): 3, 30 | (EntityType.Circle, EntityType.Point): 1, 31 | (EntityType.Line, EntityType.Line): 2, 32 | (EntityType.Line, EntityType.Point): 1, 33 | (EntityType.Point, EntityType.Point): 2}, 34 | ConstraintType.Concentric: { 35 | (EntityType.Arc, EntityType.Arc): 2, 36 | (EntityType.Arc, EntityType.Circle): 2, 37 | (EntityType.Arc, EntityType.Point): 2, 38 | (EntityType.Circle, EntityType.Circle): 2, 39 | (EntityType.Circle, EntityType.Point): 2, 40 | (EntityType.Point, EntityType.Point): 2}, 41 | ConstraintType.Diameter: { 42 | (EntityType.Arc, EntityType.Arc): 1, 43 | (EntityType.Circle, EntityType.Circle):1}, 44 | ConstraintType.Distance: { 45 | (EntityType.Arc, EntityType.Arc): 1, 46 | (EntityType.Arc, EntityType.Circle): 1, 47 | (EntityType.Arc, EntityType.Line): 1, 48 | (EntityType.Arc, EntityType.Point): 1, 49 | (EntityType.Circle, EntityType.Circle): 1, 50 | (EntityType.Circle, EntityType.Line): 1, 51 | (EntityType.Circle, EntityType.Point): 1, 52 | (EntityType.Line, EntityType.Line): 1, 53 | (EntityType.Line, EntityType.Point): 1, 54 | (EntityType.Point, EntityType.Point): 1}, 55 | ConstraintType.Equal: { 56 | (EntityType.Arc, EntityType.Arc): 1, 57 | (EntityType.Arc, EntityType.Circle): 1, 58 | (EntityType.Circle, EntityType.Circle): 1, 59 | (EntityType.Line, EntityType.Line): 1}, 60 | ConstraintType.Fix: { 61 | (EntityType.Arc, EntityType.Arc): 3, 62 | (EntityType.Circle, EntityType.Circle): 3, 63 | (EntityType.Line, EntityType.Line): 2, 64 | (EntityType.Point, EntityType.Point): 2}, 65 | ConstraintType.Horizontal: { 66 | (EntityType.Line, EntityType.Line): 1, 67 | (EntityType.Point, EntityType.Point): 1}, 68 | ConstraintType.Intersected: {}, 69 | ConstraintType.Length: { 70 | (EntityType.Arc, EntityType.Arc): 1, 71 | (EntityType.Line, EntityType.Line): 1}, 72 | ConstraintType.Midpoint: { 73 | (EntityType.Arc, EntityType.Point): 2, 74 | (EntityType.Line, EntityType.Point): 2}, 75 | ConstraintType.Normal: { 76 | (EntityType.Arc, EntityType.Line): 1, 77 | (EntityType.Circle, EntityType.Line): 1}, 78 | ConstraintType.Offset: { 79 | (EntityType.Arc, EntityType.Arc): 2, 80 | (EntityType.Arc, EntityType.Circle): 2, 81 | (EntityType.Circle, EntityType.Circle): 2, 82 | (EntityType.Line, EntityType.Line): 1}, 83 | ConstraintType.Parallel: { 84 | (EntityType.Line, EntityType.Line): 1}, 85 | ConstraintType.Perpendicular: { 86 | (EntityType.Line, EntityType.Line): 1}, 87 | ConstraintType.Radius: { 88 | (EntityType.Arc, EntityType.Arc): 1, 89 | (EntityType.Circle, EntityType.Circle): 1}, 90 | ConstraintType.Subnode: { 91 | (EntityType.Arc, EntityType.Point): 0, 92 | (EntityType.Circle, EntityType.Point): 0, 93 | (EntityType.Line, EntityType.Point): 0}, 94 | ConstraintType.Tangent: { 95 | (EntityType.Arc, EntityType.Arc): 1, 96 | (EntityType.Arc, EntityType.Circle): 1, 97 | (EntityType.Arc, EntityType.Line): 1, 98 | (EntityType.Circle, EntityType.Circle): 1, 99 | (EntityType.Circle, EntityType.Line): 1}, 100 | ConstraintType.Vertical: { 101 | (EntityType.Line, EntityType.Line): 1, 102 | (EntityType.Point, EntityType.Point): 1} 103 | } 104 | 105 | 106 | def get_node_label_for_dof(label) -> EntityType: 107 | """Get the node label to be used for degrees of freedom computation. 108 | 109 | When computing DOF, subnode labels are merged to point as they represent points. 110 | 111 | Parameters 112 | ---------- 113 | label : Union[EntityType, SubnodeType] 114 | The original node label 115 | 116 | Returns 117 | ------- 118 | EntityType 119 | The corresponding `EntityType` to be used for DOF computation. 120 | """ 121 | return EntityType.Point if isinstance(label, SubnodeType) else label 122 | 123 | 124 | def _get_dof_removed_for_edge(edge_op: EdgeOp, nodes): 125 | ref_types = [get_node_label_for_dof(nodes[r].label) for r in edge_op.references] 126 | 127 | if len(ref_types) == 1: 128 | ref_types = ref_types + ref_types 129 | elif len(ref_types) > 2: 130 | return 0 131 | t1, t2 = ref_types 132 | 133 | if EntityType.External in ref_types: 134 | return 0 135 | 136 | 137 | dof_dict = EDGE_DOF_REMOVED.get(edge_op.label) 138 | if dof_dict is None: 139 | return 0 140 | 141 | if (t1, t2) in dof_dict: 142 | return dof_dict[(t1, t2)] 143 | if (t2, t1) in dof_dict: 144 | return dof_dict[(t2, t1)] 145 | 146 | return 0 147 | 148 | 149 | def get_sequence_dof(seq): 150 | """Returns array of total DoF contribution from each op. 151 | 152 | Parameters 153 | ---------- 154 | seq : Iterable of NodeOp or EdgeOp 155 | The construction sequence of the sketch to analyze 156 | 157 | Returns 158 | ------- 159 | np.ndarray 160 | An integer array representing the number of degrees of freedom lost or gained 161 | at each construction step. 162 | """ 163 | out = np.zeros(len(seq), dtype=np.int32) 164 | nodes = [] 165 | for i, op in enumerate(seq): 166 | if isinstance(op, NodeOp): 167 | nodes.append(op) 168 | out[i] = NODE_DOF.get(op.label, 0) 169 | elif isinstance(op, EdgeOp): 170 | out[i] = -1*_get_dof_removed_for_edge(op, nodes) 171 | return out 172 | 173 | 174 | __all__ = ['get_sequence_dof', 'get_node_label_for_dof'] 175 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/model/numerical_features.py: -------------------------------------------------------------------------------- 1 | """ This module contains the components of the graph model that are concerned with handling 2 | numerical features associated with edges and nodes. 3 | 4 | """ 5 | 6 | import torch 7 | 8 | from sketchgraphs_models import nn as sg_nn 9 | 10 | 11 | class NumericalFeatureEncoding(torch.nn.Module): 12 | """Encode an array of numerical features. 13 | 14 | This encodes a sequence of features (presented as a sequence of integers) 15 | into a sequence of vector through an embedding. 16 | """ 17 | def __init__(self, feature_dims, embedding_dim): 18 | super(NumericalFeatureEncoding, self).__init__() 19 | self.feature_dims = list(feature_dims) 20 | self.register_buffer( 21 | 'feature_offsets', 22 | torch.cumsum(torch.tensor([0] + self.feature_dims[:-1], dtype=torch.int64), dim=0)) 23 | self.embeddings = torch.nn.Embedding( 24 | sum(feature_dims), embedding_dim, sparse=False) 25 | 26 | def forward(self, features): 27 | return self.embeddings(features + self.feature_offsets) 28 | 29 | 30 | class NumericalFeaturesEmbedding(torch.nn.Module): 31 | """Transform a sequence of numerical feature vectors into a single vector. 32 | 33 | Currently, this module simply aggregates the features by averaging, although more 34 | elaborate aggregation schemes (e.g. RNN) could be chosen. 35 | """ 36 | def __init__(self, embedding_dim): 37 | super(NumericalFeaturesEmbedding, self).__init__() 38 | self.embedding_dim = embedding_dim 39 | 40 | def forward(self, embeddings): 41 | return embeddings.mean(axis=-2) 42 | 43 | 44 | class NumericalFeatureDecoder(torch.nn.Module): 45 | """Module for decoding numerical feature embeddings to logits. 46 | 47 | Takes an array of features, and decodes them into a a sequence of varying length logits. 48 | """ 49 | def __init__(self, feature_dims, embedding_dim): 50 | super(NumericalFeatureDecoder, self).__init__() 51 | self.feature_dims = feature_dims 52 | self.linear_logistic = torch.nn.ModuleList([ 53 | torch.nn.Linear(embedding_dim, fd) for fd in feature_dims 54 | ]) 55 | 56 | def forward(self, embeddings): 57 | logits = [] 58 | 59 | for i, linear in enumerate(self.linear_logistic): 60 | logits.append(linear(embeddings[i])) 61 | 62 | return torch.cat(logits, dim=-1) 63 | 64 | 65 | class NumericalFeatureReadout(torch.nn.Module): 66 | """Module for numerical feature readout. 67 | 68 | This module is responsible for producing numerical edge features from the edge label 69 | and computed embeddings. 70 | """ 71 | def __init__(self, initial_input, feature_encoders, feature_decoders, sequence_model): 72 | """Creates a new edge feature readout model. 73 | 74 | Parameters 75 | ---------- 76 | initial_input : torch.nn.Module 77 | A module which creates the embedding for the first input slot based on the passed in data. 78 | It is called with all remaining arguments to the forward method. 79 | 80 | feature_encoders : torch.nn.Module 81 | A module which encodes the features from the edge for feeding in to network. 82 | It is called with an integer tensor of size `[batch, num_features]`. 83 | 84 | feature_decoders : torch.nn.Module 85 | A module which decodes sequence embeddings into an array of logits. 86 | 87 | sequence_model : torch.nn.Module 88 | The main computational module for this instance, transforms a sequence 89 | of embedding into another sequence of embedding. 90 | """ 91 | super(NumericalFeatureReadout, self).__init__() 92 | self.encoders = feature_encoders 93 | self.decoders = feature_decoders 94 | self.sequence_model = sequence_model 95 | self.initial_input = initial_input 96 | 97 | 98 | def forward(self, input_features, *args): 99 | initial_input = self.initial_input(*args) 100 | input_embeddings = self.encoders(input_features).permute(1, 0, 2) 101 | 102 | input_sequence = torch.cat((initial_input.expand(1, -1, -1), input_embeddings), dim=0) 103 | 104 | output_sequence, _ = self.sequence_model(input_sequence) 105 | 106 | return self.decoders(output_sequence[:-1]) 107 | 108 | 109 | def edge_decoder_initial_input(embedding_size): 110 | """Initial input function for edge readouts.""" 111 | return sg_nn.Sequential( 112 | sg_nn.ConcatenateLinear(embedding_size, embedding_size, embedding_size), 113 | torch.nn.ReLU(), 114 | torch.nn.Linear(embedding_size, embedding_size)) 115 | 116 | 117 | def entity_decoder_initial_input(embedding_size): 118 | """Initial input function for entity readouts.""" 119 | return torch.nn.Identity(embedding_size=embedding_size) 120 | 121 | 122 | def make_embedding_and_readout(embedding_size: int, feature_dimensions, initial_input_factory): 123 | """Creates feature embedding and readout networks for the given features. 124 | 125 | Parameters 126 | ---------- 127 | embedding_size : int 128 | Dimension of the embeddings to use. 129 | feature_dimensions : dict 130 | Dictionary whose values are lists of integers corresponding to the number of outcomes 131 | for each feature. 132 | initial_input_factory : int -> torch.nn.Module 133 | A function which returns a module responsible for transforming the inputs of the readout 134 | into initial embeddings for the internal sequence model. 135 | 136 | Returns 137 | ------- 138 | feature_embeddings : dict 139 | A dictionary containing the feature embedding modules. 140 | feature_readouts : dict 141 | A dictionary containing the feature readout modules. 142 | """ 143 | 144 | feature_encodings = { 145 | k: NumericalFeatureEncoding(dimensions.values(), embedding_size) 146 | for k, dimensions in feature_dimensions.items() 147 | } 148 | 149 | feature_embeddings = { 150 | k.name: torch.nn.Sequential( 151 | encoding, NumericalFeaturesEmbedding(embedding_size) 152 | ) 153 | for k, encoding in feature_encodings.items() 154 | } 155 | 156 | feature_readouts = { 157 | k.name: NumericalFeatureReadout( 158 | initial_input=initial_input_factory(embedding_size), 159 | feature_encoders=feature_encodings[k], 160 | feature_decoders=NumericalFeatureDecoder(dimensions.values(), embedding_size), 161 | sequence_model=torch.nn.GRU(embedding_size, embedding_size)) 162 | for k, dimensions in feature_dimensions.items() 163 | } 164 | 165 | return feature_embeddings, feature_readouts 166 | -------------------------------------------------------------------------------- /sketchgraphs_models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """This module provides utilities and generic build blocks for graph neural networks.""" 2 | 3 | import contextlib 4 | import torch 5 | 6 | def autograd_range(name): 7 | """ Creates an autograd range for pytorch autograd profiling 8 | """ 9 | return torch.autograd.profiler.record_function(name) 10 | 11 | 12 | def aggregate_by_incidence(values: torch.Tensor, incidence: torch.Tensor, 13 | transform_edge_messages=None, transform_edge_messages_args=None, 14 | output_size=None): 15 | """Aggregates values according to an incidence matrix. 16 | 17 | Effectively computes the following operation: 18 | 19 | .. code-block:: python 20 | 21 | output[i] = values[incidence[1, incidence[0] == i]].sum(axis=0) 22 | 23 | This operation essentially implements a sparse-matrix multiplication in coo format in a naive way. 24 | Optimization opportunity: write using actual cuSparse. 25 | 26 | Parameters 27 | ---------- 28 | values : torch.Tensor 29 | A tensor of rank at least 2 30 | incidence : torch.Tensor 31 | a `[2, k]` tensor 32 | transform_edge_messages : function, optional 33 | an arbitrary function which transforms edge messages. 34 | transform_edge_messages_args : any 35 | Arbitrary set of arguments that are passed to the `transform_edge_messages` function. 36 | output_size : List[int], optional 37 | if not `None`, the size of the output tensor. Otherwise, we assume the output tensor 38 | is the same size as `values`. 39 | 40 | Returns 41 | ------- 42 | torch.Tensor 43 | The output tensor, of the same rank as values. 44 | """ 45 | if output_size is None: 46 | output_size = values.shape[0] 47 | 48 | with autograd_range('broadcast_messages'): 49 | # broadcast node values to edge messages 50 | edge_messages = values.index_select(0, incidence[1]) 51 | 52 | if transform_edge_messages is not None: 53 | # apply transformation to edge messages if necessary. 54 | if transform_edge_messages_args is None: 55 | transform_edge_messages_args = tuple() 56 | 57 | with autograd_range('transform_messages'): 58 | edge_messages = transform_edge_messages(edge_messages, *transform_edge_messages_args) 59 | 60 | with autograd_range('aggregate_messages'): 61 | # collect edge messages into node values 62 | output = values.new_zeros([output_size] + list(edge_messages.shape[1:])) 63 | output.index_add_(0, incidence[0], edge_messages) 64 | 65 | return output 66 | 67 | 68 | class MessagePassingNetwork(torch.nn.Module): 69 | """ Custom configurable message-passing network. 70 | 71 | This class implements the main plumbing for a message passing network. 72 | but exposes points that can be configured to easily create different variants of the networks. 73 | """ 74 | def __init__(self, depth, message_aggregation_network, transform_edge_messages=None): 75 | """ Creates a new module representing the message passing network. 76 | Parameters 77 | ---------- 78 | depth : int 79 | number of message passing iterations to execute. 80 | message_aggregation_network : torch.nn.Module 81 | A module representing the model used to compute the embeddings 82 | to be used at the next step. This model receives the array 83 | of messages corresponding to the sum of the propagated messages, and the array 84 | of previous node embeddings. 85 | transform_edge_messages : torch.nn.Module 86 | A module representing the model used to transform 87 | edge messages at each step. See `aggregate_by_incidence`. 88 | """ 89 | super(MessagePassingNetwork, self).__init__() 90 | 91 | self.depth = depth 92 | self.message_aggregation_network = message_aggregation_network 93 | self.transform_edge_messages = transform_edge_messages 94 | 95 | 96 | __constants__ = ["depth"] 97 | 98 | def forward(self, node_embedding, incidence, edge_transform_args=None): 99 | """Forward function for the message passing network. 100 | 101 | Parameters 102 | ---------- 103 | node_embedding : torch.Tensor 104 | Tensor of shape `[num_nodes, ...]` representing the data at each node in the graph. 105 | incidence : torch.Tensor 106 | tensor of shape `[2, num_edges]` representing edge incidence in the graph 107 | edge_transform_args : any 108 | A tuple of further arguments to be passed to the edge transformation network. 109 | 110 | Returns 111 | ------- 112 | torch.Tensor 113 | The final node embedding values after the message passing has been carried out. 114 | """ 115 | # Compute message passing along edges 116 | with autograd_range("propagate_messages"): 117 | for _ in range(self.depth): 118 | activation = aggregate_by_incidence( 119 | node_embedding, incidence, self.transform_edge_messages, edge_transform_args) 120 | node_embedding = self.message_aggregation_network(activation, node_embedding) 121 | 122 | return node_embedding 123 | 124 | 125 | class ConcatenateLinear(torch.nn.Module): 126 | """A torch module which concatenates several inputs and mixes them using a linear layer. """ 127 | def __init__(self, left_size, right_size, output_size): 128 | """Creates a new concatenating linear layer. 129 | 130 | Parameters 131 | ---------- 132 | left_size : int 133 | Size of the left input 134 | right_size : int 135 | Size of the right input 136 | output_size : int 137 | Size of the output. 138 | """ 139 | super(ConcatenateLinear, self).__init__() 140 | 141 | self.left_size = left_size 142 | self.right_size = right_size 143 | self.output_size = output_size 144 | 145 | self._linear = torch.nn.Linear(left_size + right_size, output_size) 146 | 147 | def forward(self, left, right): 148 | return self._linear(torch.cat((left, right), dim=-1)) 149 | 150 | 151 | class Sequential(torch.nn.Module): 152 | """ Similar to `torch.nn.Sequential`, except can pass through modules which 153 | take multiple input arguments, and return tuples. 154 | """ 155 | def __init__(self, *args): 156 | super(Sequential, self).__init__() 157 | self._sequence_modules = torch.nn.ModuleList(args) 158 | 159 | def forward(self, *args): 160 | for module in self._sequence_modules: 161 | if not isinstance(args, (list, tuple)): 162 | args = [args] 163 | args = module(*args) 164 | 165 | return args 166 | -------------------------------------------------------------------------------- /sketchgraphs_models/graph/train/data_loading.py: -------------------------------------------------------------------------------- 1 | """This module contains the main functions used to load the required data from disk for training.""" 2 | 3 | import functools 4 | import gzip 5 | import pickle 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from sketchgraphs_models import distributed_utils 12 | from sketchgraphs_models.nn import data_util 13 | from sketchgraphs_models.graph import dataset 14 | 15 | from sketchgraphs.data import flat_array 16 | 17 | 18 | def load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features=True, edge_features=True): 19 | data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r')) 20 | 21 | if auxiliary_file is None: 22 | root, _ = os.path.splitext(dataset_file) 23 | auxiliary_file = root + '.stats.pkl.gz' 24 | 25 | if entity_features or edge_features: 26 | with gzip.open(auxiliary_file, 'rb') as f: 27 | auxiliary_dict = pickle.load(f) 28 | 29 | if entity_features: 30 | entity_feature_mapping = dataset.EntityFeatureMapping(auxiliary_dict['node']) 31 | else: 32 | entity_feature_mapping = None 33 | 34 | seqs = data['sequences'] 35 | weights = data['sequence_lengths'] 36 | 37 | if edge_features: 38 | if isinstance(quantization['angle'], dataset.QuantizationMap): 39 | angle_map = quantization['angle'] 40 | else: 41 | angle_map = dataset.QuantizationMap.from_counter(auxiliary_dict['edge']['angle'], quantization['angle']) 42 | 43 | if isinstance(quantization['length'], dataset.QuantizationMap): 44 | length_map = quantization['length'] 45 | else: 46 | length_map = dataset.QuantizationMap.from_counter(auxiliary_dict['edge']['length'], quantization['length']) 47 | edge_feature_mapping = dataset.EdgeFeatureMapping(angle_map, length_map) 48 | else: 49 | edge_feature_mapping = None 50 | 51 | return { 52 | 'sequences': seqs.share_memory_(), 53 | 'entity_feature_mapping': entity_feature_mapping, 54 | 'edge_feature_mapping': edge_feature_mapping, 55 | 'weights': weights 56 | } 57 | 58 | 59 | def load_dataset_and_weights_with_mapping(dataset_file, node_feature_mapping, edge_feature_mapping, seed=None): 60 | data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r')) 61 | seqs = data['sequences'] 62 | seqs.share_memory_() 63 | 64 | ds = dataset.GraphDataset(seqs, node_feature_mapping, edge_feature_mapping, seed) 65 | 66 | return ds, data['sequence_lengths'] 67 | 68 | 69 | def load_dataset_and_weights(dataset_file, auxiliary_file, quantization, seed=None, 70 | entity_features=True, edge_features=True, force_entity_categorical_features=False): 71 | data = load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features, edge_features) 72 | 73 | if data['entity_feature_mapping'] is None and force_entity_categorical_features: 74 | # Create an entity mapping which only computes the categorical features (i.e. isConstruction and clockwise) 75 | data['entity_feature_mapping'] = dataset.EntityFeatureMapping() 76 | 77 | return dataset.GraphDataset( 78 | data['sequences'], data['entity_feature_mapping'], data['edge_feature_mapping'], seed=seed), data['weights'] 79 | 80 | 81 | def make_dataloader_train(collate_fn, ds_train, weights, batch_size, num_epochs, num_workers, distributed_config=None): 82 | sampler = torch.utils.data.WeightedRandomSampler( 83 | weights, len(weights), replacement=True) 84 | 85 | if distributed_config is not None: 86 | sampler = distributed_utils.DistributedSampler( 87 | sampler, distributed_config.world_size, distributed_config.rank) 88 | 89 | batch_sampler = torch.utils.data.BatchSampler( 90 | sampler, batch_size, drop_last=False) 91 | 92 | dataloader_train = torch.utils.data.DataLoader( 93 | ds_train, 94 | collate_fn=collate_fn, 95 | batch_sampler=data_util.MultiEpochSampler(batch_sampler, num_epochs), 96 | num_workers=num_workers, 97 | pin_memory=True) 98 | 99 | batches_per_epoch = len(batch_sampler) 100 | 101 | return dataloader_train, batches_per_epoch 102 | 103 | 104 | def _make_dataloader_eval(ds_eval, weights, batch_size, num_workers, distributed_config=None): 105 | sampler = torch.utils.data.WeightedRandomSampler( 106 | weights, len(weights), replacement=True) 107 | 108 | if distributed_config is not None: 109 | sampler = distributed_utils.DistributedSampler( 110 | sampler, distributed_config.world_size, distributed_config.rank) 111 | 112 | dataloader_eval = torch.utils.data.DataLoader( 113 | ds_eval, 114 | collate_fn=functools.partial( 115 | dataset.collate, 116 | entity_feature_mapping=ds_eval.node_feature_mapping, 117 | edge_feature_mapping=ds_eval.edge_feature_mapping), 118 | sampler=sampler, 119 | batch_size=batch_size, 120 | num_workers=num_workers, 121 | pin_memory=True) 122 | 123 | return dataloader_eval 124 | 125 | 126 | def initialize_datasets(args, distributed_config: distributed_utils.DistributedTrainingInfo = None): 127 | """Initialize datasets and dataloaders. 128 | 129 | Parameters 130 | ---------- 131 | args : dict 132 | Dictionary containing all the dataset configurations. 133 | 134 | distributed_config : distributed_utils.DistributedTrainingInfo, optional 135 | If not None, configuration options for distributed training. 136 | 137 | Returns 138 | ------- 139 | torch.data.utils.Dataloader 140 | Training dataloader 141 | torch.data.utils.Dataloader 142 | If not None, testing dataloader 143 | int 144 | Number of batches per training epoch 145 | dataset.EntityFeatureMapping 146 | Feature mapping in use for entities 147 | dataset.EdgeFeatureMapping 148 | Feature mapping in use for constraints 149 | """ 150 | quantization = {'angle': args['num_quantize_angle'], 'length': args['num_quantize_length']} 151 | 152 | dataset_train_path = args['dataset_train'] 153 | auxiliary_path = args['dataset_auxiliary'] 154 | 155 | ds_train, weights_train = load_dataset_and_weights( 156 | dataset_train_path, auxiliary_path, quantization, args['seed'], 157 | not args.get('disable_entity_features', False), not args.get('disable_edge_features', False), 158 | args.get('force_entity_categorical_features', False)) 159 | 160 | batch_size = args['batch_size'] 161 | num_workers = args['num_workers'] 162 | 163 | if distributed_config: 164 | batch_size = batch_size // distributed_config.world_size 165 | num_workers = num_workers // distributed_config.world_size 166 | 167 | collate_fn = functools.partial( 168 | dataset.collate, 169 | entity_feature_mapping=ds_train.node_feature_mapping, 170 | edge_feature_mapping=ds_train.edge_feature_mapping) 171 | 172 | dl_train, batches_per_epoch = make_dataloader_train( 173 | collate_fn, ds_train, weights_train, batch_size, args['num_epochs'], num_workers, distributed_config) 174 | 175 | if args['dataset_test'] is not None: 176 | ds_test, weights_test = load_dataset_and_weights_with_mapping( 177 | args['dataset_test'], ds_train.node_feature_mapping, ds_train.edge_feature_mapping, args['seed']) 178 | dl_test = _make_dataloader_eval( 179 | ds_test, weights_test, batch_size, num_workers, distributed_config) 180 | else: 181 | dl_test = None 182 | 183 | return dl_train, dl_test, batches_per_epoch, ds_train.node_feature_mapping, ds_train.edge_feature_mapping 184 | -------------------------------------------------------------------------------- /sketchgraphs_models/torch_extensions/segment_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from ._repeat_interleave import repeat_interleave 4 | 5 | 6 | def segment_op_python(values, scopes, op): 7 | if scopes.dim() != 2: 8 | raise ValueError("Scopes must be two-dimensional.") 9 | 10 | if scopes.shape[1] != 2: 11 | raise ValueError("Scopes must be of length two in second dimension.") 12 | 13 | output = torch.empty(scopes.shape[0], dtype=values.dtype, device=values.device) 14 | 15 | for i in range(scopes.shape[0]): 16 | output[i] = op(values.narrow(0, scopes[i, 0], scopes[i, 1])) 17 | 18 | return output 19 | 20 | 21 | def segment_logsumexp_python(values: torch.Tensor, scopes: torch.Tensor): 22 | return SegmentLogsumexpPython.apply(values, scopes) 23 | 24 | 25 | class SegmentLogsumexpPython(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, values, scopes): 28 | result = segment_op_python(values, scopes, partial(torch.logsumexp, dim=0)) 29 | ctx.save_for_backward(values, result, scopes) 30 | return result 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | values, logsumexp, scopes = ctx.saved_tensors 35 | lengths = scopes.select(1, 1) 36 | 37 | return segment_logsumexp_backward_python(grad_output, values, logsumexp, lengths), None 38 | 39 | 40 | def _min_value(dtype): 41 | if dtype.is_floating_point: 42 | return torch.finfo(dtype).min 43 | else: 44 | return torch.iinfo(dtype).min 45 | 46 | 47 | def segment_argmax_loop(values, scopes): 48 | output_values = values.new_empty(scopes.shape[0]) 49 | output_index = scopes.new_empty(scopes.shape[0]) 50 | 51 | for i in range(scopes.shape[0]): 52 | if scopes[i, 1] != 0: 53 | output_values[i], output_index[i] = torch.max(values.narrow(0, scopes[i, 0], scopes[i, 1]), dim=0) 54 | else: 55 | output_values[i] = _min_value(output_values.dtype) 56 | output_index[i] = -1 57 | 58 | return output_values, output_index 59 | 60 | 61 | def segment_logsumexp_backward_python(grad_output, values, logsumexp, lengths): 62 | lengths = lengths.long() 63 | grad_output_repeat = repeat_interleave(grad_output, lengths, dim=0) 64 | derivative_repeat = (values - repeat_interleave(logsumexp, lengths, dim=0)).exp_() 65 | return derivative_repeat.mul_(grad_output_repeat) 66 | 67 | 68 | class SegmentLogsumexpScatter(torch.autograd.Function): 69 | @staticmethod 70 | def forward(ctx, values, scopes): 71 | import torch_scatter 72 | 73 | lengths = scopes.select(1, 1) 74 | offsets = lengths.new_zeros(len(lengths) + 1) 75 | torch.cumsum(lengths, 0, out=offsets[1:]) 76 | 77 | value_segment_max, _ = torch_scatter.segment_max_csr(values, offsets) 78 | value_segment_max_expanded = torch.repeat_interleave(value_segment_max, lengths) 79 | 80 | values_exp = values.sub(value_segment_max_expanded).exp_() 81 | values_sumexp = torch_scatter.segment_sum_csr(values_exp, offsets) 82 | 83 | logsumexp = values_sumexp.log_().add_(value_segment_max) 84 | 85 | ctx.save_for_backward(values, logsumexp, lengths) 86 | 87 | return logsumexp 88 | 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | values, logsumexp, lengths = ctx.saved_tensors 92 | 93 | lse_expand = torch.repeat_interleave(logsumexp, lengths) 94 | grad_out_expand = torch.repeat_interleave(grad_output, lengths) 95 | 96 | return values.sub(lse_expand).exp_().mul_(grad_out_expand), None 97 | 98 | 99 | def segment_logsumexp_scatter(values, scopes): 100 | return SegmentLogsumexpScatter.apply(values, scopes) 101 | 102 | 103 | def segment_argmax_backward(grad_output, argmax, scopes, input_shape, sparse_grad=True): 104 | grad_idx = torch.add(argmax, scopes.select(1, 0)).unsqueeze(0) 105 | grad_input = torch.sparse_coo_tensor(grad_idx, grad_output, size=input_shape) 106 | 107 | if not sparse_grad: 108 | grad_input = grad_input.to_dense() 109 | 110 | return grad_input 111 | 112 | 113 | class SegmentArgmaxPython(torch.autograd.Function): 114 | @staticmethod 115 | def forward(ctx, values, scopes, sparse_grad=True): 116 | ctx.input_shape = values.shape 117 | ctx.sparse_grad = sparse_grad 118 | 119 | max_values, argmax = segment_argmax_loop(values, scopes) 120 | ctx.save_for_backward(argmax, scopes) 121 | ctx.mark_non_differentiable(argmax) 122 | 123 | return max_values, argmax 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output, grad_output_index): 127 | argmax, scopes = ctx.saved_tensors 128 | grad_values = segment_argmax_backward(grad_output, argmax, scopes, ctx.input_shape, ctx.sparse_grad) 129 | return grad_values, None, None 130 | 131 | 132 | def segment_argmax_python(values, scopes, sparse_grad=True): 133 | return SegmentArgmaxPython.apply(values, scopes, sparse_grad) 134 | 135 | 136 | def segment_argmax_scatter(values, scopes, sparse_grad=True): 137 | import torch_scatter 138 | 139 | lengths = scopes.select(1, 1) 140 | offsets = lengths.new_zeros(len(lengths) + 1) 141 | torch.cumsum(lengths, 0, out=offsets[1:]) 142 | 143 | max_val, argmax = torch_scatter.segment_max_csr(values, offsets) 144 | max_val = torch.where(lengths != 0, max_val, max_val.new_full((1,), fill_value=torch.finfo(max_val.dtype).min)) 145 | argmax = argmax.sub(offsets[:-1]) 146 | 147 | return max_val, argmax 148 | 149 | 150 | _segment_argmax_docstring = \ 151 | """ 152 | Compute the maximum value and location in each segment. 153 | 154 | This function computes, for each segment, the maximum value in the segment, 155 | and the offset of the location of the maximum value from the start of the segment. 156 | 157 | This function can handle the case where the segment is zero length, in which case 158 | the maximum is given the lowest finite value representable by the type, and the 159 | index is not defined. 160 | 161 | Parameters 162 | ---------- 163 | values : torch.Tensor 164 | a 1-dimensional `torch.Tensor` representing the values. 165 | scopes : torch.Tensor 166 | a 2-dimensional integer `torch.Tensor` representing the segments. The ith segment 167 | has offset ``scopes[i, 0]`` and length ``scopes[i, 1]``. 168 | 169 | Returns 170 | ------- 171 | torch.Tensor 172 | A tensor of the same type as `values` representing the maximum in each segment. 173 | torch.Tensor 174 | An integer tensor representing the location of the maximum in ecah segment. 175 | """ 176 | 177 | segment_argmax_scatter.__docstring__ = _segment_argmax_docstring 178 | segment_argmax_python.__docstring__ = _segment_argmax_docstring 179 | 180 | _segment_logsumexp_docstring = \ 181 | """ 182 | Compute the log-sum-exp in each segment. 183 | 184 | This function computes, for each segment, the log-sum-exp value in 185 | a numerically stable fashion. 186 | 187 | Parameters 188 | ---------- 189 | values : torch.Tensor 190 | A 1-dimensional tensor representing the values. 191 | scopes : torch.Tensor 192 | A 2-dimensional integer tensor representing the segments. The ith segment 193 | has offset ``scopes[i, 0]`` and length ``scopes[i, 1]``. 194 | 195 | Returns 196 | ------- 197 | torch.Tensor 198 | A tensor representing the log-sum-exp value for each segment. 199 | """ 200 | 201 | segment_logsumexp_scatter.__docstring__ = _segment_logsumexp_docstring 202 | segment_logsumexp_python.__docstring__ = _segment_logsumexp_docstring 203 | 204 | try: 205 | import torch_scatter 206 | segment_logsumexp = segment_logsumexp_scatter 207 | segment_argmax = segment_argmax_scatter 208 | 209 | except ImportError: 210 | segment_logsumexp = segment_logsumexp_python 211 | segment_argmax = segment_argmax_python 212 | 213 | __all__ = ['segment_logsumexp', 'segment_argmax'] 214 | -------------------------------------------------------------------------------- /sketchgraphs/pipeline/make_sketch_dataset.py: -------------------------------------------------------------------------------- 1 | """This script is responsible for creating the main sketch dataset from the JSON files. 2 | 3 | Some basic filtering is applied in order to exclude empty sketches. However, as this dataset 4 | is intended to capture the original data, further filtering is left to scripts such as 5 | `sketchgraphs.pipeline.make_sequence_dataset` which process dataset for learning. 6 | 7 | """ 8 | 9 | import argparse 10 | import collections 11 | import glob 12 | import gzip 13 | import itertools 14 | import json 15 | import multiprocessing as mp 16 | import tarfile 17 | import traceback 18 | import os 19 | 20 | import numpy as np 21 | import tqdm 22 | import zstandard as zstd 23 | 24 | from sketchgraphs.data.sketch import Sketch 25 | from sketchgraphs.data import flat_array 26 | 27 | 28 | def _load_json(path): 29 | open_ = gzip.open if path.endswith('gz') else open 30 | with open_(path) as fh: 31 | return json.load(fh) 32 | 33 | 34 | def filter_sketch(sketch: Sketch): 35 | """Basic filtering which excludes empty sketches, or sketches with no constraints.""" 36 | return len(sketch.constraints) == 0 or len(sketch.entities) == 0 37 | 38 | 39 | def parse_sketch_id(filename): 40 | basename = os.path.basename(filename) 41 | while '.' in basename: 42 | basename, _ = os.path.splitext(basename) 43 | 44 | document_id, part_id = basename.split('_') 45 | return document_id, int(part_id) 46 | 47 | 48 | def load_json_tarball(path): 49 | """Loads a json tarball as an iterable of sketches. 50 | 51 | Parameters 52 | ---------- 53 | path : str 54 | A path to the location of a single shard 55 | 56 | Returns 57 | ------- 58 | iterable of `Sketch` 59 | An iterable of `Sketch` representing all the sketches present in the tarball. 60 | """ 61 | with open(path, 'rb') as base_file: 62 | dctx = zstd.ZstdDecompressor() 63 | with dctx.stream_reader(base_file) as tarball: 64 | with tarfile.open(fileobj=tarball, mode='r|') as directory: 65 | while True: 66 | json_file = directory.next() 67 | if json_file is None: 68 | break 69 | 70 | if not json_file.isfile(): 71 | continue 72 | 73 | document_id, part_id = parse_sketch_id(json_file.name) 74 | data = directory.extractfile(json_file).read() 75 | if len(data) == 0: 76 | # skip empty files 77 | continue 78 | 79 | try: 80 | sketches_json = json.loads(data) 81 | except json.JSONDecodeError as exc: 82 | raise ValueError('Error decoding JSON for document {0} part {1}.'.format(document_id, part_id)) 83 | for i, sketch_json in enumerate(sketches_json): 84 | yield (document_id, part_id, i), Sketch.from_fs_json(sketch_json) 85 | 86 | 87 | 88 | def _worker(paths_queue, processed_sketches, max_sketches, sketch_counter): 89 | num_filtered = 0 90 | num_invalid = 0 91 | 92 | while max_sketches is None or sketch_counter.value < max_sketches: 93 | paths = paths_queue.get() 94 | 95 | if paths is None: 96 | break 97 | 98 | sketches = [] 99 | 100 | for path in paths: 101 | sketch_list = _load_json(path) 102 | 103 | for sketch_json in sketch_list: 104 | try: 105 | sketch = Sketch.from_fs_json(sketch_json) 106 | except Exception as err: 107 | num_invalid += 1 108 | print('Error processing sketch in file {0}'.format(path)) 109 | traceback.print_exception(type(err), err, err.__traceback__) 110 | 111 | if filter_sketch(sketch): 112 | num_filtered += 1 113 | continue 114 | 115 | sketches.append(sketch) 116 | 117 | offsets, data = flat_array.raw_list_flat(sketches) 118 | 119 | processed_sketches.put((offsets, data)) 120 | 121 | with sketch_counter.get_lock(): 122 | sketch_counter.value += len(sketches) 123 | 124 | processed_sketches.put({ 125 | 'num_filtered': num_filtered, 126 | 'num_invalid': num_invalid 127 | }) 128 | 129 | 130 | def process(paths, threads, max_sketches=None): 131 | path_queue = mp.Queue() 132 | sketch_queue = mp.Queue() 133 | sketch_counter = mp.Value('q', 0) 134 | 135 | # Enqueue all the objects 136 | print('Enqueueing files to process.') 137 | paths_it = iter(paths) 138 | while True: 139 | path_chunk = list(itertools.islice(paths_it, 128)) 140 | if len(path_chunk) == 0: 141 | break 142 | 143 | path_queue.put_nowait(path_chunk) 144 | 145 | workers = [] 146 | 147 | for _ in range(threads or mp.cpu_count()): 148 | workers.append( 149 | mp.Process( 150 | target=_worker, 151 | args=(path_queue, sketch_queue, max_sketches, sketch_counter))) 152 | 153 | for worker in workers: 154 | path_queue.put_nowait(None) 155 | worker.start() 156 | 157 | active_workers = len(workers) 158 | 159 | offsets_arrays = [] 160 | data_arrays = [] 161 | 162 | statistics = collections.Counter() 163 | 164 | # Read-in data 165 | with tqdm.tqdm(total=len(paths)) as pbar: 166 | while active_workers > 0: 167 | result = sketch_queue.get() 168 | 169 | if isinstance(result, dict): 170 | statistics += collections.Counter(result) 171 | active_workers -= 1 172 | continue 173 | 174 | offsets, data = result 175 | offsets_arrays.append(offsets) 176 | data_arrays.append(data) 177 | 178 | pbar.update(128) 179 | 180 | # Finalize workers 181 | for worker in workers: 182 | worker.join() 183 | 184 | # Merge final flat array 185 | all_offsets, all_data = flat_array.merge_raw_list(offsets_arrays, data_arrays) 186 | total_sketches = len(all_offsets) - 1 187 | del offsets_arrays 188 | del data_arrays 189 | 190 | # Pack as required 191 | flat_data = flat_array.pack_list_flat(all_offsets, all_data) 192 | del all_offsets 193 | del all_data 194 | 195 | print('Done processing data.\nProcessed sketches: {0}'.format(total_sketches)) 196 | print('Filtered sketches: {0}'.format(statistics['num_filtered'])) 197 | print('Invalid sketches: {0}'.format(statistics['num_invalid'])) 198 | return flat_data 199 | 200 | 201 | def gather_sorted_paths(patterns): 202 | if isinstance(patterns, str): 203 | patterns = [patterns] 204 | out = [] 205 | for pattern in patterns: 206 | out.extend(glob.glob(pattern)) 207 | out.sort() 208 | return out 209 | 210 | 211 | def main(): 212 | parser = argparse.ArgumentParser(description='Process json files to create sketch dataset') 213 | parser.add_argument('--glob_pattern', required=True, action='append', 214 | help='Glob pattern(s) for json / json.gz files.') 215 | parser.add_argument('--output_path', required=True, help='Path for output file.') 216 | parser.add_argument('--max_files', type=int, help='Max number of json files to consider.') 217 | parser.add_argument('--max_sketches', type=int, help='Maximum number of sketches to consider.') 218 | parser.add_argument('--num_threads', type=int, default=0, help='Number of multiprocessing workers.') 219 | 220 | args = parser.parse_args() 221 | 222 | print('Globbing for sketch files to include.') 223 | paths = gather_sorted_paths(args.glob_pattern) 224 | print('Found %i files.' % len(paths)) 225 | if args.max_files is not None: 226 | paths = paths[:args.max_files] 227 | 228 | result = process(paths, args.num_threads, args.max_sketches) 229 | 230 | print('Saving data to {0}'.format(args.output_path)) 231 | np.save(args.output_path, result) 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /sketchgraphs_models/autoconstraint/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset for auto-constraint model.""" 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from sketchgraphs.data import sequence as datalib 9 | from sketchgraphs.pipeline.graph_model.target import NODE_TYPES, EDGE_TYPES, EDGE_TYPES_PREDICTED, NODE_IDX_MAP, EDGE_IDX_MAP 10 | from sketchgraphs.pipeline import graph_model as graph_utils 11 | 12 | from sketchgraphs_models.graph.dataset import EntityFeatureMapping, EdgeFeatureMapping, _sparse_feature_to_torch 13 | 14 | 15 | def _reindex_sparse_batch(sparse_batch, pack_batch_offsets): 16 | return graph_utils.SparseFeatureBatch( 17 | pack_batch_offsets[sparse_batch.index], 18 | sparse_batch.value) 19 | 20 | 21 | def collate(batch): 22 | # Sort batch for packing 23 | node_lengths = [len(x['node_features']) for x in batch] 24 | sorted_indices = np.argsort(node_lengths)[::-1].copy() 25 | 26 | batch = [batch[i] for i in sorted_indices] 27 | 28 | graph = graph_utils.GraphInfo.merge(*[x['graph'] for x in batch]) 29 | edge_label = torch.tensor( 30 | [x['target_edge_label'] for x in batch if x['target_edge_label'] != -1], dtype=torch.int64) 31 | node_features = torch.nn.utils.rnn.pack_sequence([x['node_features'] for x in batch]) 32 | batch_offsets = graph_utils.offsets_from_counts(node_features.batch_sizes) 33 | 34 | node_features_graph_index = torch.cat([ 35 | i + batch_offsets[:graph.node_counts[i]] for i in range(len(batch)) 36 | ], dim=0) 37 | 38 | sparse_node_features = {} 39 | 40 | for k in batch[0]['sparse_node_features']: 41 | sparse_node_features[k] = graph_utils.SparseFeatureBatch.merge( 42 | [_reindex_sparse_batch(x['sparse_node_features'][k], batch_offsets) for x in batch], range(len(batch))) 43 | 44 | last_graph_node_index = batch_offsets[graph.node_counts - 1] + torch.arange(len(graph.node_counts), dtype=torch.int64) 45 | 46 | partner_index_index = [] 47 | partner_index = [] 48 | 49 | stop_partner_index_index = [] 50 | 51 | for i, x in enumerate(batch): 52 | if x['partner_index'] == -1: 53 | stop_partner_index_index.append(i) 54 | continue 55 | 56 | partner_index_index.append(i) 57 | partner_index.append(x['partner_index'] + graph.node_offsets[i]) 58 | 59 | partner_index = graph_utils.SparseFeatureBatch( 60 | torch.tensor(partner_index_index, dtype=torch.int64), 61 | torch.tensor(partner_index, dtype=torch.int64) 62 | ) 63 | 64 | stop_partner_index_index = torch.tensor(stop_partner_index_index, dtype=torch.int64) 65 | 66 | return { 67 | 'graph': graph, 68 | 'edge_label': edge_label, 69 | 'partner_index': partner_index, 70 | 'stop_partner_index_index': stop_partner_index_index, 71 | 'node_features': node_features, 72 | 'node_features_graph_index': node_features_graph_index, 73 | 'sparse_node_features': sparse_node_features, 74 | 'last_graph_node_index': last_graph_node_index, 75 | 'sorted_indices': torch.as_tensor(sorted_indices) 76 | } 77 | 78 | 79 | 80 | def process_node_and_edge_ops(node_ops, edge_ops_in_graph, num_nodes_in_graph, node_feature_mappings: Optional[EntityFeatureMapping]): 81 | all_node_labels = torch.tensor([NODE_IDX_MAP[op.label] for op in node_ops], dtype=torch.int64) 82 | edge_labels = torch.tensor([EDGE_IDX_MAP[op.label] for op in edge_ops_in_graph], dtype=torch.int64) 83 | 84 | if len(edge_ops_in_graph) > 0: 85 | incidence = torch.tensor([(op.references[0], op.references[-1]) for op in edge_ops_in_graph], 86 | dtype=torch.int64).T.contiguous() 87 | incidence = torch.cat((incidence, torch.flip(incidence, [0])), dim=1) 88 | else: 89 | incidence = torch.empty([2, 0], dtype=torch.int64) 90 | 91 | edge_features = edge_labels.repeat(2) 92 | 93 | if node_feature_mappings is not None: 94 | sparse_node_features = _sparse_feature_to_torch(node_feature_mappings.all_sparse_features(node_ops)) 95 | else: 96 | sparse_node_features = None 97 | 98 | graph = graph_utils.GraphInfo.from_single_graph(incidence, None, edge_features, num_nodes_in_graph) 99 | 100 | return { 101 | 'graph': graph, 102 | 'node_features': all_node_labels, 103 | 'sparse_node_features': sparse_node_features 104 | } 105 | 106 | 107 | 108 | class AutoconstraintDataset(torch.utils.data.Dataset): 109 | def __init__(self, sequences, node_feature_mappings, seed=10): 110 | self.sequences = sequences 111 | self.node_feature_mappings = node_feature_mappings 112 | self._rng = np.random.Generator(np.random.Philox(seed)) 113 | 114 | def __getitem__(self, idx): 115 | idx = idx % len(self.sequences) 116 | seq = self.sequences[idx] 117 | 118 | if not isinstance(seq[0], datalib.NodeOp): 119 | raise ValueError('First operation in sequence is not a NodeOp') 120 | 121 | if seq[-1].label != datalib.EntityType.Stop: 122 | seq.append(datalib.NodeOp(datalib.EntityType.Stop, {})) 123 | 124 | node_ops = [seq[0]] 125 | edge_ops = [] 126 | 127 | num_predicted_edge_ops_per_node = [] 128 | num_non_predicted_edge_ops_per_node = [] 129 | 130 | predicted_edge_ops_for_current_node = 0 131 | non_predicted_edge_ops_for_current_node = 0 132 | 133 | for op in seq[1:]: 134 | if isinstance(op, datalib.NodeOp): 135 | num_predicted_edge_ops_per_node.append(predicted_edge_ops_for_current_node) 136 | num_non_predicted_edge_ops_per_node.append(non_predicted_edge_ops_for_current_node) 137 | 138 | predicted_edge_ops_for_current_node = 0 139 | non_predicted_edge_ops_for_current_node = 0 140 | 141 | node_ops.append(op) 142 | else: 143 | if op.label in EDGE_TYPES_PREDICTED: 144 | predicted_edge_ops_for_current_node += 1 145 | else: 146 | non_predicted_edge_ops_for_current_node += 1 147 | 148 | edge_ops.append(op) 149 | 150 | node_ops = node_ops[:-1] 151 | 152 | num_predicted_edge_ops_per_node = np.array(num_predicted_edge_ops_per_node, dtype=np.int64) 153 | num_non_predicted_edge_ops_per_node = np.array(num_non_predicted_edge_ops_per_node, dtype=np.int64) 154 | 155 | predicted_edge_ops_offsets = num_predicted_edge_ops_per_node.cumsum() 156 | non_predicted_edge_ops_offsets = num_non_predicted_edge_ops_per_node.cumsum() 157 | 158 | num_predicted_edge_ops = predicted_edge_ops_offsets[-1] 159 | 160 | stop_target = self._rng.uniform() < len(node_ops) / (len(node_ops) + num_predicted_edge_ops) 161 | 162 | if stop_target: 163 | target_node_idx = self._rng.integers(len(node_ops)) 164 | num_nodes_in_graph = target_node_idx + 1 165 | edge_ops_in_graph = edge_ops[:predicted_edge_ops_offsets[target_node_idx] + non_predicted_edge_ops_offsets[target_node_idx]] 166 | target_edge_label = -1 167 | partner_index = -1 168 | else: 169 | target_predicted_edge_idx = self._rng.integers(num_predicted_edge_ops) 170 | target_node_idx = np.searchsorted(predicted_edge_ops_offsets, target_predicted_edge_idx, side='right') 171 | num_nodes_in_graph = target_node_idx + 1 172 | 173 | target_edge_idx = target_predicted_edge_idx + non_predicted_edge_ops_offsets[target_node_idx] 174 | target_edge = edge_ops[target_edge_idx] 175 | edge_ops_in_graph = edge_ops[:target_edge_idx] 176 | target_edge_label = EDGE_IDX_MAP[target_edge.label] 177 | partner_index = target_edge.references[-1] 178 | assert target_edge_label < len(EDGE_TYPES_PREDICTED) 179 | 180 | input_features = process_node_and_edge_ops( 181 | node_ops, edge_ops_in_graph, num_nodes_in_graph, self.node_feature_mappings) 182 | 183 | return { 184 | **input_features, 185 | 'target_edge_label': target_edge_label, 186 | 'partner_index': partner_index, 187 | } 188 | 189 | __all__ = [ 190 | 'NODE_TYPES', 'EDGE_TYPES', 'EDGE_TYPES_PREDICTED', 'NODE_IDX_MAP', 'EDGE_IDX_MAP', 191 | 'EntityFeatureMapping', 'EdgeFeatureMapping', 'collate', 'AutoconstraintDataset' 192 | ] 193 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/onshape.py: -------------------------------------------------------------------------------- 1 | ''' 2 | onshape 3 | ====== 4 | 5 | Provides access to the Onshape REST API 6 | ''' 7 | 8 | from . import utils 9 | 10 | import os 11 | import random 12 | import string 13 | import json 14 | import hmac 15 | import hashlib 16 | import base64 17 | import urllib.request 18 | import urllib.parse 19 | import urllib.error 20 | import datetime 21 | import requests 22 | from urllib.parse import urlparse 23 | from urllib.parse import parse_qs 24 | 25 | __all__ = [ 26 | 'Onshape' 27 | ] 28 | 29 | 30 | class Onshape(): 31 | ''' 32 | Provides access to the Onshape REST API. 33 | 34 | Attributes: 35 | - stack (str): Base URL 36 | - creds (str, default='./sketchgraphs/onshape/creds/creds.json'): Credentials location 37 | - logging (bool, default=True): Turn logging on or off 38 | ''' 39 | 40 | def __init__(self, stack, creds='./sketchgraphs/onshape/creds/creds.json', logging=True): 41 | ''' 42 | Instantiates an instance of the Onshape class. Reads credentials from a JSON file 43 | of this format: 44 | 45 | { 46 | "http://cad.onshape.com": { 47 | "access_key": "YOUR KEY HERE", 48 | "secret_key": "YOUR KEY HERE" 49 | }, 50 | etc... add new object for each stack to test on 51 | } 52 | 53 | The creds.json file should be stored in the root project folder; optionally, 54 | you can specify the location of a different file. 55 | 56 | Args: 57 | - stack (str): Base URL 58 | - creds (str, default='./sketchgraphs/onshape/creds/creds.json'): Credentials location 59 | ''' 60 | 61 | if not os.path.isfile(creds): 62 | raise IOError('%s is not a file' % creds) 63 | 64 | with open(creds) as f: 65 | try: 66 | stacks = json.load(f) 67 | if stack in stacks: 68 | self._url = stack 69 | self._access_key = stacks[stack]['access_key'].encode( 70 | 'utf-8') 71 | self._secret_key = stacks[stack]['secret_key'].encode( 72 | 'utf-8') 73 | self._logging = logging 74 | else: 75 | raise ValueError('specified stack not in file') 76 | except TypeError: 77 | raise ValueError('%s is not valid json' % creds) 78 | 79 | if self._logging: 80 | utils.log('onshape instance created: url = %s, access key = %s' % ( 81 | self._url, self._access_key)) 82 | 83 | def _make_nonce(self): 84 | ''' 85 | Generate a unique ID for the request, 25 chars in length 86 | 87 | Returns: 88 | - str: Cryptographic nonce 89 | ''' 90 | 91 | chars = string.digits + string.ascii_letters 92 | nonce = ''.join(random.choice(chars) for i in range(25)) 93 | 94 | if self._logging: 95 | utils.log('nonce created: %s' % nonce) 96 | 97 | return nonce 98 | 99 | def _make_auth(self, method, date, nonce, path, query={}, ctype='application/json'): 100 | ''' 101 | Create the request signature to authenticate 102 | 103 | Args: 104 | - method (str): HTTP method 105 | - date (str): HTTP date header string 106 | - nonce (str): Cryptographic nonce 107 | - path (str): URL pathname 108 | - query (dict, default={}): URL query string in key-value pairs 109 | - ctype (str, default='application/json'): HTTP Content-Type 110 | ''' 111 | 112 | query = urllib.parse.urlencode(query) 113 | 114 | hmac_str = (method + '\n' + nonce + '\n' + date + '\n' + ctype + '\n' + path + 115 | '\n' + query + '\n').lower().encode('utf-8') 116 | 117 | signature = base64.b64encode( 118 | hmac.new(self._secret_key, hmac_str, digestmod=hashlib.sha256).digest()) 119 | auth = 'On ' + self._access_key.decode('utf-8') + \ 120 | ':HmacSHA256:' + signature.decode('utf-8') 121 | 122 | if self._logging: 123 | utils.log({ 124 | 'query': query, 125 | 'hmac_str': hmac_str, 126 | 'signature': signature, 127 | 'auth': auth 128 | }) 129 | 130 | return auth 131 | 132 | def _make_headers(self, method, path, query={}, headers={}): 133 | ''' 134 | Creates a headers object to sign the request 135 | 136 | Args: 137 | - method (str): HTTP method 138 | - path (str): Request path, e.g. /api/documents. No query string 139 | - query (dict, default={}): Query string in key-value format 140 | - headers (dict, default={}): Other headers to pass in 141 | 142 | Returns: 143 | - dict: Dictionary containing all headers 144 | ''' 145 | 146 | date = datetime.datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') 147 | nonce = self._make_nonce() 148 | ctype = headers.get( 149 | 'Content-Type') if headers.get('Content-Type') else 'application/json' 150 | 151 | auth = self._make_auth(method, date, nonce, path, 152 | query=query, ctype=ctype) 153 | 154 | req_headers = { 155 | 'Content-Type': 'application/json', 156 | 'Date': date, 157 | 'On-Nonce': nonce, 158 | 'Authorization': auth, 159 | 'User-Agent': 'Onshape Python Sample App', 160 | 'Accept': 'application/json' 161 | } 162 | 163 | # add in user-defined headers 164 | for h in headers: 165 | req_headers[h] = headers[h] 166 | 167 | return req_headers 168 | 169 | def request(self, method, path, query={}, headers={}, body={}, base_url=None, timeout=None, check_status=True): 170 | ''' 171 | Issues a request to Onshape 172 | 173 | Args: 174 | - method (str): HTTP method 175 | - path (str): Path e.g. /api/documents/:id 176 | - query (dict, default={}): Query params in key-value pairs 177 | - headers (dict, default={}): Key-value pairs of headers 178 | - body (dict, default={}): Body for POST request 179 | - base_url (str, default=None): Host, including scheme and port (if different from creds file) 180 | - timeout (float, default=None): Timeout to use with requests.request(). 181 | - check_status (bool, default=True): Raise exception if response status code is unsuccessful. 182 | 183 | Returns: 184 | - requests.Response: Object containing the response from Onshape 185 | ''' 186 | 187 | req_headers = self._make_headers(method, path, query, headers) 188 | if base_url is None: 189 | base_url = self._url 190 | url = base_url + path + '?' + urllib.parse.urlencode(query) 191 | 192 | if self._logging: 193 | utils.log(body) 194 | utils.log(req_headers) 195 | utils.log('request url: ' + url) 196 | 197 | # only parse as json string if we have to 198 | body = json.dumps(body) if type(body) == dict else body 199 | 200 | res = requests.request( 201 | method, url, headers=req_headers, data=body, allow_redirects=False, stream=True, 202 | timeout=timeout) 203 | 204 | if res.status_code == 307: 205 | location = urlparse(res.headers["Location"]) 206 | querystring = parse_qs(location.query) 207 | 208 | if self._logging: 209 | utils.log('request redirected to: ' + location.geturl()) 210 | 211 | new_query = {} 212 | new_base_url = location.scheme + '://' + location.netloc 213 | 214 | for key in querystring: 215 | # won't work for repeated query params 216 | new_query[key] = querystring[key][0] 217 | 218 | return self.request(method, location.path, query=new_query, headers=headers, base_url=new_base_url) 219 | elif not 200 <= res.status_code <= 206: 220 | if self._logging: 221 | utils.log('request failed, details: ' + res.text, level=1) 222 | else: 223 | if self._logging: 224 | utils.log('request succeeded, details: ' + res.text) 225 | if check_status: 226 | res.raise_for_status() 227 | return res -------------------------------------------------------------------------------- /sketchgraphs_models/graph/train/harness.py: -------------------------------------------------------------------------------- 1 | """This module contains the main training harness for training graph models. """ 2 | 3 | import collections 4 | import itertools 5 | import os 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from sketchgraphs_models import training 12 | from sketchgraphs_models.graph import model as graph_model 13 | from sketchgraphs_models.nn import summary 14 | 15 | 16 | def _detach(x): 17 | if isinstance(x, torch.Tensor): 18 | return x.detach() 19 | else: 20 | return x 21 | 22 | 23 | def _mean_value(v): 24 | values = [a.mean().cpu().numpy() for a in v] 25 | return np.mean(values) if values else np.nan 26 | 27 | 28 | def _total_loss(losses): 29 | result = 0 30 | 31 | for v in losses.values(): 32 | if v is None: 33 | continue 34 | 35 | if isinstance(v, dict): 36 | result += _total_loss(v) 37 | else: 38 | result += v.sum() 39 | 40 | return result 41 | 42 | 43 | class GraphModelHarness(training.TrainingHarness): 44 | """This class is the main harness for training graph models. 45 | 46 | The harness is responsible for coordinating all the procedures that surround training, 47 | such as learning rate scheduling, data loading, and logging. 48 | """ 49 | def __init__(self, model, opt, node_feature_dimension, edge_feature_dimension, 50 | config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None, 51 | profile_enabled=False, additional_model_information=None): 52 | super(GraphModelHarness, self).__init__(model, opt, config_train, config_eval, dist_config) 53 | self.scheduler = scheduler 54 | self.output_dir = output_dir 55 | 56 | self.node_feature_dimension = node_feature_dimension 57 | self.edge_feature_dimension = edge_feature_dimension 58 | self.feature_dimension = {**node_feature_dimension, **edge_feature_dimension} 59 | self.profile_enabled = profile_enabled 60 | self._last_profile_step = 0 61 | self.additional_model_information = additional_model_information or {} 62 | 63 | def _make_feature_summary(fd): 64 | return collections.OrderedDict( 65 | (t, collections.OrderedDict( 66 | (feature_name, summary.ClassificationSummary(dim)) 67 | for feature_name, dim in feature_description.items() 68 | )) for t, feature_description in fd.items()) 69 | 70 | self.edge_feature_summaries = _make_feature_summary(edge_feature_dimension) 71 | self.node_feature_summaries = _make_feature_summary(node_feature_dimension) 72 | 73 | 74 | def _get_profile_path(self, global_step): 75 | if not self.profile_enabled: 76 | return None 77 | 78 | if self._last_profile_step is None or global_step - self._last_profile_step > 100000: 79 | self._last_profile_step = global_step 80 | return 'profile_step_{0}.pkl'.format(global_step) 81 | 82 | return None 83 | 84 | def single_step(self, batch, global_step): 85 | self.opt.zero_grad() 86 | 87 | profile_path = self._get_profile_path(global_step) 88 | with torch.autograd.profiler.profile(enabled=profile_path is not None, use_cuda=True) as trace: 89 | with torch.autograd.profiler.record_function("forward"): 90 | readout = self.model(batch) 91 | losses, accuracy, edge_metrics, node_metrics = graph_model.compute_losses( 92 | readout, batch, self.feature_dimension) 93 | total_loss = _total_loss(losses) 94 | if self.model.training: 95 | with torch.autograd.profiler.record_function("backward"): 96 | total_loss.backward() 97 | 98 | with torch.autograd.profiler.record_function("opt_update"): 99 | self.opt.step() 100 | 101 | if profile_path is not None: 102 | with open(profile_path, 'wb') as f: 103 | pickle.dump(trace, f, pickle.HIGHEST_PROTOCOL) 104 | 105 | losses = training.map_structure_flat(losses, _detach) 106 | losses = graph_model.compute_average_losses(losses, batch['graph_counts']) 107 | avg_loss = total_loss.detach() / float(sum(batch['graph_counts'])) 108 | losses['average'] = avg_loss 109 | 110 | if self.is_leader(): 111 | def _record_classification_summaries(metrics, summaries): 112 | for t, (labels, preds) in metrics.items(): 113 | for i, cs in enumerate(summaries[t].values()): 114 | cs.record_statistics(labels[:, i], preds[:, i]) 115 | 116 | _record_classification_summaries(edge_metrics, self.edge_feature_summaries) 117 | _record_classification_summaries(node_metrics, self.node_feature_summaries) 118 | 119 | return losses, accuracy 120 | 121 | def on_epoch_end(self, epoch, global_step): 122 | if self.scheduler is not None: 123 | self.scheduler.step() 124 | 125 | if self.config_train.tb_writer is not None and self.is_leader(): 126 | lr = self.scheduler.get_last_lr()[0] 127 | self.config_train.tb_writer.add_scalar('learning_rate', lr, global_step) 128 | 129 | if self.is_leader() and self.output_dir is not None and (epoch + 1) % 10 == 0: 130 | self.log('Saving checkpoint for epoch {}'.format(epoch + 1)) 131 | torch.save( 132 | { 133 | 'opt': self.opt.state_dict(), 134 | 'model': self.model.state_dict(), 135 | 'epoch': epoch, 136 | 'global_step': global_step, 137 | **self.additional_model_information, 138 | }, 139 | os.path.join(self.output_dir, 'model_state_{0}.pt'.format(epoch + 1))) 140 | 141 | def write_summaries(self, global_step, losses, accuracies, tb_writer): 142 | if tb_writer is None: 143 | return 144 | 145 | for k, v in losses.items(): 146 | if k == 'edge_features' or k == 'node_features': 147 | v = _mean_value(v.values()) 148 | tb_writer.add_scalar('loss/' + k, v, global_step) 149 | 150 | for k, v in accuracies.items(): 151 | if k == 'edge_features' or k == 'node_features': 152 | v = _mean_value(v.values()) 153 | tb_writer.add_scalar('accuracy/' + k, v, global_step) 154 | 155 | for t, feature_schema in self.edge_feature_summaries.items(): 156 | for n, cs in feature_schema.items(): 157 | cs.write_tensorboard(tb_writer, 'kappa' + '/' + t.name + '/' + n, global_step) 158 | 159 | for t, feature_schema in self.node_feature_summaries.items(): 160 | for n, cs in feature_schema.items(): 161 | cs.write_tensorboard(tb_writer, 'kappa' + '/' + t.name + '/' + n, global_step) 162 | 163 | def print_statistics(self, loss_acc, accuracy_acc): 164 | self.log(f'\tLoss ({loss_acc["average"]:.3f}): Node({loss_acc["node_label"]:.3f}, {loss_acc["node_stop"]:.3f}) ' 165 | f'Edge({loss_acc["edge_label"]:.3f}, {loss_acc["edge_partner"]:.3f}, ' 166 | f'{_mean_value(loss_acc["edge_features"].values()):.3f}) ' 167 | f'Subnode({loss_acc["subnode_stop"]:.3f}).') 168 | self.log(f'\tAccuracy: Node({accuracy_acc["node_label"]:4.1%}, {accuracy_acc["node_stop"]:4.1%}) ' 169 | f'Edge({accuracy_acc["edge_label"]:4.1%}, {accuracy_acc["edge_partner"]:4.1%}, ' 170 | f'{_mean_value(accuracy_acc["edge_features"].values()):4.1%}) ' 171 | f'Subnode({accuracy_acc["subnode_stop"]:4.1%})') 172 | self.log() 173 | 174 | def _summary_text(target, features): 175 | return f'{target.name}: ' + '; '.join(f'{n} ({cs.cohen_kappa():.3f})' for n, cs in features.items()) 176 | 177 | if self.node_feature_summaries: 178 | self.log('Kappa Entity') 179 | for t, features in self.node_feature_summaries.items(): 180 | self.log(_summary_text(t, features)) 181 | 182 | self.log() 183 | 184 | if self.edge_feature_summaries: 185 | self.log('Kappa Edges') 186 | for t, features in self.edge_feature_summaries.items(): 187 | self.log(_summary_text(t, features)) 188 | self.log() 189 | 190 | def reset_statistics(self): 191 | for features in itertools.chain(self.node_feature_summaries.values(), self.edge_feature_summaries.values()): 192 | for classification_summary in features.values(): 193 | classification_summary.reset_statistics() 194 | 195 | super(GraphModelHarness, self).reset_statistics() 196 | -------------------------------------------------------------------------------- /sketchgraphs/data/constraint_checks.py: -------------------------------------------------------------------------------- 1 | """This module implements a basic constraint checker for a solved graph. 2 | 3 | This module implements a number of functions to help check basic relational constraints 4 | between entities. Note that only checking is implemented: this is not an implementation 5 | of a solver, and cannot solve for the desired constraints. 6 | 7 | """ 8 | 9 | import numpy as np 10 | from numpy.linalg import norm 11 | 12 | from ._entity import Point, Line, Circle, Arc, SubnodeType, ENTITY_TYPE_TO_CLASS, Entity 13 | from ._constraint import ConstraintType 14 | from .sequence import NodeOp, EdgeOp 15 | 16 | 17 | def get_entity_by_idx(seq, idx: int) -> Entity: 18 | """Returns the entity or sub-entity corresponding to idx. 19 | 20 | Parameters 21 | ---------- 22 | seq : List[Union[NodeOp, EdgeOp]] 23 | A list of node and edge operations representing the construction sequence. 24 | idx : int 25 | An integer representing the index of the desired entity. 26 | 27 | Returns 28 | ------- 29 | Entity 30 | An entity object representing the entity corresponding to the `NodeOp` at the given index. 31 | """ 32 | node_ops = [op for op in seq if isinstance(op, NodeOp)] 33 | label = node_ops[idx].label 34 | 35 | def _entity_from_op(op): 36 | entity = ENTITY_TYPE_TO_CLASS[op.label]('dummy_id') 37 | for param_id, val in op.parameters.items(): 38 | setattr(entity, param_id, val) 39 | return entity 40 | 41 | if not isinstance(label, SubnodeType): 42 | return _entity_from_op(node_ops[idx]) 43 | 44 | for i in reversed(range(idx)): 45 | this_label = node_ops[i].label 46 | if not isinstance(this_label, SubnodeType): 47 | parent = _entity_from_op(node_ops[i]) 48 | break 49 | 50 | if label == SubnodeType.SN_Start: 51 | return parent.start_point 52 | elif label == SubnodeType.SN_End: 53 | return parent.end_point 54 | elif label == SubnodeType.SN_Center: 55 | return parent.center_point 56 | 57 | raise ValueError('Could not find entity corresponding to idx.') 58 | 59 | 60 | def check_edge_satisfied(seq, op: EdgeOp): 61 | """Determines whether the given EdgeOp instance is geometrically satisfied in the graph represented by `seq`. 62 | 63 | Parameters 64 | ---------- 65 | seq : List[Union[EdgeOp, NodeOp]] 66 | A construction sequence representing the underlying graph. 67 | op : EdgeOp 68 | An edge op to check for validity in the current graph. 69 | 70 | Returns 71 | ------- 72 | bool 73 | `True` if the current edge constraint is satisified, otherwise `False`. 74 | 75 | Raises 76 | ------ 77 | ValueError 78 | If the current constraint is not supported (e.g. its type is not supported), 79 | or it refers to an external entity, which is not supported. 80 | """ 81 | if 0 in op.references: 82 | raise ValueError('External constraints not supported.') 83 | 84 | entities = [get_entity_by_idx(seq, ref) for ref in op.references] 85 | 86 | try: 87 | constraint_f = CONSTRAINT_BY_LABEL[op.label] 88 | except KeyError: 89 | raise ValueError('%s not currently supported.' % op.label) 90 | 91 | return constraint_f(*entities) 92 | 93 | 94 | def get_sorted_types(entities): 95 | """Obtains the types and sorts the entities based on their type order. 96 | 97 | Parameters 98 | ---------- 99 | entities : iterable of `Entity` 100 | An list of entities to be sorted. 101 | 102 | Returns 103 | ------- 104 | types : List 105 | A list of types representing the type of each entity 106 | entities : List 107 | A list of entities, containing the same elements as the input iterable, 108 | but sorted in the order given by types. 109 | """ 110 | types = [Point if isinstance(ent, np.ndarray) else type(ent) for ent in entities] 111 | type_names = [t.__name__ for t in types] 112 | idxs = np.argsort(type_names) 113 | return [types[idx] for idx in idxs], [entities[idx] for idx in idxs] 114 | 115 | 116 | def _ensure_array(point): 117 | if isinstance(point, np.ndarray): 118 | return point 119 | else: 120 | return np.array([point.x, point.y]) 121 | 122 | 123 | def coincident(*entities): 124 | types, entities = get_sorted_types(entities) 125 | 126 | if types == [Point, Point]: 127 | return np.allclose(_ensure_array(entities[0]), _ensure_array(entities[1])) 128 | 129 | elif types == [Line, Point]: 130 | vec1 = entities[0].end_point - entities[0].start_point 131 | vec2 = _ensure_array(entities[1]) - entities[0].start_point 132 | return np.isclose(np.cross(vec1, vec2), 0) 133 | 134 | elif types == [Line, Line]: 135 | return coincident(entities[0], entities[1].start_point) and coincident(entities[0], entities[1].end_point) 136 | 137 | elif types in [[Arc, Point], [Circle, Point]]: 138 | circle_or_arc, point = entities 139 | dist = norm(_ensure_array(point) - circle_or_arc.center_point) 140 | return np.isclose(circle_or_arc.radius, dist) 141 | 142 | elif types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]: 143 | return np.allclose([entities[0].xCenter, entities[0].yCenter, entities[0].radius], 144 | [entities[1].xCenter, entities[1].yCenter, entities[1].radius]) 145 | 146 | else: 147 | return None 148 | 149 | 150 | def parallel(*ents): 151 | types, ents = get_sorted_types(ents) 152 | 153 | if types == [Line, Line]: 154 | vec1 = ents[0].end_point - ents[0].start_point 155 | vec2 = ents[1].end_point - ents[1].start_point 156 | return np.isclose(np.cross(vec1, vec2), 0) 157 | 158 | else: 159 | return None 160 | 161 | 162 | def horizontal(*ents): 163 | types, ents = get_sorted_types(ents) 164 | 165 | if types == [Line]: 166 | return horizontal(ents[0].start_point, ents[0].end_point) 167 | elif types == [Point, Point]: 168 | return np.isclose(ents[0][1], ents[1][1], atol=1e-6) 169 | else: 170 | return None 171 | 172 | 173 | def vertical(*ents): 174 | types, ents = get_sorted_types(ents) 175 | 176 | if types == [Line]: 177 | return vertical(ents[0].start_point, ents[0].end_point) 178 | elif types == [Point, Point]: 179 | return np.isclose(ents[0][0], ents[1][0], atol=1e-6) 180 | else: 181 | return None 182 | 183 | 184 | def perpendicular(*ents): 185 | types, ents = get_sorted_types(ents) 186 | 187 | if types == [Line, Line]: 188 | vec1 = ents[0].start_point - ents[0].end_point 189 | vec2 = ents[1].start_point - ents[1].end_point 190 | return np.isclose(np.dot(vec1, vec2), 0) 191 | else: 192 | return None 193 | 194 | 195 | def tangent(*ents): 196 | types, ents = get_sorted_types(ents) 197 | 198 | if types == [Circle, Line]: 199 | circle, line = ents 200 | p1, p2 = (line.start_point, line.end_point) 201 | p3 = circle.center_point 202 | 203 | line_dir = p2 - p1 204 | line_dir_norm = norm(line_dir) 205 | 206 | if np.abs(line_dir_norm) < 1e-6: 207 | dist = norm(p1 - p3) 208 | else: 209 | dist = norm(np.cross(line_dir, p1-p3)) / line_dir_norm 210 | 211 | return np.isclose(circle.radius, dist) 212 | 213 | elif types == [Arc, Line]: 214 | arc, line = ents 215 | circle = Circle('', xCenter=arc.xCenter, yCenter=arc.yCenter, radius=arc.radius) 216 | return tangent(circle, line) 217 | elif types in [[Arc, Arc], [Arc, Circle], [Circle, Circle]]: 218 | dist = norm(ents[1].center_point - ents[0].center_point) 219 | return (np.isclose(dist, ents[0].radius + ents[1].radius) 220 | or np.isclose(dist, np.abs(ents[0].radius - ents[1].radius))) 221 | else: 222 | return None 223 | 224 | 225 | def equal(*ents): 226 | types, ents = get_sorted_types(ents) 227 | 228 | if types == [Line, Line]: 229 | line0, line1 = ents 230 | vec0 = line0.end_point - line0.start_point 231 | vec1 = line1.end_point - line1.start_point 232 | return np.isclose(norm(vec0), norm(vec1)) 233 | 234 | elif types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]: 235 | return np.isclose(ents[0].radius, ents[1].radius) 236 | 237 | else: 238 | return None 239 | 240 | 241 | def midpoint(*ents): 242 | types, ents = get_sorted_types(ents) 243 | 244 | if types == [Line, Point]: 245 | line, point = ents 246 | mid_coords = (line.start_point + line.end_point) / 2 247 | return np.allclose(mid_coords, _ensure_array(point)) 248 | 249 | else: 250 | return None 251 | 252 | 253 | def concentric(*ents): 254 | types, ents = get_sorted_types(ents) 255 | 256 | if types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]: 257 | return coincident(ents[0].center_point, ents[1].center_point) 258 | 259 | else: 260 | return None 261 | 262 | 263 | CONSTRAINT_BY_LABEL = { 264 | ConstraintType.Coincident: coincident, 265 | ConstraintType.Parallel: parallel, 266 | ConstraintType.Horizontal: horizontal, 267 | ConstraintType.Vertical: vertical, 268 | ConstraintType.Perpendicular: perpendicular, 269 | ConstraintType.Tangent: tangent, 270 | ConstraintType.Equal: equal, 271 | ConstraintType.Midpoint: midpoint, 272 | ConstraintType.Concentric: concentric 273 | } 274 | -------------------------------------------------------------------------------- /sketchgraphs/onshape/call.py: -------------------------------------------------------------------------------- 1 | """Simple command line utilties for interacting with Onshape API. 2 | 3 | A sketch is considered a feature of an Onshape PartStudio. This script enables adding a sketch to a part (add_feature), retrieving all features from a part including sketches (get_features), and retrieving the possibly updated state of each sketch's entities/primitives post constraint solving (get_info). 4 | """ 5 | import argparse 6 | import json 7 | import urllib.parse 8 | 9 | from . import Client 10 | 11 | 12 | TEMPLATE_PATH = 'sketchgraphs/onshape/feature_template.json' 13 | 14 | 15 | def _parse_resp(resp): 16 | """Parse the response of a retrieval call. 17 | """ 18 | parsed_resp = json.loads(resp.content.decode('utf8').replace("'", '"')) 19 | return parsed_resp 20 | 21 | 22 | def _save_or_print_resp(resp_dict, output_path=None, indent=4): 23 | """Saves or prints the given response dict. 24 | """ 25 | if output_path: 26 | with open(output_path, 'w') as fh: 27 | json.dump(resp_dict, fh, indent=indent) 28 | else: 29 | print(json.dumps(resp_dict, indent=indent)) 30 | 31 | 32 | def _create_client(logging): 33 | """Creates a `Client` with the given bool value for `logging`. 34 | """ 35 | client = Client(stack='https://cad.onshape.com', 36 | logging=logging) 37 | return client 38 | 39 | 40 | def _parse_url(url): 41 | """Extracts doc, workspace, element ids from url. 42 | """ 43 | _, _, docid, _, wid, _, eid = urllib.parse.urlparse(url).path.split('/') 44 | return docid, wid, eid 45 | 46 | 47 | def update_template(url, logging=False): 48 | """Updates version identifiers in feature_template.json. 49 | 50 | Parameters 51 | ---------- 52 | url : str 53 | URL of Onshape PartStudio 54 | logging: bool 55 | Whether to log API messages (default False) 56 | 57 | Returns 58 | ------- 59 | None 60 | """ 61 | # Get PartStudio features (including version IDs) 62 | features = get_features(url, logging) 63 | # Get current feature template 64 | with open(TEMPLATE_PATH, 'r') as fh: 65 | template = json.load(fh) 66 | for version_key in ['serializationVersion', 'sourceMicroversion', 'libraryVersion']: 67 | template[version_key] = features[version_key] 68 | # Save updated feature template 69 | with open(TEMPLATE_PATH, 'w') as fh: 70 | json.dump(template, fh, indent=4) 71 | 72 | 73 | def add_feature(url, sketch_dict, sketch_name=None, logging=False): 74 | """Adds a sketch to a part. 75 | 76 | Parameters 77 | ---------- 78 | url : str 79 | URL of Onshape PartStudio 80 | sketch_dict: dict 81 | A dictionary representing a `Sketch` instance with keys `entities` and `constraints` 82 | sketch_name: str 83 | Optional name for the sketch. If none provided, defaults to 'My Sketch'. 84 | logging: bool 85 | Whether to log API messages (default False) 86 | 87 | Returns 88 | ------- 89 | None 90 | """ 91 | # Get doc ids and create Client 92 | docid, wid, eid = _parse_url(url) 93 | client = _create_client(logging) 94 | # Get feature template 95 | with open(TEMPLATE_PATH, 'r') as fh: 96 | template = json.load(fh) 97 | # Add sketch's entities and constraints to the template 98 | template['feature']['message']['entities'] = sketch_dict['entities'] 99 | template['feature']['message']['constraints'] = sketch_dict['constraints'] 100 | if not sketch_name: 101 | sketch_name = 'My Sketch' 102 | template['feature']['message']['name'] = sketch_name 103 | # Send to Onshape 104 | client.add_feature(docid, wid, eid, payload=template) 105 | 106 | 107 | def get_features(url, logging=False): 108 | """Retrieves features from a part. 109 | 110 | Parameters 111 | ---------- 112 | url : str 113 | URL of Onshape PartStudio 114 | logging : bool 115 | Whether to log API messages (default False) 116 | 117 | Returns 118 | ------- 119 | features : dict 120 | A dictionary containing the part's features 121 | 122 | """ 123 | # Get doc ids and create Client 124 | docid, wid, eid = _parse_url(url) 125 | client = _create_client(logging) 126 | # Get features 127 | resp = client.get_features(docid, wid, eid) 128 | features = _parse_resp(resp) 129 | return features 130 | 131 | 132 | def get_info(url, sketch_name=None, logging=False): 133 | """Retrieves possibly updated states of entities in a part's sketches. 134 | 135 | Parameters 136 | ---------- 137 | url : str 138 | URL of Onshape PartStudio 139 | sketch_name : str 140 | If provided, only the entity info for the specified sketch will be returned. Otherwise, the full response is returned. 141 | logging : bool 142 | Whether to log API messages (default False) 143 | 144 | Returns 145 | ------- 146 | sketch_info : dict 147 | A dictionary containing entity info for sketches 148 | 149 | """ 150 | # Get doc ids and create Client 151 | docid, wid, eid = _parse_url(url) 152 | client = _create_client(logging) 153 | # Get features 154 | resp = client.sketch_information(docid, wid, eid) 155 | sketch_info = _parse_resp(resp) 156 | if sketch_name: 157 | sketch_found = False 158 | for sk in sketch_info['sketches']: 159 | if sk['sketch'] == sketch_name: 160 | sketch_info = sk 161 | sketch_found = True 162 | break 163 | if not sketch_found: 164 | raise ValueError("No sketch found with given name.") 165 | return sketch_info 166 | 167 | 168 | def get_states(url, logging=False): 169 | """Retrieves states of sketches in a part. 170 | 171 | If there are no issues with a sketch, the feature state is `OK`. If there 172 | are issues, e.g., unsolved constraints, the state is `WARNING`. All sketches 173 | in the queried PartStudio must have unique names. 174 | 175 | Parameters 176 | ---------- 177 | url : str 178 | URL of Onshape PartStudio 179 | logging : bool 180 | Whether to log API messages (default False) 181 | 182 | Returns 183 | ------- 184 | sketch_states : dict 185 | A dictionary containing sketch names as keys and associated states as 186 | values 187 | """ 188 | # Get features for the given part url 189 | features = get_features(url, logging=logging) 190 | # Gather all feature states 191 | feat_states = {f['key']:f['value']['message']['featureStatus'] 192 | for f in features['featureStates']} 193 | # Gather sketch states 194 | sketch_states = {} 195 | for feat in features['features']: 196 | if feat['typeName'] != 'BTMSketch': 197 | continue 198 | sk_name = feat['message']['name'] 199 | sk_id = feat['message']['featureId'] 200 | # Check if name already encountered 201 | if sk_name in sketch_states: 202 | raise ValueError("Each sketch must have a unique name.") 203 | sketch_states[sk_name] = feat_states[sk_id] 204 | return sketch_states 205 | 206 | 207 | def main(): 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--url', 210 | help='URL of Onshape PartStudio',required=True) 211 | parser.add_argument('--action', 212 | help='The API call to perform', required=True, 213 | choices=['add_feature', 'get_features', 'get_info', 'get_states', 214 | 'update_template']) 215 | parser.add_argument('--payload_path', 216 | help='Path to payload being sent to Onshape', default=None) 217 | parser.add_argument('--output_path', 218 | help='Path to save result of API call', default=None) 219 | parser.add_argument('--enable_logging', 220 | help='Whether to log API messages', action='store_true') 221 | parser.add_argument('--sketch_name', 222 | help='Optional name for sketch', default=None) 223 | 224 | args = parser.parse_args() 225 | 226 | # Parse the URL 227 | _, _, docid, _, wid, _, eid = urllib.parse.urlparse(args.url).path.split('/') 228 | 229 | # Create client 230 | client = Client(stack='https://cad.onshape.com', 231 | logging=args.enable_logging) 232 | 233 | # Perform the specified action 234 | if args.action =='add_feature': 235 | # Add a sketch to a part 236 | if not args.payload_path: 237 | raise ValueError("payload_path required when adding a feature") 238 | with open(args.payload_path, 'r') as fh: 239 | sketch_dict = json.load(fh) 240 | add_feature(args.url, sketch_dict, args.sketch_name, 241 | args.enable_logging) 242 | 243 | elif args.action == 'get_features': 244 | # Retrieve features from a part 245 | features = get_features(args.url, args.enable_logging) 246 | _save_or_print_resp(features, output_path=args.output_path) 247 | 248 | elif args.action == 'get_info': 249 | # Retrieve possibly updated states of entities in a part's sketches 250 | sketch_info = get_info(args.url, args.sketch_name, args.enable_logging) 251 | _save_or_print_resp(sketch_info, output_path=args.output_path) 252 | 253 | elif args.action == 'get_states': 254 | # Retrieve states of sketches in a part 255 | sketch_states = get_states(args.url, args.enable_logging) 256 | _save_or_print_resp(sketch_states, output_path=args.output_path) 257 | 258 | elif args.action == 'update_template': 259 | # Updates version identifiers in template 260 | update_template(args.url, args.enable_logging) 261 | 262 | 263 | if __name__ == '__main__': 264 | main() --------------------------------------------------------------------------------