├── tests
├── __init__.py
├── testdata
│ ├── sample_json.tar.xz
│ └── sg_t16.stats.pkl.gz
├── conftest.py
├── test_constraint_checks.py
├── test_flat_array.py
├── test_graph_dataset.py
├── test_data.py
├── test_graph_model.py
└── test_torch_extensions.py
├── sketchgraphs_models
├── graph
│ ├── train
│ │ ├── __main__.py
│ │ ├── data_loading.py
│ │ └── harness.py
│ ├── __init__.py
│ ├── entropy_rate.py
│ ├── dataset
│ │ └── benchmark.py
│ ├── eval_likelihood.py
│ └── model
│ │ ├── message_passing.py
│ │ └── numerical_features.py
├── autoconstraint
│ ├── scripts
│ │ ├── __init__.py
│ │ ├── eval_statistics_stratified.py
│ │ └── eval_statistics_mask.py
│ ├── __init__.py
│ ├── eval_likelihood.py
│ └── dataset.py
├── torch_extensions
│ ├── __init__.py
│ ├── index.py
│ ├── _repeat_interleave.py
│ ├── segment_pool.py
│ └── segment_ops.py
├── __init__.py
├── nn
│ ├── distributed.py
│ ├── data_util.py
│ ├── functional.py
│ ├── summary.py
│ └── __init__.py
└── distributed_utils.py
├── assets
├── sketchgraphs.gif
└── sketch_w_graph.png
├── requirements.txt
├── sketchgraphs
├── onshape
│ ├── creds
│ │ └── creds.json.example
│ ├── __init__.py
│ ├── LICENSE
│ ├── feature_template.json
│ ├── utils.py
│ ├── onshape.py
│ └── call.py
├── pipeline
│ ├── __init__.py
│ ├── graph_model
│ │ ├── __init__.py
│ │ └── target.py
│ ├── numerical_parameters.py
│ ├── make_quantization_statistics.py
│ └── make_sketch_dataset.py
├── __init__.py
└── data
│ ├── __init__.py
│ ├── sketch.py
│ ├── _plotting.py
│ ├── dof.py
│ └── constraint_checks.py
├── setup.py
├── .gitignore
├── docs
├── _templates
│ └── autosummary
│ │ ├── class.rst
│ │ └── module.rst
├── Makefile
├── make.bat
├── onshape_setup.rst
├── models.rst
├── conf.py
├── index.rst
└── data.rst
├── LICENSE
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/train/__main__.py:
--------------------------------------------------------------------------------
1 | from . import main
2 | main()
3 |
--------------------------------------------------------------------------------
/assets/sketchgraphs.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/assets/sketchgraphs.gif
--------------------------------------------------------------------------------
/assets/sketch_w_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/assets/sketch_w_graph.png
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | """Miscellaneous scripts used to process the result."""
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.17
2 | matplotlib>=3.1
3 | zstandard>=0.14
4 | lz4>=3.1
5 | tqdm>=4.48
6 | requests>=2.25.1
7 |
--------------------------------------------------------------------------------
/tests/testdata/sample_json.tar.xz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/tests/testdata/sample_json.tar.xz
--------------------------------------------------------------------------------
/tests/testdata/sg_t16.stats.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PrincetonLIPS/SketchGraphs/HEAD/tests/testdata/sg_t16.stats.pkl.gz
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/__init__.py:
--------------------------------------------------------------------------------
1 | """This module implements the model, dataset and training procedure for the auto-constraint model."""
2 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/creds/creds.json.example:
--------------------------------------------------------------------------------
1 | {
2 | "https://cad.onshape.com": {
3 | "access_key": "cA...vT",
4 | "secret_key": "aU...Hw"
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='sketchgraphs',
5 | version='0.1',
6 | packages=find_packages(),
7 | license='MIT'
8 | )
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .pytest_cache/
2 | __pycache__/
3 |
4 | docs/_autosummary
5 | docs/_build
6 |
7 | creds.json
8 | build/
9 | *.cpython-37m-x86_64-linux-gnu.so
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | """This module implements the transformations in our data pipeline, taking the raw JSON data obtained
2 | from Onshape into data that is suitable for model training.
3 | """
4 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline }}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :members:
7 | :special-members:
8 | :exclude-members: __dict__,__weakref__,__repr__,__str__
9 |
--------------------------------------------------------------------------------
/sketchgraphs/__init__.py:
--------------------------------------------------------------------------------
1 | """Tools for manipulating and learning with the SketchGraphs dataset.
2 |
3 | The sub-modules of this module implement the required operations to create datasets
4 | at various levels of abstraction from the source data. In particular, there are three
5 | main scripts used to produce training data for the models.
6 |
7 | """
8 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/__init__.py:
--------------------------------------------------------------------------------
1 | """This module contains basic code to interact with the Onshape API.
2 |
3 | The Onshape API contains some important functionality to perform more advanced manipulation of sketches.
4 | In particular, the API provides functions for solving constraints in sketches.
5 |
6 | """
7 |
8 | from .onshape import Onshape
9 | from .client import Client
10 | from .utils import log
11 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/graph_model/__init__.py:
--------------------------------------------------------------------------------
1 | """This module provides the data pipeline to convert from the stored sequence format to
2 | a format suitable for training our graph-based deep learning model. In particular, it features
3 | a number of important helpers for quantizing numerical features, and representing the graph in
4 | terms of tensors.
5 |
6 | """
7 |
8 | from . import _graph_info
9 | from ._graph_info import *
10 |
11 | __all__ = _graph_info.__all__
12 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/__init__.py:
--------------------------------------------------------------------------------
1 | """ This module contains the main implementation for the graph models.
2 |
3 | The graph models represent a family of models which recursively build
4 | the sketch by viewing the sketch as a graph of entities and constraints.
5 | As described in the paper, the baseline generative model does not consider
6 | entity (primitive) coordinates and relies on constraints to determine
7 | the final configuration of the solved sketch.
8 |
9 | """
10 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Shared test fixtures in order to easily provide testing data to all tests."""
2 |
3 | import json
4 | import tarfile
5 |
6 | import pytest
7 |
8 | from sketchgraphs.data.sketch import Sketch
9 |
10 |
11 | @pytest.fixture
12 | def sketches_json():
13 | sample_file = 'tests/testdata/sample_json.tar.xz'
14 |
15 | result = []
16 |
17 | with tarfile.open(sample_file, 'r:xz') as tar_archive:
18 | for f in tar_archive:
19 | if not f.isfile():
20 | continue
21 | result.extend(json.load(tar_archive.extractfile(f)))
22 |
23 | return result
24 |
25 | @pytest.fixture
26 | def sketches(sketches_json):
27 | """Return a list of sample sketches."""
28 | return [Sketch.from_fs_json(j) for j in sketches_json]
29 |
--------------------------------------------------------------------------------
/tests/test_constraint_checks.py:
--------------------------------------------------------------------------------
1 | """Tests for constraint checking"""
2 |
3 | import itertools
4 |
5 | from sketchgraphs.data import constraint_checks, sketch_to_sequence, EdgeOp
6 |
7 | def test_constraint_check(sketches):
8 | for sketch_idx, sketch in enumerate(itertools.islice(sketches, 1000)):
9 | sequence = sketch_to_sequence(sketch)
10 | for op_idx, op in enumerate(sequence):
11 | if not isinstance(op, EdgeOp):
12 | continue
13 | try:
14 | check_value = constraint_checks.check_edge_satisfied(sequence, op)
15 | except ValueError:
16 | check_value = True
17 |
18 | if check_value is None:
19 | check_value = True
20 |
21 | assert check_value, "constraint check failed at sketch {0}, op {1}".format(sketch_idx, op_idx)
22 |
--------------------------------------------------------------------------------
/sketchgraphs_models/torch_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | """ Pytorch extensions for our models.
2 |
3 | This module provides some important extensions that work with segmented data, which is
4 | used extensively in our models due to the nature of graph batching. This modules provides
5 | a pure python implementation, although can leverage a C++ extension if present. This
6 | enables speed-up of over 2x (on GPU) when it is present, so it is recommended to compile the C++
7 | extensions if possible.
8 |
9 | """
10 |
11 | from ._repeat_interleave import repeat_interleave
12 | from .segment_ops import segment_logsumexp
13 | from .segment_pool import segment_avg_pool1d, segment_max_pool1d
14 |
15 | from .index import segment_triu_indices, segment_cartesian_product
16 |
17 | __all__ = ['repeat_interleave', 'segment_logsumexp', 'segment_avg_pool1d', 'segment_max_pool1d',
18 | 'segment_triu_indices', 'segment_cartesian_product']
19 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/sketchgraphs_models/torch_extensions/index.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def segment_triu_indices_loop(scopes, offset=0, dtype=torch.long, device='cpu'):
5 | indices = torch.cat([
6 | torch.triu_indices(
7 | n, n, offset, dtype=dtype, device=device).t() + o for (o, n) in scopes
8 | ], dim=0)
9 |
10 | return indices
11 |
12 |
13 | def segment_cartesian_product_loop(values_a, values_b, scopes_a, scopes_b):
14 | result = []
15 |
16 | for sa, sb in zip(scopes_a, scopes_b):
17 | va = values_a.narrow(0, sa[0], sa[1]).unsqueeze(1).expand((-1, sb[1], -1))
18 | vb = values_b.narrow(0, sb[0], sb[1]).unsqueeze(0).expand((sa[1], -1, -1))
19 |
20 | values_cat = torch.cat((va, vb), dim=-1)
21 | result.append(values_cat.flatten(end_dim=1))
22 |
23 | return torch.cat(result, dim=0)
24 |
25 |
26 | segment_triu_indices = segment_triu_indices_loop
27 | segment_cartesian_product = segment_cartesian_product_loop
28 |
--------------------------------------------------------------------------------
/sketchgraphs/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This module contains the main data representations for the SketchGraphs dataset.
2 |
3 | There are two main projections for the data.
4 |
5 | The first is the `sketch` projection, which is close to the underlying FeatureScript JSON format from Onshape.
6 | This projection is favourable for interactions with Onshape's API, and serves as a first step for parsing
7 | the raw data from Onshape.
8 |
9 | The second projection in the `sequence` projection, and is more
10 | specialized towards applications in machine learning. It is more streamlined, and requires conversion
11 | to interact with Onshape's API. However, it interacts more naturally with machine learning applications.
12 |
13 | """
14 |
15 | from .sketch import *
16 | from .sequence import *
17 |
18 | from .sketch import Sketch, EntityType, SubnodeType, Entity, GenericEntity, Point, Line, Circle, Arc, Spline, Ellipse, ENTITY_TYPE_TO_CLASS
19 | from .sketch import render_sketch, render_graph
20 |
--------------------------------------------------------------------------------
/sketchgraphs_models/__init__.py:
--------------------------------------------------------------------------------
1 | """This module contains supporting code to the main sketchgraphs dataset. In particular, it contains two models based
2 | on graph representations of the data, which are geared towards the task of generation (see `sketchgraphs_models.graph`)
3 | and autoconstrain (see `sketchgraphs_models.autoconstraint`). In addition, this module contains a number of submodules
4 | to help the implementation of these models.
5 |
6 | These models are tuned and operate on a subset of the full sketchgraphs dataset. In particular, they only handle
7 | constraints which relate at most two entities (e.g. the mirror constraint is not handled), and are trained on a subset
8 | of sketches which excludes sketches that are too large or too small. In addition, these model use a quantization
9 | strategy to model continuous parameters in the dataset, which must be pre-computed before training. To create
10 | quantization maps, see `sketchgraphs.pipeline.make_quantization_statistics`.
11 |
12 | """
13 |
--------------------------------------------------------------------------------
/tests/test_flat_array.py:
--------------------------------------------------------------------------------
1 | #pylint: disable=missing-module-docstring,missing-function-docstring
2 |
3 | import itertools
4 |
5 | import numpy as np
6 |
7 | from sketchgraphs.data import flat_array
8 |
9 |
10 | def test_flat_array_of_int():
11 | x = [2, 3, 4, 5]
12 |
13 | x_flat = flat_array.save_list_flat(x)
14 | x_array = flat_array.FlatSerializedArray.from_flat_array(x_flat)
15 |
16 | assert len(x_array) == len(x)
17 |
18 | for i, j in itertools.zip_longest(x, x_array):
19 | assert i == j
20 |
21 |
22 | def test_flat_dictionary():
23 | x = [2, 3, 4, 5]
24 | y = np.array([3, 5])
25 | z = ["A", "python", "list"]
26 |
27 | x_flat = flat_array.save_list_flat(x)
28 |
29 | dict_flat = flat_array.pack_dictionary_flat({
30 | 'x': x_flat,
31 | 'y': y,
32 | 'z': z
33 | })
34 |
35 | result = flat_array.load_dictionary_flat(dict_flat)
36 |
37 | assert isinstance(result['x'], flat_array.FlatSerializedArray)
38 | assert len(result['x']) == len(x)
39 |
40 | assert result['z'] == z
41 | assert all(result['y'] == y)
42 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 PTC Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/feature_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "feature": {
3 | "type": 151,
4 | "typeName": "BTMSketch",
5 | "message": {
6 | "entities": [],
7 | "constraints": [],
8 | "featureType": "newSketch",
9 | "name": "My Sketch",
10 | "parameters": [{
11 | "type": 148,
12 | "typeName": "BTMParameterQueryList",
13 | "message": {
14 | "queries": [{
15 | "type": 138,
16 | "typeName": "BTMIndividualQuery",
17 | "message": {
18 | "geometryIds": ["JEC"],
19 | "hasUserCode": false
20 | }
21 | }],
22 | "parameterId": "sketchPlane",
23 | "hasUserCode": false
24 | }
25 | }],
26 | "suppressed": false,
27 | "namespace": "",
28 | "subFeatures": [],
29 | "returnAfterSubfeatures": false,
30 | "suppressionState": {
31 | "type": 0
32 | },
33 | "hasUserCode": false
34 | }
35 | },
36 | "serializationVersion": "1.1.18",
37 | "sourceMicroversion": "57c52f87718e8008cce43b47",
38 | "rejectMicroversionSkew": false,
39 | "microversionSkew": false,
40 | "libraryVersion": 1324
41 | }
42 |
--------------------------------------------------------------------------------
/docs/_templates/autosummary/module.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 |
5 | {% block classes %}
6 | {% if classes %}
7 | .. rubric:: {{ _('Classes') }}
8 |
9 | .. autosummary::
10 | :toctree:
11 | {% for item in classes %}
12 | {{ item }}
13 | {%- endfor %}
14 | {% endif %}
15 | {% endblock %}
16 |
17 | {% block attributes %}
18 | {% if attributes %}
19 | .. rubric:: Module Attributes
20 |
21 | {% for item in attributes %}
22 | .. autoattribute:: {{ item }}
23 | {%- endfor %}
24 | {% endif %}
25 | {% endblock %}
26 |
27 | {% block functions %}
28 | {% if functions %}
29 | .. rubric:: {{ _('Functions') }}
30 |
31 | {% for item in functions %}
32 | .. autofunction:: {{ item }}
33 | {%- endfor %}
34 | {% endif %}
35 | {% endblock %}
36 |
37 |
38 | {% block exceptions %}
39 | {% if exceptions %}
40 | .. rubric:: {{ _('Exceptions') }}
41 |
42 | .. autosummary::
43 | {% for item in exceptions %}
44 | {{ item }}
45 | {%- endfor %}
46 | {% endif %}
47 | {% endblock %}
48 |
49 | {% block modules %}
50 | {% if modules %}
51 | .. rubric:: Modules
52 |
53 | .. autosummary::
54 | :toctree:
55 | :recursive:
56 | {% for item in modules %}
57 | {{ item }}
58 | {%- endfor %}
59 | {% endif %}
60 | {% endblock %}
--------------------------------------------------------------------------------
/sketchgraphs_models/nn/distributed.py:
--------------------------------------------------------------------------------
1 | """Utility modules for distributed and parallel training. """
2 |
3 | import torch
4 |
5 | class SingleDeviceDistributedParallel(torch.nn.parallel.distributed.DistributedDataParallel):
6 | """This module implements a module similar to `DistributedDataParallel`, but it accepts
7 | inputs of any shape, and only supports a single device per instance.
8 | """
9 | def __init__(self, module, device_id, find_unused_parameters=False):
10 | super(SingleDeviceDistributedParallel, self).__init__(
11 | module, [device_id], find_unused_parameters=find_unused_parameters)
12 |
13 | def forward(self, *inputs, **kwargs):
14 | if self.require_forward_param_sync:
15 | self._sync_params()
16 |
17 | output = self.module(*inputs, **kwargs)
18 |
19 | if torch.is_grad_enabled() and self.require_backward_grad_sync:
20 | self.require_forward_param_sync = True
21 |
22 | if self.find_unused_parameters:
23 | self.reducer.prepare_for_backward(list(torch.nn.parallel.distributed._find_tensors(output)))
24 | else:
25 | self.reducer.prepare_for_backward([])
26 |
27 | return output
28 |
29 | def state_dict(self, destination=None, prefix='', keep_vars=False):
30 | return self.module.state_dict(destination, prefix, keep_vars)
31 |
32 | def load_state_dict(self, state_dict, strict=True):
33 | return self.module.load_state_dict(state_dict, strict)
34 |
--------------------------------------------------------------------------------
/sketchgraphs_models/nn/data_util.py:
--------------------------------------------------------------------------------
1 | """Utilities and extensions to work with the torch.utils.data package.
2 | """
3 |
4 | import torch
5 | import torch.utils.data
6 |
7 | class MultiEpochSampler(torch.utils.data.Sampler):
8 | """A sampler wrapper which creates multiple epochs of indices
9 | from one sampler.
10 |
11 | Due to the way torch dataloaders function, they reset all workers
12 | when starting a new epoch. This is undesirable due to the high setup
13 | cost of starting workers.
14 |
15 | This sampler thus wraps an existing batch sampler, and simply repeatedly
16 | samples several epochs from the wrapped sampler.
17 |
18 | """
19 | def __init__(self, batch_sampler, num_epochs):
20 | """Initializes a new instance of `MultiEpochSampler`.
21 |
22 | Parameters
23 | ----------
24 | batch_sampler : torch.utils.data.Sampler
25 | A batch sampler to wrap
26 | num_epochs : int
27 | The number of epochs for this sampler.
28 | """
29 | #pylint: disable=super-init-not-called
30 | self._sampler = batch_sampler
31 | self._epochs = num_epochs
32 |
33 | def __len__(self):
34 | return len(self._sampler) * self._epochs
35 |
36 | @property
37 | def batches_per_epoch(self):
38 | """Get the number of batches per one single epoch of the original sampler."""
39 | return len(self._sampler)
40 |
41 | def __iter__(self):
42 | for _ in range(self._epochs):
43 | for sample in self._sampler:
44 | yield sample
45 |
--------------------------------------------------------------------------------
/docs/onshape_setup.rst:
--------------------------------------------------------------------------------
1 | Onshape setup
2 | =============
3 |
4 | We include a set of functions (`sketchgraphs.onshape.call`) for interacting with the CAD program `Onshape `_ in order to solve geometric constraints.
5 | The SketchGraphs `demo notebook `_ demonstrates their main usage.
6 |
7 | Before calling these for the first time, a small bit of setup is required.
8 |
9 | Account & credentials
10 | ---------------------
11 |
12 | - Visit `Onshape `_ and create an account (it's free).
13 | - Create an API key at https://dev-portal.onshape.com/keys. Be sure to enable read and write permissions so that we may send and retrieve CAD sketches from Onshape.
14 | - Save the file at `sketchgraphs/onshape/creds/creds.json`
15 |
16 |
17 | Create document
18 | ---------------
19 | We'll need a document to serve as the target for our sketches.
20 |
21 | - Go to https://cad.onshape.com/documents and click Create->Document in the upper left.
22 | - The document URL will be needed to perform API calls. For example, in the `demo notebook `_, we manually paste in the target URL.
23 | - Now that we have a new document, we must set the version identifiers of our `feature_template.json` accordingly. Run the following whenever working with a new document (set the variable `url` accordingly):
24 |
25 | >>> url=https://cad.onshape.com/documents/6f6d14f8facf0bba02184e88/w/66a5db71489c81f4893101ed/e/120c56983451157d26a7102d
26 | >>> python -m sketchgraphs.onshape.call --action update_template --url $url --enable_logging
27 |
28 | You should see a "request succeeded" included in the log if the configuration is correct.
29 |
30 | Our Onshape setup is complete! Please reach out with any questions.
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | SketchGraphs models
2 | ===================
3 |
4 | This page describes the models implemented in SketchGraphs, as well as details their usage.
5 | The models are based on a Graph Neural Network architecture, modelling the sketch as a graph
6 | with vertices given by entities and edges given by their constraints.
7 |
8 | Quickstart
9 | ----------
10 |
11 | For an initial quick-start, we recommend users to start with the provided sequence files and
12 | associated quantization maps, and following the default hyper-parameter values for training.
13 | Additionally, we strongly recommend using a powerful GPU, as training is compute intensive.
14 |
15 | For example, assuming that you have downloaded the training dataset, as well as the accompanying
16 | quantization statistics (available `here `_),
17 | the generative model may be trained by running:
18 |
19 | .. code-block:: bash
20 |
21 | python -m sketchgraphs_models.graph.train --dataset_train sg_t16_train.npy
22 |
23 | You may monitor the training progress on the standard output, or through `tensorboard `_.
24 |
25 | Similarly, the autoconstrain model may be trained by running:
26 |
27 | .. code-block:: bash
28 |
29 | python -m sketchgraphs_models.autoconstraint.train --dataset_train sg_t16_train.npy
30 |
31 | We will also provide pre-trained models (coming soon!).
32 |
33 |
34 | Torch scatter
35 | -------------
36 |
37 | In order to enjoy the best training performance, we strongly recommend you install the `torch-scatter `_
38 | package with the correct CUDA version for training on GPU. If you do not have access,
39 | the models will automatically fall back to a plain python / pytorch implementation (however, there will be a
40 | performance penalty due to a substantial amount of looping).
41 |
--------------------------------------------------------------------------------
/tests/test_graph_dataset.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | import pickle
4 | import pytest
5 |
6 | import numpy as np
7 |
8 | from sketchgraphs_models.graph import dataset as graph_dataset
9 | from sketchgraphs.data import sequence as data_sequence
10 |
11 | @pytest.fixture
12 | def edge_feature_mapping():
13 | """Dummy quantization functions."""
14 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f:
15 | mapping = pickle.load(f)
16 |
17 | return graph_dataset.EdgeFeatureMapping(
18 | mapping['edge']['angle'], mapping['edge']['length'])
19 |
20 | @pytest.fixture
21 | def node_feature_mapping():
22 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f:
23 | mapping = pickle.load(f)
24 | return graph_dataset.EntityFeatureMapping(mapping['node'])
25 |
26 |
27 | def test_dataset(sketches, node_feature_mapping, edge_feature_mapping):
28 | sequences = list(map(data_sequence.sketch_to_sequence, sketches))
29 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=12)
30 |
31 | graph, target = dataset[0]
32 | assert graph_dataset.TargetType.EdgeDistance in graph.sparse_edge_features
33 | assert graph.node_features is not None
34 | assert graph.sparse_node_features is not None
35 |
36 |
37 | def test_collate_empty(node_feature_mapping, edge_feature_mapping):
38 | result = graph_dataset.collate([], node_feature_mapping, edge_feature_mapping)
39 | assert 'edge_numerical' in result
40 |
41 |
42 | def test_collate_some(sketches, node_feature_mapping, edge_feature_mapping):
43 | sequences = list(map(data_sequence.sketch_to_sequence, sketches))
44 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=42)
45 |
46 | batch = [dataset[i] for i in range(5)]
47 |
48 | batch_info = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping)
49 | assert 'edge_numerical' in batch_info
50 | assert batch_info['graph'].node_features is not None
51 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | utils
3 | =====
4 |
5 | Handy functions for API key sample app
6 | '''
7 |
8 | import logging
9 | from logging.config import dictConfig
10 |
11 | __all__ = [
12 | 'log'
13 | ]
14 |
15 |
16 | def log(msg, level=0):
17 | '''
18 | Logs a message to the console, with optional level paramater
19 |
20 | Args:
21 | - msg (str): message to send to console
22 | - level (int): log level; 0 for info, 1 for error (default = 0)
23 | '''
24 |
25 | red = '\033[91m'
26 | endc = '\033[0m'
27 |
28 | # configure the logging module
29 | cfg = {
30 | 'version': 1,
31 | 'disable_existing_loggers': False,
32 | 'formatters': {
33 | 'stdout': {
34 | 'format': '[%(levelname)s]: %(asctime)s - %(message)s',
35 | 'datefmt': '%x %X'
36 | },
37 | 'stderr': {
38 | 'format': red + '[%(levelname)s]: %(asctime)s - %(message)s' + endc,
39 | 'datefmt': '%x %X'
40 | }
41 | },
42 | 'handlers': {
43 | 'stdout': {
44 | 'class': 'logging.StreamHandler',
45 | 'level': 'DEBUG',
46 | 'formatter': 'stdout'
47 | },
48 | 'stderr': {
49 | 'class': 'logging.StreamHandler',
50 | 'level': 'ERROR',
51 | 'formatter': 'stderr'
52 | }
53 | },
54 | 'loggers': {
55 | 'info': {
56 | 'handlers': ['stdout'],
57 | 'level': 'INFO',
58 | 'propagate': True
59 | },
60 | 'error': {
61 | 'handlers': ['stderr'],
62 | 'level': 'ERROR',
63 | 'propagate': False
64 | }
65 | }
66 | }
67 |
68 | dictConfig(cfg)
69 |
70 | lg = 'info' if level == 0 else 'error'
71 | lvl = 20 if level == 0 else 40
72 |
73 | logger = logging.getLogger(lg)
74 | logger.log(lvl, msg)
75 |
--------------------------------------------------------------------------------
/sketchgraphs_models/torch_extensions/_repeat_interleave.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _ensure_repeats(repeats_or_scopes: torch.Tensor):
4 | if repeats_or_scopes.dim() == 2:
5 | return repeats_or_scopes.select(1, 1)
6 | else:
7 | return repeats_or_scopes
8 |
9 |
10 | def repeat_interleave_out(values, repeats_or_scope, dim, out):
11 | index = torch.repeat_interleave(_ensure_repeats(repeats_or_scope))
12 |
13 | if dim is None:
14 | values = values.flatten()
15 | dim = 0
16 |
17 | return torch.index_select(values, dim, index, out=out)
18 |
19 |
20 | def repeat_interleave(values, repeats_or_scope, dim=None, out=None, out_length=None):
21 | """ Extended `repeat_interleave` with the ability to pass in pre-computed information
22 | to reduce CPU-GPU communication when the repeats tensor resides on GPU.
23 |
24 | In particular, this function can make use of a scope parameter instead of simply repeats,
25 | which is a two-dimensional tensor with two columns, one corresponding to the offsets
26 | of each segment, and the second one corresponding to the length (i.e. the number of repeats).
27 |
28 | Parameters
29 | ----------
30 | values : torch.Tensor
31 | An array of values to repeat.
32 | repeats_or_scope : torch.Tensor
33 | Either a one-dimensional array of repeats, or a two-dimensional array
34 | representing the offset and length (which is referred to as scope).
35 | dim : int, optional
36 | The dimension of values along which to repeat.
37 | out : torch.Tensor, optional
38 | if not None, a tensor into which to produce the output. Note that this is not differentiable.
39 | out_length : int, optional
40 | if not None, the length of the output along the repeated dimension.
41 | Returns
42 | -------
43 | A new tensor, containing values repeated the appropriate amount of times along the
44 | given dimension.
45 | """
46 | if out is not None:
47 | return repeat_interleave_out(values, repeats_or_scope, dim, out)
48 | return torch.repeat_interleave(values, repeats_or_scope, dim)
49 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/entropy_rate.py:
--------------------------------------------------------------------------------
1 | """Module to compute the entropy rate of a given representation. """
2 |
3 | import argparse
4 | import lzma
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 |
10 | from sketchgraphs_models.graph import dataset, sample
11 | from sketchgraphs.data import flat_array
12 |
13 |
14 |
15 | def sequence_to_integers(seq, node_feature_mapping, edge_feature_mapping):
16 | graph = dataset.graph_info_from_sequence(seq, node_feature_mapping, edge_feature_mapping)
17 | all_tensors = (
18 | [torch.full([1], fill_value=graph.edge_counts[0], dtype=torch.int64)] +
19 | [graph.node_features,
20 | graph.edge_features,
21 | graph.incidence.flatten()] +
22 | [v.value.flatten() for v in graph.sparse_node_features.values()] +
23 | [v.value.flatten() for v in graph.sparse_edge_features.values()])
24 |
25 | integers = torch.cat(all_tensors)
26 |
27 | return integers.numpy().astype(np.int32)
28 |
29 |
30 | def compress_sequences(seqs, node_feature_mapping, edge_feature_mapping):
31 | compressor = lzma.LZMACompressor(preset=9 | lzma.PRESET_EXTREME)
32 |
33 | total_bytes = 0
34 |
35 | for seq in seqs:
36 | integers = sequence_to_integers(seq, node_feature_mapping, edge_feature_mapping)
37 | total_bytes += len(compressor.compress(integers.tobytes()))
38 |
39 | total_bytes += len(compressor.flush())
40 |
41 | return total_bytes
42 |
43 |
44 | def main():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--dataset', required=True)
47 | parser.add_argument('--model_state', help='Path to saved model state_dict.')
48 | parser.add_argument('--limit', type=int, default=None)
49 |
50 | args = parser.parse_args()
51 |
52 | print('Loading trained model')
53 | _, node_feature_mapping, edge_feature_mapping = sample.load_saved_model(args.model_state)
54 |
55 | print('Loading testing data')
56 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences']
57 |
58 | if args.limit is not None:
59 | seqs = seqs[:args.limit]
60 |
61 | total_bytes = compress_sequences(tqdm.tqdm(seqs, total=len(seqs)), node_feature_mapping, edge_feature_mapping)
62 | print('Total bytes: {0}'.format(total_bytes))
63 | print('Average Entropy: {0:.3f} bits / graph'.format(total_bytes * 8 / len(seqs)))
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/dataset/benchmark.py:
--------------------------------------------------------------------------------
1 | """Utility to benchmark data loading speed."""
2 |
3 | import argparse
4 | import time
5 |
6 |
7 | from sketchgraphs_models.graph.train.data_loading import initialize_datasets
8 |
9 |
10 | def time_iterator(iterator, batch_size, total_time_seconds=None):
11 | start_time = time.perf_counter()
12 | last_time = start_time
13 | num_batch_processed = 0
14 |
15 | for _ in iterator:
16 | num_batch_processed += 1
17 | current_time = time.perf_counter()
18 | elapsed = current_time - last_time
19 | if elapsed > 5:
20 | print('Processed {0} elements in {1:.2f} seconds. ({2:.2f} / second )'.format(
21 | num_batch_processed * batch_size, elapsed, num_batch_processed * batch_size / elapsed ))
22 | last_time = time.perf_counter()
23 | num_batch_processed = 0
24 |
25 | if total_time_seconds is not None and current_time - start_time > total_time_seconds:
26 | break
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument('--seed', type=int, default=42)
33 | parser.add_argument('--dataset_train', required=True,
34 | help='Pickle dataset for train data.')
35 | parser.add_argument('--dataset_auxiliary', default=None, help='path to auxiliary dataset containing metadata')
36 | parser.add_argument('--num_quantize_length', type=int, default=383, help='number of quantization values for length')
37 | parser.add_argument('--num_quantize_angle', type=int, default=127, help='number of quantization values for angle')
38 | parser.add_argument('--batch_size', type=int, default=32, help='Training batch size.')
39 | parser.add_argument('--num_epochs', type=int, default=100,
40 | help='Number of training epochs.')
41 | parser.add_argument('--num_workers', type=int, default=0,
42 | help='Number of dataloader workers.')
43 | parser.add_argument('--disable_edge_features', action='store_true',
44 | help='Disable using and predicting edge features')
45 |
46 | parser.add_argument('--max_time', type=float, default=300)
47 |
48 | args = vars(parser.parse_args())
49 | args['dataset_test'] = None
50 | args['disable_entity_features'] = True
51 |
52 | print('Loading dataset')
53 | dataloader, _, batches_per_epoch, _, _ = initialize_datasets(args)
54 |
55 | print('Benchmarking dataloader')
56 | time_iterator(dataloader, args['batch_size'], args['max_time'])
57 |
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
62 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'SketchGraphs'
21 | copyright = '2020, Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams'
22 | author = 'Ari Seff, Yaniv Ovadia, Wenda Zhou, Ryan P. Adams'
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = [
31 | 'sphinx.ext.autodoc',
32 | 'sphinx.ext.autosummary',
33 | 'sphinx.ext.napoleon',
34 | 'sphinx.ext.intersphinx'
35 | ]
36 |
37 | autosummary_generate = True
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
46 |
47 | default_role = 'any'
48 |
49 | # -- Options for intersphinx extension ---------------------------------------
50 |
51 | intersphinx_mapping = {
52 | 'python': ('https://docs.python.org/3', None),
53 | 'torch': ('https://pytorch.org/docs/stable/', None),
54 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None),
55 | }
56 |
57 |
58 | # -- Options for HTML output -------------------------------------------------
59 |
60 | # The theme to use for HTML and HTML Help pages. See the documentation for
61 | # a list of builtin themes.
62 | #
63 | html_theme = 'sphinx_rtd_theme'
64 |
65 | # Add any paths that contain custom static files (such as style sheets) here,
66 | # relative to this directory. They are copied after the builtin static files,
67 | # so a file named "default.css" will overwrite the builtin "default.css".
68 | html_static_path = ['_static']
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | #pylint: disable=missing-module-docstring,missing-function-docstring
2 |
3 | import numpy as np
4 |
5 | from sketchgraphs.data.sequence import sketch_to_sequence, NodeOp, EdgeOp, sketch_from_sequence
6 | from sketchgraphs.data.sketch import Sketch, render_sketch, EntityType, ConstraintType, SubnodeType
7 | from sketchgraphs.data.dof import get_sequence_dof
8 |
9 |
10 | def test_sketch_from_json(sketches_json):
11 | for sketch_json in sketches_json:
12 | Sketch.from_fs_json(sketch_json)
13 |
14 |
15 | def test_sequence_from_sketch(sketches_json):
16 |
17 | for sketch_json in sketches_json:
18 | sketch = Sketch.from_fs_json(sketch_json)
19 | seq = sketch_to_sequence(sketch)
20 |
21 | def test_plot_sketch(sketches_json):
22 | sketch_json_list = sketches_json[:10]
23 |
24 | for sketch_json in sketch_json_list:
25 | fig = render_sketch(Sketch.from_fs_json(sketch_json))
26 | assert fig is not None
27 |
28 |
29 | def test_get_sequence_dof():
30 | seq = [
31 | NodeOp(label=EntityType.External),
32 | NodeOp(label=EntityType.Line),
33 | NodeOp(label=SubnodeType.SN_Start),
34 | EdgeOp(label=ConstraintType.Subnode, references=(2, 1)),
35 | NodeOp(label=SubnodeType.SN_End),
36 | EdgeOp(label=ConstraintType.Subnode, references=(3, 1)),
37 | NodeOp(label=EntityType.Line),
38 | EdgeOp(label=ConstraintType.Parallel, references=(4, 1)),
39 | EdgeOp(label=ConstraintType.Horizontal, references=(4,)),
40 | EdgeOp(label=ConstraintType.Distance, references=(4, 1)),
41 | NodeOp(label=SubnodeType.SN_Start),
42 | EdgeOp(label=ConstraintType.Subnode, references=(5, 4)),
43 | NodeOp(label=SubnodeType.SN_End),
44 | EdgeOp(label=ConstraintType.Subnode, references=(6, 4)),
45 | NodeOp(label=EntityType.Stop)]
46 |
47 | dof_remaining = np.sum(get_sequence_dof(seq))
48 | assert dof_remaining == 5
49 |
50 |
51 | _UNSUPPORTED_CONSTRAINTS = (
52 | ConstraintType.Circular_Pattern,
53 | ConstraintType.Linear_Pattern,
54 | ConstraintType.Midpoint,
55 | ConstraintType.Mirror,
56 | )
57 |
58 | def test_sketch_from_sequence(sketches_json):
59 | for sketch_json in sketches_json:
60 | sketch = Sketch.from_fs_json(sketch_json, include_external_constraints=False)
61 | seq = sketch_to_sequence(sketch)
62 |
63 | if any(s.label in _UNSUPPORTED_CONSTRAINTS for s in seq):
64 | # Skip not supported constraints for now
65 | continue
66 |
67 | sketch2 = sketch_from_sequence(seq)
68 | seq2 = sketch_to_sequence(sketch2)
69 |
70 | assert len(seq) == len(seq2)
71 | for op1, op2 in zip(seq, seq2):
72 | assert op1 == op2
73 |
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/scripts/eval_statistics_stratified.py:
--------------------------------------------------------------------------------
1 | """Computes edge evaluation statistics stratified by size. """
2 |
3 | import argparse
4 | import gzip
5 | import pickle
6 | import os
7 |
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 |
12 | def compute_stratified_statistics(statistics, sizes):
13 | unique_sizes, unique_counts = np.unique(sizes, return_counts=True)
14 |
15 | precision = statistics['precision']
16 | recall = statistics['recall']
17 |
18 | average_precision = np.empty(len(sizes))
19 | average_recall = np.empty(len(sizes))
20 | average_f1 = np.empty(len(sizes))
21 |
22 | sd_precision = np.empty(len(sizes))
23 | sd_recall = np.empty(len(sizes))
24 | sd_f1 = np.empty(len(sizes))
25 |
26 | for i, size in enumerate(unique_sizes):
27 | mask = sizes == size
28 | pm = precision[mask]
29 | rm = recall[mask]
30 |
31 | average_precision[i] = pm.mean()
32 | sd_precision[i] = pm.std() / np.sqrt(len(pm))
33 |
34 | average_recall[i] = rm.mean()
35 | sd_recall[i] = rm.std() / np.sqrt(len(rm))
36 |
37 | f1 = np.divide(2 * pm * rm, pm + rm, out=np.zeros_like(rm), where=pm + rm != 0)
38 | average_f1[i] = f1.mean()
39 | sd_f1[i] = f1.std() / np.sqrt(len(f1))
40 |
41 |
42 | return {
43 | 'precision': average_precision,
44 | 'precision_sd': sd_precision,
45 | 'recall': average_recall,
46 | 'recall_sd': sd_recall,
47 | 'f1': average_f1,
48 | 'f1_sd': sd_f1,
49 | 'sizes': unique_sizes,
50 | 'counts': unique_counts,
51 | }
52 |
53 |
54 | def boxplot_stratified(statistic, sizes):
55 | unique_sizes = np.unique(sizes)
56 | unique_sizes = unique_sizes[(unique_sizes > 3) & (unique_sizes <= 50)]
57 |
58 | return plt.boxplot(
59 | [statistic[s == sizes] for s in unique_sizes], labels=unique_sizes,
60 | showfliers=False, showcaps=False, whis=0.0, medianprops={'linewidth': 2.5})
61 |
62 |
63 | def main():
64 | parser = argparse.ArgumentParser()
65 |
66 | parser.add_argument('--input', required=True)
67 | parser.add_argument('--output')
68 |
69 | args = parser.parse_args()
70 |
71 | print('Reading input files')
72 |
73 | input_result_path = args.input
74 | input_basename, input_ext = os.path.splitext(input_result_path)
75 | if input_ext == '.gz':
76 | input_basename, _ = os.path.splitext(input_basename)
77 |
78 | input_stat_path = input_basename + '_stat.npz'
79 |
80 | with gzip.open(input_result_path, 'rb') as f:
81 | ops = pickle.load(f)
82 |
83 | input_stat = np.load(input_stat_path)
84 |
85 | print('Computing statistics')
86 | stratified_statistics = compute_stratified_statistics(
87 | input_stat, [len(x['node_ops']) for x in ops])
88 |
89 | if args.output is not None:
90 | print('Saving result')
91 | np.savez_compressed(args.output, **stratified_statistics)
92 |
93 | print(stratified_statistics)
94 |
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
--------------------------------------------------------------------------------
/sketchgraphs_models/nn/functional.py:
--------------------------------------------------------------------------------
1 | """ Utility functions for computing specific nn functions. """
2 |
3 | import torch
4 | from sketchgraphs_models.torch_extensions import segment_logsumexp, segment_avg_pool1d
5 | from sketchgraphs_models.torch_extensions.segment_ops import segment_argmax
6 |
7 |
8 | def segmented_cross_entropy(logits: torch.Tensor, target: torch.Tensor, scopes: torch.Tensor) -> torch.Tensor:
9 | """Segmented cross-entropy loss.
10 |
11 | Computes the cross-entropy loss from unscaled `logits` for a segmented problem.
12 |
13 | Parameters
14 | ----------
15 | logits : torch.Tensor
16 | unscaled logits by segment
17 | target : torch.Tensor
18 | tensor of length `n_segments`, representing the index of the true label
19 | for each segment.
20 | scopes : tonch.Tensor
21 | tensor of shape `[n_segments, 2]`, representing the segments as `(start, length)`.
22 |
23 | Returns
24 | -------
25 | torch.Tensor
26 | A tensor of length `n_segments` representing the cross-entropy loss
27 | at each segment.
28 | """
29 | input_logsumexp = segment_logsumexp(logits, scopes)
30 | return logits.index_select(0, target) - input_logsumexp
31 |
32 |
33 | def segmented_multinomial(logits, scopes, generator=None):
34 | """Segmented multinomial sample.
35 |
36 | Parameters
37 | ----------
38 | logits : torch.Tensor
39 | unscaled logits by segment
40 | scopes : torch.Tensor
41 | tensor of shape `[n_segments, 2]` representing the segments as `(start, length)`.
42 | generator : torch.Generator, optional
43 | PRNG for sampling
44 |
45 | Returns
46 | -------
47 | torch.Tensor
48 | A tensor of length `n_segments` representing the sampled values.
49 | """
50 | output = logits.new_empty([scopes.shape[0]], dtype=torch.int64)
51 |
52 | logits = logits.detach()
53 |
54 | for i in range(scopes.shape[0]):
55 | segment_logits = logits.narrow(0, scopes[i, 0], scopes[i, 1])
56 | torch.multinomial(
57 | torch.nn.functional.softmax(segment_logits, dim=0), num_samples=1,
58 | generator=generator,
59 | out=output[i:i])
60 |
61 | return output
62 |
63 |
64 | def segmented_multinomial_extended(logits, scopes, generator=None, also_return_probs=False):
65 | """Segmented multinomial sample with implicit element.
66 |
67 | Parameters
68 | ----------
69 | logits : torch.Tensor
70 | logits for explicit outcomes by segment
71 | scopes : torch.Tensor
72 | tensor of shape `[n_segments, 2]` representing the segments as `(start, length)`.
73 | generator : torch.Generator, optional
74 | PRNG for sampling
75 | also_return_probs: bool, optional
76 | If true, returns tuple including log-likelihood of the sample.
77 |
78 | Returns
79 | -------
80 | torch.Tensor
81 | A tensor of length `n_segments` representing the sampled values.
82 | """
83 |
84 | output = logits.new_empty([scopes.shape[0]], dtype=torch.int64)
85 |
86 | logits = logits.detach()
87 |
88 | for i in range(scopes.shape[0]):
89 | segment_logits = logits.new_zeros([scopes[i, 1] + 1])
90 | segment_logits[:-1] = logits.narrow(0, scopes[i, 0], scopes[i, 1])
91 |
92 | dist = torch.nn.functional.softmax(segment_logits, dim=0)
93 | torch.multinomial(dist, num_samples=1, generator=generator,
94 | out=output[i:i])
95 | if also_return_probs:
96 | return output, dist[output]
97 | return output
98 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/graph_model/target.py:
--------------------------------------------------------------------------------
1 | """This module implements the basic definitions for targets."""
2 |
3 | import enum
4 |
5 | from sketchgraphs.data import sketch as datalib, sequence as data_sequence
6 |
7 | NODE_TYPES_PREDICTED = list(datalib.EntityType)
8 | EDGE_TYPES_PREDICTED = list(x for x in datalib.ConstraintType if x != datalib.ConstraintType.Subnode)
9 | NODE_TYPES = list(datalib.EntityType) + list(datalib.SubnodeType)
10 | EDGE_TYPES = list(datalib.ConstraintType)
11 |
12 |
13 | EDGE_IDX_MAP = {t: i for i, t in enumerate(datalib.ConstraintType)}
14 | NODE_IDX_MAP = {t: i for i, t in enumerate(list(datalib.EntityType) + list(datalib.SubnodeType))}
15 | NODE_IDX_MAP_REVERSE = {i: t for i, t in enumerate(list(datalib.EntityType) + list(datalib.SubnodeType))}
16 | EDGE_IDX_MAP_REVERSE = {i: t for i, t in enumerate(datalib.ConstraintType)}
17 |
18 |
19 | class TargetType(enum.IntEnum):
20 | EdgeCategorical = 0
21 | EdgeAngle = 1
22 | EdgeLength = 2
23 | EdgeDistance = 3
24 | EdgeDiameter = 4
25 | EdgeRadius = 5
26 | NodeGeneric = 6
27 | NodeArc = 7
28 | NodeCircle = 8
29 | NodeLine = 9
30 | NodePoint = 10
31 | Subnode = 11
32 |
33 | @staticmethod
34 | def from_op(op):
35 | return TargetType.from_label(op.label)
36 |
37 | @staticmethod
38 | def from_label(label):
39 | if isinstance(label, datalib.SubnodeType):
40 | return TargetType.Subnode
41 | elif isinstance(label, datalib.EntityType):
42 | return _entity_to_target_type.get(label, TargetType.NodeGeneric)
43 | elif isinstance(label, datalib.ConstraintType):
44 | return _edge_to_target_type.get(label, TargetType.EdgeCategorical)
45 | else:
46 | raise ValueError('label must be one of SubnodeType, EntityType or ConstraintType.')
47 |
48 | @staticmethod
49 | def from_edge_label_int(x):
50 | fake_edge_op = datalib.EdgeOp(x, 0, 0)
51 | return TargetType.from_op(fake_edge_op)
52 |
53 | @staticmethod
54 | def from_node_label_int(x):
55 | fake_node_op = datalib.NodeOp(x)
56 | return TargetType.from_op(fake_node_op)
57 |
58 | @staticmethod
59 | def edge_types():
60 | return (TargetType.EdgeCategorical,) + TargetType.numerical_edge_types()
61 |
62 | @staticmethod
63 | def numerical_edge_types():
64 | return (
65 | TargetType.EdgeAngle,
66 | TargetType.EdgeLength,
67 | TargetType.EdgeDistance,
68 | TargetType.EdgeDiameter,
69 | TargetType.EdgeRadius
70 | )
71 |
72 | @staticmethod
73 | def numerical_node_types():
74 | return (
75 | TargetType.NodeArc,
76 | TargetType.NodeCircle,
77 | TargetType.NodeLine,
78 | TargetType.NodePoint
79 | )
80 |
81 | @staticmethod
82 | def node_types():
83 | return (TargetType.NodeGeneric,) + TargetType.numerical_node_types()
84 |
85 | _str_to_numerical_edge = {
86 | t.name.upper(): t
87 | for t in TargetType.numerical_edge_types()
88 | }
89 |
90 | _str_to_numerical_entity = {
91 | t.name.title(): t
92 | for t in TargetType.numerical_node_types()
93 | }
94 |
95 | _edge_to_target_type = {
96 | s: TargetType['Edge' + s.name] for s in
97 | (datalib.ConstraintType.Angle, datalib.ConstraintType.Length, datalib.ConstraintType.Distance,
98 | datalib.ConstraintType.Diameter, datalib.ConstraintType.Radius)
99 | }
100 |
101 | _entity_to_target_type = {
102 | s: TargetType['Node' + s.name] for s in
103 | (datalib.EntityType.Arc, datalib.EntityType.Circle, datalib.EntityType.Line, datalib.EntityType.Point)
104 | }
105 |
--------------------------------------------------------------------------------
/sketchgraphs/data/sketch.py:
--------------------------------------------------------------------------------
1 | """This module implements parsing and representation for sketches.
2 | """
3 |
4 | from collections import OrderedDict
5 | from typing import Dict
6 |
7 | # pylint: disable=invalid-name, too-many-arguments, too-many-return-statements, too-many-instance-attributes, wildcard-import, unused-wildcard-import
8 |
9 |
10 | from . import _entity
11 | from . import _constraint
12 | from . import _plotting
13 |
14 | from ._entity import EntityType, SubnodeType, Entity, GenericEntity, Point, Line, Circle, Arc, Spline, Ellipse, ENTITY_TYPE_TO_CLASS
15 |
16 | from ._constraint import *
17 | from ._plotting import render_sketch, render_graph
18 |
19 |
20 | class Sketch:
21 | """This class encapsulates a sketch instance.
22 |
23 | A sketch is defined by a list of entities, and a list of constraints between these entities.
24 | The Sketch class is designed to represent the sketches as obtained from Onshape in a structured
25 | and faithful manner. In particular, it can round-trip the relevant parts of the JSON representation
26 | of Onshape's feature-script.
27 | """
28 | entities: Dict[str, Entity]
29 | constraints: Dict[str, Constraint]
30 |
31 | def __init__(self, entities=None, constraints=None):
32 | if entities is None:
33 | entities = OrderedDict()
34 | if constraints is None:
35 | constraints = OrderedDict()
36 |
37 | self.entities = entities
38 | self.constraints = constraints
39 |
40 | def to_dict(self) -> dict:
41 | """Create a dictionary representing this sketch.
42 |
43 | The created dictionary should be compatible with the json represention of the sketch.
44 | """
45 | return {
46 | 'entities': [e.to_dict() for e in self.entities.values()],
47 | 'constraints': [c.to_dict() for c in self.constraints.values()]
48 | }
49 |
50 | @staticmethod
51 | def from_fs_json(sketch_dict, include_external_constraints=True):
52 | """Parse primitives and constraints.
53 |
54 | Parameters
55 | ----------
56 | include_external_constraints : bool, optional
57 | If True, indicates that constraints referencing the first external node (representing)
58 | the origin should be included, otherwise, exclude those from the graph.
59 | """
60 | entities = (Entity.from_dict(ed) for ed in sketch_dict['entities'])
61 | entities_dict = OrderedDict((e.entityId, e) for e in entities)
62 |
63 | constraints = (Constraint.from_dict(cd) for cd in sketch_dict['constraints'])
64 |
65 | if not include_external_constraints:
66 | constraints = (
67 | c for c in constraints
68 | if not any(isinstance(p, ExternalReferenceParameter) for p in c.parameters))
69 |
70 | constraints_dict = OrderedDict((c.identifier, c) for c in constraints)
71 | return Sketch(entities_dict, constraints_dict)
72 |
73 | @staticmethod
74 | def from_info(sketch_info):
75 | """Parse entities given result of `sketch information` call."""
76 | subnode_suffixes = ('.start', '.end', '.center') # TODO: dry-ify this
77 | entities = [Entity.from_info(ed) for ed in sketch_info if not ed['id'].endswith(subnode_suffixes)]
78 | entities_dict = OrderedDict((e.entityId, e) for e in entities)
79 | return Sketch(entities=entities_dict)
80 |
81 | def __repr__(self):
82 | return 'Sketch(n_entities={0}, n_constraints={1})'.format(len(self.entities), len(self.constraints))
83 |
84 |
85 | __all__ = [
86 | 'Sketch', 'EntityType', 'SubnodeType', 'Entity', 'GenericEntity', 'Point', 'Line',
87 | 'Circle', 'Arc', 'Spline', 'Ellipse', 'ENTITY_TYPE_TO_CLASS'] + _constraint.__all__ + _plotting.__all__
88 |
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/scripts/eval_statistics_mask.py:
--------------------------------------------------------------------------------
1 | """Evaluate statistics from the masks used applied to samples in the data."""
2 |
3 | import argparse
4 | import functools
5 |
6 | import numpy as np
7 | import torch
8 | from torch import multiprocessing
9 | import tqdm
10 |
11 | from sketchgraphs.data import flat_array
12 | from sketchgraphs.data.sequence import NodeOp
13 | from sketchgraphs_models.autoconstraint.eval import MASK_FUNCTIONS
14 |
15 |
16 | def count_valid_choices_by_step(node_ops, mask_function):
17 | if node_ops[-1].label == 'Stop':
18 | node_ops = node_ops[:-1]
19 |
20 | masks = mask_function(node_ops)
21 |
22 | valid_choices = np.ones(len(node_ops) - 1, dtype=np.int)
23 |
24 | for i in range(1, len(node_ops)):
25 | mask_offset = i * (i + 1) // 2
26 | valid_choices[i - 1] = (torch.narrow(masks, 0, mask_offset, i + 1) == 0).int().sum().item() + 1
27 |
28 | return valid_choices
29 |
30 |
31 | def total_valid_choices(seq, mask_function):
32 | node_ops = [op for op in seq if isinstance(op, NodeOp)]
33 |
34 | valid_choices = count_valid_choices_by_step(node_ops, mask_function)
35 |
36 | if len(valid_choices) > 0:
37 | average_choices = valid_choices.mean()
38 | average_entropy = np.log2(valid_choices).mean()
39 | else:
40 | # These values will be ignored
41 | average_choices = 0
42 | average_entropy = 0
43 |
44 | return average_choices, average_entropy, len(valid_choices)
45 |
46 |
47 | def uniform_valid_perplexity(seqs, mask_function, num_workers=0):
48 | if num_workers is None or num_workers > 0:
49 | pool = multiprocessing.Pool(num_workers)
50 | map_fn = functools.partial(pool.imap, chunksize=8)
51 | else:
52 | map_fn = map
53 |
54 | valid_choices_fn = functools.partial(total_valid_choices, mask_function=mask_function)
55 |
56 | average_choices = np.empty(len(seqs))
57 | average_entropy = np.empty(len(seqs))
58 | seq_length = np.empty(len(seqs), dtype=np.int)
59 |
60 | for i, (choices, entropy, length) in enumerate(tqdm.tqdm(map_fn(valid_choices_fn, seqs), total=len(seqs))):
61 | average_choices[i] = choices
62 | average_entropy[i] = entropy
63 | seq_length[i] = length
64 |
65 | return {
66 | 'choices': average_choices,
67 | 'entropy': average_entropy,
68 | 'sequence_length': seq_length
69 | }
70 |
71 |
72 | def main():
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('--input', required=True)
75 | parser.add_argument('--mask', choices=list(MASK_FUNCTIONS.keys()), default='node_type')
76 | parser.add_argument('--output')
77 | parser.add_argument('--limit', type=int, default=None)
78 | parser.add_argument('--num_workers', type=int, default=16)
79 |
80 | args = parser.parse_args()
81 |
82 | print('Reading data')
83 | seqs = flat_array.load_dictionary_flat(np.load(args.input, mmap_mode='r'))['sequences']
84 | seqs.share_memory_()
85 |
86 | if args.limit is not None:
87 | seqs = seqs[:args.limit]
88 |
89 | print('Computing statistics')
90 | result = uniform_valid_perplexity(seqs, MASK_FUNCTIONS[args.mask], args.num_workers)
91 |
92 | if args.output is not None:
93 | print('Saving results')
94 | np.savez_compressed(args.output, **result)
95 |
96 | choices = np.average(result['choices'], weights=result['sequence_length'])
97 | entropy = np.average(result['entropy'], weights=result['sequence_length'])
98 | print('Average choices: {:.3f}'.format(choices))
99 | print('Average entropy: {:.3f}'.format(entropy))
100 |
101 | if __name__ == '__main__':
102 | multiprocessing.set_start_method('spawn')
103 | main()
104 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. SketchGraphs documentation master file, created by
2 | sphinx-quickstart on Sun Jul 12 05:33:22 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to SketchGraphs' documentation!
7 | ========================================
8 |
9 | `SketchGraphs `_ is a dataset of 15 million sketches extracted from real world CAD models intended to facilitate
10 | research in ML-aided design and geometric program induction.
11 | In addition to the raw data, we provide several processed datasets, an accompanying Python package to work with
12 | the data, as well as a couple of starter models which implement GNN-type strategies on a couple of example problems.
13 |
14 |
15 | Data
16 | ----
17 |
18 | We provide our dataset in a number of forms, some of which may be more appropriate for your desired usage.
19 | The following data files are provided:
20 |
21 | - The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed
22 | using zstandard_. They total about 43GB of data. In general, we only recommend these for advanced usage, as
23 | they are heavy and require extensive processing to manipulate. Users interested in working with the raw data
24 | may wish to inspect `sketchgraphs.pipeline.make_sketch_dataset` and `sketchgraphs.pipeline.make_sequence_dataset`
25 | to view our data processing scripts. The data is available for download `here `_.
26 |
27 | - A dataset of construction sequences. This is provided as a single file, stored in a custom binary format.
28 | This format is much more concise (as it eliminates many of the unique identifiers used in the raw JSON format),
29 | and is better suited for ML applications. It is supported by our python libraries, and forms the baseline
30 | on which our models are trained. The data is available for download `here `_
31 | (warning: 15GB file!).
32 |
33 | - A filtered dataset of construction sequences. This is provided as a single file, similarly stored in a custom
34 | binary format. This dataset is similar to the sequence dataset, but simplified by filtering out sketches
35 | that are too large or too small, and only includes a simplified set of entities and constraints (while still
36 | capturing a large portion of the data). Additionally, this dataset has been split into training, testing and
37 | validation splits for convenience. We train our models on this subset of the data. You can find download the splits
38 | here: `train `_,
39 | `test `_,
40 | `validation `_.
41 |
42 |
43 | More details concerning the data can be found in the :doc:`data` page.
44 |
45 |
46 | .. _zstandard: https://facebook.github.io/zstd/
47 |
48 |
49 | Models
50 | ------
51 |
52 | In addition to the dataset, we also provide some baseline model implementations to tackle the tasks of generative
53 | modelling and autoconstrain. These models are based on Graph Neural Network approaches, and model the sketch as
54 | a graph where vertices are given by the entities in the sketch, and edges by the constraints between those entities.
55 | For more details, please refer to the dedicated :doc:`models` page.
56 |
57 |
58 | .. toctree::
59 | models
60 | data
61 | onshape_setup
62 | .. autosummary::
63 | :toctree: _autosummary
64 | :recursive:
65 |
66 | sketchgraphs
67 | sketchgraphs_models
68 |
69 |
70 |
71 | Indices and tables
72 | ==================
73 |
74 | * :ref:`genindex`
75 | * :ref:`modindex`
76 | * :ref:`search`
77 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/eval_likelihood.py:
--------------------------------------------------------------------------------
1 | """Module to evaluate the graph model according to data likelihood."""
2 |
3 | import argparse
4 |
5 | import numpy as np
6 | import torch
7 | import tqdm
8 |
9 | from sketchgraphs_models import training
10 |
11 | from sketchgraphs_models.graph import dataset, sample
12 | from sketchgraphs_models.graph import model as graph_model
13 |
14 | from sketchgraphs.data import flat_array
15 |
16 |
17 | def _total_loss(losses):
18 | result = 0
19 |
20 | for v in losses.values():
21 | if v is None:
22 | continue
23 |
24 | if isinstance(v, dict):
25 | result += _total_loss(v)
26 | else:
27 | result += v.sum()
28 |
29 | return result
30 |
31 |
32 | def batch_from_example(seq, node_feature_mapping, edge_feature_mapping):
33 | step_indices = [i for i, op in enumerate(seq) if i > 0 and not dataset._is_subnode_edge(op)]
34 |
35 | batch_list = []
36 |
37 | for step_idx in step_indices:
38 | graph = dataset.graph_info_from_sequence(seq[:step_idx], node_feature_mapping, edge_feature_mapping)
39 | target = seq[step_idx]
40 | batch_list.append((graph, target))
41 |
42 | return dataset.collate(batch_list, node_feature_mapping, edge_feature_mapping)
43 |
44 |
45 | class GraphLikelihoodEvaluator:
46 | def __init__(self, model, node_feature_mapping, edge_feature_mapping, device=None):
47 | self.model = model
48 | self.node_feature_mapping = node_feature_mapping
49 | self.edge_feature_mapping = edge_feature_mapping
50 | self.device = device
51 | self._feature_dimensions = {
52 | **node_feature_mapping.feature_dimensions,
53 | **edge_feature_mapping.feature_dimensions
54 | }
55 |
56 | def compute_likelihood(self, seqs):
57 | for i, seq in enumerate(seqs):
58 | batch = batch_from_example(seq, self.node_feature_mapping, self.edge_feature_mapping)
59 | batch_device = training.load_cuda_async(batch, device=self.device)
60 |
61 | with torch.no_grad():
62 | readout = self.model(batch_device)
63 | losses, *_ = graph_model.compute_losses(readout, batch_device, self._feature_dimensions)
64 | loss = _total_loss(losses).cpu().item()
65 |
66 | yield loss, len(seq)
67 |
68 |
69 | def main():
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--dataset', required=True)
72 | parser.add_argument('--model_state', help='Path to saved model state_dict.')
73 | parser.add_argument('--device', type=str, default='cuda')
74 | parser.add_argument('--limit', type=int, default=None)
75 |
76 | args = parser.parse_args()
77 |
78 | device = torch.device(args.device)
79 |
80 | print('Loading trained model')
81 | model, node_feature_mapping, edge_feature_mapping = sample.load_saved_model(args.model_state)
82 | model = model.eval().to(device)
83 |
84 | print('Loading testing data')
85 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences']
86 |
87 | if args.limit is not None:
88 | seqs = seqs[:args.limit]
89 |
90 | evaluator = GraphLikelihoodEvaluator(model, node_feature_mapping, edge_feature_mapping, device)
91 |
92 | losses = np.empty(len(seqs))
93 | length = np.empty(len(seqs), dtype=np.int64)
94 |
95 | for i, result in enumerate(tqdm.tqdm(evaluator.compute_likelihood(seqs), total=len(seqs))):
96 | losses[i], length[i] = result
97 |
98 | print('Average bits per sketch: {:.2f}'.format(losses.mean() / np.log(2)))
99 | print('Average bits per step: {:.2f}'.format(losses.sum() / np.log(2) / length.sum()))
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/sketchgraphs_models/distributed_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for distributed (multi-gpu) training.
2 | """
3 |
4 | import math
5 | import socket
6 | import typing
7 |
8 | import torch
9 | import torch.distributed
10 |
11 | from functools import partial
12 |
13 |
14 | class DistributedTrainingInfo(typing.NamedTuple):
15 | sync_url: str
16 | world_size: int
17 | rank: int
18 | local_rank: int
19 |
20 | def __bool__(self):
21 | return self.world_size > 1
22 |
23 |
24 |
25 | def is_leader(config: DistributedTrainingInfo):
26 | """Tests whether the current process is the leader for distributed training."""
27 | return config is None or config.rank == 0
28 |
29 |
30 | def train_boostrap_distributed(parameters, train):
31 | world_size = parameters.get('world_size', 1)
32 |
33 | if world_size == 1:
34 | # Single-node training, nothing to do.
35 | parameters['rank'] = 0
36 | return train(parameters)
37 |
38 | parameters['rank'] = -1
39 | from torch.multiprocessing import spawn
40 |
41 | sync_url = parameters.get('sync_url')
42 | if sync_url is None:
43 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
44 | s.bind(('', 0))
45 | port = s.getsockname()[1]
46 | s.close()
47 | sync_url = f'tcp://127.0.0.1:{port}'
48 | parameters['sync_url'] = sync_url
49 | print(f'Using URL bootstrap at {sync_url}')
50 |
51 | spawn(partial(train_distributed, train=train), nprocs=parameters['world_size'], args=(parameters,))
52 |
53 |
54 | def initialize_distributed(config: DistributedTrainingInfo):
55 | from torch import distributed as dist
56 |
57 | dist.init_process_group(
58 | 'nccl', init_method=config.sync_url,
59 | world_size=config.world_size, rank=config.rank)
60 |
61 | torch.cuda.set_device(config.local_rank)
62 |
63 |
64 | def get_distributed_config(parameters, local_rank=None):
65 | world_size = parameters.get('world_size', 1)
66 |
67 | if world_size == 1:
68 | return DistributedTrainingInfo('', 1, 0, 0)
69 |
70 | sync_url = parameters['sync_url']
71 |
72 | rank = local_rank
73 |
74 | return DistributedTrainingInfo(sync_url, world_size, rank, local_rank)
75 |
76 |
77 | def train_distributed(local_rank, parameters, train):
78 | config = get_distributed_config(parameters, local_rank)
79 | initialize_distributed(config)
80 | train(parameters, config)
81 |
82 |
83 | class DistributedSampler(torch.utils.data.Sampler):
84 | """Utility class which adapts a sampler into a distributed sampler which only
85 | samples a subset of the underlying sampler, according to a division by rank.
86 | """
87 | def __init__(self, sampler, num_replicas=None, rank=None):
88 | if num_replicas is not None:
89 | if not torch.distributed.is_available():
90 | raise RuntimeError("DistributedSampler requires torch distributed to be available")
91 | num_replicas = torch.distributed.get_world_size()
92 |
93 | if rank is not None:
94 | if not torch.distributed.is_available():
95 | raise RuntimeError("DistributedSampler requires torch distributed to be available")
96 |
97 | self.sampler = sampler
98 | self.num_replicas = num_replicas
99 | self.rank = rank
100 | self.num_samples = int(math.ceil(len(sampler) / self.num_replicas))
101 | self.total_size = self.num_replicas * self.num_samples
102 |
103 | def __iter__(self):
104 | indices = list(self.sampler)
105 | indices += indices[:(self.total_size - len(indices))]
106 | assert len(indices) == self.total_size
107 |
108 | indices = indices[self.rank:self.total_size:self.num_replicas]
109 | assert len(indices) == self.num_samples
110 |
111 | return iter(indices)
112 |
113 | def __len__(self):
114 | return self.num_samples
115 |
--------------------------------------------------------------------------------
/docs/data.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ========
3 |
4 | We provide three datasets at various levels of pre-processing in order to enable a wide variety of uses
5 | while ensuring that it is possible to quickly get started on this data. We describe each dataset below.
6 |
7 | Tarballs
8 | --------
9 |
10 | The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed
11 | using `zstandard `_.
12 | They total about 43GB of data. In general, we only recommend these for advanced usage, as
13 | they are heavy and require extensive processing to manipulate.
14 | Users who wish to directly access the json from python may be interested in the utility function
15 | `sketchgraphs.pipeline.make_sketch_dataset.load_json_tarball`, which iterates through a single compressed tarball
16 | and enumerates the sketches in the tarball.
17 | The data is available for download `here `_.
18 |
19 |
20 | Filtered sequence
21 | -----------------
22 |
23 | The filtered sequences are the most convenient dataset to use, and are provided ready to use with ML models.
24 | They are stored in a custom binary format, which can be read using `sketchgraphs.data.flat_array.load_dictionary_flat`,
25 | by for example executing:
26 |
27 | >>> from sketchgraphs.data import flat_array
28 | >>> data = flat_array.load_dictionary_flat('sg_t16_train.npy')
29 | >>> data['sequences']
30 | FlatArray(len=9179789, mem=5.3 GB)
31 |
32 | In addition to the sequences, which may be accessed through the ``data['sequences']`` key, the sequence files
33 | also contain an integer array of the same length which record the length of each sequence, accessed through
34 | the ``data['sequence_lengths']`` key, and a structured array which contains an identifier uniquely identifying
35 | the sketch, accessed through the ``data['sketch_ids']`` key. The latter is an array of tuples, the first element
36 | being the document id of the sketch, the second being the part index in that document, and the last being the
37 | sketch index in that part.
38 |
39 | This data is pre-split into a `train `_,
40 | `test `_ and
41 | `validation `_ set (the training set
42 | is formed from the shards 1-120, the validation set from the shards 121-124, and the testing set from the
43 | shards 125-128). Additionally, we also provide pre-computed `quantization statistics `_,
44 | computed on the training set using the `sketchgraphs.pipeline.make_quantization_statistics` script.
45 | These quantization statistics are used by the models in order to handle continuous parameters.
46 |
47 |
48 | Full sequence
49 | -------------
50 |
51 | As a middle-ground between the filtered sequences and the json tarballs, we also provide the full sequences,
52 | which are directly converted from the json tarballs with no filtering (except the exclusing of empty sketches).
53 | This full sequence file is available `here `_
54 | (warning: 15GB download), and contains the sequence representation of all the sketches in the dataset.
55 | We note that although it is mostly equivalent to the data contained in the tarballs for ML purposes, it is much smaller
56 | as many of the original identifiers are discarded and renamed to sequential indices.
57 |
58 | The full sequence may be accessed in the same fashion as the filtered sequences (using the `sketchgraphs.data.flat_array`
59 | module). Note that the full sequence file contains all sketches, and does not provide any train / test split.
60 | Additionally, users should be warned that the dataset contains substantial outliers (for example, the largest sketch in the
61 | dataset contains more than 350 thousand operations, whereas the 99th percentile is "only" 640 operations).
62 |
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SketchGraphs: A Large-Scale Dataset for Modeling Relational Geometry in Computer-Aided Design
2 |
3 | SketchGraphs is a dataset of 15 million sketches extracted from real-world CAD models intended to facilitate research in both ML-aided design and geometric program induction.
4 |
5 | 
6 |
7 |
8 | Each sketch is represented as a geometric constraint graph where edges denote designer-imposed geometric relationships between primitives, the nodes of the graph.
9 |
10 | 
11 |
12 | Video: https://youtu.be/ki784S3wjqw
13 | Paper: https://arxiv.org/abs/2007.08506
14 |
15 | See [demo notebook](demos/sketchgraphs_demo.ipynb) for a quick overview of the data representions in SketchGraphs as well as an example of solving constraints via Onshape's API.
16 |
17 | ## Installation
18 |
19 | SketchGraphs can be installed using pip:
20 |
21 | ```bash
22 | >> pip install -e SketchGraphs
23 | ```
24 |
25 | This will provide you with the necessary dependencies to load and explore the data.
26 | However, to train the models, you will need to additionally install [pytorch](https://pytorch.org/)
27 | and [torch-scatter](https://github.com/rusty1s/pytorch_scatter).
28 |
29 | ## Data
30 |
31 | We provide our dataset in a number of forms, some of which may be more appropriate for your desired usage.
32 | The following data files are provided:
33 |
34 | - The raw json data as obtained from Onshape. This is provided as a set of 128 tar archives which are compressed
35 | using [zstandard](https://facebook.github.io/zstd). They total about 43GB of data. In general, we only recommend these for advanced usage, as
36 | they are heavy and require extensive processing to manipulate. Users interested in working with the raw data
37 | may wish to inspect `sketchgraphs.pipeline.make_sketch_dataset` and `sketchgraphs.pipeline.make_sequence_dataset`
38 | to view our data processing scripts. The data is available for download [here](https://sketchgraphs.cs.princeton.edu/shards).
39 |
40 | - A dataset of construction sequences. This is provided as a single file, stored in a custom binary format.
41 | This format is much more concise (as it eliminates many of the unique identifiers used in the raw JSON format),
42 | and is better suited for ML applications. It is supported by our Python libraries and forms the baseline
43 | on which our models are trained. The data is available for download [here](https://sketchgraphs.cs.princeton.edu/sequence/sg_all.npy) (warning: 15GB file!).
44 |
45 | - A filtered dataset of construction sequences. This is provided as a single file, similarly stored in a custom
46 | binary format. This dataset is similar to the sequence dataset, but simplified by filtering out sketches
47 | that are too large or too small, and only includes a simplified set of entities and constraints (while still
48 | capturing a large portion of the data). Additionally, this dataset has been split into training, testing and
49 | validation splits for convenience. We train our models on this subset of the data. You can download the splits
50 | here: [train](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_train.npy)
51 | [validation](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_validation.npy)
52 | [test](https://sketchgraphs.cs.princeton.edu/sequence/sg_t16_test.npy)
53 |
54 | For full documentation of the processing pipeline, see https://princetonlips.github.io/SketchGraphs.
55 |
56 | The original creators of the CAD sketches hold the copyright. See [Onshape Terms of Use 1.g.ii](https://www.onshape.com/legal/terms-of-use#your_content) for additional licensing details.
57 |
58 |
59 | ## Models
60 | In addition to the dataset, we also provide some baseline model implementations to tackle the tasks of generative
61 | modeling and autoconstrain. These models are based on Graph Neural Network approaches and model the sketch as
62 | a graph where vertices are given by the entities in the sketch, and edges by the constraints between those entities.
63 | For more details, please refer to https://princetonlips.github.io/SketchGraphs/models.
64 |
65 |
66 | ## Citation
67 | If you use this dataset in your research, please cite:
68 | ```
69 | @inproceedings{SketchGraphs,
70 | title={Sketch{G}raphs: A Large-Scale Dataset for Modeling Relational Geometry in Computer-Aided Design},
71 | author={Seff, Ari and Ovadia, Yaniv and Zhou, Wenda and Adams, Ryan P.},
72 | booktitle={ICML 2020 Workshop on Object-Oriented Learning},
73 | year={2020}
74 | }
75 | ```
76 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/numerical_parameters.py:
--------------------------------------------------------------------------------
1 | """This module implements functionality to handle numerical parameters in the sketch.
2 |
3 | """
4 |
5 | import math
6 | import re
7 |
8 | import numpy as np
9 | from sklearn.cluster import KMeans
10 |
11 | _IMPLICIT_MUL_PATTERN = re.compile(r"(?<=\))\s*(?=\w+) | (?<=\d)\s*(?=[A-Z]+)", re.X)
12 |
13 |
14 | ##### Constraint handling
15 |
16 | METER_CONVERSIONS = {
17 | 'METER': 1,
18 | 'METERS': 1,
19 | 'M': 1,
20 | 'MILLIMETER': 1e-3,
21 | 'MILLIMETERS': 1e-3,
22 | 'MM': 1e-3,
23 | 'CM': 1e-2,
24 | 'CENTIMETER': 1e-2,
25 | 'CENTIMETERS': 1e-2,
26 | 'INCHES': 0.0254,
27 | 'INCH': 0.0254,
28 | 'IN': 0.0254,
29 | 'FOOT': 0.3048,
30 | 'FEET': 0.048,
31 | 'FT': 0.3048,
32 | 'YD': 0.9144
33 | }
34 |
35 | DEGREE_CONVERSIONS = {
36 | 'DEG': 1,
37 | 'DEGREE': 1,
38 | 'RADIAN': 180/np.pi,
39 | 'RAD': 180/np.pi
40 | }
41 |
42 |
43 | def normalize_expression(expression, parameter_id):
44 | """Converts a numerical expression into a normalized form.
45 |
46 | Parameters
47 | ----------
48 | expression : str
49 | A string representing a quantity parameter value
50 | parameter_id : str
51 | A parameterId string; must be one of 'angle' or 'length'
52 |
53 | Returns
54 | -------
55 | norm_expression : str or None
56 | Normalized expression string if successful, or None otherwise.
57 | """
58 | expression = expression.upper().strip()
59 |
60 | if parameter_id == 'angle':
61 | conversions = DEGREE_CONVERSIONS
62 | default_unit = 'DEGREE'
63 | elif parameter_id == 'length':
64 | conversions = METER_CONVERSIONS
65 | default_unit = 'METER'
66 | else:
67 | raise ValueError('parameter_id must be one of angle or length')
68 |
69 | if '#' in expression or 'lookup' in expression:
70 | # variables in feature script are not supported
71 | return None
72 |
73 | # Adds implicit multiplication signs in order to be more easily parsed.
74 | expression = _IMPLICIT_MUL_PATTERN.sub('*', expression)
75 |
76 | try:
77 | value = eval(expression, {'PI': np.pi, 'SQRT': math.sqrt, 'TAN': math.tan}, conversions)
78 | except:
79 | return None
80 |
81 | value_str = np.format_float_positional(value, precision=4, trim='0', fractional=False)
82 | return '{} {}'.format(value_str, default_unit)
83 |
84 |
85 | def make_quantization(values, num_points, scheme):
86 | """Find optimal centers for parameter via either uniform, K-means, or CDF-based K-means.
87 |
88 | Obtains a quantization scheme for the given values, according to a given strategy.
89 | Several schemes are supported, although we prefer the 'cdf' scheme, a hybrid which
90 | avoids some issues with large outliers in the datasets faced by other schemes.
91 |
92 | Parameters
93 | ----------
94 | values : np.array
95 | An array of values representing a sample of the values to quantize
96 | num_points : int
97 | Number of points to obtain in the dictionary
98 | scheme : str
99 | Indicates the quantization scheme to use, must be 'uniform', 'kmeans' or 'cdf'
100 |
101 | Returns
102 | -------
103 | np.array
104 | Array of quantization codes
105 | """
106 |
107 | if scheme == 'uniform':
108 | edges = np.linspace(np.min(values), np.max(values), num_points+1)
109 | return np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
110 |
111 | if scheme == 'kmeans':
112 | km = KMeans(n_clusters=num_points)
113 | km.fit(values.reshape(-1, 1))
114 | return np.sort(np.squeeze(km.cluster_centers_))
115 |
116 | if scheme == 'cdf':
117 | values, cdf = make_unique_cdf(values)
118 | cdf_centers = make_quantization(cdf, num_points, 'kmeans')
119 | return np.interp(cdf_centers, cdf, values)
120 |
121 | schemes = ['uniform', 'kmeans', 'cdf']
122 | raise ValueError("scheme must be one of " + str(schemes))
123 |
124 |
125 | def make_unique_cdf(arr):
126 | """Return 'collapsed' cdf of arr (identical/close arr vals all have same cdf point).
127 |
128 | Parameters
129 | ----------
130 | arr: array of parameter values
131 |
132 | Returns
133 | -------
134 | sorted_arr: sorted copy of arr.
135 | cdf: collapsed cdf of arr.
136 | """
137 | cdf = np.linspace(0, 1, len(arr))
138 | sorted_arr = np.sort(arr)
139 | last_unique_idx = 0
140 | for idx, arr_val in enumerate(np.append(sorted_arr, [np.inf])):
141 | if not np.isclose(arr_val, sorted_arr[last_unique_idx]):
142 | cdf[last_unique_idx:idx] = np.mean(cdf[last_unique_idx:idx])
143 | last_unique_idx = idx
144 | return sorted_arr, cdf
145 |
--------------------------------------------------------------------------------
/sketchgraphs_models/torch_extensions/segment_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ._repeat_interleave import repeat_interleave
4 |
5 |
6 | class SegmentAvgPool1DLoop(torch.autograd.Function):
7 | @staticmethod
8 | def forward(ctx, values, scopes):
9 | ctx.save_for_backward(scopes)
10 | ctx.input_length = values.shape[0]
11 | result = values.new_empty((scopes.shape[0],) + values.shape[1:])
12 |
13 | for i in range(len(scopes)):
14 | x = values.narrow(0, scopes[i, 0], scopes[i, 1])
15 | result[i] = x.mean(dim=0)
16 |
17 | return result
18 |
19 | @staticmethod
20 | def backward(ctx, grad_output):
21 | scopes, = ctx.saved_tensors
22 | return segment_avg_pool1d_backward(grad_output, scopes, ctx.input_length), None
23 |
24 |
25 | _avg_pool_docstring = """
26 | Segmented average pooling.
27 |
28 | This function computes an average pool over a set of values, where the
29 | pool is taken over segments defined by scope.
30 |
31 | Parameters
32 | ----------
33 | values : torch.Tensor
34 | A 1-dimensional tensor.
35 | scopes : torch.Tensor
36 | a 2-dimensional integer tensor representing segments.
37 | Each row of scopes represents a segment, which starts at ``scopes[i, 0]``,
38 | and has length ``scopes[i, 1]``.
39 |
40 | Returns
41 | -------
42 | torch.Tensor
43 | A tensor representing the average value for each segment.
44 | """
45 |
46 |
47 | def segment_avg_pool1d_loop(values, scopes):
48 | return SegmentAvgPool1DLoop.apply(values, scopes)
49 |
50 |
51 | def segment_avg_pool1d_scatter(values, scopes):
52 | import torch_scatter
53 |
54 | lengths = scopes.select(1, 1)
55 | offsets = lengths.new_zeros(len(lengths) + 1)
56 | torch.cumsum(lengths, 0, out=offsets[1:])
57 |
58 | return torch_scatter.segment_mean_csr(values, offsets)
59 |
60 |
61 | segment_avg_pool1d_loop.__docstring__ = _avg_pool_docstring
62 | segment_avg_pool1d_scatter.__docstring__ = _avg_pool_docstring
63 |
64 |
65 | def segment_avg_pool1d_backward(grad_output, scopes, input_length):
66 | """ Backward pass for segmented average pooling. """
67 | scopes_length = scopes.select(1, 1)
68 | norm_grad = torch.true_divide(grad_output, scopes_length.type_as(grad_output).unsqueeze(1))
69 |
70 | result = grad_output.new_empty([input_length] + list(norm_grad.shape[1:]))
71 | return repeat_interleave(norm_grad, scopes, dim=0, out=result)
72 |
73 |
74 | def segment_max_pool1d_backward(grad_output, scopes, max_indices, input_length):
75 | result = grad_output.new_zeros((input_length, max_indices.shape[1]))
76 |
77 | scopes_offsets = scopes.select(1, 0).unsqueeze(1)
78 | linear_max_indices = scopes_offsets + max_indices
79 |
80 | result.scatter_(0, linear_max_indices, grad_output)
81 | return result
82 |
83 |
84 | class SegmentMaxPool1DLoop(torch.autograd.Function):
85 | @staticmethod
86 | def forward(ctx, values, scopes, return_indices=False):
87 | result = values.new_empty([scopes.shape[0], values.shape[1]])
88 | result_idx = values.new_empty([scopes.shape[0], values.shape[1]], dtype=torch.int64)
89 |
90 | for i in range(len(scopes)):
91 | x = values.narrow(0, scopes[i, 0], scopes[i, 1]).t().unsqueeze(0)
92 | r, ri = torch.nn.functional.adaptive_max_pool1d(x, 1, return_indices=True)
93 | r = r.squeeze(0).squeeze(1)
94 | ri = ri.squeeze(0).squeeze(1)
95 |
96 | result[i] = r
97 | result_idx[i] = ri
98 |
99 | ctx.save_for_backward(scopes, result_idx)
100 | ctx.input_length = values.shape[0]
101 | ctx.mark_non_differentiable(result_idx)
102 |
103 | if return_indices:
104 | return result, result_idx
105 | else:
106 | return result
107 |
108 | @staticmethod
109 | def backward(ctx, grad_output, *args):
110 | scopes, result_idx = ctx.saved_tensors
111 | return segment_max_pool1d_backward(grad_output, scopes, result_idx, ctx.input_length), None, None
112 |
113 |
114 | def segment_max_pool1d(values: torch.Tensor, scopes: torch.Tensor, return_indices=False) -> torch.Tensor:
115 | """
116 | Computes the maximum value in each segment.
117 |
118 | Parameters
119 | ----------
120 | values : torch.Tensor
121 | A 1-dimensional tensor.
122 | scopes : torch.Tensor
123 | a 2-dimensional integer tensor representing segments.
124 | Each row of scopes represents a segment, which starts at ``scopes[i, 0]``,
125 | and has length ``scopes[i, 1]``.
126 |
127 | Returns
128 | -------
129 | torch.Tensor
130 | A tensor representing the maximum value for each segment.
131 | """
132 | return SegmentMaxPool1DLoop.apply(values, scopes, return_indices)
133 |
134 |
135 | try:
136 | import torch_scatter
137 |
138 | segment_avg_pool1d = segment_avg_pool1d_scatter
139 | except ImportError:
140 | segment_avg_pool1d = segment_avg_pool1d_loop
141 |
142 |
143 | __all__ = ['segment_avg_pool1d']
144 |
--------------------------------------------------------------------------------
/tests/test_graph_model.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | import pickle
4 |
5 | import numpy as np
6 |
7 | import pytest
8 |
9 | from sketchgraphs_models.graph import dataset as graph_dataset
10 | from sketchgraphs_models.graph import model as graph_model
11 |
12 | from sketchgraphs.data.sketch import Sketch
13 |
14 | @pytest.fixture
15 | def edge_feature_mapping():
16 | """Dummy quantization functions."""
17 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f:
18 | mapping = pickle.load(f)
19 |
20 | return graph_dataset.EdgeFeatureMapping(
21 | mapping['edge']['angle'], mapping['edge']['length'])
22 |
23 | @pytest.fixture
24 | def node_feature_mapping():
25 | with gzip.open('tests/testdata/sg_t16.stats.pkl.gz', 'rb') as f:
26 | mapping = pickle.load(f)
27 | return graph_dataset.EntityFeatureMapping(mapping['node'])
28 |
29 |
30 | def test_compute_model(sketches, node_feature_mapping, edge_feature_mapping):
31 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
32 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=36)
33 |
34 | batch = [dataset[i] for i in range(10)]
35 | batch_input = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping)
36 |
37 | model = graph_model.make_graph_model(32, {**node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions})
38 | result = model(batch_input)
39 | assert 'node_embedding' in result
40 |
41 |
42 | def test_compute_losses(sketches, node_feature_mapping, edge_feature_mapping):
43 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
44 | dataset = graph_dataset.GraphDataset(sequences, node_feature_mapping, edge_feature_mapping, seed=36)
45 |
46 | batch = [dataset[i] for i in range(10)]
47 | batch_input = graph_dataset.collate(batch, node_feature_mapping, edge_feature_mapping)
48 |
49 | feature_dimensions = {**node_feature_mapping.feature_dimensions, **edge_feature_mapping.feature_dimensions}
50 |
51 | model = graph_model.make_graph_model(32, feature_dimensions)
52 | model_output = model(batch_input)
53 |
54 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
55 | model_output, batch_input, feature_dimensions)
56 | assert isinstance(losses, dict)
57 |
58 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts'])
59 | assert isinstance(avg_losses, dict)
60 |
61 | for t in edge_metrics:
62 | assert edge_metrics[t][0].shape == edge_metrics[t][1].shape
63 |
64 | for t in edge_metrics:
65 | assert node_metrics[t][0].shape == node_metrics[t][1].shape
66 |
67 |
68 | def test_compute_model_no_entity_features(sketches, edge_feature_mapping):
69 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
70 | dataset = graph_dataset.GraphDataset(sequences, None, edge_feature_mapping, seed=36)
71 |
72 | batch = [dataset[i] for i in range(10)]
73 | batch_input = graph_dataset.collate(batch, None, edge_feature_mapping)
74 |
75 | model = graph_model.make_graph_model(
76 | 32, {**edge_feature_mapping.feature_dimensions}, readout_entity_features=False)
77 | result = model(batch_input)
78 | assert 'node_embedding' in result
79 |
80 |
81 | def test_compute_losses_no_entity_features(sketches, edge_feature_mapping):
82 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
83 | dataset = graph_dataset.GraphDataset(sequences, None, edge_feature_mapping, seed=36)
84 |
85 | batch = [dataset[i] for i in range(10)]
86 | batch_input = graph_dataset.collate(batch, None, edge_feature_mapping)
87 |
88 | feature_dimensions = {**edge_feature_mapping.feature_dimensions}
89 |
90 | model = graph_model.make_graph_model(32, feature_dimensions, readout_entity_features=False)
91 | model_output = model(batch_input)
92 |
93 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
94 | model_output, batch_input, feature_dimensions)
95 | assert isinstance(losses, dict)
96 |
97 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts'])
98 | assert isinstance(avg_losses, dict)
99 |
100 | for t in edge_metrics:
101 | assert edge_metrics[t][0].shape == edge_metrics[t][1].shape
102 |
103 | for t in edge_metrics:
104 | assert node_metrics[t][0].shape == node_metrics[t][1].shape
105 |
106 |
107 | def test_compute_model_subnode(sketches):
108 | sequences = list(map(graph_dataset.sketch_to_sequence, sketches))
109 | dataset = graph_dataset.GraphDataset(sequences, seed=36)
110 |
111 | batch = [dataset[i] for i in range(10)]
112 | batch_input = graph_dataset.collate(batch)
113 |
114 | model = graph_model.make_graph_model(
115 | 32, feature_dimensions={}, readout_entity_features=False, readout_edge_features=False)
116 | result = model(batch_input)
117 | assert 'node_embedding' in result
118 |
119 | losses, _, edge_metrics, node_metrics = graph_model.compute_losses(
120 | result, batch_input, {})
121 | assert isinstance(losses, dict)
122 |
123 | avg_losses = graph_model.compute_average_losses(losses, batch_input['graph_counts'])
124 | assert isinstance(avg_losses, dict)
125 |
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/eval_likelihood.py:
--------------------------------------------------------------------------------
1 | """Evaluates an auto-constraint model in terms of the sequential assigned likelihood. """
2 |
3 | import argparse
4 | import itertools
5 | import gzip
6 | import pickle
7 |
8 | import numpy as np
9 | import torch
10 | import tqdm
11 |
12 | from sketchgraphs import data as datalib
13 | from sketchgraphs.data import flat_array
14 |
15 | from sketchgraphs_models import training
16 | from sketchgraphs_models.autoconstraint import dataset, model as auto_model, eval
17 |
18 |
19 | def _edge_ops(ops):
20 | return [op for op in ops if isinstance(op, datalib.EdgeOp)]
21 |
22 | def _node_ops(ops):
23 | return [op for op in ops if isinstance(op, datalib.NodeOp)]
24 |
25 |
26 | def ops_to_batch(ops, node_feature_mappings):
27 | """Given a sequence of operations representing a sketch, constructs a batch
28 | """
29 | node_ops = [op for op in ops if isinstance(op, datalib.NodeOp)]
30 |
31 | batch = []
32 |
33 | for i, op in enumerate(ops):
34 | if i == 0:
35 | # First external node does not induce edge stop token
36 | continue
37 |
38 | if op.label == datalib.ConstraintType.Subnode:
39 | continue
40 |
41 | ops_in_graph = ops[:i]
42 |
43 | features = dataset.process_node_and_edge_ops(
44 | node_ops, _edge_ops(ops_in_graph),
45 | len(_node_ops(ops_in_graph)), node_feature_mappings)
46 |
47 | if isinstance(op, datalib.NodeOp):
48 | # Stop problem
49 | partner_index = -1
50 | target_edge_label = -1
51 | else:
52 | partner_index = op.references[-1]
53 | target_edge_label = dataset.EDGE_IDX_MAP[op.label]
54 |
55 | features['partner_index'] = partner_index
56 | features['target_edge_label'] = target_edge_label
57 |
58 | batch.append(features)
59 |
60 | # Append final edge stop target
61 | batch.append({
62 | **dataset.process_node_and_edge_ops(node_ops, _edge_ops(ops), len(node_ops), node_feature_mappings),
63 | 'partner_index': -1,
64 | 'target_edge_label': -1
65 | })
66 |
67 | return batch
68 |
69 |
70 | class EdgeLikelihoodEvaluator:
71 | def __init__(self, model, node_feature_mappings, device=None):
72 | self.model = model
73 | self.node_feature_mappings = node_feature_mappings
74 | self.device = device
75 |
76 | def edge_likelihood(self, ops):
77 | if ops[-1].label == 'Stop':
78 | ops = ops[:-1]
79 |
80 | if len(ops) == 1:
81 | return np.empty([0, 2])
82 |
83 | batch_list = ops_to_batch(ops, self.node_feature_mappings)
84 |
85 | batch = dataset.collate(batch_list)
86 | batch_device = training.load_cuda_async(batch, self.device)
87 |
88 | with torch.no_grad():
89 | readout = self.model(batch_device)
90 | losses, _ = auto_model.compute_losses(batch_device, readout, reduction='none')
91 |
92 | losses_flat = torch.zeros((len(batch_list), 2))
93 | losses_flat[batch['partner_index'].index, 0] = losses['edge_partner'].cpu()
94 | losses_flat[batch['stop_partner_index_index'], 0] = losses['edge_stop'].cpu()
95 |
96 | losses_flat[batch['partner_index'].index, 1] = losses['edge_label'].cpu()
97 |
98 | losses_flat = losses_flat[torch.argsort(batch['sorted_indices'])]
99 |
100 | return losses_flat.numpy()
101 |
102 |
103 |
104 | def main():
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument('--model', type=str, required=True)
107 | parser.add_argument('--dataset', type=str, required=True)
108 | parser.add_argument('--device', type=str, default='cpu')
109 | parser.add_argument('--output', type=str)
110 | parser.add_argument('--max_predictions', type=int, default=None)
111 |
112 |
113 | args = parser.parse_args()
114 |
115 | device = torch.device(args.device)
116 |
117 | print('Loading trained model')
118 | model, node_feature_mapping = eval.load_sampling_model(args.model)
119 | model = model.eval().to(device)
120 |
121 | print('Loading testing data')
122 | seqs = flat_array.load_dictionary_flat(np.load(args.dataset, mmap_mode='r'))['sequences']
123 |
124 | likelihood_evaluation = EdgeLikelihoodEvaluator(model, node_feature_mapping, device)
125 |
126 | length = len(seqs)
127 | if args.max_predictions is not None:
128 | length = min(length, args.max_predictions)
129 |
130 | results = []
131 | average_likelihoods = np.empty(length)
132 | sequence_length = np.empty(length)
133 |
134 | for i in tqdm.trange(length):
135 | seq = seqs[i]
136 | result = likelihood_evaluation.edge_likelihood(seq)
137 | results.append({
138 | 'seq': seq,
139 | 'likelihood': result
140 | })
141 |
142 | sequence_length[i] = result.shape[0]
143 |
144 | if result.shape[0] == 0:
145 | average_likelihoods[i] = 0.0
146 | else:
147 | average_likelihoods[i] = np.mean(np.sum(result, axis=-1), axis=0)
148 |
149 | print('Average bit per edge {0:.3f}'.format(np.average(average_likelihoods, weights=sequence_length) / np.log(2)))
150 |
151 | if args.output is None:
152 | return
153 |
154 | print('Saving to output {0}'.format(args.output))
155 | with gzip.open(args.output, 'wb') as f:
156 | pickle.dump(results, f, protocol=4)
157 |
158 |
159 | if __name__ == '__main__':
160 | main()
161 |
--------------------------------------------------------------------------------
/sketchgraphs_models/nn/summary.py:
--------------------------------------------------------------------------------
1 | """This module implements utilities to compute summary statistics.
2 | """
3 |
4 | import posixpath
5 | import torch
6 | import torchmetrics
7 |
8 |
9 | class ClassificationSummary:
10 | """ Simple class to keep track of summaries of a classification problem. """
11 | def __init__(self, num_outcomes=2, device=None):
12 | """ Initializes a new summary class with the given number of outcomes.
13 |
14 | Parameters
15 | ----------
16 | num_outcomes : int
17 | the number of possible outcomes of the classification problem.
18 | device : torch.device, optional
19 | device on which to place the recorded statistics.
20 | """
21 | self.recorded = torch.zeros(num_outcomes * num_outcomes, dtype=torch.int32, device=device)
22 | self.num_outcomes = num_outcomes
23 |
24 | @property
25 | def prediction_matrix(self):
26 | """Returns a `torch.Tensor` representing the prediction matrix."""
27 | return self.recorded.view((self.num_outcomes, self.num_outcomes))
28 |
29 | def record_statistics(self, labels, predictions):
30 | """ Records statistics for a batch of predictions.
31 |
32 | Parameters
33 | ----------
34 | labels : torch.Tensor
35 | an array of true labels in integer format. Each label must correspond to an
36 | integer in 0 to num_outcomes - 1 inclusive.
37 | predictions : torch.Tensor
38 | an array of predicted labels. Must follow the same format as `labels`.
39 | """
40 | indices = torch.add(labels.int(), predictions.int(), alpha=self.num_outcomes).long().to(device=self.recorded.device)
41 | self.recorded = self.recorded.scatter_add_(
42 | 0, indices, torch.ones_like(indices, dtype=torch.int32))
43 |
44 | def reset_statistics(self):
45 | """ Resets statistics recorded in this accumulator. """
46 | self.recorded = torch.zeros_like(self.recorded)
47 |
48 | def accuracy(self):
49 | """ Compute the accuracy of the recorded problem. """
50 | num_correct = self.prediction_matrix.diag().sum()
51 | num_total = self.recorded.sum()
52 |
53 | return num_correct.float() / num_total.float()
54 |
55 | def confusion_matrix(self):
56 | """Returns a `torch.Tensor` representing the confusion matrix."""
57 | return self.prediction_matrix.float() / self.prediction_matrix.sum().float()
58 |
59 | def cohen_kappa(self):
60 | """Computes the Cohen kappa measure of agreement.
61 | """
62 | pm = self.prediction_matrix.float()
63 | N = self.recorded.sum().float()
64 |
65 | p_observed = pm.diag().sum() / N
66 | p_expected = torch.dot(pm.sum(dim=0), pm.sum(dim=1)) / (N * N)
67 |
68 | if p_expected == 1:
69 | return 1
70 | else:
71 | return 1 - (1 - p_observed) / (1 - p_expected)
72 |
73 | def marginal_labels(self):
74 | """Computes the empirical marginal distribution of the true labels."""
75 | return self.prediction_matrix.sum(dim=0).float() / self.recorded.sum().float()
76 |
77 | def marginal_predicted(self):
78 | """Computes the empirical marginal distribution of the predicted labels."""
79 | return self.prediction_matrix.sum(dim=1).float() / self.recorded.sum().float()
80 |
81 | def write_tensorboard(self, writer, prefix="", global_step=None, **kwargs):
82 | """Write the accuracy and kappa metrics to a tensorboard writer.
83 |
84 | Parameters
85 | ----------
86 | writer: torch.utils.tensorboard.SummaryWriter
87 | The writer to which the metrics will be written
88 | prefix: str, optional
89 | Optional prefix for the name under which the metrics will be written
90 | global_step: int, optional
91 | Global step at which the metric is recorded
92 | **kwargs
93 | Further arguments to `torch.utils.tensorboard.SummaryWriter.add_scalar`.
94 | """
95 | writer.add_scalar(posixpath.join(prefix, "kappa"), self.cohen_kappa(), global_step, **kwargs)
96 | writer.add_scalar(posixpath.join(prefix, "accuracy"), self.accuracy(), global_step, **kwargs)
97 |
98 |
99 | class CohenKappa(torchmetrics.Metric):
100 | """A pytorch-lightning compatible metric which computes Cohen's kappa score of agreement.
101 | """
102 | recorded: torch.Tensor
103 |
104 | def __init__(self, num_outcomes=2, compute_on_step=True, dist_sync_on_step=False):
105 | super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
106 |
107 | self.num_outcomes = num_outcomes
108 | self.add_state("recorded", torch.zeros(num_outcomes * num_outcomes, dtype=torch.int32), dist_reduce_fx='sum')
109 |
110 | def update(self, preds: torch.Tensor, targets: torch.Tensor):
111 | indices = torch.add(targets.int(), preds.int(), alpha=self.num_outcomes).long().to(device=self.recorded.device)
112 | self.recorded.scatter_add_(0, indices, torch.ones_like(indices, dtype=torch.int32))
113 |
114 | def compute(self):
115 | pm = self.recorded.view(self.num_outcomes, self.num_outcomes).float()
116 | N = self.recorded.sum().float()
117 |
118 | if N == 0:
119 | return pm.new_tensor(0.0)
120 |
121 | p_observed = pm.diag().sum() / N
122 | p_expected = torch.dot(pm.sum(dim=0), pm.sum(dim=1)) / (N * N)
123 |
124 | if p_expected == 1:
125 | return pm.new_tensor(1.0)
126 | else:
127 | return 1 - (1 - p_observed) / (1 - p_expected)
128 |
--------------------------------------------------------------------------------
/sketchgraphs/data/_plotting.py:
--------------------------------------------------------------------------------
1 | """Functions for drawing sketches using matplotlib
2 |
3 | This module implements local plotting functionality in order to render Onshape sketches using matplotlib.
4 |
5 | """
6 |
7 | import math
8 | from contextlib import nullcontext
9 |
10 | import matplotlib as mpl
11 | import matplotlib.patches
12 | import matplotlib.pyplot as plt
13 |
14 | from ._entity import Arc, Circle, Line, Point
15 |
16 |
17 | def _get_linestyle(entity):
18 | return '--' if entity.isConstruction else '-'
19 |
20 | def sketch_point(ax, point: Point, color='black', show_subnodes=False):
21 | ax.scatter(point.x, point.y, c=color, marker='.')
22 |
23 | def sketch_line(ax, line: Line, color='black', show_subnodes=False):
24 | start_x, start_y = line.start_point
25 | end_x, end_y = line.end_point
26 | if show_subnodes:
27 | marker = '.'
28 | else:
29 | marker = None
30 | ax.plot((start_x, end_x), (start_y, end_y), color, linestyle=_get_linestyle(line), linewidth=1, marker=marker)
31 |
32 | def sketch_circle(ax, circle: Circle, color='black', show_subnodes=False):
33 | patch = matplotlib.patches.Circle(
34 | (circle.xCenter, circle.yCenter), circle.radius,
35 | fill=False, linestyle=_get_linestyle(circle), color=color)
36 | if show_subnodes:
37 | ax.scatter(circle.xCenter, circle.yCenter, c=color, marker='.', zorder=20)
38 | ax.add_patch(patch)
39 |
40 | def sketch_arc(ax, arc: Arc, color='black', show_subnodes=False):
41 | angle = math.atan2(arc.yDir, arc.xDir) * 180 / math.pi
42 | startParam = arc.startParam * 180 / math.pi
43 | endParam = arc.endParam * 180 / math.pi
44 |
45 | if arc.clockwise:
46 | startParam, endParam = -endParam, -startParam
47 |
48 | ax.add_patch(
49 | matplotlib.patches.Arc(
50 | (arc.xCenter, arc.yCenter), 2*arc.radius, 2*arc.radius,
51 | angle=angle, theta1=startParam, theta2=endParam,
52 | linestyle=_get_linestyle(arc), color=color))
53 |
54 | if show_subnodes:
55 | ax.scatter(arc.xCenter, arc.yCenter, c=color, marker='.')
56 | ax.scatter(*arc.start_point, c=color, marker='.', zorder=40)
57 | ax.scatter(*arc.end_point, c=color, marker='.', zorder=40)
58 |
59 |
60 | _PLOT_BY_TYPE = {
61 | Arc: sketch_arc,
62 | Circle: sketch_circle,
63 | Line: sketch_line,
64 | Point: sketch_point
65 | }
66 |
67 |
68 | def render_sketch(sketch, ax=None, show_axes=False, show_origin=False,
69 | hand_drawn=False, show_subnodes=False, show_points=True):
70 | """Renders the given sketch using matplotlib.
71 |
72 | Parameters
73 | ----------
74 | sketch : Sketch
75 | The sketch instance to render
76 | ax : matplotlib.Axis, optional
77 | Axis object on which to render the sketch. If None, a new figure is created.
78 | show_axes : bool
79 | Indicates whether axis lines should be drawn
80 | show_origin : bool
81 | Indicates whether origin point should be drawn
82 | hand_drawn : bool
83 | Indicates whether to emulate a hand-drawn appearance
84 | show_subnodes : bool
85 | Indicates whether endpoints/centerpoints should be drawn
86 | show_points : bool
87 | Indicates whether Point entities should be drawn
88 |
89 | Returns
90 | -------
91 | matplotlib.Figure
92 | If `ax` is not provided, the newly created figure. Otherwise, `None`.
93 | """
94 | with plt.xkcd(
95 | scale=1, length=100, randomness=3) if hand_drawn else nullcontext():
96 |
97 | if ax is None:
98 | fig = plt.figure()
99 | ax = fig.add_subplot(111, aspect='equal')
100 | else:
101 | fig = None
102 |
103 | # Eliminate upper and right axes
104 | ax.spines['right'].set_color('none')
105 | ax.spines['top'].set_color('none')
106 |
107 | if not show_axes:
108 | ax.set_yticklabels([])
109 | ax.set_xticklabels([])
110 | _ = [line.set_marker('None') for line in ax.get_xticklines()]
111 | _ = [line.set_marker('None') for line in ax.get_yticklines()]
112 |
113 | # Eliminate lower and left axes
114 | ax.spines['left'].set_color('none')
115 | ax.spines['bottom'].set_color('none')
116 |
117 | if show_origin:
118 | point_size = mpl.rcParams['lines.markersize'] * 1
119 | ax.scatter(0, 0, s=point_size, c='black')
120 |
121 | for ent in sketch.entities.values():
122 | sketch_fn = _PLOT_BY_TYPE.get(type(ent))
123 | if isinstance(ent, Point):
124 | if not show_points:
125 | continue
126 | if sketch_fn is None:
127 | continue
128 | sketch_fn(ax, ent, show_subnodes=show_subnodes)
129 |
130 | # Rescale axis limits
131 | ax.relim()
132 | ax.autoscale_view()
133 |
134 | return fig
135 |
136 |
137 | def render_graph(graph, filename, show_node_idxs=False):
138 | """Renders the given pgv.AGraph to an image file.
139 |
140 | Parameters
141 | ----------
142 | graph : pgv.AGraph
143 | The graph to render
144 | filename : string
145 | Where to save the image file
146 | show_node_idxs: bool
147 | If true, append node indexes to their labels
148 |
149 |
150 | Returns
151 | -------
152 | None
153 | """
154 | if show_node_idxs:
155 | for idx, node in enumerate(graph.nodes()):
156 | node.attr['label'] += ' (' + str(idx) + ')'
157 | graph.layout('dot')
158 | graph.draw(filename)
159 |
160 |
161 | __all__ = ['render_sketch', 'render_graph']
--------------------------------------------------------------------------------
/tests/test_torch_extensions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 |
5 | from scipy.special import logsumexp
6 |
7 | from sketchgraphs_models.torch_extensions import _repeat_interleave, segment_ops, segment_pool
8 |
9 |
10 | def test_repeat_python():
11 | x = np.random.randn(40).reshape(4, 10)
12 | times = [2, 5, 0, 1]
13 |
14 | expected = np.repeat(x, times, axis=0)
15 | result = _repeat_interleave.repeat_interleave(torch.tensor(x), torch.tensor(times), 0)
16 |
17 | assert np.allclose(result.numpy(), expected)
18 |
19 |
20 | def test_segment_logsumexp_python():
21 | x = np.random.randn(40)
22 | lengths = [5, 10, 6, 4, 15]
23 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
24 |
25 | scopes = np.stack((offsets, lengths), axis=1)
26 |
27 | expected = np.array([logsumexp(x[s[0]:s[0] + s[1]]) for s in scopes])
28 | result = segment_ops.segment_logsumexp_python(torch.tensor(x), torch.tensor(scopes))
29 |
30 | assert np.allclose(result, expected)
31 |
32 |
33 | def test_segment_logsumexp_python_grad():
34 | x = np.random.randn(40)
35 |
36 | lengths = [5, 10, 6, 4, 15]
37 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
38 |
39 | scopes = np.stack((offsets, lengths), axis=1)
40 |
41 | torch.autograd.gradcheck(
42 | segment_ops.segment_logsumexp_python,
43 | (torch.tensor(x, requires_grad=True), torch.tensor(scopes)))
44 |
45 |
46 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
47 | def test_segment_logsumexp_scatter(device):
48 | x = np.random.randn(40)
49 | lengths = [0, 5, 10, 6, 4, 15, 0]
50 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
51 |
52 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
53 |
54 | expected = np.array([logsumexp(x[s[0]:s[0] + s[1]]) if s[1] != 0 else -np.inf for s in scopes])
55 |
56 | result = segment_ops.segment_logsumexp_scatter(torch.tensor(x, device=device), torch.tensor(scopes, device=device))
57 |
58 | assert np.allclose(result.cpu().numpy(), expected)
59 |
60 |
61 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
62 | def test_segment_logsumexp_scatter_grad(device):
63 | x = np.random.randn(40)
64 |
65 | lengths = [5, 10, 6, 4, 15]
66 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
67 |
68 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
69 |
70 | torch.autograd.gradcheck(
71 | segment_ops.segment_logsumexp_scatter,
72 | (torch.tensor(x, requires_grad=True, device=device), torch.tensor(scopes, device=device)))
73 |
74 |
75 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
76 | def test_segment_logsumexp_scatter_grad_full(device):
77 | x = np.random.randn(20)
78 |
79 | scopes = torch.tensor([[0, 20]], dtype=torch.int64, device=device)
80 |
81 | torch.autograd.gradcheck(
82 | segment_ops.segment_logsumexp_scatter,
83 | (torch.tensor(x, requires_grad=True, device=device), scopes))
84 |
85 |
86 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
87 | def test_segment_argmax(device):
88 | x = np.random.randn(40)
89 |
90 | lengths = np.array([0, 5, 10, 6, 4, 15, 0])
91 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
92 |
93 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
94 |
95 | x = torch.tensor(x, device=device)
96 | scopes = torch.tensor(scopes, device=device)
97 |
98 | expected_values, expected_index = segment_ops.segment_argmax_python(x, scopes)
99 | result_values, result_index = segment_ops.segment_argmax_scatter(x, scopes)
100 |
101 | result_values = result_values.cpu().numpy()
102 | expected_values = expected_values.cpu().numpy()
103 | result_index = result_index.cpu().numpy()
104 | expected_index = expected_index.cpu().numpy()
105 |
106 | assert np.allclose(result_values, expected_values)
107 | assert np.allclose(result_index[lengths > 0], expected_index[lengths > 0])
108 |
109 |
110 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
111 | def test_segment_argmax_backward(device):
112 | x = np.random.randn(40)
113 |
114 | lengths = [5, 10, 6, 4, 15]
115 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
116 |
117 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
118 |
119 | torch.autograd.gradcheck(
120 | segment_ops.segment_argmax_scatter,
121 | (torch.tensor(x, requires_grad=True, device=device),
122 | torch.tensor(scopes, device=device),
123 | False))
124 |
125 |
126 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
127 | def test_segment_pool(device):
128 | x = np.random.randn(40)
129 |
130 | lengths = [5, 10, 6, 4, 15]
131 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
132 |
133 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
134 |
135 | x = torch.tensor(x, device=device)
136 | scopes = torch.tensor(scopes, device=device)
137 |
138 | expected_values = segment_pool.segment_avg_pool1d_loop(x, scopes)
139 | result_values = segment_pool.segment_avg_pool1d_scatter(x, scopes)
140 |
141 | assert torch.allclose(expected_values, result_values)
142 |
143 |
144 | @pytest.mark.parametrize("device", ["cpu", "cuda"])
145 | def test_segment_pool_2d(device):
146 | x = np.random.randn(40, 5)
147 |
148 | lengths = [5, 10, 6, 4, 15]
149 | offsets = np.concatenate(([0], np.cumsum(lengths[:-1])))
150 |
151 | scopes = np.stack((offsets, lengths), axis=1).astype(np.int64)
152 |
153 | x = torch.tensor(x, device=device)
154 | scopes = torch.tensor(scopes, device=device)
155 |
156 | expected_values = segment_pool.segment_avg_pool1d_loop(x, scopes)
157 | result_values = segment_pool.segment_avg_pool1d_scatter(x, scopes)
158 |
159 | assert torch.allclose(expected_values, result_values)
160 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/model/message_passing.py:
--------------------------------------------------------------------------------
1 | """This module contains the components of the graph model associated with handling
2 | the message passing.
3 | """
4 |
5 | import torch
6 |
7 | from sketchgraphs_models import nn as sg_nn
8 | import sketchgraphs_models.nn.functional
9 | from sketchgraphs.pipeline import graph_model
10 |
11 |
12 | class DenseSparsePreEmbedding(torch.nn.Module):
13 | """This is a generic pre-embedding module which combines dense and sparse pre-embeddings."""
14 | def __init__(self, target_type, feature_embeddings, fixed_embedding_cardinality, fixed_embedding_dim,
15 | sparse_embedding_dim=None, embedding_dim=None):
16 | """Initializes a new DenseSparsePreEmbedding module.
17 |
18 | Parameters
19 | ----------
20 | target_type : enum
21 | The underlying enumeration indicating target types
22 | feature_embeddings : dict of modules
23 | A dictionary of embeddings for each of the sparse feature types.
24 | fixed_embedding_cardinality : int
25 | The number of classes in the fixed (dense) embedding layer.
26 | fixed_embedding_dim : int
27 | The dimension of the embedding for the fixed layer.
28 | sparse_embedding_dim : int, optional
29 | The dimension of the sparse embeddings. If None, assumed to be the same as the dense embedding.
30 | embedding_dim : int, optional
31 | The outpu dimension of the embedding. If None, assumed to be the same as the dense embedding.
32 | """
33 | super(DenseSparsePreEmbedding, self).__init__()
34 | sparse_embedding_dim = sparse_embedding_dim or fixed_embedding_dim
35 | embedding_dim = embedding_dim or fixed_embedding_dim
36 |
37 | self.target_type = target_type
38 | self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings)
39 | self.sparse_embedding_dim = sparse_embedding_dim
40 | self.fixed_embedding_dim = fixed_embedding_dim
41 | self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality, fixed_embedding_dim)
42 | self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim, sparse_embedding_dim, embedding_dim)
43 |
44 | def forward(self, fixed_features, sparse_features):
45 | fixed_embeddings = self.fixed_embedding(fixed_features)
46 | sparse_embeddings = fixed_embeddings.new_zeros((fixed_embeddings.shape[0], self.sparse_embedding_dim))
47 |
48 | for k, embedding_network in self.feature_embeddings.items():
49 | sf = sparse_features[self.target_type[k]]
50 | if sf is None or len(sf.index) == 0:
51 | continue
52 |
53 | assert (sf.index < fixed_embeddings.shape[0]).all()
54 | sparse_embeddings[sf.index] = embedding_network(sf.value)
55 |
56 | return self.dense_merge(fixed_embeddings, sparse_embeddings)
57 |
58 |
59 | class DenseOnlyEmbedding(torch.nn.Module):
60 | """Generic pre-embedding module which encapsulates a pytorch embedding layer.
61 |
62 | This class is simply provided for compatibility with `DenseSparsePreEmbedding`, to construct
63 | models where sparse embeddings are not present.
64 | """
65 | def __init__(self, cardinality, dimension):
66 | super(DenseOnlyEmbedding, self).__init__()
67 | self.fixed_embedding = torch.nn.Embedding(cardinality, dimension)
68 |
69 | def forward(self, features, *_):
70 | return self.fixed_embedding(features)
71 |
72 |
73 | class GraphModelCore(torch.nn.Module):
74 | """Component of the entity model used to compute global features, i.e. graph and node embeddings.
75 |
76 | This component is responsible for the computation that is independent of any
77 | specific target (edge / node). It is split off to ease sharing with sampling models.
78 | """
79 | def __init__(self, message_passing, node_embedding, edge_embedding, graph_embedding):
80 | super(GraphModelCore, self).__init__()
81 | self.message_passing = message_passing
82 | self.node_embedding = node_embedding
83 | self.edge_embedding = edge_embedding
84 | self.graph_embedding = graph_embedding
85 |
86 | def forward(self, data):
87 | graph = data['graph']
88 |
89 | with torch.autograd.profiler.record_function('feature_embedding'):
90 | node_pre_embedding = self.node_embedding(graph.node_features, graph.sparse_node_features)
91 | edge_pre_embedding = self.edge_embedding(graph.edge_features, graph.sparse_edge_features)
92 |
93 | node_post_embedding = self.message_passing(node_pre_embedding, graph.incidence, (edge_pre_embedding,))
94 | graph_embedding = self.graph_embedding(node_post_embedding, graph)
95 | return node_post_embedding, graph_embedding
96 |
97 |
98 | class GraphPostEmbedding(torch.nn.Module):
99 | """Component of the graph model which computes graph-wide representation by aggregating node representations.
100 | """
101 | def __init__(self, hidden_size, graph_embedding_size=None):
102 | super(GraphPostEmbedding, self).__init__()
103 |
104 | if graph_embedding_size is None:
105 | graph_embedding_size = 2 * hidden_size
106 |
107 | self.node_gating_net = torch.nn.Sequential(
108 | torch.nn.Linear(hidden_size, 1),
109 | torch.nn.Sigmoid()
110 | )
111 | self.node_to_graph_net = torch.nn.Linear(hidden_size, graph_embedding_size)
112 |
113 | def forward(self, node_embedding, graph):
114 | scopes = graph_model.scopes_from_offsets(graph.node_offsets)
115 |
116 | transformed_embedding = self.node_gating_net(node_embedding) * self.node_to_graph_net(node_embedding)
117 |
118 | graph_embedding = sg_nn.functional.segment_avg_pool1d(
119 | transformed_embedding, scopes) * graph.node_counts.unsqueeze(-1)
120 |
121 | return graph_embedding
122 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/make_quantization_statistics.py:
--------------------------------------------------------------------------------
1 | """This script computes statistics required for quantization of continuous parameters.
2 |
3 | Many models we use require quantization to process the continuous parameters in the sketch.
4 | This scripts computes the required statistics for quantization of the dataset.
5 |
6 | """
7 |
8 | import argparse
9 | import collections
10 | import functools
11 | import gzip
12 | import itertools
13 | import multiprocessing
14 | import pickle
15 | import os
16 |
17 | import numpy as np
18 | import tqdm
19 |
20 | from sketchgraphs.data import flat_array
21 | from sketchgraphs.data.sequence import EdgeOp
22 | from sketchgraphs.data.sketch import EntityType, ENTITY_TYPE_TO_CLASS
23 | from . import numerical_parameters
24 |
25 |
26 | _EDGE_PARAMETER_IDS = ('angle', 'length')
27 |
28 |
29 | def _worker_edges(dataset_path, worker_idx, num_workers, result_queue):
30 | # Load data
31 | data = flat_array.load_dictionary_flat(np.load(dataset_path, mmap_mode='r'))
32 | sequences = data['sequences']
33 |
34 | # Extract sub-sequence for worker
35 | length_for_worker, num_additional = divmod(len(sequences), num_workers)
36 | offset = worker_idx * length_for_worker + max(worker_idx, num_additional)
37 | if worker_idx < num_additional:
38 | length_for_worker += 1
39 |
40 | seq_indices = range(offset, min((offset+length_for_worker, len(sequences))))
41 |
42 | # Process data
43 | expression_counters = {
44 | k: collections.Counter() for k in _EDGE_PARAMETER_IDS
45 | }
46 |
47 | num_processed = 0
48 |
49 | for seq_idx in seq_indices:
50 | seq = sequences[seq_idx]
51 |
52 | try:
53 | for op in seq:
54 | if not isinstance(op, EdgeOp):
55 | continue
56 |
57 | for k in _EDGE_PARAMETER_IDS:
58 | if k in op.parameters:
59 | value = op.parameters[k]
60 | value = numerical_parameters.normalize_expression(value, k)
61 | expression_counters[k][value] += 1
62 | except Exception:
63 | print('Error processing sequence at index {0}'.format(seq_idx))
64 |
65 | num_processed += 1
66 | if num_processed > 1000:
67 | result_queue.put(num_processed)
68 | num_processed = 0
69 | result_queue.put(num_processed)
70 |
71 | result_queue.put(expression_counters)
72 |
73 |
74 | def _worker_node(param_combination, filepath, num_centers, max_values=None):
75 | label, param_name = param_combination
76 | sequences = flat_array.load_dictionary_flat(np.load(filepath, mmap_mode='r'))['sequences']
77 |
78 | values = (op.parameters[param_name] for op in itertools.chain.from_iterable(sequences)
79 | if op.label == label and param_name in op.parameters)
80 |
81 | if max_values is not None:
82 | values = itertools.islice(values, max_values)
83 |
84 | values = np.array(list(values))
85 | centers = numerical_parameters.make_quantization(values, num_centers, 'cdf')
86 | return centers
87 |
88 |
89 | def process_edges(dataset_path, num_threads):
90 | print('Checking total sketch dataset size.')
91 | total_sequences = len(flat_array.load_dictionary_flat(np.load(dataset_path, mmap_mode='r'))['sequences'])
92 |
93 | result_queue = multiprocessing.Queue()
94 |
95 | workers = []
96 |
97 | for worker_idx in range(num_threads):
98 | workers.append(
99 | multiprocessing.Process(
100 | target=_worker_edges,
101 | args=(dataset_path, worker_idx, num_threads, result_queue)))
102 |
103 | for worker in workers:
104 | worker.start()
105 |
106 | active_workers = len(workers)
107 |
108 | total_result = {}
109 |
110 | print('Processing sequences for edge statistics')
111 | with tqdm.tqdm(total=total_sequences) as pbar:
112 | while active_workers > 0:
113 | result = result_queue.get()
114 |
115 | if isinstance(result, int):
116 | pbar.update(result)
117 | continue
118 |
119 | for k, v in result.items():
120 | total_result.setdefault(k, collections.Counter()).update(v)
121 | active_workers -= 1
122 |
123 | for worker in workers:
124 | worker.join()
125 |
126 | return total_result
127 |
128 |
129 | def process_nodes(dataset_path, num_centers, num_threads):
130 | print('Processing sequences for node statistics')
131 | label_parameter_combinations = [
132 | (t, parameter_name)
133 | for t in (EntityType.Arc, EntityType.Circle, EntityType.Line, EntityType.Point)
134 | for parameter_name in ENTITY_TYPE_TO_CLASS[t].float_ids
135 | ]
136 |
137 | pool = multiprocessing.Pool(num_threads)
138 |
139 | all_centers = pool.map(
140 | functools.partial(
141 | _worker_node, filepath=dataset_path, num_centers=num_centers, max_values=50000),
142 | label_parameter_combinations)
143 |
144 | result = {}
145 | for (t, parameter_name), centers in zip(label_parameter_combinations, all_centers):
146 | result.setdefault(t, {})[parameter_name] = centers
147 |
148 | return result
149 |
150 |
151 | def main():
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument('--input', type=str, help='Input sequence dataset', required=True)
154 | parser.add_argument('--output', type=str, help='Output dataset path', default='meta.pkl.gz')
155 | parser.add_argument('--num_threads', type=int, default=0)
156 | parser.add_argument('--node_num_centers', type=int, default=256)
157 |
158 | args = parser.parse_args()
159 |
160 | num_threads = args.num_threads
161 | if num_threads is None:
162 | num_threads = len(os.sched_getaffinity(0))
163 |
164 | edge_results = process_edges(args.input, num_threads)
165 | node_results = process_nodes(args.input, args.node_num_centers, num_threads)
166 |
167 | print('Saving results in {0}'.format(args.output))
168 | with gzip.open(args.output, 'wb', compresslevel=9) as f:
169 | pickle.dump({
170 | 'edge': edge_results,
171 | 'node': node_results
172 | }, f, protocol=pickle.HIGHEST_PROTOCOL)
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/sketchgraphs/data/dof.py:
--------------------------------------------------------------------------------
1 | """This module contains the implementation and data for heuristic degrees of freedom computation."""
2 |
3 | import numpy as np
4 |
5 | from ._entity import EntityType, SubnodeType
6 | from ._constraint import ConstraintType
7 | from .sequence import EdgeOp, NodeOp
8 |
9 |
10 | NODE_DOF = {
11 | EntityType.Point: 2,
12 | EntityType.Line: 4,
13 | EntityType.Circle: 3,
14 | EntityType.Arc: 5,
15 | }
16 |
17 | EDGE_DOF_REMOVED = {
18 | ConstraintType.Angle: {
19 | (EntityType.Line, EntityType.Line): 1},
20 | ConstraintType.Centerline_Dimension: {
21 | (EntityType.Arc, EntityType.Line): 0,
22 | (EntityType.Circle, EntityType.Line): 0,
23 | (EntityType.Line, EntityType.Line): 0,
24 | (EntityType.Line, EntityType.Point): 0},
25 | ConstraintType.Coincident: {
26 | (EntityType.Arc, EntityType.Arc): 3,
27 | (EntityType.Arc, EntityType.Circle): 3,
28 | (EntityType.Arc, EntityType.Point): 1,
29 | (EntityType.Circle, EntityType.Circle): 3,
30 | (EntityType.Circle, EntityType.Point): 1,
31 | (EntityType.Line, EntityType.Line): 2,
32 | (EntityType.Line, EntityType.Point): 1,
33 | (EntityType.Point, EntityType.Point): 2},
34 | ConstraintType.Concentric: {
35 | (EntityType.Arc, EntityType.Arc): 2,
36 | (EntityType.Arc, EntityType.Circle): 2,
37 | (EntityType.Arc, EntityType.Point): 2,
38 | (EntityType.Circle, EntityType.Circle): 2,
39 | (EntityType.Circle, EntityType.Point): 2,
40 | (EntityType.Point, EntityType.Point): 2},
41 | ConstraintType.Diameter: {
42 | (EntityType.Arc, EntityType.Arc): 1,
43 | (EntityType.Circle, EntityType.Circle):1},
44 | ConstraintType.Distance: {
45 | (EntityType.Arc, EntityType.Arc): 1,
46 | (EntityType.Arc, EntityType.Circle): 1,
47 | (EntityType.Arc, EntityType.Line): 1,
48 | (EntityType.Arc, EntityType.Point): 1,
49 | (EntityType.Circle, EntityType.Circle): 1,
50 | (EntityType.Circle, EntityType.Line): 1,
51 | (EntityType.Circle, EntityType.Point): 1,
52 | (EntityType.Line, EntityType.Line): 1,
53 | (EntityType.Line, EntityType.Point): 1,
54 | (EntityType.Point, EntityType.Point): 1},
55 | ConstraintType.Equal: {
56 | (EntityType.Arc, EntityType.Arc): 1,
57 | (EntityType.Arc, EntityType.Circle): 1,
58 | (EntityType.Circle, EntityType.Circle): 1,
59 | (EntityType.Line, EntityType.Line): 1},
60 | ConstraintType.Fix: {
61 | (EntityType.Arc, EntityType.Arc): 3,
62 | (EntityType.Circle, EntityType.Circle): 3,
63 | (EntityType.Line, EntityType.Line): 2,
64 | (EntityType.Point, EntityType.Point): 2},
65 | ConstraintType.Horizontal: {
66 | (EntityType.Line, EntityType.Line): 1,
67 | (EntityType.Point, EntityType.Point): 1},
68 | ConstraintType.Intersected: {},
69 | ConstraintType.Length: {
70 | (EntityType.Arc, EntityType.Arc): 1,
71 | (EntityType.Line, EntityType.Line): 1},
72 | ConstraintType.Midpoint: {
73 | (EntityType.Arc, EntityType.Point): 2,
74 | (EntityType.Line, EntityType.Point): 2},
75 | ConstraintType.Normal: {
76 | (EntityType.Arc, EntityType.Line): 1,
77 | (EntityType.Circle, EntityType.Line): 1},
78 | ConstraintType.Offset: {
79 | (EntityType.Arc, EntityType.Arc): 2,
80 | (EntityType.Arc, EntityType.Circle): 2,
81 | (EntityType.Circle, EntityType.Circle): 2,
82 | (EntityType.Line, EntityType.Line): 1},
83 | ConstraintType.Parallel: {
84 | (EntityType.Line, EntityType.Line): 1},
85 | ConstraintType.Perpendicular: {
86 | (EntityType.Line, EntityType.Line): 1},
87 | ConstraintType.Radius: {
88 | (EntityType.Arc, EntityType.Arc): 1,
89 | (EntityType.Circle, EntityType.Circle): 1},
90 | ConstraintType.Subnode: {
91 | (EntityType.Arc, EntityType.Point): 0,
92 | (EntityType.Circle, EntityType.Point): 0,
93 | (EntityType.Line, EntityType.Point): 0},
94 | ConstraintType.Tangent: {
95 | (EntityType.Arc, EntityType.Arc): 1,
96 | (EntityType.Arc, EntityType.Circle): 1,
97 | (EntityType.Arc, EntityType.Line): 1,
98 | (EntityType.Circle, EntityType.Circle): 1,
99 | (EntityType.Circle, EntityType.Line): 1},
100 | ConstraintType.Vertical: {
101 | (EntityType.Line, EntityType.Line): 1,
102 | (EntityType.Point, EntityType.Point): 1}
103 | }
104 |
105 |
106 | def get_node_label_for_dof(label) -> EntityType:
107 | """Get the node label to be used for degrees of freedom computation.
108 |
109 | When computing DOF, subnode labels are merged to point as they represent points.
110 |
111 | Parameters
112 | ----------
113 | label : Union[EntityType, SubnodeType]
114 | The original node label
115 |
116 | Returns
117 | -------
118 | EntityType
119 | The corresponding `EntityType` to be used for DOF computation.
120 | """
121 | return EntityType.Point if isinstance(label, SubnodeType) else label
122 |
123 |
124 | def _get_dof_removed_for_edge(edge_op: EdgeOp, nodes):
125 | ref_types = [get_node_label_for_dof(nodes[r].label) for r in edge_op.references]
126 |
127 | if len(ref_types) == 1:
128 | ref_types = ref_types + ref_types
129 | elif len(ref_types) > 2:
130 | return 0
131 | t1, t2 = ref_types
132 |
133 | if EntityType.External in ref_types:
134 | return 0
135 |
136 |
137 | dof_dict = EDGE_DOF_REMOVED.get(edge_op.label)
138 | if dof_dict is None:
139 | return 0
140 |
141 | if (t1, t2) in dof_dict:
142 | return dof_dict[(t1, t2)]
143 | if (t2, t1) in dof_dict:
144 | return dof_dict[(t2, t1)]
145 |
146 | return 0
147 |
148 |
149 | def get_sequence_dof(seq):
150 | """Returns array of total DoF contribution from each op.
151 |
152 | Parameters
153 | ----------
154 | seq : Iterable of NodeOp or EdgeOp
155 | The construction sequence of the sketch to analyze
156 |
157 | Returns
158 | -------
159 | np.ndarray
160 | An integer array representing the number of degrees of freedom lost or gained
161 | at each construction step.
162 | """
163 | out = np.zeros(len(seq), dtype=np.int32)
164 | nodes = []
165 | for i, op in enumerate(seq):
166 | if isinstance(op, NodeOp):
167 | nodes.append(op)
168 | out[i] = NODE_DOF.get(op.label, 0)
169 | elif isinstance(op, EdgeOp):
170 | out[i] = -1*_get_dof_removed_for_edge(op, nodes)
171 | return out
172 |
173 |
174 | __all__ = ['get_sequence_dof', 'get_node_label_for_dof']
175 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/model/numerical_features.py:
--------------------------------------------------------------------------------
1 | """ This module contains the components of the graph model that are concerned with handling
2 | numerical features associated with edges and nodes.
3 |
4 | """
5 |
6 | import torch
7 |
8 | from sketchgraphs_models import nn as sg_nn
9 |
10 |
11 | class NumericalFeatureEncoding(torch.nn.Module):
12 | """Encode an array of numerical features.
13 |
14 | This encodes a sequence of features (presented as a sequence of integers)
15 | into a sequence of vector through an embedding.
16 | """
17 | def __init__(self, feature_dims, embedding_dim):
18 | super(NumericalFeatureEncoding, self).__init__()
19 | self.feature_dims = list(feature_dims)
20 | self.register_buffer(
21 | 'feature_offsets',
22 | torch.cumsum(torch.tensor([0] + self.feature_dims[:-1], dtype=torch.int64), dim=0))
23 | self.embeddings = torch.nn.Embedding(
24 | sum(feature_dims), embedding_dim, sparse=False)
25 |
26 | def forward(self, features):
27 | return self.embeddings(features + self.feature_offsets)
28 |
29 |
30 | class NumericalFeaturesEmbedding(torch.nn.Module):
31 | """Transform a sequence of numerical feature vectors into a single vector.
32 |
33 | Currently, this module simply aggregates the features by averaging, although more
34 | elaborate aggregation schemes (e.g. RNN) could be chosen.
35 | """
36 | def __init__(self, embedding_dim):
37 | super(NumericalFeaturesEmbedding, self).__init__()
38 | self.embedding_dim = embedding_dim
39 |
40 | def forward(self, embeddings):
41 | return embeddings.mean(axis=-2)
42 |
43 |
44 | class NumericalFeatureDecoder(torch.nn.Module):
45 | """Module for decoding numerical feature embeddings to logits.
46 |
47 | Takes an array of features, and decodes them into a a sequence of varying length logits.
48 | """
49 | def __init__(self, feature_dims, embedding_dim):
50 | super(NumericalFeatureDecoder, self).__init__()
51 | self.feature_dims = feature_dims
52 | self.linear_logistic = torch.nn.ModuleList([
53 | torch.nn.Linear(embedding_dim, fd) for fd in feature_dims
54 | ])
55 |
56 | def forward(self, embeddings):
57 | logits = []
58 |
59 | for i, linear in enumerate(self.linear_logistic):
60 | logits.append(linear(embeddings[i]))
61 |
62 | return torch.cat(logits, dim=-1)
63 |
64 |
65 | class NumericalFeatureReadout(torch.nn.Module):
66 | """Module for numerical feature readout.
67 |
68 | This module is responsible for producing numerical edge features from the edge label
69 | and computed embeddings.
70 | """
71 | def __init__(self, initial_input, feature_encoders, feature_decoders, sequence_model):
72 | """Creates a new edge feature readout model.
73 |
74 | Parameters
75 | ----------
76 | initial_input : torch.nn.Module
77 | A module which creates the embedding for the first input slot based on the passed in data.
78 | It is called with all remaining arguments to the forward method.
79 |
80 | feature_encoders : torch.nn.Module
81 | A module which encodes the features from the edge for feeding in to network.
82 | It is called with an integer tensor of size `[batch, num_features]`.
83 |
84 | feature_decoders : torch.nn.Module
85 | A module which decodes sequence embeddings into an array of logits.
86 |
87 | sequence_model : torch.nn.Module
88 | The main computational module for this instance, transforms a sequence
89 | of embedding into another sequence of embedding.
90 | """
91 | super(NumericalFeatureReadout, self).__init__()
92 | self.encoders = feature_encoders
93 | self.decoders = feature_decoders
94 | self.sequence_model = sequence_model
95 | self.initial_input = initial_input
96 |
97 |
98 | def forward(self, input_features, *args):
99 | initial_input = self.initial_input(*args)
100 | input_embeddings = self.encoders(input_features).permute(1, 0, 2)
101 |
102 | input_sequence = torch.cat((initial_input.expand(1, -1, -1), input_embeddings), dim=0)
103 |
104 | output_sequence, _ = self.sequence_model(input_sequence)
105 |
106 | return self.decoders(output_sequence[:-1])
107 |
108 |
109 | def edge_decoder_initial_input(embedding_size):
110 | """Initial input function for edge readouts."""
111 | return sg_nn.Sequential(
112 | sg_nn.ConcatenateLinear(embedding_size, embedding_size, embedding_size),
113 | torch.nn.ReLU(),
114 | torch.nn.Linear(embedding_size, embedding_size))
115 |
116 |
117 | def entity_decoder_initial_input(embedding_size):
118 | """Initial input function for entity readouts."""
119 | return torch.nn.Identity(embedding_size=embedding_size)
120 |
121 |
122 | def make_embedding_and_readout(embedding_size: int, feature_dimensions, initial_input_factory):
123 | """Creates feature embedding and readout networks for the given features.
124 |
125 | Parameters
126 | ----------
127 | embedding_size : int
128 | Dimension of the embeddings to use.
129 | feature_dimensions : dict
130 | Dictionary whose values are lists of integers corresponding to the number of outcomes
131 | for each feature.
132 | initial_input_factory : int -> torch.nn.Module
133 | A function which returns a module responsible for transforming the inputs of the readout
134 | into initial embeddings for the internal sequence model.
135 |
136 | Returns
137 | -------
138 | feature_embeddings : dict
139 | A dictionary containing the feature embedding modules.
140 | feature_readouts : dict
141 | A dictionary containing the feature readout modules.
142 | """
143 |
144 | feature_encodings = {
145 | k: NumericalFeatureEncoding(dimensions.values(), embedding_size)
146 | for k, dimensions in feature_dimensions.items()
147 | }
148 |
149 | feature_embeddings = {
150 | k.name: torch.nn.Sequential(
151 | encoding, NumericalFeaturesEmbedding(embedding_size)
152 | )
153 | for k, encoding in feature_encodings.items()
154 | }
155 |
156 | feature_readouts = {
157 | k.name: NumericalFeatureReadout(
158 | initial_input=initial_input_factory(embedding_size),
159 | feature_encoders=feature_encodings[k],
160 | feature_decoders=NumericalFeatureDecoder(dimensions.values(), embedding_size),
161 | sequence_model=torch.nn.GRU(embedding_size, embedding_size))
162 | for k, dimensions in feature_dimensions.items()
163 | }
164 |
165 | return feature_embeddings, feature_readouts
166 |
--------------------------------------------------------------------------------
/sketchgraphs_models/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """This module provides utilities and generic build blocks for graph neural networks."""
2 |
3 | import contextlib
4 | import torch
5 |
6 | def autograd_range(name):
7 | """ Creates an autograd range for pytorch autograd profiling
8 | """
9 | return torch.autograd.profiler.record_function(name)
10 |
11 |
12 | def aggregate_by_incidence(values: torch.Tensor, incidence: torch.Tensor,
13 | transform_edge_messages=None, transform_edge_messages_args=None,
14 | output_size=None):
15 | """Aggregates values according to an incidence matrix.
16 |
17 | Effectively computes the following operation:
18 |
19 | .. code-block:: python
20 |
21 | output[i] = values[incidence[1, incidence[0] == i]].sum(axis=0)
22 |
23 | This operation essentially implements a sparse-matrix multiplication in coo format in a naive way.
24 | Optimization opportunity: write using actual cuSparse.
25 |
26 | Parameters
27 | ----------
28 | values : torch.Tensor
29 | A tensor of rank at least 2
30 | incidence : torch.Tensor
31 | a `[2, k]` tensor
32 | transform_edge_messages : function, optional
33 | an arbitrary function which transforms edge messages.
34 | transform_edge_messages_args : any
35 | Arbitrary set of arguments that are passed to the `transform_edge_messages` function.
36 | output_size : List[int], optional
37 | if not `None`, the size of the output tensor. Otherwise, we assume the output tensor
38 | is the same size as `values`.
39 |
40 | Returns
41 | -------
42 | torch.Tensor
43 | The output tensor, of the same rank as values.
44 | """
45 | if output_size is None:
46 | output_size = values.shape[0]
47 |
48 | with autograd_range('broadcast_messages'):
49 | # broadcast node values to edge messages
50 | edge_messages = values.index_select(0, incidence[1])
51 |
52 | if transform_edge_messages is not None:
53 | # apply transformation to edge messages if necessary.
54 | if transform_edge_messages_args is None:
55 | transform_edge_messages_args = tuple()
56 |
57 | with autograd_range('transform_messages'):
58 | edge_messages = transform_edge_messages(edge_messages, *transform_edge_messages_args)
59 |
60 | with autograd_range('aggregate_messages'):
61 | # collect edge messages into node values
62 | output = values.new_zeros([output_size] + list(edge_messages.shape[1:]))
63 | output.index_add_(0, incidence[0], edge_messages)
64 |
65 | return output
66 |
67 |
68 | class MessagePassingNetwork(torch.nn.Module):
69 | """ Custom configurable message-passing network.
70 |
71 | This class implements the main plumbing for a message passing network.
72 | but exposes points that can be configured to easily create different variants of the networks.
73 | """
74 | def __init__(self, depth, message_aggregation_network, transform_edge_messages=None):
75 | """ Creates a new module representing the message passing network.
76 | Parameters
77 | ----------
78 | depth : int
79 | number of message passing iterations to execute.
80 | message_aggregation_network : torch.nn.Module
81 | A module representing the model used to compute the embeddings
82 | to be used at the next step. This model receives the array
83 | of messages corresponding to the sum of the propagated messages, and the array
84 | of previous node embeddings.
85 | transform_edge_messages : torch.nn.Module
86 | A module representing the model used to transform
87 | edge messages at each step. See `aggregate_by_incidence`.
88 | """
89 | super(MessagePassingNetwork, self).__init__()
90 |
91 | self.depth = depth
92 | self.message_aggregation_network = message_aggregation_network
93 | self.transform_edge_messages = transform_edge_messages
94 |
95 |
96 | __constants__ = ["depth"]
97 |
98 | def forward(self, node_embedding, incidence, edge_transform_args=None):
99 | """Forward function for the message passing network.
100 |
101 | Parameters
102 | ----------
103 | node_embedding : torch.Tensor
104 | Tensor of shape `[num_nodes, ...]` representing the data at each node in the graph.
105 | incidence : torch.Tensor
106 | tensor of shape `[2, num_edges]` representing edge incidence in the graph
107 | edge_transform_args : any
108 | A tuple of further arguments to be passed to the edge transformation network.
109 |
110 | Returns
111 | -------
112 | torch.Tensor
113 | The final node embedding values after the message passing has been carried out.
114 | """
115 | # Compute message passing along edges
116 | with autograd_range("propagate_messages"):
117 | for _ in range(self.depth):
118 | activation = aggregate_by_incidence(
119 | node_embedding, incidence, self.transform_edge_messages, edge_transform_args)
120 | node_embedding = self.message_aggregation_network(activation, node_embedding)
121 |
122 | return node_embedding
123 |
124 |
125 | class ConcatenateLinear(torch.nn.Module):
126 | """A torch module which concatenates several inputs and mixes them using a linear layer. """
127 | def __init__(self, left_size, right_size, output_size):
128 | """Creates a new concatenating linear layer.
129 |
130 | Parameters
131 | ----------
132 | left_size : int
133 | Size of the left input
134 | right_size : int
135 | Size of the right input
136 | output_size : int
137 | Size of the output.
138 | """
139 | super(ConcatenateLinear, self).__init__()
140 |
141 | self.left_size = left_size
142 | self.right_size = right_size
143 | self.output_size = output_size
144 |
145 | self._linear = torch.nn.Linear(left_size + right_size, output_size)
146 |
147 | def forward(self, left, right):
148 | return self._linear(torch.cat((left, right), dim=-1))
149 |
150 |
151 | class Sequential(torch.nn.Module):
152 | """ Similar to `torch.nn.Sequential`, except can pass through modules which
153 | take multiple input arguments, and return tuples.
154 | """
155 | def __init__(self, *args):
156 | super(Sequential, self).__init__()
157 | self._sequence_modules = torch.nn.ModuleList(args)
158 |
159 | def forward(self, *args):
160 | for module in self._sequence_modules:
161 | if not isinstance(args, (list, tuple)):
162 | args = [args]
163 | args = module(*args)
164 |
165 | return args
166 |
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/train/data_loading.py:
--------------------------------------------------------------------------------
1 | """This module contains the main functions used to load the required data from disk for training."""
2 |
3 | import functools
4 | import gzip
5 | import pickle
6 | import os
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from sketchgraphs_models import distributed_utils
12 | from sketchgraphs_models.nn import data_util
13 | from sketchgraphs_models.graph import dataset
14 |
15 | from sketchgraphs.data import flat_array
16 |
17 |
18 | def load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features=True, edge_features=True):
19 | data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r'))
20 |
21 | if auxiliary_file is None:
22 | root, _ = os.path.splitext(dataset_file)
23 | auxiliary_file = root + '.stats.pkl.gz'
24 |
25 | if entity_features or edge_features:
26 | with gzip.open(auxiliary_file, 'rb') as f:
27 | auxiliary_dict = pickle.load(f)
28 |
29 | if entity_features:
30 | entity_feature_mapping = dataset.EntityFeatureMapping(auxiliary_dict['node'])
31 | else:
32 | entity_feature_mapping = None
33 |
34 | seqs = data['sequences']
35 | weights = data['sequence_lengths']
36 |
37 | if edge_features:
38 | if isinstance(quantization['angle'], dataset.QuantizationMap):
39 | angle_map = quantization['angle']
40 | else:
41 | angle_map = dataset.QuantizationMap.from_counter(auxiliary_dict['edge']['angle'], quantization['angle'])
42 |
43 | if isinstance(quantization['length'], dataset.QuantizationMap):
44 | length_map = quantization['length']
45 | else:
46 | length_map = dataset.QuantizationMap.from_counter(auxiliary_dict['edge']['length'], quantization['length'])
47 | edge_feature_mapping = dataset.EdgeFeatureMapping(angle_map, length_map)
48 | else:
49 | edge_feature_mapping = None
50 |
51 | return {
52 | 'sequences': seqs.share_memory_(),
53 | 'entity_feature_mapping': entity_feature_mapping,
54 | 'edge_feature_mapping': edge_feature_mapping,
55 | 'weights': weights
56 | }
57 |
58 |
59 | def load_dataset_and_weights_with_mapping(dataset_file, node_feature_mapping, edge_feature_mapping, seed=None):
60 | data = flat_array.load_dictionary_flat(np.load(dataset_file, mmap_mode='r'))
61 | seqs = data['sequences']
62 | seqs.share_memory_()
63 |
64 | ds = dataset.GraphDataset(seqs, node_feature_mapping, edge_feature_mapping, seed)
65 |
66 | return ds, data['sequence_lengths']
67 |
68 |
69 | def load_dataset_and_weights(dataset_file, auxiliary_file, quantization, seed=None,
70 | entity_features=True, edge_features=True, force_entity_categorical_features=False):
71 | data = load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features, edge_features)
72 |
73 | if data['entity_feature_mapping'] is None and force_entity_categorical_features:
74 | # Create an entity mapping which only computes the categorical features (i.e. isConstruction and clockwise)
75 | data['entity_feature_mapping'] = dataset.EntityFeatureMapping()
76 |
77 | return dataset.GraphDataset(
78 | data['sequences'], data['entity_feature_mapping'], data['edge_feature_mapping'], seed=seed), data['weights']
79 |
80 |
81 | def make_dataloader_train(collate_fn, ds_train, weights, batch_size, num_epochs, num_workers, distributed_config=None):
82 | sampler = torch.utils.data.WeightedRandomSampler(
83 | weights, len(weights), replacement=True)
84 |
85 | if distributed_config is not None:
86 | sampler = distributed_utils.DistributedSampler(
87 | sampler, distributed_config.world_size, distributed_config.rank)
88 |
89 | batch_sampler = torch.utils.data.BatchSampler(
90 | sampler, batch_size, drop_last=False)
91 |
92 | dataloader_train = torch.utils.data.DataLoader(
93 | ds_train,
94 | collate_fn=collate_fn,
95 | batch_sampler=data_util.MultiEpochSampler(batch_sampler, num_epochs),
96 | num_workers=num_workers,
97 | pin_memory=True)
98 |
99 | batches_per_epoch = len(batch_sampler)
100 |
101 | return dataloader_train, batches_per_epoch
102 |
103 |
104 | def _make_dataloader_eval(ds_eval, weights, batch_size, num_workers, distributed_config=None):
105 | sampler = torch.utils.data.WeightedRandomSampler(
106 | weights, len(weights), replacement=True)
107 |
108 | if distributed_config is not None:
109 | sampler = distributed_utils.DistributedSampler(
110 | sampler, distributed_config.world_size, distributed_config.rank)
111 |
112 | dataloader_eval = torch.utils.data.DataLoader(
113 | ds_eval,
114 | collate_fn=functools.partial(
115 | dataset.collate,
116 | entity_feature_mapping=ds_eval.node_feature_mapping,
117 | edge_feature_mapping=ds_eval.edge_feature_mapping),
118 | sampler=sampler,
119 | batch_size=batch_size,
120 | num_workers=num_workers,
121 | pin_memory=True)
122 |
123 | return dataloader_eval
124 |
125 |
126 | def initialize_datasets(args, distributed_config: distributed_utils.DistributedTrainingInfo = None):
127 | """Initialize datasets and dataloaders.
128 |
129 | Parameters
130 | ----------
131 | args : dict
132 | Dictionary containing all the dataset configurations.
133 |
134 | distributed_config : distributed_utils.DistributedTrainingInfo, optional
135 | If not None, configuration options for distributed training.
136 |
137 | Returns
138 | -------
139 | torch.data.utils.Dataloader
140 | Training dataloader
141 | torch.data.utils.Dataloader
142 | If not None, testing dataloader
143 | int
144 | Number of batches per training epoch
145 | dataset.EntityFeatureMapping
146 | Feature mapping in use for entities
147 | dataset.EdgeFeatureMapping
148 | Feature mapping in use for constraints
149 | """
150 | quantization = {'angle': args['num_quantize_angle'], 'length': args['num_quantize_length']}
151 |
152 | dataset_train_path = args['dataset_train']
153 | auxiliary_path = args['dataset_auxiliary']
154 |
155 | ds_train, weights_train = load_dataset_and_weights(
156 | dataset_train_path, auxiliary_path, quantization, args['seed'],
157 | not args.get('disable_entity_features', False), not args.get('disable_edge_features', False),
158 | args.get('force_entity_categorical_features', False))
159 |
160 | batch_size = args['batch_size']
161 | num_workers = args['num_workers']
162 |
163 | if distributed_config:
164 | batch_size = batch_size // distributed_config.world_size
165 | num_workers = num_workers // distributed_config.world_size
166 |
167 | collate_fn = functools.partial(
168 | dataset.collate,
169 | entity_feature_mapping=ds_train.node_feature_mapping,
170 | edge_feature_mapping=ds_train.edge_feature_mapping)
171 |
172 | dl_train, batches_per_epoch = make_dataloader_train(
173 | collate_fn, ds_train, weights_train, batch_size, args['num_epochs'], num_workers, distributed_config)
174 |
175 | if args['dataset_test'] is not None:
176 | ds_test, weights_test = load_dataset_and_weights_with_mapping(
177 | args['dataset_test'], ds_train.node_feature_mapping, ds_train.edge_feature_mapping, args['seed'])
178 | dl_test = _make_dataloader_eval(
179 | ds_test, weights_test, batch_size, num_workers, distributed_config)
180 | else:
181 | dl_test = None
182 |
183 | return dl_train, dl_test, batches_per_epoch, ds_train.node_feature_mapping, ds_train.edge_feature_mapping
184 |
--------------------------------------------------------------------------------
/sketchgraphs_models/torch_extensions/segment_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 | from ._repeat_interleave import repeat_interleave
4 |
5 |
6 | def segment_op_python(values, scopes, op):
7 | if scopes.dim() != 2:
8 | raise ValueError("Scopes must be two-dimensional.")
9 |
10 | if scopes.shape[1] != 2:
11 | raise ValueError("Scopes must be of length two in second dimension.")
12 |
13 | output = torch.empty(scopes.shape[0], dtype=values.dtype, device=values.device)
14 |
15 | for i in range(scopes.shape[0]):
16 | output[i] = op(values.narrow(0, scopes[i, 0], scopes[i, 1]))
17 |
18 | return output
19 |
20 |
21 | def segment_logsumexp_python(values: torch.Tensor, scopes: torch.Tensor):
22 | return SegmentLogsumexpPython.apply(values, scopes)
23 |
24 |
25 | class SegmentLogsumexpPython(torch.autograd.Function):
26 | @staticmethod
27 | def forward(ctx, values, scopes):
28 | result = segment_op_python(values, scopes, partial(torch.logsumexp, dim=0))
29 | ctx.save_for_backward(values, result, scopes)
30 | return result
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | values, logsumexp, scopes = ctx.saved_tensors
35 | lengths = scopes.select(1, 1)
36 |
37 | return segment_logsumexp_backward_python(grad_output, values, logsumexp, lengths), None
38 |
39 |
40 | def _min_value(dtype):
41 | if dtype.is_floating_point:
42 | return torch.finfo(dtype).min
43 | else:
44 | return torch.iinfo(dtype).min
45 |
46 |
47 | def segment_argmax_loop(values, scopes):
48 | output_values = values.new_empty(scopes.shape[0])
49 | output_index = scopes.new_empty(scopes.shape[0])
50 |
51 | for i in range(scopes.shape[0]):
52 | if scopes[i, 1] != 0:
53 | output_values[i], output_index[i] = torch.max(values.narrow(0, scopes[i, 0], scopes[i, 1]), dim=0)
54 | else:
55 | output_values[i] = _min_value(output_values.dtype)
56 | output_index[i] = -1
57 |
58 | return output_values, output_index
59 |
60 |
61 | def segment_logsumexp_backward_python(grad_output, values, logsumexp, lengths):
62 | lengths = lengths.long()
63 | grad_output_repeat = repeat_interleave(grad_output, lengths, dim=0)
64 | derivative_repeat = (values - repeat_interleave(logsumexp, lengths, dim=0)).exp_()
65 | return derivative_repeat.mul_(grad_output_repeat)
66 |
67 |
68 | class SegmentLogsumexpScatter(torch.autograd.Function):
69 | @staticmethod
70 | def forward(ctx, values, scopes):
71 | import torch_scatter
72 |
73 | lengths = scopes.select(1, 1)
74 | offsets = lengths.new_zeros(len(lengths) + 1)
75 | torch.cumsum(lengths, 0, out=offsets[1:])
76 |
77 | value_segment_max, _ = torch_scatter.segment_max_csr(values, offsets)
78 | value_segment_max_expanded = torch.repeat_interleave(value_segment_max, lengths)
79 |
80 | values_exp = values.sub(value_segment_max_expanded).exp_()
81 | values_sumexp = torch_scatter.segment_sum_csr(values_exp, offsets)
82 |
83 | logsumexp = values_sumexp.log_().add_(value_segment_max)
84 |
85 | ctx.save_for_backward(values, logsumexp, lengths)
86 |
87 | return logsumexp
88 |
89 | @staticmethod
90 | def backward(ctx, grad_output):
91 | values, logsumexp, lengths = ctx.saved_tensors
92 |
93 | lse_expand = torch.repeat_interleave(logsumexp, lengths)
94 | grad_out_expand = torch.repeat_interleave(grad_output, lengths)
95 |
96 | return values.sub(lse_expand).exp_().mul_(grad_out_expand), None
97 |
98 |
99 | def segment_logsumexp_scatter(values, scopes):
100 | return SegmentLogsumexpScatter.apply(values, scopes)
101 |
102 |
103 | def segment_argmax_backward(grad_output, argmax, scopes, input_shape, sparse_grad=True):
104 | grad_idx = torch.add(argmax, scopes.select(1, 0)).unsqueeze(0)
105 | grad_input = torch.sparse_coo_tensor(grad_idx, grad_output, size=input_shape)
106 |
107 | if not sparse_grad:
108 | grad_input = grad_input.to_dense()
109 |
110 | return grad_input
111 |
112 |
113 | class SegmentArgmaxPython(torch.autograd.Function):
114 | @staticmethod
115 | def forward(ctx, values, scopes, sparse_grad=True):
116 | ctx.input_shape = values.shape
117 | ctx.sparse_grad = sparse_grad
118 |
119 | max_values, argmax = segment_argmax_loop(values, scopes)
120 | ctx.save_for_backward(argmax, scopes)
121 | ctx.mark_non_differentiable(argmax)
122 |
123 | return max_values, argmax
124 |
125 | @staticmethod
126 | def backward(ctx, grad_output, grad_output_index):
127 | argmax, scopes = ctx.saved_tensors
128 | grad_values = segment_argmax_backward(grad_output, argmax, scopes, ctx.input_shape, ctx.sparse_grad)
129 | return grad_values, None, None
130 |
131 |
132 | def segment_argmax_python(values, scopes, sparse_grad=True):
133 | return SegmentArgmaxPython.apply(values, scopes, sparse_grad)
134 |
135 |
136 | def segment_argmax_scatter(values, scopes, sparse_grad=True):
137 | import torch_scatter
138 |
139 | lengths = scopes.select(1, 1)
140 | offsets = lengths.new_zeros(len(lengths) + 1)
141 | torch.cumsum(lengths, 0, out=offsets[1:])
142 |
143 | max_val, argmax = torch_scatter.segment_max_csr(values, offsets)
144 | max_val = torch.where(lengths != 0, max_val, max_val.new_full((1,), fill_value=torch.finfo(max_val.dtype).min))
145 | argmax = argmax.sub(offsets[:-1])
146 |
147 | return max_val, argmax
148 |
149 |
150 | _segment_argmax_docstring = \
151 | """
152 | Compute the maximum value and location in each segment.
153 |
154 | This function computes, for each segment, the maximum value in the segment,
155 | and the offset of the location of the maximum value from the start of the segment.
156 |
157 | This function can handle the case where the segment is zero length, in which case
158 | the maximum is given the lowest finite value representable by the type, and the
159 | index is not defined.
160 |
161 | Parameters
162 | ----------
163 | values : torch.Tensor
164 | a 1-dimensional `torch.Tensor` representing the values.
165 | scopes : torch.Tensor
166 | a 2-dimensional integer `torch.Tensor` representing the segments. The ith segment
167 | has offset ``scopes[i, 0]`` and length ``scopes[i, 1]``.
168 |
169 | Returns
170 | -------
171 | torch.Tensor
172 | A tensor of the same type as `values` representing the maximum in each segment.
173 | torch.Tensor
174 | An integer tensor representing the location of the maximum in ecah segment.
175 | """
176 |
177 | segment_argmax_scatter.__docstring__ = _segment_argmax_docstring
178 | segment_argmax_python.__docstring__ = _segment_argmax_docstring
179 |
180 | _segment_logsumexp_docstring = \
181 | """
182 | Compute the log-sum-exp in each segment.
183 |
184 | This function computes, for each segment, the log-sum-exp value in
185 | a numerically stable fashion.
186 |
187 | Parameters
188 | ----------
189 | values : torch.Tensor
190 | A 1-dimensional tensor representing the values.
191 | scopes : torch.Tensor
192 | A 2-dimensional integer tensor representing the segments. The ith segment
193 | has offset ``scopes[i, 0]`` and length ``scopes[i, 1]``.
194 |
195 | Returns
196 | -------
197 | torch.Tensor
198 | A tensor representing the log-sum-exp value for each segment.
199 | """
200 |
201 | segment_logsumexp_scatter.__docstring__ = _segment_logsumexp_docstring
202 | segment_logsumexp_python.__docstring__ = _segment_logsumexp_docstring
203 |
204 | try:
205 | import torch_scatter
206 | segment_logsumexp = segment_logsumexp_scatter
207 | segment_argmax = segment_argmax_scatter
208 |
209 | except ImportError:
210 | segment_logsumexp = segment_logsumexp_python
211 | segment_argmax = segment_argmax_python
212 |
213 | __all__ = ['segment_logsumexp', 'segment_argmax']
214 |
--------------------------------------------------------------------------------
/sketchgraphs/pipeline/make_sketch_dataset.py:
--------------------------------------------------------------------------------
1 | """This script is responsible for creating the main sketch dataset from the JSON files.
2 |
3 | Some basic filtering is applied in order to exclude empty sketches. However, as this dataset
4 | is intended to capture the original data, further filtering is left to scripts such as
5 | `sketchgraphs.pipeline.make_sequence_dataset` which process dataset for learning.
6 |
7 | """
8 |
9 | import argparse
10 | import collections
11 | import glob
12 | import gzip
13 | import itertools
14 | import json
15 | import multiprocessing as mp
16 | import tarfile
17 | import traceback
18 | import os
19 |
20 | import numpy as np
21 | import tqdm
22 | import zstandard as zstd
23 |
24 | from sketchgraphs.data.sketch import Sketch
25 | from sketchgraphs.data import flat_array
26 |
27 |
28 | def _load_json(path):
29 | open_ = gzip.open if path.endswith('gz') else open
30 | with open_(path) as fh:
31 | return json.load(fh)
32 |
33 |
34 | def filter_sketch(sketch: Sketch):
35 | """Basic filtering which excludes empty sketches, or sketches with no constraints."""
36 | return len(sketch.constraints) == 0 or len(sketch.entities) == 0
37 |
38 |
39 | def parse_sketch_id(filename):
40 | basename = os.path.basename(filename)
41 | while '.' in basename:
42 | basename, _ = os.path.splitext(basename)
43 |
44 | document_id, part_id = basename.split('_')
45 | return document_id, int(part_id)
46 |
47 |
48 | def load_json_tarball(path):
49 | """Loads a json tarball as an iterable of sketches.
50 |
51 | Parameters
52 | ----------
53 | path : str
54 | A path to the location of a single shard
55 |
56 | Returns
57 | -------
58 | iterable of `Sketch`
59 | An iterable of `Sketch` representing all the sketches present in the tarball.
60 | """
61 | with open(path, 'rb') as base_file:
62 | dctx = zstd.ZstdDecompressor()
63 | with dctx.stream_reader(base_file) as tarball:
64 | with tarfile.open(fileobj=tarball, mode='r|') as directory:
65 | while True:
66 | json_file = directory.next()
67 | if json_file is None:
68 | break
69 |
70 | if not json_file.isfile():
71 | continue
72 |
73 | document_id, part_id = parse_sketch_id(json_file.name)
74 | data = directory.extractfile(json_file).read()
75 | if len(data) == 0:
76 | # skip empty files
77 | continue
78 |
79 | try:
80 | sketches_json = json.loads(data)
81 | except json.JSONDecodeError as exc:
82 | raise ValueError('Error decoding JSON for document {0} part {1}.'.format(document_id, part_id))
83 | for i, sketch_json in enumerate(sketches_json):
84 | yield (document_id, part_id, i), Sketch.from_fs_json(sketch_json)
85 |
86 |
87 |
88 | def _worker(paths_queue, processed_sketches, max_sketches, sketch_counter):
89 | num_filtered = 0
90 | num_invalid = 0
91 |
92 | while max_sketches is None or sketch_counter.value < max_sketches:
93 | paths = paths_queue.get()
94 |
95 | if paths is None:
96 | break
97 |
98 | sketches = []
99 |
100 | for path in paths:
101 | sketch_list = _load_json(path)
102 |
103 | for sketch_json in sketch_list:
104 | try:
105 | sketch = Sketch.from_fs_json(sketch_json)
106 | except Exception as err:
107 | num_invalid += 1
108 | print('Error processing sketch in file {0}'.format(path))
109 | traceback.print_exception(type(err), err, err.__traceback__)
110 |
111 | if filter_sketch(sketch):
112 | num_filtered += 1
113 | continue
114 |
115 | sketches.append(sketch)
116 |
117 | offsets, data = flat_array.raw_list_flat(sketches)
118 |
119 | processed_sketches.put((offsets, data))
120 |
121 | with sketch_counter.get_lock():
122 | sketch_counter.value += len(sketches)
123 |
124 | processed_sketches.put({
125 | 'num_filtered': num_filtered,
126 | 'num_invalid': num_invalid
127 | })
128 |
129 |
130 | def process(paths, threads, max_sketches=None):
131 | path_queue = mp.Queue()
132 | sketch_queue = mp.Queue()
133 | sketch_counter = mp.Value('q', 0)
134 |
135 | # Enqueue all the objects
136 | print('Enqueueing files to process.')
137 | paths_it = iter(paths)
138 | while True:
139 | path_chunk = list(itertools.islice(paths_it, 128))
140 | if len(path_chunk) == 0:
141 | break
142 |
143 | path_queue.put_nowait(path_chunk)
144 |
145 | workers = []
146 |
147 | for _ in range(threads or mp.cpu_count()):
148 | workers.append(
149 | mp.Process(
150 | target=_worker,
151 | args=(path_queue, sketch_queue, max_sketches, sketch_counter)))
152 |
153 | for worker in workers:
154 | path_queue.put_nowait(None)
155 | worker.start()
156 |
157 | active_workers = len(workers)
158 |
159 | offsets_arrays = []
160 | data_arrays = []
161 |
162 | statistics = collections.Counter()
163 |
164 | # Read-in data
165 | with tqdm.tqdm(total=len(paths)) as pbar:
166 | while active_workers > 0:
167 | result = sketch_queue.get()
168 |
169 | if isinstance(result, dict):
170 | statistics += collections.Counter(result)
171 | active_workers -= 1
172 | continue
173 |
174 | offsets, data = result
175 | offsets_arrays.append(offsets)
176 | data_arrays.append(data)
177 |
178 | pbar.update(128)
179 |
180 | # Finalize workers
181 | for worker in workers:
182 | worker.join()
183 |
184 | # Merge final flat array
185 | all_offsets, all_data = flat_array.merge_raw_list(offsets_arrays, data_arrays)
186 | total_sketches = len(all_offsets) - 1
187 | del offsets_arrays
188 | del data_arrays
189 |
190 | # Pack as required
191 | flat_data = flat_array.pack_list_flat(all_offsets, all_data)
192 | del all_offsets
193 | del all_data
194 |
195 | print('Done processing data.\nProcessed sketches: {0}'.format(total_sketches))
196 | print('Filtered sketches: {0}'.format(statistics['num_filtered']))
197 | print('Invalid sketches: {0}'.format(statistics['num_invalid']))
198 | return flat_data
199 |
200 |
201 | def gather_sorted_paths(patterns):
202 | if isinstance(patterns, str):
203 | patterns = [patterns]
204 | out = []
205 | for pattern in patterns:
206 | out.extend(glob.glob(pattern))
207 | out.sort()
208 | return out
209 |
210 |
211 | def main():
212 | parser = argparse.ArgumentParser(description='Process json files to create sketch dataset')
213 | parser.add_argument('--glob_pattern', required=True, action='append',
214 | help='Glob pattern(s) for json / json.gz files.')
215 | parser.add_argument('--output_path', required=True, help='Path for output file.')
216 | parser.add_argument('--max_files', type=int, help='Max number of json files to consider.')
217 | parser.add_argument('--max_sketches', type=int, help='Maximum number of sketches to consider.')
218 | parser.add_argument('--num_threads', type=int, default=0, help='Number of multiprocessing workers.')
219 |
220 | args = parser.parse_args()
221 |
222 | print('Globbing for sketch files to include.')
223 | paths = gather_sorted_paths(args.glob_pattern)
224 | print('Found %i files.' % len(paths))
225 | if args.max_files is not None:
226 | paths = paths[:args.max_files]
227 |
228 | result = process(paths, args.num_threads, args.max_sketches)
229 |
230 | print('Saving data to {0}'.format(args.output_path))
231 | np.save(args.output_path, result)
232 |
233 |
234 | if __name__ == '__main__':
235 | main()
236 |
--------------------------------------------------------------------------------
/sketchgraphs_models/autoconstraint/dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset for auto-constraint model."""
2 |
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from sketchgraphs.data import sequence as datalib
9 | from sketchgraphs.pipeline.graph_model.target import NODE_TYPES, EDGE_TYPES, EDGE_TYPES_PREDICTED, NODE_IDX_MAP, EDGE_IDX_MAP
10 | from sketchgraphs.pipeline import graph_model as graph_utils
11 |
12 | from sketchgraphs_models.graph.dataset import EntityFeatureMapping, EdgeFeatureMapping, _sparse_feature_to_torch
13 |
14 |
15 | def _reindex_sparse_batch(sparse_batch, pack_batch_offsets):
16 | return graph_utils.SparseFeatureBatch(
17 | pack_batch_offsets[sparse_batch.index],
18 | sparse_batch.value)
19 |
20 |
21 | def collate(batch):
22 | # Sort batch for packing
23 | node_lengths = [len(x['node_features']) for x in batch]
24 | sorted_indices = np.argsort(node_lengths)[::-1].copy()
25 |
26 | batch = [batch[i] for i in sorted_indices]
27 |
28 | graph = graph_utils.GraphInfo.merge(*[x['graph'] for x in batch])
29 | edge_label = torch.tensor(
30 | [x['target_edge_label'] for x in batch if x['target_edge_label'] != -1], dtype=torch.int64)
31 | node_features = torch.nn.utils.rnn.pack_sequence([x['node_features'] for x in batch])
32 | batch_offsets = graph_utils.offsets_from_counts(node_features.batch_sizes)
33 |
34 | node_features_graph_index = torch.cat([
35 | i + batch_offsets[:graph.node_counts[i]] for i in range(len(batch))
36 | ], dim=0)
37 |
38 | sparse_node_features = {}
39 |
40 | for k in batch[0]['sparse_node_features']:
41 | sparse_node_features[k] = graph_utils.SparseFeatureBatch.merge(
42 | [_reindex_sparse_batch(x['sparse_node_features'][k], batch_offsets) for x in batch], range(len(batch)))
43 |
44 | last_graph_node_index = batch_offsets[graph.node_counts - 1] + torch.arange(len(graph.node_counts), dtype=torch.int64)
45 |
46 | partner_index_index = []
47 | partner_index = []
48 |
49 | stop_partner_index_index = []
50 |
51 | for i, x in enumerate(batch):
52 | if x['partner_index'] == -1:
53 | stop_partner_index_index.append(i)
54 | continue
55 |
56 | partner_index_index.append(i)
57 | partner_index.append(x['partner_index'] + graph.node_offsets[i])
58 |
59 | partner_index = graph_utils.SparseFeatureBatch(
60 | torch.tensor(partner_index_index, dtype=torch.int64),
61 | torch.tensor(partner_index, dtype=torch.int64)
62 | )
63 |
64 | stop_partner_index_index = torch.tensor(stop_partner_index_index, dtype=torch.int64)
65 |
66 | return {
67 | 'graph': graph,
68 | 'edge_label': edge_label,
69 | 'partner_index': partner_index,
70 | 'stop_partner_index_index': stop_partner_index_index,
71 | 'node_features': node_features,
72 | 'node_features_graph_index': node_features_graph_index,
73 | 'sparse_node_features': sparse_node_features,
74 | 'last_graph_node_index': last_graph_node_index,
75 | 'sorted_indices': torch.as_tensor(sorted_indices)
76 | }
77 |
78 |
79 |
80 | def process_node_and_edge_ops(node_ops, edge_ops_in_graph, num_nodes_in_graph, node_feature_mappings: Optional[EntityFeatureMapping]):
81 | all_node_labels = torch.tensor([NODE_IDX_MAP[op.label] for op in node_ops], dtype=torch.int64)
82 | edge_labels = torch.tensor([EDGE_IDX_MAP[op.label] for op in edge_ops_in_graph], dtype=torch.int64)
83 |
84 | if len(edge_ops_in_graph) > 0:
85 | incidence = torch.tensor([(op.references[0], op.references[-1]) for op in edge_ops_in_graph],
86 | dtype=torch.int64).T.contiguous()
87 | incidence = torch.cat((incidence, torch.flip(incidence, [0])), dim=1)
88 | else:
89 | incidence = torch.empty([2, 0], dtype=torch.int64)
90 |
91 | edge_features = edge_labels.repeat(2)
92 |
93 | if node_feature_mappings is not None:
94 | sparse_node_features = _sparse_feature_to_torch(node_feature_mappings.all_sparse_features(node_ops))
95 | else:
96 | sparse_node_features = None
97 |
98 | graph = graph_utils.GraphInfo.from_single_graph(incidence, None, edge_features, num_nodes_in_graph)
99 |
100 | return {
101 | 'graph': graph,
102 | 'node_features': all_node_labels,
103 | 'sparse_node_features': sparse_node_features
104 | }
105 |
106 |
107 |
108 | class AutoconstraintDataset(torch.utils.data.Dataset):
109 | def __init__(self, sequences, node_feature_mappings, seed=10):
110 | self.sequences = sequences
111 | self.node_feature_mappings = node_feature_mappings
112 | self._rng = np.random.Generator(np.random.Philox(seed))
113 |
114 | def __getitem__(self, idx):
115 | idx = idx % len(self.sequences)
116 | seq = self.sequences[idx]
117 |
118 | if not isinstance(seq[0], datalib.NodeOp):
119 | raise ValueError('First operation in sequence is not a NodeOp')
120 |
121 | if seq[-1].label != datalib.EntityType.Stop:
122 | seq.append(datalib.NodeOp(datalib.EntityType.Stop, {}))
123 |
124 | node_ops = [seq[0]]
125 | edge_ops = []
126 |
127 | num_predicted_edge_ops_per_node = []
128 | num_non_predicted_edge_ops_per_node = []
129 |
130 | predicted_edge_ops_for_current_node = 0
131 | non_predicted_edge_ops_for_current_node = 0
132 |
133 | for op in seq[1:]:
134 | if isinstance(op, datalib.NodeOp):
135 | num_predicted_edge_ops_per_node.append(predicted_edge_ops_for_current_node)
136 | num_non_predicted_edge_ops_per_node.append(non_predicted_edge_ops_for_current_node)
137 |
138 | predicted_edge_ops_for_current_node = 0
139 | non_predicted_edge_ops_for_current_node = 0
140 |
141 | node_ops.append(op)
142 | else:
143 | if op.label in EDGE_TYPES_PREDICTED:
144 | predicted_edge_ops_for_current_node += 1
145 | else:
146 | non_predicted_edge_ops_for_current_node += 1
147 |
148 | edge_ops.append(op)
149 |
150 | node_ops = node_ops[:-1]
151 |
152 | num_predicted_edge_ops_per_node = np.array(num_predicted_edge_ops_per_node, dtype=np.int64)
153 | num_non_predicted_edge_ops_per_node = np.array(num_non_predicted_edge_ops_per_node, dtype=np.int64)
154 |
155 | predicted_edge_ops_offsets = num_predicted_edge_ops_per_node.cumsum()
156 | non_predicted_edge_ops_offsets = num_non_predicted_edge_ops_per_node.cumsum()
157 |
158 | num_predicted_edge_ops = predicted_edge_ops_offsets[-1]
159 |
160 | stop_target = self._rng.uniform() < len(node_ops) / (len(node_ops) + num_predicted_edge_ops)
161 |
162 | if stop_target:
163 | target_node_idx = self._rng.integers(len(node_ops))
164 | num_nodes_in_graph = target_node_idx + 1
165 | edge_ops_in_graph = edge_ops[:predicted_edge_ops_offsets[target_node_idx] + non_predicted_edge_ops_offsets[target_node_idx]]
166 | target_edge_label = -1
167 | partner_index = -1
168 | else:
169 | target_predicted_edge_idx = self._rng.integers(num_predicted_edge_ops)
170 | target_node_idx = np.searchsorted(predicted_edge_ops_offsets, target_predicted_edge_idx, side='right')
171 | num_nodes_in_graph = target_node_idx + 1
172 |
173 | target_edge_idx = target_predicted_edge_idx + non_predicted_edge_ops_offsets[target_node_idx]
174 | target_edge = edge_ops[target_edge_idx]
175 | edge_ops_in_graph = edge_ops[:target_edge_idx]
176 | target_edge_label = EDGE_IDX_MAP[target_edge.label]
177 | partner_index = target_edge.references[-1]
178 | assert target_edge_label < len(EDGE_TYPES_PREDICTED)
179 |
180 | input_features = process_node_and_edge_ops(
181 | node_ops, edge_ops_in_graph, num_nodes_in_graph, self.node_feature_mappings)
182 |
183 | return {
184 | **input_features,
185 | 'target_edge_label': target_edge_label,
186 | 'partner_index': partner_index,
187 | }
188 |
189 | __all__ = [
190 | 'NODE_TYPES', 'EDGE_TYPES', 'EDGE_TYPES_PREDICTED', 'NODE_IDX_MAP', 'EDGE_IDX_MAP',
191 | 'EntityFeatureMapping', 'EdgeFeatureMapping', 'collate', 'AutoconstraintDataset'
192 | ]
193 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/onshape.py:
--------------------------------------------------------------------------------
1 | '''
2 | onshape
3 | ======
4 |
5 | Provides access to the Onshape REST API
6 | '''
7 |
8 | from . import utils
9 |
10 | import os
11 | import random
12 | import string
13 | import json
14 | import hmac
15 | import hashlib
16 | import base64
17 | import urllib.request
18 | import urllib.parse
19 | import urllib.error
20 | import datetime
21 | import requests
22 | from urllib.parse import urlparse
23 | from urllib.parse import parse_qs
24 |
25 | __all__ = [
26 | 'Onshape'
27 | ]
28 |
29 |
30 | class Onshape():
31 | '''
32 | Provides access to the Onshape REST API.
33 |
34 | Attributes:
35 | - stack (str): Base URL
36 | - creds (str, default='./sketchgraphs/onshape/creds/creds.json'): Credentials location
37 | - logging (bool, default=True): Turn logging on or off
38 | '''
39 |
40 | def __init__(self, stack, creds='./sketchgraphs/onshape/creds/creds.json', logging=True):
41 | '''
42 | Instantiates an instance of the Onshape class. Reads credentials from a JSON file
43 | of this format:
44 |
45 | {
46 | "http://cad.onshape.com": {
47 | "access_key": "YOUR KEY HERE",
48 | "secret_key": "YOUR KEY HERE"
49 | },
50 | etc... add new object for each stack to test on
51 | }
52 |
53 | The creds.json file should be stored in the root project folder; optionally,
54 | you can specify the location of a different file.
55 |
56 | Args:
57 | - stack (str): Base URL
58 | - creds (str, default='./sketchgraphs/onshape/creds/creds.json'): Credentials location
59 | '''
60 |
61 | if not os.path.isfile(creds):
62 | raise IOError('%s is not a file' % creds)
63 |
64 | with open(creds) as f:
65 | try:
66 | stacks = json.load(f)
67 | if stack in stacks:
68 | self._url = stack
69 | self._access_key = stacks[stack]['access_key'].encode(
70 | 'utf-8')
71 | self._secret_key = stacks[stack]['secret_key'].encode(
72 | 'utf-8')
73 | self._logging = logging
74 | else:
75 | raise ValueError('specified stack not in file')
76 | except TypeError:
77 | raise ValueError('%s is not valid json' % creds)
78 |
79 | if self._logging:
80 | utils.log('onshape instance created: url = %s, access key = %s' % (
81 | self._url, self._access_key))
82 |
83 | def _make_nonce(self):
84 | '''
85 | Generate a unique ID for the request, 25 chars in length
86 |
87 | Returns:
88 | - str: Cryptographic nonce
89 | '''
90 |
91 | chars = string.digits + string.ascii_letters
92 | nonce = ''.join(random.choice(chars) for i in range(25))
93 |
94 | if self._logging:
95 | utils.log('nonce created: %s' % nonce)
96 |
97 | return nonce
98 |
99 | def _make_auth(self, method, date, nonce, path, query={}, ctype='application/json'):
100 | '''
101 | Create the request signature to authenticate
102 |
103 | Args:
104 | - method (str): HTTP method
105 | - date (str): HTTP date header string
106 | - nonce (str): Cryptographic nonce
107 | - path (str): URL pathname
108 | - query (dict, default={}): URL query string in key-value pairs
109 | - ctype (str, default='application/json'): HTTP Content-Type
110 | '''
111 |
112 | query = urllib.parse.urlencode(query)
113 |
114 | hmac_str = (method + '\n' + nonce + '\n' + date + '\n' + ctype + '\n' + path +
115 | '\n' + query + '\n').lower().encode('utf-8')
116 |
117 | signature = base64.b64encode(
118 | hmac.new(self._secret_key, hmac_str, digestmod=hashlib.sha256).digest())
119 | auth = 'On ' + self._access_key.decode('utf-8') + \
120 | ':HmacSHA256:' + signature.decode('utf-8')
121 |
122 | if self._logging:
123 | utils.log({
124 | 'query': query,
125 | 'hmac_str': hmac_str,
126 | 'signature': signature,
127 | 'auth': auth
128 | })
129 |
130 | return auth
131 |
132 | def _make_headers(self, method, path, query={}, headers={}):
133 | '''
134 | Creates a headers object to sign the request
135 |
136 | Args:
137 | - method (str): HTTP method
138 | - path (str): Request path, e.g. /api/documents. No query string
139 | - query (dict, default={}): Query string in key-value format
140 | - headers (dict, default={}): Other headers to pass in
141 |
142 | Returns:
143 | - dict: Dictionary containing all headers
144 | '''
145 |
146 | date = datetime.datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT')
147 | nonce = self._make_nonce()
148 | ctype = headers.get(
149 | 'Content-Type') if headers.get('Content-Type') else 'application/json'
150 |
151 | auth = self._make_auth(method, date, nonce, path,
152 | query=query, ctype=ctype)
153 |
154 | req_headers = {
155 | 'Content-Type': 'application/json',
156 | 'Date': date,
157 | 'On-Nonce': nonce,
158 | 'Authorization': auth,
159 | 'User-Agent': 'Onshape Python Sample App',
160 | 'Accept': 'application/json'
161 | }
162 |
163 | # add in user-defined headers
164 | for h in headers:
165 | req_headers[h] = headers[h]
166 |
167 | return req_headers
168 |
169 | def request(self, method, path, query={}, headers={}, body={}, base_url=None, timeout=None, check_status=True):
170 | '''
171 | Issues a request to Onshape
172 |
173 | Args:
174 | - method (str): HTTP method
175 | - path (str): Path e.g. /api/documents/:id
176 | - query (dict, default={}): Query params in key-value pairs
177 | - headers (dict, default={}): Key-value pairs of headers
178 | - body (dict, default={}): Body for POST request
179 | - base_url (str, default=None): Host, including scheme and port (if different from creds file)
180 | - timeout (float, default=None): Timeout to use with requests.request().
181 | - check_status (bool, default=True): Raise exception if response status code is unsuccessful.
182 |
183 | Returns:
184 | - requests.Response: Object containing the response from Onshape
185 | '''
186 |
187 | req_headers = self._make_headers(method, path, query, headers)
188 | if base_url is None:
189 | base_url = self._url
190 | url = base_url + path + '?' + urllib.parse.urlencode(query)
191 |
192 | if self._logging:
193 | utils.log(body)
194 | utils.log(req_headers)
195 | utils.log('request url: ' + url)
196 |
197 | # only parse as json string if we have to
198 | body = json.dumps(body) if type(body) == dict else body
199 |
200 | res = requests.request(
201 | method, url, headers=req_headers, data=body, allow_redirects=False, stream=True,
202 | timeout=timeout)
203 |
204 | if res.status_code == 307:
205 | location = urlparse(res.headers["Location"])
206 | querystring = parse_qs(location.query)
207 |
208 | if self._logging:
209 | utils.log('request redirected to: ' + location.geturl())
210 |
211 | new_query = {}
212 | new_base_url = location.scheme + '://' + location.netloc
213 |
214 | for key in querystring:
215 | # won't work for repeated query params
216 | new_query[key] = querystring[key][0]
217 |
218 | return self.request(method, location.path, query=new_query, headers=headers, base_url=new_base_url)
219 | elif not 200 <= res.status_code <= 206:
220 | if self._logging:
221 | utils.log('request failed, details: ' + res.text, level=1)
222 | else:
223 | if self._logging:
224 | utils.log('request succeeded, details: ' + res.text)
225 | if check_status:
226 | res.raise_for_status()
227 | return res
--------------------------------------------------------------------------------
/sketchgraphs_models/graph/train/harness.py:
--------------------------------------------------------------------------------
1 | """This module contains the main training harness for training graph models. """
2 |
3 | import collections
4 | import itertools
5 | import os
6 | import pickle
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from sketchgraphs_models import training
12 | from sketchgraphs_models.graph import model as graph_model
13 | from sketchgraphs_models.nn import summary
14 |
15 |
16 | def _detach(x):
17 | if isinstance(x, torch.Tensor):
18 | return x.detach()
19 | else:
20 | return x
21 |
22 |
23 | def _mean_value(v):
24 | values = [a.mean().cpu().numpy() for a in v]
25 | return np.mean(values) if values else np.nan
26 |
27 |
28 | def _total_loss(losses):
29 | result = 0
30 |
31 | for v in losses.values():
32 | if v is None:
33 | continue
34 |
35 | if isinstance(v, dict):
36 | result += _total_loss(v)
37 | else:
38 | result += v.sum()
39 |
40 | return result
41 |
42 |
43 | class GraphModelHarness(training.TrainingHarness):
44 | """This class is the main harness for training graph models.
45 |
46 | The harness is responsible for coordinating all the procedures that surround training,
47 | such as learning rate scheduling, data loading, and logging.
48 | """
49 | def __init__(self, model, opt, node_feature_dimension, edge_feature_dimension,
50 | config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None,
51 | profile_enabled=False, additional_model_information=None):
52 | super(GraphModelHarness, self).__init__(model, opt, config_train, config_eval, dist_config)
53 | self.scheduler = scheduler
54 | self.output_dir = output_dir
55 |
56 | self.node_feature_dimension = node_feature_dimension
57 | self.edge_feature_dimension = edge_feature_dimension
58 | self.feature_dimension = {**node_feature_dimension, **edge_feature_dimension}
59 | self.profile_enabled = profile_enabled
60 | self._last_profile_step = 0
61 | self.additional_model_information = additional_model_information or {}
62 |
63 | def _make_feature_summary(fd):
64 | return collections.OrderedDict(
65 | (t, collections.OrderedDict(
66 | (feature_name, summary.ClassificationSummary(dim))
67 | for feature_name, dim in feature_description.items()
68 | )) for t, feature_description in fd.items())
69 |
70 | self.edge_feature_summaries = _make_feature_summary(edge_feature_dimension)
71 | self.node_feature_summaries = _make_feature_summary(node_feature_dimension)
72 |
73 |
74 | def _get_profile_path(self, global_step):
75 | if not self.profile_enabled:
76 | return None
77 |
78 | if self._last_profile_step is None or global_step - self._last_profile_step > 100000:
79 | self._last_profile_step = global_step
80 | return 'profile_step_{0}.pkl'.format(global_step)
81 |
82 | return None
83 |
84 | def single_step(self, batch, global_step):
85 | self.opt.zero_grad()
86 |
87 | profile_path = self._get_profile_path(global_step)
88 | with torch.autograd.profiler.profile(enabled=profile_path is not None, use_cuda=True) as trace:
89 | with torch.autograd.profiler.record_function("forward"):
90 | readout = self.model(batch)
91 | losses, accuracy, edge_metrics, node_metrics = graph_model.compute_losses(
92 | readout, batch, self.feature_dimension)
93 | total_loss = _total_loss(losses)
94 | if self.model.training:
95 | with torch.autograd.profiler.record_function("backward"):
96 | total_loss.backward()
97 |
98 | with torch.autograd.profiler.record_function("opt_update"):
99 | self.opt.step()
100 |
101 | if profile_path is not None:
102 | with open(profile_path, 'wb') as f:
103 | pickle.dump(trace, f, pickle.HIGHEST_PROTOCOL)
104 |
105 | losses = training.map_structure_flat(losses, _detach)
106 | losses = graph_model.compute_average_losses(losses, batch['graph_counts'])
107 | avg_loss = total_loss.detach() / float(sum(batch['graph_counts']))
108 | losses['average'] = avg_loss
109 |
110 | if self.is_leader():
111 | def _record_classification_summaries(metrics, summaries):
112 | for t, (labels, preds) in metrics.items():
113 | for i, cs in enumerate(summaries[t].values()):
114 | cs.record_statistics(labels[:, i], preds[:, i])
115 |
116 | _record_classification_summaries(edge_metrics, self.edge_feature_summaries)
117 | _record_classification_summaries(node_metrics, self.node_feature_summaries)
118 |
119 | return losses, accuracy
120 |
121 | def on_epoch_end(self, epoch, global_step):
122 | if self.scheduler is not None:
123 | self.scheduler.step()
124 |
125 | if self.config_train.tb_writer is not None and self.is_leader():
126 | lr = self.scheduler.get_last_lr()[0]
127 | self.config_train.tb_writer.add_scalar('learning_rate', lr, global_step)
128 |
129 | if self.is_leader() and self.output_dir is not None and (epoch + 1) % 10 == 0:
130 | self.log('Saving checkpoint for epoch {}'.format(epoch + 1))
131 | torch.save(
132 | {
133 | 'opt': self.opt.state_dict(),
134 | 'model': self.model.state_dict(),
135 | 'epoch': epoch,
136 | 'global_step': global_step,
137 | **self.additional_model_information,
138 | },
139 | os.path.join(self.output_dir, 'model_state_{0}.pt'.format(epoch + 1)))
140 |
141 | def write_summaries(self, global_step, losses, accuracies, tb_writer):
142 | if tb_writer is None:
143 | return
144 |
145 | for k, v in losses.items():
146 | if k == 'edge_features' or k == 'node_features':
147 | v = _mean_value(v.values())
148 | tb_writer.add_scalar('loss/' + k, v, global_step)
149 |
150 | for k, v in accuracies.items():
151 | if k == 'edge_features' or k == 'node_features':
152 | v = _mean_value(v.values())
153 | tb_writer.add_scalar('accuracy/' + k, v, global_step)
154 |
155 | for t, feature_schema in self.edge_feature_summaries.items():
156 | for n, cs in feature_schema.items():
157 | cs.write_tensorboard(tb_writer, 'kappa' + '/' + t.name + '/' + n, global_step)
158 |
159 | for t, feature_schema in self.node_feature_summaries.items():
160 | for n, cs in feature_schema.items():
161 | cs.write_tensorboard(tb_writer, 'kappa' + '/' + t.name + '/' + n, global_step)
162 |
163 | def print_statistics(self, loss_acc, accuracy_acc):
164 | self.log(f'\tLoss ({loss_acc["average"]:.3f}): Node({loss_acc["node_label"]:.3f}, {loss_acc["node_stop"]:.3f}) '
165 | f'Edge({loss_acc["edge_label"]:.3f}, {loss_acc["edge_partner"]:.3f}, '
166 | f'{_mean_value(loss_acc["edge_features"].values()):.3f}) '
167 | f'Subnode({loss_acc["subnode_stop"]:.3f}).')
168 | self.log(f'\tAccuracy: Node({accuracy_acc["node_label"]:4.1%}, {accuracy_acc["node_stop"]:4.1%}) '
169 | f'Edge({accuracy_acc["edge_label"]:4.1%}, {accuracy_acc["edge_partner"]:4.1%}, '
170 | f'{_mean_value(accuracy_acc["edge_features"].values()):4.1%}) '
171 | f'Subnode({accuracy_acc["subnode_stop"]:4.1%})')
172 | self.log()
173 |
174 | def _summary_text(target, features):
175 | return f'{target.name}: ' + '; '.join(f'{n} ({cs.cohen_kappa():.3f})' for n, cs in features.items())
176 |
177 | if self.node_feature_summaries:
178 | self.log('Kappa Entity')
179 | for t, features in self.node_feature_summaries.items():
180 | self.log(_summary_text(t, features))
181 |
182 | self.log()
183 |
184 | if self.edge_feature_summaries:
185 | self.log('Kappa Edges')
186 | for t, features in self.edge_feature_summaries.items():
187 | self.log(_summary_text(t, features))
188 | self.log()
189 |
190 | def reset_statistics(self):
191 | for features in itertools.chain(self.node_feature_summaries.values(), self.edge_feature_summaries.values()):
192 | for classification_summary in features.values():
193 | classification_summary.reset_statistics()
194 |
195 | super(GraphModelHarness, self).reset_statistics()
196 |
--------------------------------------------------------------------------------
/sketchgraphs/data/constraint_checks.py:
--------------------------------------------------------------------------------
1 | """This module implements a basic constraint checker for a solved graph.
2 |
3 | This module implements a number of functions to help check basic relational constraints
4 | between entities. Note that only checking is implemented: this is not an implementation
5 | of a solver, and cannot solve for the desired constraints.
6 |
7 | """
8 |
9 | import numpy as np
10 | from numpy.linalg import norm
11 |
12 | from ._entity import Point, Line, Circle, Arc, SubnodeType, ENTITY_TYPE_TO_CLASS, Entity
13 | from ._constraint import ConstraintType
14 | from .sequence import NodeOp, EdgeOp
15 |
16 |
17 | def get_entity_by_idx(seq, idx: int) -> Entity:
18 | """Returns the entity or sub-entity corresponding to idx.
19 |
20 | Parameters
21 | ----------
22 | seq : List[Union[NodeOp, EdgeOp]]
23 | A list of node and edge operations representing the construction sequence.
24 | idx : int
25 | An integer representing the index of the desired entity.
26 |
27 | Returns
28 | -------
29 | Entity
30 | An entity object representing the entity corresponding to the `NodeOp` at the given index.
31 | """
32 | node_ops = [op for op in seq if isinstance(op, NodeOp)]
33 | label = node_ops[idx].label
34 |
35 | def _entity_from_op(op):
36 | entity = ENTITY_TYPE_TO_CLASS[op.label]('dummy_id')
37 | for param_id, val in op.parameters.items():
38 | setattr(entity, param_id, val)
39 | return entity
40 |
41 | if not isinstance(label, SubnodeType):
42 | return _entity_from_op(node_ops[idx])
43 |
44 | for i in reversed(range(idx)):
45 | this_label = node_ops[i].label
46 | if not isinstance(this_label, SubnodeType):
47 | parent = _entity_from_op(node_ops[i])
48 | break
49 |
50 | if label == SubnodeType.SN_Start:
51 | return parent.start_point
52 | elif label == SubnodeType.SN_End:
53 | return parent.end_point
54 | elif label == SubnodeType.SN_Center:
55 | return parent.center_point
56 |
57 | raise ValueError('Could not find entity corresponding to idx.')
58 |
59 |
60 | def check_edge_satisfied(seq, op: EdgeOp):
61 | """Determines whether the given EdgeOp instance is geometrically satisfied in the graph represented by `seq`.
62 |
63 | Parameters
64 | ----------
65 | seq : List[Union[EdgeOp, NodeOp]]
66 | A construction sequence representing the underlying graph.
67 | op : EdgeOp
68 | An edge op to check for validity in the current graph.
69 |
70 | Returns
71 | -------
72 | bool
73 | `True` if the current edge constraint is satisified, otherwise `False`.
74 |
75 | Raises
76 | ------
77 | ValueError
78 | If the current constraint is not supported (e.g. its type is not supported),
79 | or it refers to an external entity, which is not supported.
80 | """
81 | if 0 in op.references:
82 | raise ValueError('External constraints not supported.')
83 |
84 | entities = [get_entity_by_idx(seq, ref) for ref in op.references]
85 |
86 | try:
87 | constraint_f = CONSTRAINT_BY_LABEL[op.label]
88 | except KeyError:
89 | raise ValueError('%s not currently supported.' % op.label)
90 |
91 | return constraint_f(*entities)
92 |
93 |
94 | def get_sorted_types(entities):
95 | """Obtains the types and sorts the entities based on their type order.
96 |
97 | Parameters
98 | ----------
99 | entities : iterable of `Entity`
100 | An list of entities to be sorted.
101 |
102 | Returns
103 | -------
104 | types : List
105 | A list of types representing the type of each entity
106 | entities : List
107 | A list of entities, containing the same elements as the input iterable,
108 | but sorted in the order given by types.
109 | """
110 | types = [Point if isinstance(ent, np.ndarray) else type(ent) for ent in entities]
111 | type_names = [t.__name__ for t in types]
112 | idxs = np.argsort(type_names)
113 | return [types[idx] for idx in idxs], [entities[idx] for idx in idxs]
114 |
115 |
116 | def _ensure_array(point):
117 | if isinstance(point, np.ndarray):
118 | return point
119 | else:
120 | return np.array([point.x, point.y])
121 |
122 |
123 | def coincident(*entities):
124 | types, entities = get_sorted_types(entities)
125 |
126 | if types == [Point, Point]:
127 | return np.allclose(_ensure_array(entities[0]), _ensure_array(entities[1]))
128 |
129 | elif types == [Line, Point]:
130 | vec1 = entities[0].end_point - entities[0].start_point
131 | vec2 = _ensure_array(entities[1]) - entities[0].start_point
132 | return np.isclose(np.cross(vec1, vec2), 0)
133 |
134 | elif types == [Line, Line]:
135 | return coincident(entities[0], entities[1].start_point) and coincident(entities[0], entities[1].end_point)
136 |
137 | elif types in [[Arc, Point], [Circle, Point]]:
138 | circle_or_arc, point = entities
139 | dist = norm(_ensure_array(point) - circle_or_arc.center_point)
140 | return np.isclose(circle_or_arc.radius, dist)
141 |
142 | elif types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]:
143 | return np.allclose([entities[0].xCenter, entities[0].yCenter, entities[0].radius],
144 | [entities[1].xCenter, entities[1].yCenter, entities[1].radius])
145 |
146 | else:
147 | return None
148 |
149 |
150 | def parallel(*ents):
151 | types, ents = get_sorted_types(ents)
152 |
153 | if types == [Line, Line]:
154 | vec1 = ents[0].end_point - ents[0].start_point
155 | vec2 = ents[1].end_point - ents[1].start_point
156 | return np.isclose(np.cross(vec1, vec2), 0)
157 |
158 | else:
159 | return None
160 |
161 |
162 | def horizontal(*ents):
163 | types, ents = get_sorted_types(ents)
164 |
165 | if types == [Line]:
166 | return horizontal(ents[0].start_point, ents[0].end_point)
167 | elif types == [Point, Point]:
168 | return np.isclose(ents[0][1], ents[1][1], atol=1e-6)
169 | else:
170 | return None
171 |
172 |
173 | def vertical(*ents):
174 | types, ents = get_sorted_types(ents)
175 |
176 | if types == [Line]:
177 | return vertical(ents[0].start_point, ents[0].end_point)
178 | elif types == [Point, Point]:
179 | return np.isclose(ents[0][0], ents[1][0], atol=1e-6)
180 | else:
181 | return None
182 |
183 |
184 | def perpendicular(*ents):
185 | types, ents = get_sorted_types(ents)
186 |
187 | if types == [Line, Line]:
188 | vec1 = ents[0].start_point - ents[0].end_point
189 | vec2 = ents[1].start_point - ents[1].end_point
190 | return np.isclose(np.dot(vec1, vec2), 0)
191 | else:
192 | return None
193 |
194 |
195 | def tangent(*ents):
196 | types, ents = get_sorted_types(ents)
197 |
198 | if types == [Circle, Line]:
199 | circle, line = ents
200 | p1, p2 = (line.start_point, line.end_point)
201 | p3 = circle.center_point
202 |
203 | line_dir = p2 - p1
204 | line_dir_norm = norm(line_dir)
205 |
206 | if np.abs(line_dir_norm) < 1e-6:
207 | dist = norm(p1 - p3)
208 | else:
209 | dist = norm(np.cross(line_dir, p1-p3)) / line_dir_norm
210 |
211 | return np.isclose(circle.radius, dist)
212 |
213 | elif types == [Arc, Line]:
214 | arc, line = ents
215 | circle = Circle('', xCenter=arc.xCenter, yCenter=arc.yCenter, radius=arc.radius)
216 | return tangent(circle, line)
217 | elif types in [[Arc, Arc], [Arc, Circle], [Circle, Circle]]:
218 | dist = norm(ents[1].center_point - ents[0].center_point)
219 | return (np.isclose(dist, ents[0].radius + ents[1].radius)
220 | or np.isclose(dist, np.abs(ents[0].radius - ents[1].radius)))
221 | else:
222 | return None
223 |
224 |
225 | def equal(*ents):
226 | types, ents = get_sorted_types(ents)
227 |
228 | if types == [Line, Line]:
229 | line0, line1 = ents
230 | vec0 = line0.end_point - line0.start_point
231 | vec1 = line1.end_point - line1.start_point
232 | return np.isclose(norm(vec0), norm(vec1))
233 |
234 | elif types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]:
235 | return np.isclose(ents[0].radius, ents[1].radius)
236 |
237 | else:
238 | return None
239 |
240 |
241 | def midpoint(*ents):
242 | types, ents = get_sorted_types(ents)
243 |
244 | if types == [Line, Point]:
245 | line, point = ents
246 | mid_coords = (line.start_point + line.end_point) / 2
247 | return np.allclose(mid_coords, _ensure_array(point))
248 |
249 | else:
250 | return None
251 |
252 |
253 | def concentric(*ents):
254 | types, ents = get_sorted_types(ents)
255 |
256 | if types in [[Circle, Circle], [Arc, Arc], [Arc, Circle]]:
257 | return coincident(ents[0].center_point, ents[1].center_point)
258 |
259 | else:
260 | return None
261 |
262 |
263 | CONSTRAINT_BY_LABEL = {
264 | ConstraintType.Coincident: coincident,
265 | ConstraintType.Parallel: parallel,
266 | ConstraintType.Horizontal: horizontal,
267 | ConstraintType.Vertical: vertical,
268 | ConstraintType.Perpendicular: perpendicular,
269 | ConstraintType.Tangent: tangent,
270 | ConstraintType.Equal: equal,
271 | ConstraintType.Midpoint: midpoint,
272 | ConstraintType.Concentric: concentric
273 | }
274 |
--------------------------------------------------------------------------------
/sketchgraphs/onshape/call.py:
--------------------------------------------------------------------------------
1 | """Simple command line utilties for interacting with Onshape API.
2 |
3 | A sketch is considered a feature of an Onshape PartStudio. This script enables adding a sketch to a part (add_feature), retrieving all features from a part including sketches (get_features), and retrieving the possibly updated state of each sketch's entities/primitives post constraint solving (get_info).
4 | """
5 | import argparse
6 | import json
7 | import urllib.parse
8 |
9 | from . import Client
10 |
11 |
12 | TEMPLATE_PATH = 'sketchgraphs/onshape/feature_template.json'
13 |
14 |
15 | def _parse_resp(resp):
16 | """Parse the response of a retrieval call.
17 | """
18 | parsed_resp = json.loads(resp.content.decode('utf8').replace("'", '"'))
19 | return parsed_resp
20 |
21 |
22 | def _save_or_print_resp(resp_dict, output_path=None, indent=4):
23 | """Saves or prints the given response dict.
24 | """
25 | if output_path:
26 | with open(output_path, 'w') as fh:
27 | json.dump(resp_dict, fh, indent=indent)
28 | else:
29 | print(json.dumps(resp_dict, indent=indent))
30 |
31 |
32 | def _create_client(logging):
33 | """Creates a `Client` with the given bool value for `logging`.
34 | """
35 | client = Client(stack='https://cad.onshape.com',
36 | logging=logging)
37 | return client
38 |
39 |
40 | def _parse_url(url):
41 | """Extracts doc, workspace, element ids from url.
42 | """
43 | _, _, docid, _, wid, _, eid = urllib.parse.urlparse(url).path.split('/')
44 | return docid, wid, eid
45 |
46 |
47 | def update_template(url, logging=False):
48 | """Updates version identifiers in feature_template.json.
49 |
50 | Parameters
51 | ----------
52 | url : str
53 | URL of Onshape PartStudio
54 | logging: bool
55 | Whether to log API messages (default False)
56 |
57 | Returns
58 | -------
59 | None
60 | """
61 | # Get PartStudio features (including version IDs)
62 | features = get_features(url, logging)
63 | # Get current feature template
64 | with open(TEMPLATE_PATH, 'r') as fh:
65 | template = json.load(fh)
66 | for version_key in ['serializationVersion', 'sourceMicroversion', 'libraryVersion']:
67 | template[version_key] = features[version_key]
68 | # Save updated feature template
69 | with open(TEMPLATE_PATH, 'w') as fh:
70 | json.dump(template, fh, indent=4)
71 |
72 |
73 | def add_feature(url, sketch_dict, sketch_name=None, logging=False):
74 | """Adds a sketch to a part.
75 |
76 | Parameters
77 | ----------
78 | url : str
79 | URL of Onshape PartStudio
80 | sketch_dict: dict
81 | A dictionary representing a `Sketch` instance with keys `entities` and `constraints`
82 | sketch_name: str
83 | Optional name for the sketch. If none provided, defaults to 'My Sketch'.
84 | logging: bool
85 | Whether to log API messages (default False)
86 |
87 | Returns
88 | -------
89 | None
90 | """
91 | # Get doc ids and create Client
92 | docid, wid, eid = _parse_url(url)
93 | client = _create_client(logging)
94 | # Get feature template
95 | with open(TEMPLATE_PATH, 'r') as fh:
96 | template = json.load(fh)
97 | # Add sketch's entities and constraints to the template
98 | template['feature']['message']['entities'] = sketch_dict['entities']
99 | template['feature']['message']['constraints'] = sketch_dict['constraints']
100 | if not sketch_name:
101 | sketch_name = 'My Sketch'
102 | template['feature']['message']['name'] = sketch_name
103 | # Send to Onshape
104 | client.add_feature(docid, wid, eid, payload=template)
105 |
106 |
107 | def get_features(url, logging=False):
108 | """Retrieves features from a part.
109 |
110 | Parameters
111 | ----------
112 | url : str
113 | URL of Onshape PartStudio
114 | logging : bool
115 | Whether to log API messages (default False)
116 |
117 | Returns
118 | -------
119 | features : dict
120 | A dictionary containing the part's features
121 |
122 | """
123 | # Get doc ids and create Client
124 | docid, wid, eid = _parse_url(url)
125 | client = _create_client(logging)
126 | # Get features
127 | resp = client.get_features(docid, wid, eid)
128 | features = _parse_resp(resp)
129 | return features
130 |
131 |
132 | def get_info(url, sketch_name=None, logging=False):
133 | """Retrieves possibly updated states of entities in a part's sketches.
134 |
135 | Parameters
136 | ----------
137 | url : str
138 | URL of Onshape PartStudio
139 | sketch_name : str
140 | If provided, only the entity info for the specified sketch will be returned. Otherwise, the full response is returned.
141 | logging : bool
142 | Whether to log API messages (default False)
143 |
144 | Returns
145 | -------
146 | sketch_info : dict
147 | A dictionary containing entity info for sketches
148 |
149 | """
150 | # Get doc ids and create Client
151 | docid, wid, eid = _parse_url(url)
152 | client = _create_client(logging)
153 | # Get features
154 | resp = client.sketch_information(docid, wid, eid)
155 | sketch_info = _parse_resp(resp)
156 | if sketch_name:
157 | sketch_found = False
158 | for sk in sketch_info['sketches']:
159 | if sk['sketch'] == sketch_name:
160 | sketch_info = sk
161 | sketch_found = True
162 | break
163 | if not sketch_found:
164 | raise ValueError("No sketch found with given name.")
165 | return sketch_info
166 |
167 |
168 | def get_states(url, logging=False):
169 | """Retrieves states of sketches in a part.
170 |
171 | If there are no issues with a sketch, the feature state is `OK`. If there
172 | are issues, e.g., unsolved constraints, the state is `WARNING`. All sketches
173 | in the queried PartStudio must have unique names.
174 |
175 | Parameters
176 | ----------
177 | url : str
178 | URL of Onshape PartStudio
179 | logging : bool
180 | Whether to log API messages (default False)
181 |
182 | Returns
183 | -------
184 | sketch_states : dict
185 | A dictionary containing sketch names as keys and associated states as
186 | values
187 | """
188 | # Get features for the given part url
189 | features = get_features(url, logging=logging)
190 | # Gather all feature states
191 | feat_states = {f['key']:f['value']['message']['featureStatus']
192 | for f in features['featureStates']}
193 | # Gather sketch states
194 | sketch_states = {}
195 | for feat in features['features']:
196 | if feat['typeName'] != 'BTMSketch':
197 | continue
198 | sk_name = feat['message']['name']
199 | sk_id = feat['message']['featureId']
200 | # Check if name already encountered
201 | if sk_name in sketch_states:
202 | raise ValueError("Each sketch must have a unique name.")
203 | sketch_states[sk_name] = feat_states[sk_id]
204 | return sketch_states
205 |
206 |
207 | def main():
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument('--url',
210 | help='URL of Onshape PartStudio',required=True)
211 | parser.add_argument('--action',
212 | help='The API call to perform', required=True,
213 | choices=['add_feature', 'get_features', 'get_info', 'get_states',
214 | 'update_template'])
215 | parser.add_argument('--payload_path',
216 | help='Path to payload being sent to Onshape', default=None)
217 | parser.add_argument('--output_path',
218 | help='Path to save result of API call', default=None)
219 | parser.add_argument('--enable_logging',
220 | help='Whether to log API messages', action='store_true')
221 | parser.add_argument('--sketch_name',
222 | help='Optional name for sketch', default=None)
223 |
224 | args = parser.parse_args()
225 |
226 | # Parse the URL
227 | _, _, docid, _, wid, _, eid = urllib.parse.urlparse(args.url).path.split('/')
228 |
229 | # Create client
230 | client = Client(stack='https://cad.onshape.com',
231 | logging=args.enable_logging)
232 |
233 | # Perform the specified action
234 | if args.action =='add_feature':
235 | # Add a sketch to a part
236 | if not args.payload_path:
237 | raise ValueError("payload_path required when adding a feature")
238 | with open(args.payload_path, 'r') as fh:
239 | sketch_dict = json.load(fh)
240 | add_feature(args.url, sketch_dict, args.sketch_name,
241 | args.enable_logging)
242 |
243 | elif args.action == 'get_features':
244 | # Retrieve features from a part
245 | features = get_features(args.url, args.enable_logging)
246 | _save_or_print_resp(features, output_path=args.output_path)
247 |
248 | elif args.action == 'get_info':
249 | # Retrieve possibly updated states of entities in a part's sketches
250 | sketch_info = get_info(args.url, args.sketch_name, args.enable_logging)
251 | _save_or_print_resp(sketch_info, output_path=args.output_path)
252 |
253 | elif args.action == 'get_states':
254 | # Retrieve states of sketches in a part
255 | sketch_states = get_states(args.url, args.enable_logging)
256 | _save_or_print_resp(sketch_states, output_path=args.output_path)
257 |
258 | elif args.action == 'update_template':
259 | # Updates version identifiers in template
260 | update_template(args.url, args.enable_logging)
261 |
262 |
263 | if __name__ == '__main__':
264 | main()
--------------------------------------------------------------------------------