├── tests
├── __init__.py
├── match_test.py
├── edit_test.py
├── subgraph_test.py
├── reroute_test.py
├── function_graph_test.py
├── util_test.py
├── select_test.py
└── transform_test.py
├── .gitignore
├── scripts
├── test.sh
└── env.sh
├── graph_def_editor
├── visualization
│ ├── __init__.py
│ ├── graphviz_style.py
│ ├── jupyter_helper.py
│ └── graphviz_wrapper.py
├── __init__.py
├── tensor.py
├── match.py
├── edit.py
├── variable.py
├── function_graph.py
├── base_graph.py
└── reroute.py
├── setup.py
├── examples
├── edit_graph_example.py
├── batch_size_example.py
├── mobilenet_example.py
└── coco_example.py
├── README.md
└── LICENSE
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Dummy __init__.py to keep pytest happy."""
2 |
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .ipynb_checkpoints
3 | env
4 | *.swp
5 | */__pycache__
6 | *.pyc
7 | *.iml
8 | test.out
9 | example.out
10 |
11 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # test.sh
3 | #
4 | # Run regression tests for this project.
5 | #
6 | # Usage:
7 | # ./scripts/test.sh
8 | #
9 |
10 | #conda activate ./env
11 | ./env/bin/pytest --ignore=env | tee test.out
12 | #conda deactivate
13 |
14 |
--------------------------------------------------------------------------------
/graph_def_editor/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 | # Copyright 2019 IBM. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Dummy __init__.py to keep pytest happy."""
17 |
18 |
--------------------------------------------------------------------------------
/graph_def_editor/visualization/graphviz_style.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 | # Copyright 2019 IBM. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Default GraphViz styles to use."""
17 |
18 |
19 | graph_pref = {
20 | 'fontcolor': '#414141',
21 | 'style': 'rounded',
22 | }
23 |
24 | name_scope_graph_pref = {
25 | 'bgcolor': '#eeeeee',
26 | 'color': '#aaaaaa',
27 | 'penwidth': '2',
28 | }
29 |
30 | non_name_scope_graph_pref = {
31 | 'fillcolor': 'white',
32 | 'color': 'white',
33 | }
34 |
35 | node_pref = {
36 | 'style': 'filled',
37 | 'fillcolor': 'white',
38 | 'color': '#aaaaaa',
39 | 'penwidth': '2',
40 | 'fontcolor': '#414141',
41 | }
42 |
43 | edge_pref = {
44 | 'color': '#aaaaaa',
45 | 'arrowsize': '1.2',
46 | 'penwidth': '2.5',
47 | 'fontcolor': '#414141',
48 | }
49 |
--------------------------------------------------------------------------------
/scripts/env.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | ################################################################################
4 | # env.sh
5 | #
6 | # Set up an Anaconda virtualenv in the directory ./env
7 | #
8 | # Run this script from the root of the project, i.e.
9 | # ./scripts/env.sh
10 | #
11 | # Requires that conda be installed and set up for calling from bash scripts.
12 | #
13 | # Also requires that you set the environment variable CONDA_HOME to the
14 | # location of the root of your anaconda/miniconda distribution.
15 | ################################################################################
16 |
17 | PYTHON_VERSION=3.6
18 |
19 | ############################
20 | # HACK ALERT *** HACK ALERT
21 | # The friendly folks at Anaconda thought it would be a good idea to make the
22 | # "conda" command a shell function.
23 | # See https://github.com/conda/conda/issues/7126
24 | # The following workaround will probably be fragile.
25 | if [ -z "$CONDA_HOME" ]
26 | then
27 | echo "Error: CONDA_HOME not set"
28 | exit
29 | fi
30 | . ${CONDA_HOME}/etc/profile.d/conda.sh
31 | # END HACK
32 | ############################
33 |
34 | ################################################################################
35 | # Remove any previous outputs of this script
36 |
37 | rm -rf ./env
38 |
39 |
40 | ################################################################################
41 | # Create the environment
42 | conda create -y --prefix ./env \
43 | python=${PYTHON_VERSION} \
44 | numpy \
45 | tensorflow \
46 | jupyterlab \
47 | pytest \
48 | keras \
49 | pillow \
50 | nomkl
51 |
52 | echo << EOM
53 | Anaconda virtualenv installed in ./env.
54 | Run \"conda activate ./env\" to use it.
55 | EOM
56 |
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | from setuptools import find_packages, setup
18 |
19 | with open('README.md') as f:
20 | long_description = f.read()
21 |
22 | setup(
23 | name='graph_def_editor',
24 | version='0.1.0',
25 | description=('TensorFlow Graph Def Editor'),
26 | long_description=long_description,
27 | long_description_content_type='text/markdown',
28 | packages=find_packages(),
29 | install_requires=['numpy', 'tensorflow', 'six'],
30 | include_package_data=True,
31 | zip_safe=False,
32 | classifiers=[
33 | 'License :: OSI Approved :: Apache Software License',
34 | 'Programming Language :: Python :: 2.7',
35 | 'Programming Language :: Python :: 3.5',
36 | 'Programming Language :: Python :: 3.6',
37 | 'Programming Language :: Python :: 3.7',
38 | 'Intended Audience :: Developers',
39 | 'Intended Audience :: Education',
40 | 'Intended Audience :: Science/Research',
41 | 'Topic :: Scientific/Engineering :: Mathematics',
42 | 'Topic :: Software Development :: Libraries :: Python Modules',
43 | 'Topic :: Software Development :: Libraries',
44 | ],
45 | license='Apache License, Version 2.0',
46 | maintainer='Graph Def Developers',
47 | maintainer_email='',
48 | )
49 |
--------------------------------------------------------------------------------
/graph_def_editor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """GDE: A GraphDef Editor for TensorFlow
17 |
18 | A version of the old [`contrib.graph_editor`](https://github.com/tensorflow/tensorflow/tree/r1.12/tensorflow/contrib/graph_editor) API that operates over serialized TensorFlow graphs represented as GraphDef protocol buffer messages.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | # pylint: disable=wildcard-import
26 | from graph_def_editor.base_graph import *
27 | from graph_def_editor.edit import *
28 | from graph_def_editor.graph import *
29 | from graph_def_editor.match import *
30 | from graph_def_editor.node import *
31 | from graph_def_editor.reroute import *
32 | from graph_def_editor.select import *
33 | from graph_def_editor.subgraph import *
34 | from graph_def_editor.tensor import *
35 | from graph_def_editor.transform import *
36 | from graph_def_editor.util import *
37 | from graph_def_editor.variable import *
38 | # pylint: enable=wildcard-import
39 |
40 | # Other parts go under sub-packages
41 | from graph_def_editor import rewrite
42 |
43 | # some useful aliases
44 | # pylint: disable=g-bad-import-order
45 | from graph_def_editor import subgraph as _subgraph
46 | from graph_def_editor import util as _util
47 | # pylint: enable=g-bad-import-order
48 | ph = _util.make_placeholder_from_dtype_and_shape
49 | sgv = _subgraph.make_view
50 | sgv_scope = _subgraph.make_view_from_scope
51 |
52 | del absolute_import
53 | del division
54 | del print_function
55 |
--------------------------------------------------------------------------------
/examples/edit_graph_example.py:
--------------------------------------------------------------------------------
1 | # Coypright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """Simple example of the GraphDef Editor.
18 |
19 | To run this example from the root of the project, type:
20 | PYTHONPATH=$PWD env/bin/python examples/edit_graph_example.py
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import numpy as np
28 | import tensorflow as tf
29 | import graph_def_editor as gde
30 | import textwrap
31 |
32 | FLAGS = tf.flags.FLAGS
33 |
34 |
35 | def _indent(s):
36 | return textwrap.indent(str(s), " ")
37 |
38 |
39 | def main(_):
40 | # Create a graph
41 | tf_g = tf.Graph()
42 | with tf_g.as_default():
43 | a = tf.constant(1.0, shape=[2, 3], name="a")
44 | c = tf.add(
45 | tf.placeholder(dtype=np.float32),
46 | tf.placeholder(dtype=np.float32),
47 | name="c")
48 |
49 | # Serialize the graph
50 | g = gde.Graph(tf_g.as_graph_def())
51 | print("Before:\n{}".format(_indent(g.to_graph_def())))
52 |
53 | # Modify the graph.
54 | # In this case we replace the two input placeholders with constants.
55 | # One of the constants (a) is a node that was in the original graph.
56 | # The other one (b) we create here.
57 | b = gde.make_const(g, "b", np.full([2, 3], 2.0, dtype=np.float32))
58 | gde.swap_inputs(g[c.op.name], [g[a.name], b.output(0)])
59 |
60 | print("After:\n{}".format(_indent(g.to_graph_def())))
61 |
62 | # Reconstitute the modified serialized graph as TensorFlow graph...
63 | with g.to_tf_graph().as_default():
64 | # ...and print the value of c, which should be 2x3 matrix of 3.0's
65 | with tf.Session() as sess:
66 | res = sess.run(c.name)
67 | print("Result is:\n{}".format(_indent(res)))
68 |
69 |
70 | if __name__ == "__main__":
71 | tf.app.run()
72 |
--------------------------------------------------------------------------------
/tests/match_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests for tensorflow.contrib.graph_editor."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow.compat.v1 as tf
22 | tf.disable_eager_execution()
23 |
24 | import unittest
25 |
26 | import graph_def_editor as gde
27 |
28 |
29 | class MatchTest(unittest.TestCase):
30 |
31 | def setUp(self):
32 | tf_graph = tf.Graph()
33 | with tf_graph.as_default():
34 | a = tf.constant([1., 1.], shape=[2], name="a")
35 | with tf.name_scope("foo"):
36 | b = tf.constant([2., 2.], shape=[2], name="b")
37 | c = tf.add(a, b, name="c")
38 | d = tf.constant([3., 3.], shape=[2], name="d")
39 | with tf.name_scope("bar"):
40 | _ = tf.add(c, d, name="e")
41 | f = tf.add(c, d, name="f")
42 | g = tf.add(c, a, name="g")
43 | with tf.control_dependencies([c.op]):
44 | _ = tf.add(f, g, name="h")
45 |
46 | self.graph = gde.Graph(tf_graph)
47 | self.f_op = self.graph[f.op.name]
48 |
49 | def test_simple_match(self):
50 | self.assertTrue(gde.OpMatcher("^.*/f$")(self.f_op))
51 | self.assertTrue(
52 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f_op))
53 | self.assertTrue(
54 | gde.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f_op))
55 | self.assertTrue(
56 | gde.OpMatcher("^.*/f$").input_ops(
57 | gde.op_type("Add"), gde.op_type("Const"))(self.f_op) or
58 | gde.OpMatcher("^.*/f$").input_ops(
59 | gde.op_type("AddV2"), gde.op_type("Const"))(self.f_op))
60 | self.assertTrue(
61 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")
62 | .output_ops(gde.OpMatcher("^.*/h$")
63 | .control_input_ops("^.*/c$"))(self.f_op))
64 | self.assertTrue(
65 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
66 | gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$")
67 | .output_ops([]))(self.f_op))
68 |
69 |
70 | if __name__ == "__main__":
71 | unittest.main()
72 |
--------------------------------------------------------------------------------
/graph_def_editor/visualization/jupyter_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 | # Copyright 2019 IBM. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Helper methods to display gde graph in Jupyter Notebook or Colab."""
17 |
18 | import time
19 |
20 | def _import_ipython():
21 | try:
22 | from IPython.display import HTML
23 | except ModuleNotFoundError as error:
24 | raise ModuleNotFoundError(
25 | "You need to install ipython or Jupyter to be able to use this functionality. "
26 | "See https://ipython.org/install.html or https://jupyter.org/install for details.")
27 |
28 |
29 | def jupyter_show_as_svg(dg):
30 | """Shows object as SVG (by default it is rendered as image).
31 |
32 | Args:
33 | dg: digraph object
34 |
35 | Returns:
36 | Graph rendered in SVG format
37 | """
38 | _import_ipython()
39 | return HTML(dg.pipe(format="svg").decode("utf-8"))
40 |
41 |
42 | def jupyter_pan_and_zoom(
43 | dg,
44 | element_styles="height:auto",
45 | container_styles="overflow:hidden",
46 | pan_zoom_json="{controlIconsEnabled: true, zoomScaleSensitivity: 0.4, "
47 | "minZoom: 0.2}"):
48 | """Embeds SVG object into Jupyter cell with ability to pan and zoom.
49 |
50 | Args:
51 | dg: digraph object
52 | element_styles: CSS styles for embedded SVG element.
53 | container_styles: CSS styles for container div element.
54 | pan_zoom_json: pan and zoom settings, see
55 | https://github.com/bumbu/svg-pan-zoom
56 |
57 | Returns:
58 | Graph rendered as HTML using javascript for Pan and Zoom functionality.
59 | """
60 | svg_txt = dg.pipe(format="svg").decode("utf-8")
61 | html_container_class_name = F"svg_container_{int(time.time())}"
62 | html = F"""
63 |
64 |
72 |
73 |
89 | {svg_txt}
90 |
91 | """
92 | _import_ipython()
93 | return HTML(html)
94 |
--------------------------------------------------------------------------------
/tests/edit_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests for gde.edit"""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow.compat.v1 as tf
23 | tf.disable_eager_execution()
24 | import unittest
25 |
26 | import graph_def_editor as gde
27 |
28 |
29 | class EditTest(unittest.TestCase):
30 | """edit module test.
31 |
32 | Generally the tests are in two steps:
33 | - modify an existing graph.
34 | - then make sure it has the expected topology using the graph matcher.
35 | """
36 |
37 | # TODO(frreiss): Merge duplicate setup code across test cases
38 | def setUp(self):
39 | tf_graph = tf.Graph()
40 | with tf_graph.as_default():
41 | a = tf.constant([1., 1.], shape=[2], name="a")
42 | with tf.name_scope("foo"):
43 | b = tf.constant([2., 2.], shape=[2], name="b")
44 | c = tf.add(a, b, name="c")
45 | d = tf.constant([3., 3.], shape=[2], name="d")
46 | with tf.name_scope("bar"):
47 | e = tf.add(c, d, name="e")
48 | f = tf.add(c, d, name="f")
49 | g = tf.add(c, a, name="g")
50 | with tf.control_dependencies([c.op]):
51 | h = tf.add(f, g, name="h")
52 | self.graph = gde.Graph(tf_graph)
53 | self.a = self.graph.get_tensor_by_name(a.name)
54 | self.b = self.graph.get_tensor_by_name(b.name)
55 | self.c = self.graph.get_tensor_by_name(c.name)
56 | self.d = self.graph.get_tensor_by_name(d.name)
57 | self.e = self.graph.get_tensor_by_name(e.name)
58 | self.f = self.graph.get_tensor_by_name(f.name)
59 | self.g = self.graph.get_tensor_by_name(g.name)
60 | self.h = self.graph.get_tensor_by_name(h.name)
61 |
62 | def test_detach(self):
63 | """Test for ge.detach."""
64 | sgv = gde.sgv(self.c.op, self.a.op)
65 | control_outputs = gde.ControlOutputs(self.graph)
66 | gde.detach(sgv, control_ios=control_outputs)
67 | # make sure the detached graph is as expected.
68 | self.assertTrue(
69 | gde.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
70 |
71 | def test_connect(self):
72 | """Test for gde.connect."""
73 | # Original code:
74 | # with self.graph.as_default():
75 | # x = constant_op.constant([1., 1.], shape=[2], name="x")
76 | # y = constant_op.constant([2., 2.], shape=[2], name="y")
77 | # z = math_ops.add(x, y, name="z")
78 | x = gde.make_const(self.graph, "x", np.array([1., 1.], dtype=np.float32))
79 | y = gde.make_const(self.graph, "y", np.array([2., 2.], dtype=np.float32))
80 | z = self.graph.add_node("z", "Add")
81 | z.add_attr("T", tf.float32)
82 | z.set_inputs([x.outputs[0], y.outputs[0]])
83 | z.infer_outputs()
84 |
85 | sgv = gde.sgv(x, y, z)
86 | gde.connect(sgv, gde.sgv(self.e.op).remap_inputs([0]))
87 | self.assertTrue(
88 | gde.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
89 |
90 | def test_bypass(self):
91 | """Test for ge.bypass."""
92 | gde.bypass(gde.sgv(self.f.op).remap_inputs([0]))
93 | self.assertTrue(
94 | gde.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(
95 | self.h.op))
96 |
97 |
98 | if __name__ == "__main__":
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/tests/subgraph_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests for tensorflow.contrib.graph_editor."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow.compat.v1 as tf
22 | tf.disable_eager_execution()
23 |
24 | import unittest
25 |
26 | import graph_def_editor as gde
27 |
28 |
29 | class SubgraphTest(unittest.TestCase):
30 |
31 | # TODO(frreiss): Merge duplicate setup code across test cases
32 | def setUp(self):
33 | tf_graph = tf.Graph()
34 | with tf_graph.as_default():
35 | a = tf.constant([1., 1.], shape=[2], name="a")
36 | with tf.name_scope("foo"):
37 | b = tf.constant([2., 2.], shape=[2], name="b")
38 | c = tf.add(a, b, name="c")
39 | d = tf.constant([3., 3.], shape=[2], name="d")
40 | with tf.name_scope("bar"):
41 | e = tf.add(c, d, name="e")
42 | f = tf.add(c, d, name="f")
43 | g = tf.add(c, a, name="g")
44 | with tf.control_dependencies([c.op]):
45 | h = tf.add(f, g, name="h")
46 | self.graph = gde.Graph(tf_graph)
47 | self.a = self.graph.get_tensor_by_name(a.name)
48 | self.b = self.graph.get_tensor_by_name(b.name)
49 | self.c = self.graph.get_tensor_by_name(c.name)
50 | self.d = self.graph.get_tensor_by_name(d.name)
51 | self.e = self.graph.get_tensor_by_name(e.name)
52 | self.f = self.graph.get_tensor_by_name(f.name)
53 | self.g = self.graph.get_tensor_by_name(g.name)
54 | self.h = self.graph.get_tensor_by_name(h.name)
55 |
56 | def test_subgraph(self):
57 | sgv = gde.sgv(self.graph)
58 | self.assertEqual(list(sgv.outputs), [self.e, self.h])
59 | self.assertEqual(list(sgv.inputs), [])
60 | self.assertEqual(len(sgv.ops), 8)
61 |
62 | sgv = gde.sgv(self.f.op, self.g.op)
63 | self.assertEqual(list(sgv.outputs), [self.f, self.g])
64 | self.assertEqual(list(sgv.inputs), [self.c, self.d, self.a])
65 |
66 | sgv = gde.sgv_scope("foo/bar", graph=self.graph)
67 | self.assertEqual(
68 | list(sgv.ops), [self.e.op, self.f.op, self.g.op, self.h.op])
69 |
70 | def test_subgraph_remap(self):
71 | sgv = gde.sgv(self.c.op)
72 | self.assertEqual(list(sgv.outputs), [self.c])
73 | self.assertEqual(list(sgv.inputs), [self.a, self.b])
74 |
75 | sgv = gde.sgv(self.c.op).remap([self.a], [0, self.c])
76 | self.assertEqual(list(sgv.outputs), [self.c, self.c])
77 | self.assertEqual(list(sgv.inputs), [self.a])
78 |
79 | sgv = sgv.remap_outputs_to_consumers()
80 | self.assertEqual(list(sgv.outputs), [self.c, self.c, self.c])
81 | sgv = sgv.remap_outputs_make_unique()
82 | self.assertEqual(list(sgv.outputs), [self.c])
83 |
84 | sgv = sgv.remap(new_input_indices=[], new_output_indices=[])
85 | self.assertEqual(len(sgv.inputs), 0)
86 | self.assertEqual(len(sgv.outputs), 0)
87 | sgv = sgv.remap_default()
88 | self.assertEqual(list(sgv.outputs), [self.c])
89 | self.assertEqual(list(sgv.inputs), [self.a, self.b])
90 |
91 | def test_remove_unused_ops(self):
92 | sgv = gde.sgv(self.graph)
93 | self.assertEqual(list(sgv.outputs), [self.e, self.h])
94 | self.assertEqual(len(sgv.ops), 8)
95 |
96 | sgv = sgv.remap_outputs(new_output_indices=[1]).remove_unused_ops()
97 | self.assertEqual(list(sgv.outputs), [self.h])
98 | self.assertEqual(len(sgv.ops), 7)
99 |
100 |
101 | if __name__ == "__main__":
102 | test.main()
103 |
--------------------------------------------------------------------------------
/examples/batch_size_example.py:
--------------------------------------------------------------------------------
1 | # Coypright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | """
18 | Example of using the GraphDef editor to adjust the batch size of a pretrained
19 | model.
20 |
21 | To run this example from the root of the project, type:
22 | PYTHONPATH=$PWD env/bin/python examples/batch_size_example.py
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import numpy as np
30 | import os
31 | import tensorflow as tf
32 | import graph_def_editor as gde
33 | import shutil
34 | import tarfile
35 | import textwrap
36 | import urllib.request
37 |
38 | FLAGS = tf.flags.FLAGS
39 |
40 |
41 | def _indent(s):
42 | return textwrap.indent(str(s), " ")
43 |
44 |
45 | _TMP_DIR = "/tmp/batch_size_example"
46 | _MODEL_URL = "http://download.tensorflow.org/models/official/20181001_resnet" \
47 | "/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz"
48 | _MODEL_TARBALL = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC.tar.gz"
49 | _SAVED_MODEL_DIR = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC/1538686978"
50 | _AFTER_MODEL_DIR = _TMP_DIR + "/rewritten_model"
51 |
52 |
53 | def main(_):
54 | # Grab a copy of the official TensorFlow ResNet50 model in fp16.
55 | # See https://github.com/tensorflow/models/tree/master/official/resnet
56 | # Cache the tarball so we don't download it repeatedly
57 | if not os.path.isdir(_SAVED_MODEL_DIR):
58 | if os.path.isdir(_TMP_DIR):
59 | shutil.rmtree(_TMP_DIR)
60 | os.mkdir(_TMP_DIR)
61 | print("Downloading model tarball from {}".format(_MODEL_URL))
62 | urllib.request.urlretrieve(_MODEL_URL, _MODEL_TARBALL)
63 | print("Unpacking SavedModel from {} to {}".format(_MODEL_TARBALL, _TMP_DIR))
64 | with tarfile.open(_MODEL_TARBALL) as t:
65 | t.extractall(_TMP_DIR)
66 |
67 | # Load the SavedModel
68 | tf_g = tf.Graph()
69 | with tf.Session(graph=tf_g) as sess:
70 | tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING],
71 | _SAVED_MODEL_DIR)
72 |
73 | # print("Graph is:\n{}".format(tf_g.as_graph_def()))
74 |
75 | # Print out some statistics about tensor shapes
76 | print("BEFORE:")
77 | print(" Input tensor is {}".format(tf_g.get_tensor_by_name(
78 | "input_tensor:0")))
79 | print(" Softmax tensor is {}".format(tf_g.get_tensor_by_name(
80 | "softmax_tensor:0")))
81 |
82 | # Convert the SavedModel to a gde.Graph and rewrite the batch size to None
83 | g = gde.saved_model_to_graph(_SAVED_MODEL_DIR)
84 | gde.rewrite.change_batch_size(g, new_size=None, inputs=[g["input_tensor"]])
85 | if os.path.exists(_AFTER_MODEL_DIR):
86 | shutil.rmtree(_AFTER_MODEL_DIR)
87 | g.to_saved_model(_AFTER_MODEL_DIR)
88 |
89 | # Load the rewritten SavedModel into a TensorFlow graph
90 | after_tf_g = tf.Graph()
91 | with tf.Session(graph=after_tf_g) as sess:
92 | tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING],
93 | _AFTER_MODEL_DIR)
94 | print("AFTER:")
95 | print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name(
96 | "input_tensor:0")))
97 | print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name(
98 | "softmax_tensor:0")))
99 |
100 | # Feed a single array of zeros through the graph
101 | print("Running inference on dummy data")
102 | result = sess.run("softmax_tensor:0",
103 | {"input_tensor:0": np.zeros([1, 224, 224, 3])})
104 | print("Result is {}".format(result))
105 |
106 |
107 | if __name__ == "__main__":
108 | tf.app.run()
109 |
--------------------------------------------------------------------------------
/tests/reroute_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Tests for tensorflow.contrib.graph_editor."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import numpy as np
22 | import tensorflow.compat.v1 as tf
23 | tf.disable_eager_execution()
24 |
25 | import unittest
26 |
27 | import graph_def_editor as gde
28 |
29 |
30 | class RerouteTest(unittest.TestCase):
31 |
32 | def setUp(self):
33 | tf_graph = tf.Graph()
34 | with tf_graph.as_default():
35 | a0 = tf.constant(1.0, shape=[2], name="a0")
36 | b0 = tf.constant(2.0, shape=[2], name="b0")
37 | _ = tf.add(a0, b0, name="c0")
38 | a1 = tf.constant(3.0, shape=[2], name="a1")
39 | b1 = tf.constant(4.0, shape=[2], name="b1")
40 | _ = tf.add(a1, b1, name="c1")
41 | a2 = tf.constant(3.0, shape=[3], name="a2")
42 | b2 = tf.constant(4.0, shape=[3], name="b2")
43 | _ = tf.add(a2, b2, name="c2")
44 |
45 | self.graph = gde.Graph(tf_graph)
46 | # Programmatically add all the tensors as fields of this object.
47 | for letter in ["a", "b", "c"]:
48 | for number in ["0", "1", "2"]:
49 | op_name = letter + number
50 | self.__dict__[op_name] = self.graph[op_name].output(0)
51 |
52 | def test_swap(self):
53 | gde.swap_ts([self.a0, self.b0], [self.a1, self.b1])
54 | self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
55 | self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
56 |
57 | def test_multiswap(self):
58 | # Original code:
59 | # with self.graph.as_default():
60 | # a3 = constant_op.constant(3.0, shape=[2], name="a3")
61 | # New code adds a NodeDef to the graph:
62 | a3_node = gde.make_const(self.graph, "a3", np.full([2], 3.0,
63 | dtype=np.float32))
64 |
65 | gde.swap_ios(gde.sgv(a3_node).remap_outputs([0, 0]),
66 | gde.sgv(self.a0.op, self.a1.op))
67 | self.assertTrue(gde.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
68 | self.assertTrue(gde.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
69 |
70 | def test_reroute(self):
71 | gde.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
72 | self.assertTrue(gde.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
73 | self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
74 |
75 | gde.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
76 | self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
77 | self.assertTrue(gde.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
78 |
79 | def test_compatibility(self):
80 | with self.assertRaises(ValueError):
81 | gde.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
82 |
83 | def test_reroute_can_modify(self):
84 | # create a special graph where "a" is an ambiguous tensor. That is
85 | # it is both an input and an output of the ops in sgv0.
86 | tf_graph = tf.Graph()
87 | with tf_graph.as_default():
88 | a_tensor = tf.constant(1.0, shape=[2], name="a")
89 | b_tensor = tf.constant(2.0, shape=[2], name="b")
90 | c_tensor = tf.add(a_tensor, b_tensor, name="c")
91 | _ = tf.add(a_tensor, c_tensor, name="d")
92 | e_tensor = tf.constant(1.0, shape=[2], name="e")
93 | f_tensor = tf.constant(2.0, shape=[2], name="f")
94 | _ = tf.add(e_tensor, f_tensor, name="g")
95 | g = gde.Graph(tf_graph)
96 |
97 | sgv0 = gde.sgv(g["a"], g["b"], g["c"])
98 | sgv1 = gde.sgv(g["e"], g["f"])
99 |
100 | gde.swap_outputs(sgv0, sgv1)
101 | self.assertTrue(
102 | gde.OpMatcher("g").input_ops(
103 | "a", gde.OpMatcher("c").input_ops("a", "b"))(g["g"]))
104 | self.assertTrue(gde.OpMatcher("d").input_ops("e", "f")(g["d"]))
105 |
106 |
107 | if __name__ == "__main__":
108 | unittest.main()
109 |
--------------------------------------------------------------------------------
/graph_def_editor/tensor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 IBM. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import tensorflow.compat.v1 as tf
17 | import sys
18 | if sys.version >= '3':
19 | from typing import AbstractSet
20 |
21 | __all__ = [
22 | "Tensor",
23 | ]
24 |
25 |
26 | class Tensor(object):
27 | """
28 | Surrogate object that represents an output of a Node. Corresponds roughly to
29 | a tf.Tensor in the TensorFlow Python API, though serialized TensorFlow graphs
30 | do not contain any separate objects that represent tensors.
31 | """
32 | def __init__(self,
33 | node,
34 | index,
35 | dtype, # type: tf.DType,
36 | shape # type: tf.shape
37 | ):
38 | """
39 | Args:
40 | node: gde.Node object that represents the graph node that produces this
41 | tensor
42 | index: Output index of this tensor among the outputs of the specified node
43 | dtype: Data type of the tensor
44 | shape: Shape of the tensor
45 | """
46 | self._node = node
47 | self._index = index
48 | self._dtype = dtype
49 | self._shape = shape
50 | self._collection_names = set() # Set[str]
51 |
52 | def __str__(self):
53 | return "Tensor '{}' (dtype {}, shape {})".format(self.name, self.dtype,
54 | self.shape)
55 |
56 | def __repr__(self):
57 | return str(self)
58 |
59 | @property
60 | def node(self):
61 | return self._node
62 |
63 | @property
64 | def op(self):
65 | """Alias for self.node, for compatibility with code written for
66 | tf.Tensor"""
67 | return self.node
68 |
69 | @property
70 | def value_index(self):
71 | """
72 | Emulates the behavior of `tf.Tensor.value_index`
73 |
74 | Returns the output index of this Tensor among the outputs of the parent
75 | Node."""
76 | return self._index
77 |
78 | @property
79 | def dtype(self):
80 | # type () -> tf.DType:
81 | return self._dtype
82 |
83 | @dtype.setter
84 | def dtype(self,
85 | value # type: tf.DType
86 | ):
87 | self._dtype = value
88 |
89 | @property
90 | def shape(self):
91 | # type () -> tf.TensorShape
92 | return self._shape
93 |
94 | @shape.setter
95 | def shape(self,
96 | value # type: tf.TensorShape
97 | ):
98 | self._shape = value
99 |
100 | @property
101 | def graph(self):
102 | """Returns the `gde.Graph` object representing the graph in which the
103 | node that produces this tensor resides."""
104 | return self._node.graph
105 |
106 | def consumers(self):
107 | """Returns the `gde.Node` objects representing the ops that consume the
108 | tensor that this object represents."""
109 | # TODO: Maintain a lookup table of graph edges.
110 | # For now we do linear search for correctness.
111 | ret = []
112 | for n in self.graph.nodes:
113 | if self in n.inputs:
114 | ret.append(n)
115 | return ret
116 |
117 | @property
118 | def name(self):
119 | """
120 | Emulates the behavior of `tf.Tensor.name`
121 |
122 | Returns:
123 | A TensorFlow-like tensor name string in the form ":