├── tests ├── __init__.py ├── match_test.py ├── edit_test.py ├── subgraph_test.py ├── reroute_test.py ├── function_graph_test.py ├── util_test.py ├── select_test.py └── transform_test.py ├── .gitignore ├── scripts ├── test.sh └── env.sh ├── graph_def_editor ├── visualization │ ├── __init__.py │ ├── graphviz_style.py │ ├── jupyter_helper.py │ └── graphviz_wrapper.py ├── __init__.py ├── tensor.py ├── match.py ├── edit.py ├── variable.py ├── function_graph.py ├── base_graph.py └── reroute.py ├── setup.py ├── examples ├── edit_graph_example.py ├── batch_size_example.py ├── mobilenet_example.py └── coco_example.py ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Dummy __init__.py to keep pytest happy.""" 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | env 4 | *.swp 5 | */__pycache__ 6 | *.pyc 7 | *.iml 8 | test.out 9 | example.out 10 | 11 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # test.sh 3 | # 4 | # Run regression tests for this project. 5 | # 6 | # Usage: 7 | # ./scripts/test.sh 8 | # 9 | 10 | #conda activate ./env 11 | ./env/bin/pytest --ignore=env | tee test.out 12 | #conda deactivate 13 | 14 | -------------------------------------------------------------------------------- /graph_def_editor/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Dummy __init__.py to keep pytest happy.""" 17 | 18 | -------------------------------------------------------------------------------- /graph_def_editor/visualization/graphviz_style.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Default GraphViz styles to use.""" 17 | 18 | 19 | graph_pref = { 20 | 'fontcolor': '#414141', 21 | 'style': 'rounded', 22 | } 23 | 24 | name_scope_graph_pref = { 25 | 'bgcolor': '#eeeeee', 26 | 'color': '#aaaaaa', 27 | 'penwidth': '2', 28 | } 29 | 30 | non_name_scope_graph_pref = { 31 | 'fillcolor': 'white', 32 | 'color': 'white', 33 | } 34 | 35 | node_pref = { 36 | 'style': 'filled', 37 | 'fillcolor': 'white', 38 | 'color': '#aaaaaa', 39 | 'penwidth': '2', 40 | 'fontcolor': '#414141', 41 | } 42 | 43 | edge_pref = { 44 | 'color': '#aaaaaa', 45 | 'arrowsize': '1.2', 46 | 'penwidth': '2.5', 47 | 'fontcolor': '#414141', 48 | } 49 | -------------------------------------------------------------------------------- /scripts/env.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | ################################################################################ 4 | # env.sh 5 | # 6 | # Set up an Anaconda virtualenv in the directory ./env 7 | # 8 | # Run this script from the root of the project, i.e. 9 | # ./scripts/env.sh 10 | # 11 | # Requires that conda be installed and set up for calling from bash scripts. 12 | # 13 | # Also requires that you set the environment variable CONDA_HOME to the 14 | # location of the root of your anaconda/miniconda distribution. 15 | ################################################################################ 16 | 17 | PYTHON_VERSION=3.6 18 | 19 | ############################ 20 | # HACK ALERT *** HACK ALERT 21 | # The friendly folks at Anaconda thought it would be a good idea to make the 22 | # "conda" command a shell function. 23 | # See https://github.com/conda/conda/issues/7126 24 | # The following workaround will probably be fragile. 25 | if [ -z "$CONDA_HOME" ] 26 | then 27 | echo "Error: CONDA_HOME not set" 28 | exit 29 | fi 30 | . ${CONDA_HOME}/etc/profile.d/conda.sh 31 | # END HACK 32 | ############################ 33 | 34 | ################################################################################ 35 | # Remove any previous outputs of this script 36 | 37 | rm -rf ./env 38 | 39 | 40 | ################################################################################ 41 | # Create the environment 42 | conda create -y --prefix ./env \ 43 | python=${PYTHON_VERSION} \ 44 | numpy \ 45 | tensorflow \ 46 | jupyterlab \ 47 | pytest \ 48 | keras \ 49 | pillow \ 50 | nomkl 51 | 52 | echo << EOM 53 | Anaconda virtualenv installed in ./env. 54 | Run \"conda activate ./env\" to use it. 55 | EOM 56 | 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | from setuptools import find_packages, setup 18 | 19 | with open('README.md') as f: 20 | long_description = f.read() 21 | 22 | setup( 23 | name='graph_def_editor', 24 | version='0.1.0', 25 | description=('TensorFlow Graph Def Editor'), 26 | long_description=long_description, 27 | long_description_content_type='text/markdown', 28 | packages=find_packages(), 29 | install_requires=['numpy', 'tensorflow', 'six'], 30 | include_package_data=True, 31 | zip_safe=False, 32 | classifiers=[ 33 | 'License :: OSI Approved :: Apache Software License', 34 | 'Programming Language :: Python :: 2.7', 35 | 'Programming Language :: Python :: 3.5', 36 | 'Programming Language :: Python :: 3.6', 37 | 'Programming Language :: Python :: 3.7', 38 | 'Intended Audience :: Developers', 39 | 'Intended Audience :: Education', 40 | 'Intended Audience :: Science/Research', 41 | 'Topic :: Scientific/Engineering :: Mathematics', 42 | 'Topic :: Software Development :: Libraries :: Python Modules', 43 | 'Topic :: Software Development :: Libraries', 44 | ], 45 | license='Apache License, Version 2.0', 46 | maintainer='Graph Def Developers', 47 | maintainer_email='', 48 | ) 49 | -------------------------------------------------------------------------------- /graph_def_editor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """GDE: A GraphDef Editor for TensorFlow 17 | 18 | A version of the old [`contrib.graph_editor`](https://github.com/tensorflow/tensorflow/tree/r1.12/tensorflow/contrib/graph_editor) API that operates over serialized TensorFlow graphs represented as GraphDef protocol buffer messages. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | # pylint: disable=wildcard-import 26 | from graph_def_editor.base_graph import * 27 | from graph_def_editor.edit import * 28 | from graph_def_editor.graph import * 29 | from graph_def_editor.match import * 30 | from graph_def_editor.node import * 31 | from graph_def_editor.reroute import * 32 | from graph_def_editor.select import * 33 | from graph_def_editor.subgraph import * 34 | from graph_def_editor.tensor import * 35 | from graph_def_editor.transform import * 36 | from graph_def_editor.util import * 37 | from graph_def_editor.variable import * 38 | # pylint: enable=wildcard-import 39 | 40 | # Other parts go under sub-packages 41 | from graph_def_editor import rewrite 42 | 43 | # some useful aliases 44 | # pylint: disable=g-bad-import-order 45 | from graph_def_editor import subgraph as _subgraph 46 | from graph_def_editor import util as _util 47 | # pylint: enable=g-bad-import-order 48 | ph = _util.make_placeholder_from_dtype_and_shape 49 | sgv = _subgraph.make_view 50 | sgv_scope = _subgraph.make_view_from_scope 51 | 52 | del absolute_import 53 | del division 54 | del print_function 55 | -------------------------------------------------------------------------------- /examples/edit_graph_example.py: -------------------------------------------------------------------------------- 1 | # Coypright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Simple example of the GraphDef Editor. 18 | 19 | To run this example from the root of the project, type: 20 | PYTHONPATH=$PWD env/bin/python examples/edit_graph_example.py 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | import graph_def_editor as gde 30 | import textwrap 31 | 32 | FLAGS = tf.flags.FLAGS 33 | 34 | 35 | def _indent(s): 36 | return textwrap.indent(str(s), " ") 37 | 38 | 39 | def main(_): 40 | # Create a graph 41 | tf_g = tf.Graph() 42 | with tf_g.as_default(): 43 | a = tf.constant(1.0, shape=[2, 3], name="a") 44 | c = tf.add( 45 | tf.placeholder(dtype=np.float32), 46 | tf.placeholder(dtype=np.float32), 47 | name="c") 48 | 49 | # Serialize the graph 50 | g = gde.Graph(tf_g.as_graph_def()) 51 | print("Before:\n{}".format(_indent(g.to_graph_def()))) 52 | 53 | # Modify the graph. 54 | # In this case we replace the two input placeholders with constants. 55 | # One of the constants (a) is a node that was in the original graph. 56 | # The other one (b) we create here. 57 | b = gde.make_const(g, "b", np.full([2, 3], 2.0, dtype=np.float32)) 58 | gde.swap_inputs(g[c.op.name], [g[a.name], b.output(0)]) 59 | 60 | print("After:\n{}".format(_indent(g.to_graph_def()))) 61 | 62 | # Reconstitute the modified serialized graph as TensorFlow graph... 63 | with g.to_tf_graph().as_default(): 64 | # ...and print the value of c, which should be 2x3 matrix of 3.0's 65 | with tf.Session() as sess: 66 | res = sess.run(c.name) 67 | print("Result is:\n{}".format(_indent(res))) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.app.run() 72 | -------------------------------------------------------------------------------- /tests/match_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | tf.disable_eager_execution() 23 | 24 | import unittest 25 | 26 | import graph_def_editor as gde 27 | 28 | 29 | class MatchTest(unittest.TestCase): 30 | 31 | def setUp(self): 32 | tf_graph = tf.Graph() 33 | with tf_graph.as_default(): 34 | a = tf.constant([1., 1.], shape=[2], name="a") 35 | with tf.name_scope("foo"): 36 | b = tf.constant([2., 2.], shape=[2], name="b") 37 | c = tf.add(a, b, name="c") 38 | d = tf.constant([3., 3.], shape=[2], name="d") 39 | with tf.name_scope("bar"): 40 | _ = tf.add(c, d, name="e") 41 | f = tf.add(c, d, name="f") 42 | g = tf.add(c, a, name="g") 43 | with tf.control_dependencies([c.op]): 44 | _ = tf.add(f, g, name="h") 45 | 46 | self.graph = gde.Graph(tf_graph) 47 | self.f_op = self.graph[f.op.name] 48 | 49 | def test_simple_match(self): 50 | self.assertTrue(gde.OpMatcher("^.*/f$")(self.f_op)) 51 | self.assertTrue( 52 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f_op)) 53 | self.assertTrue( 54 | gde.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f_op)) 55 | self.assertTrue( 56 | gde.OpMatcher("^.*/f$").input_ops( 57 | gde.op_type("Add"), gde.op_type("Const"))(self.f_op) or 58 | gde.OpMatcher("^.*/f$").input_ops( 59 | gde.op_type("AddV2"), gde.op_type("Const"))(self.f_op)) 60 | self.assertTrue( 61 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") 62 | .output_ops(gde.OpMatcher("^.*/h$") 63 | .control_input_ops("^.*/c$"))(self.f_op)) 64 | self.assertTrue( 65 | gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( 66 | gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$") 67 | .output_ops([]))(self.f_op)) 68 | 69 | 70 | if __name__ == "__main__": 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /graph_def_editor/visualization/jupyter_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Helper methods to display gde graph in Jupyter Notebook or Colab.""" 17 | 18 | import time 19 | 20 | def _import_ipython(): 21 | try: 22 | from IPython.display import HTML 23 | except ModuleNotFoundError as error: 24 | raise ModuleNotFoundError( 25 | "You need to install ipython or Jupyter to be able to use this functionality. " 26 | "See https://ipython.org/install.html or https://jupyter.org/install for details.") 27 | 28 | 29 | def jupyter_show_as_svg(dg): 30 | """Shows object as SVG (by default it is rendered as image). 31 | 32 | Args: 33 | dg: digraph object 34 | 35 | Returns: 36 | Graph rendered in SVG format 37 | """ 38 | _import_ipython() 39 | return HTML(dg.pipe(format="svg").decode("utf-8")) 40 | 41 | 42 | def jupyter_pan_and_zoom( 43 | dg, 44 | element_styles="height:auto", 45 | container_styles="overflow:hidden", 46 | pan_zoom_json="{controlIconsEnabled: true, zoomScaleSensitivity: 0.4, " 47 | "minZoom: 0.2}"): 48 | """Embeds SVG object into Jupyter cell with ability to pan and zoom. 49 | 50 | Args: 51 | dg: digraph object 52 | element_styles: CSS styles for embedded SVG element. 53 | container_styles: CSS styles for container div element. 54 | pan_zoom_json: pan and zoom settings, see 55 | https://github.com/bumbu/svg-pan-zoom 56 | 57 | Returns: 58 | Graph rendered as HTML using javascript for Pan and Zoom functionality. 59 | """ 60 | svg_txt = dg.pipe(format="svg").decode("utf-8") 61 | html_container_class_name = F"svg_container_{int(time.time())}" 62 | html = F""" 63 |
64 | 72 | 73 | 89 | {svg_txt} 90 |
91 | """ 92 | _import_ipython() 93 | return HTML(html) 94 | -------------------------------------------------------------------------------- /tests/edit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for gde.edit""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow.compat.v1 as tf 23 | tf.disable_eager_execution() 24 | import unittest 25 | 26 | import graph_def_editor as gde 27 | 28 | 29 | class EditTest(unittest.TestCase): 30 | """edit module test. 31 | 32 | Generally the tests are in two steps: 33 | - modify an existing graph. 34 | - then make sure it has the expected topology using the graph matcher. 35 | """ 36 | 37 | # TODO(frreiss): Merge duplicate setup code across test cases 38 | def setUp(self): 39 | tf_graph = tf.Graph() 40 | with tf_graph.as_default(): 41 | a = tf.constant([1., 1.], shape=[2], name="a") 42 | with tf.name_scope("foo"): 43 | b = tf.constant([2., 2.], shape=[2], name="b") 44 | c = tf.add(a, b, name="c") 45 | d = tf.constant([3., 3.], shape=[2], name="d") 46 | with tf.name_scope("bar"): 47 | e = tf.add(c, d, name="e") 48 | f = tf.add(c, d, name="f") 49 | g = tf.add(c, a, name="g") 50 | with tf.control_dependencies([c.op]): 51 | h = tf.add(f, g, name="h") 52 | self.graph = gde.Graph(tf_graph) 53 | self.a = self.graph.get_tensor_by_name(a.name) 54 | self.b = self.graph.get_tensor_by_name(b.name) 55 | self.c = self.graph.get_tensor_by_name(c.name) 56 | self.d = self.graph.get_tensor_by_name(d.name) 57 | self.e = self.graph.get_tensor_by_name(e.name) 58 | self.f = self.graph.get_tensor_by_name(f.name) 59 | self.g = self.graph.get_tensor_by_name(g.name) 60 | self.h = self.graph.get_tensor_by_name(h.name) 61 | 62 | def test_detach(self): 63 | """Test for ge.detach.""" 64 | sgv = gde.sgv(self.c.op, self.a.op) 65 | control_outputs = gde.ControlOutputs(self.graph) 66 | gde.detach(sgv, control_ios=control_outputs) 67 | # make sure the detached graph is as expected. 68 | self.assertTrue( 69 | gde.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) 70 | 71 | def test_connect(self): 72 | """Test for gde.connect.""" 73 | # Original code: 74 | # with self.graph.as_default(): 75 | # x = constant_op.constant([1., 1.], shape=[2], name="x") 76 | # y = constant_op.constant([2., 2.], shape=[2], name="y") 77 | # z = math_ops.add(x, y, name="z") 78 | x = gde.make_const(self.graph, "x", np.array([1., 1.], dtype=np.float32)) 79 | y = gde.make_const(self.graph, "y", np.array([2., 2.], dtype=np.float32)) 80 | z = self.graph.add_node("z", "Add") 81 | z.add_attr("T", tf.float32) 82 | z.set_inputs([x.outputs[0], y.outputs[0]]) 83 | z.infer_outputs() 84 | 85 | sgv = gde.sgv(x, y, z) 86 | gde.connect(sgv, gde.sgv(self.e.op).remap_inputs([0])) 87 | self.assertTrue( 88 | gde.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) 89 | 90 | def test_bypass(self): 91 | """Test for ge.bypass.""" 92 | gde.bypass(gde.sgv(self.f.op).remap_inputs([0])) 93 | self.assertTrue( 94 | gde.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( 95 | self.h.op)) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /tests/subgraph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | tf.disable_eager_execution() 23 | 24 | import unittest 25 | 26 | import graph_def_editor as gde 27 | 28 | 29 | class SubgraphTest(unittest.TestCase): 30 | 31 | # TODO(frreiss): Merge duplicate setup code across test cases 32 | def setUp(self): 33 | tf_graph = tf.Graph() 34 | with tf_graph.as_default(): 35 | a = tf.constant([1., 1.], shape=[2], name="a") 36 | with tf.name_scope("foo"): 37 | b = tf.constant([2., 2.], shape=[2], name="b") 38 | c = tf.add(a, b, name="c") 39 | d = tf.constant([3., 3.], shape=[2], name="d") 40 | with tf.name_scope("bar"): 41 | e = tf.add(c, d, name="e") 42 | f = tf.add(c, d, name="f") 43 | g = tf.add(c, a, name="g") 44 | with tf.control_dependencies([c.op]): 45 | h = tf.add(f, g, name="h") 46 | self.graph = gde.Graph(tf_graph) 47 | self.a = self.graph.get_tensor_by_name(a.name) 48 | self.b = self.graph.get_tensor_by_name(b.name) 49 | self.c = self.graph.get_tensor_by_name(c.name) 50 | self.d = self.graph.get_tensor_by_name(d.name) 51 | self.e = self.graph.get_tensor_by_name(e.name) 52 | self.f = self.graph.get_tensor_by_name(f.name) 53 | self.g = self.graph.get_tensor_by_name(g.name) 54 | self.h = self.graph.get_tensor_by_name(h.name) 55 | 56 | def test_subgraph(self): 57 | sgv = gde.sgv(self.graph) 58 | self.assertEqual(list(sgv.outputs), [self.e, self.h]) 59 | self.assertEqual(list(sgv.inputs), []) 60 | self.assertEqual(len(sgv.ops), 8) 61 | 62 | sgv = gde.sgv(self.f.op, self.g.op) 63 | self.assertEqual(list(sgv.outputs), [self.f, self.g]) 64 | self.assertEqual(list(sgv.inputs), [self.c, self.d, self.a]) 65 | 66 | sgv = gde.sgv_scope("foo/bar", graph=self.graph) 67 | self.assertEqual( 68 | list(sgv.ops), [self.e.op, self.f.op, self.g.op, self.h.op]) 69 | 70 | def test_subgraph_remap(self): 71 | sgv = gde.sgv(self.c.op) 72 | self.assertEqual(list(sgv.outputs), [self.c]) 73 | self.assertEqual(list(sgv.inputs), [self.a, self.b]) 74 | 75 | sgv = gde.sgv(self.c.op).remap([self.a], [0, self.c]) 76 | self.assertEqual(list(sgv.outputs), [self.c, self.c]) 77 | self.assertEqual(list(sgv.inputs), [self.a]) 78 | 79 | sgv = sgv.remap_outputs_to_consumers() 80 | self.assertEqual(list(sgv.outputs), [self.c, self.c, self.c]) 81 | sgv = sgv.remap_outputs_make_unique() 82 | self.assertEqual(list(sgv.outputs), [self.c]) 83 | 84 | sgv = sgv.remap(new_input_indices=[], new_output_indices=[]) 85 | self.assertEqual(len(sgv.inputs), 0) 86 | self.assertEqual(len(sgv.outputs), 0) 87 | sgv = sgv.remap_default() 88 | self.assertEqual(list(sgv.outputs), [self.c]) 89 | self.assertEqual(list(sgv.inputs), [self.a, self.b]) 90 | 91 | def test_remove_unused_ops(self): 92 | sgv = gde.sgv(self.graph) 93 | self.assertEqual(list(sgv.outputs), [self.e, self.h]) 94 | self.assertEqual(len(sgv.ops), 8) 95 | 96 | sgv = sgv.remap_outputs(new_output_indices=[1]).remove_unused_ops() 97 | self.assertEqual(list(sgv.outputs), [self.h]) 98 | self.assertEqual(len(sgv.ops), 7) 99 | 100 | 101 | if __name__ == "__main__": 102 | test.main() 103 | -------------------------------------------------------------------------------- /examples/batch_size_example.py: -------------------------------------------------------------------------------- 1 | # Coypright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """ 18 | Example of using the GraphDef editor to adjust the batch size of a pretrained 19 | model. 20 | 21 | To run this example from the root of the project, type: 22 | PYTHONPATH=$PWD env/bin/python examples/batch_size_example.py 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import numpy as np 30 | import os 31 | import tensorflow as tf 32 | import graph_def_editor as gde 33 | import shutil 34 | import tarfile 35 | import textwrap 36 | import urllib.request 37 | 38 | FLAGS = tf.flags.FLAGS 39 | 40 | 41 | def _indent(s): 42 | return textwrap.indent(str(s), " ") 43 | 44 | 45 | _TMP_DIR = "/tmp/batch_size_example" 46 | _MODEL_URL = "http://download.tensorflow.org/models/official/20181001_resnet" \ 47 | "/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz" 48 | _MODEL_TARBALL = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC.tar.gz" 49 | _SAVED_MODEL_DIR = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC/1538686978" 50 | _AFTER_MODEL_DIR = _TMP_DIR + "/rewritten_model" 51 | 52 | 53 | def main(_): 54 | # Grab a copy of the official TensorFlow ResNet50 model in fp16. 55 | # See https://github.com/tensorflow/models/tree/master/official/resnet 56 | # Cache the tarball so we don't download it repeatedly 57 | if not os.path.isdir(_SAVED_MODEL_DIR): 58 | if os.path.isdir(_TMP_DIR): 59 | shutil.rmtree(_TMP_DIR) 60 | os.mkdir(_TMP_DIR) 61 | print("Downloading model tarball from {}".format(_MODEL_URL)) 62 | urllib.request.urlretrieve(_MODEL_URL, _MODEL_TARBALL) 63 | print("Unpacking SavedModel from {} to {}".format(_MODEL_TARBALL, _TMP_DIR)) 64 | with tarfile.open(_MODEL_TARBALL) as t: 65 | t.extractall(_TMP_DIR) 66 | 67 | # Load the SavedModel 68 | tf_g = tf.Graph() 69 | with tf.Session(graph=tf_g) as sess: 70 | tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], 71 | _SAVED_MODEL_DIR) 72 | 73 | # print("Graph is:\n{}".format(tf_g.as_graph_def())) 74 | 75 | # Print out some statistics about tensor shapes 76 | print("BEFORE:") 77 | print(" Input tensor is {}".format(tf_g.get_tensor_by_name( 78 | "input_tensor:0"))) 79 | print(" Softmax tensor is {}".format(tf_g.get_tensor_by_name( 80 | "softmax_tensor:0"))) 81 | 82 | # Convert the SavedModel to a gde.Graph and rewrite the batch size to None 83 | g = gde.saved_model_to_graph(_SAVED_MODEL_DIR) 84 | gde.rewrite.change_batch_size(g, new_size=None, inputs=[g["input_tensor"]]) 85 | if os.path.exists(_AFTER_MODEL_DIR): 86 | shutil.rmtree(_AFTER_MODEL_DIR) 87 | g.to_saved_model(_AFTER_MODEL_DIR) 88 | 89 | # Load the rewritten SavedModel into a TensorFlow graph 90 | after_tf_g = tf.Graph() 91 | with tf.Session(graph=after_tf_g) as sess: 92 | tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], 93 | _AFTER_MODEL_DIR) 94 | print("AFTER:") 95 | print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name( 96 | "input_tensor:0"))) 97 | print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name( 98 | "softmax_tensor:0"))) 99 | 100 | # Feed a single array of zeros through the graph 101 | print("Running inference on dummy data") 102 | result = sess.run("softmax_tensor:0", 103 | {"input_tensor:0": np.zeros([1, 224, 224, 3])}) 104 | print("Result is {}".format(result)) 105 | 106 | 107 | if __name__ == "__main__": 108 | tf.app.run() 109 | -------------------------------------------------------------------------------- /tests/reroute_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow.compat.v1 as tf 23 | tf.disable_eager_execution() 24 | 25 | import unittest 26 | 27 | import graph_def_editor as gde 28 | 29 | 30 | class RerouteTest(unittest.TestCase): 31 | 32 | def setUp(self): 33 | tf_graph = tf.Graph() 34 | with tf_graph.as_default(): 35 | a0 = tf.constant(1.0, shape=[2], name="a0") 36 | b0 = tf.constant(2.0, shape=[2], name="b0") 37 | _ = tf.add(a0, b0, name="c0") 38 | a1 = tf.constant(3.0, shape=[2], name="a1") 39 | b1 = tf.constant(4.0, shape=[2], name="b1") 40 | _ = tf.add(a1, b1, name="c1") 41 | a2 = tf.constant(3.0, shape=[3], name="a2") 42 | b2 = tf.constant(4.0, shape=[3], name="b2") 43 | _ = tf.add(a2, b2, name="c2") 44 | 45 | self.graph = gde.Graph(tf_graph) 46 | # Programmatically add all the tensors as fields of this object. 47 | for letter in ["a", "b", "c"]: 48 | for number in ["0", "1", "2"]: 49 | op_name = letter + number 50 | self.__dict__[op_name] = self.graph[op_name].output(0) 51 | 52 | def test_swap(self): 53 | gde.swap_ts([self.a0, self.b0], [self.a1, self.b1]) 54 | self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) 55 | self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) 56 | 57 | def test_multiswap(self): 58 | # Original code: 59 | # with self.graph.as_default(): 60 | # a3 = constant_op.constant(3.0, shape=[2], name="a3") 61 | # New code adds a NodeDef to the graph: 62 | a3_node = gde.make_const(self.graph, "a3", np.full([2], 3.0, 63 | dtype=np.float32)) 64 | 65 | gde.swap_ios(gde.sgv(a3_node).remap_outputs([0, 0]), 66 | gde.sgv(self.a0.op, self.a1.op)) 67 | self.assertTrue(gde.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) 68 | self.assertTrue(gde.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) 69 | 70 | def test_reroute(self): 71 | gde.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) 72 | self.assertTrue(gde.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) 73 | self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) 74 | 75 | gde.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) 76 | self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) 77 | self.assertTrue(gde.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) 78 | 79 | def test_compatibility(self): 80 | with self.assertRaises(ValueError): 81 | gde.reroute_ts([self.a0, self.b0], [self.a2, self.b2]) 82 | 83 | def test_reroute_can_modify(self): 84 | # create a special graph where "a" is an ambiguous tensor. That is 85 | # it is both an input and an output of the ops in sgv0. 86 | tf_graph = tf.Graph() 87 | with tf_graph.as_default(): 88 | a_tensor = tf.constant(1.0, shape=[2], name="a") 89 | b_tensor = tf.constant(2.0, shape=[2], name="b") 90 | c_tensor = tf.add(a_tensor, b_tensor, name="c") 91 | _ = tf.add(a_tensor, c_tensor, name="d") 92 | e_tensor = tf.constant(1.0, shape=[2], name="e") 93 | f_tensor = tf.constant(2.0, shape=[2], name="f") 94 | _ = tf.add(e_tensor, f_tensor, name="g") 95 | g = gde.Graph(tf_graph) 96 | 97 | sgv0 = gde.sgv(g["a"], g["b"], g["c"]) 98 | sgv1 = gde.sgv(g["e"], g["f"]) 99 | 100 | gde.swap_outputs(sgv0, sgv1) 101 | self.assertTrue( 102 | gde.OpMatcher("g").input_ops( 103 | "a", gde.OpMatcher("c").input_ops("a", "b"))(g["g"])) 104 | self.assertTrue(gde.OpMatcher("d").input_ops("e", "f")(g["d"])) 105 | 106 | 107 | if __name__ == "__main__": 108 | unittest.main() 109 | -------------------------------------------------------------------------------- /graph_def_editor/tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import tensorflow.compat.v1 as tf 17 | import sys 18 | if sys.version >= '3': 19 | from typing import AbstractSet 20 | 21 | __all__ = [ 22 | "Tensor", 23 | ] 24 | 25 | 26 | class Tensor(object): 27 | """ 28 | Surrogate object that represents an output of a Node. Corresponds roughly to 29 | a tf.Tensor in the TensorFlow Python API, though serialized TensorFlow graphs 30 | do not contain any separate objects that represent tensors. 31 | """ 32 | def __init__(self, 33 | node, 34 | index, 35 | dtype, # type: tf.DType, 36 | shape # type: tf.shape 37 | ): 38 | """ 39 | Args: 40 | node: gde.Node object that represents the graph node that produces this 41 | tensor 42 | index: Output index of this tensor among the outputs of the specified node 43 | dtype: Data type of the tensor 44 | shape: Shape of the tensor 45 | """ 46 | self._node = node 47 | self._index = index 48 | self._dtype = dtype 49 | self._shape = shape 50 | self._collection_names = set() # Set[str] 51 | 52 | def __str__(self): 53 | return "Tensor '{}' (dtype {}, shape {})".format(self.name, self.dtype, 54 | self.shape) 55 | 56 | def __repr__(self): 57 | return str(self) 58 | 59 | @property 60 | def node(self): 61 | return self._node 62 | 63 | @property 64 | def op(self): 65 | """Alias for self.node, for compatibility with code written for 66 | tf.Tensor""" 67 | return self.node 68 | 69 | @property 70 | def value_index(self): 71 | """ 72 | Emulates the behavior of `tf.Tensor.value_index` 73 | 74 | Returns the output index of this Tensor among the outputs of the parent 75 | Node.""" 76 | return self._index 77 | 78 | @property 79 | def dtype(self): 80 | # type () -> tf.DType: 81 | return self._dtype 82 | 83 | @dtype.setter 84 | def dtype(self, 85 | value # type: tf.DType 86 | ): 87 | self._dtype = value 88 | 89 | @property 90 | def shape(self): 91 | # type () -> tf.TensorShape 92 | return self._shape 93 | 94 | @shape.setter 95 | def shape(self, 96 | value # type: tf.TensorShape 97 | ): 98 | self._shape = value 99 | 100 | @property 101 | def graph(self): 102 | """Returns the `gde.Graph` object representing the graph in which the 103 | node that produces this tensor resides.""" 104 | return self._node.graph 105 | 106 | def consumers(self): 107 | """Returns the `gde.Node` objects representing the ops that consume the 108 | tensor that this object represents.""" 109 | # TODO: Maintain a lookup table of graph edges. 110 | # For now we do linear search for correctness. 111 | ret = [] 112 | for n in self.graph.nodes: 113 | if self in n.inputs: 114 | ret.append(n) 115 | return ret 116 | 117 | @property 118 | def name(self): 119 | """ 120 | Emulates the behavior of `tf.Tensor.name` 121 | 122 | Returns: 123 | A TensorFlow-like tensor name string in the form ":" 124 | """ 125 | return "{}:{}".format(self.node.name, self.value_index) 126 | 127 | @property 128 | def collection_names(self): 129 | # type -> AbstractSet[str] 130 | """ 131 | Returns the names of all collections this tensor is a member of in the 132 | parent graph. 133 | """ 134 | return frozenset(self._collection_names) 135 | 136 | def add_to_collection(self, 137 | collection_name # type: str 138 | ): 139 | """ 140 | Add the tensor to the indicated collection. 141 | """ 142 | if collection_name not in self._collection_names: 143 | self._collection_names.add(collection_name) 144 | # Invalidate any information the parent graph may have cached about 145 | # collections. 146 | self.node._graph.increment_version_counter() 147 | 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphDef Editor 2 | 3 | ### A port of the TensorFlow `contrib.graph_editor` package that operates over serialized graphs 4 | 5 | TensorFlow versions prior to version 2.0 had a Python graph editor in 6 | `contrib.graph_editor`. This functionality is slated to be removed in 7 | TensorFlow 2.0, along with the rest of the `contrib` package (see the 8 | [RFC](https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md)). 9 | This project brings back the graph editor as a standalone Python package. 10 | 11 | The original graph editor operated over TensorFlow's Python classes `Graph`, 12 | `Variable`, `Operator`, etc., often poking into the internals of these classes. 13 | As a result of this design, the graph editor needed to be updated whenever the 14 | underlying classes changed. 15 | 16 | The GraphDef Editor operates over *serialized* TensorFlow graphs represented as 17 | `GraphDef` protocol buffer messages. Although TensorFlow's serialization format 18 | is technically not a public API, there is public 19 | [documentation](https://www.tensorflow.org/guide/extend/model_files) 20 | for its structure, and the format changes much less frequently than the Python 21 | classes that the original graph editor depended on. TensorFlow's C++ 22 | [Graph Transform Tool](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md) 23 | also operates over serialized graphs. 24 | 25 | Example usage: 26 | 27 | ```python 28 | import numpy as np 29 | import tensorflow as tf 30 | import graph_def_editor as gde 31 | # Create a graph 32 | tf_g = tf.Graph() 33 | with tf_g.as_default(): 34 | a = tf.constant(1.0, shape=[2, 3], name="a") 35 | c = tf.add( 36 | tf.placeholder(dtype=np.float32), 37 | tf.placeholder(dtype=np.float32), 38 | name="c") 39 | 40 | # Serialize the graph 41 | g = gde.Graph(tf_g.as_graph_def()) 42 | 43 | # Modify the graph. 44 | # In this case we replace the two input placeholders with constants. 45 | # One of the constants (a) is a node that was in the original graph. 46 | # The other one (b) we create here. 47 | b = gde.make_const(g, "b", np.full([2, 3], 2.0, dtype=np.float32)) 48 | gde.swap_inputs(g[c.op.name], [g[a.name], b.output(0)]) 49 | 50 | # Reconstitute the modified serialized graph as TensorFlow graph 51 | with g.to_tf_graph().as_default(): 52 | 53 | # Run a session using the modified graph and print the value of c 54 | with tf.Session() as sess: 55 | res = sess.run(c.name) 56 | print("Result is:\n{}".format(res)) 57 | ``` 58 | 59 | ``` 60 | Result is: 61 | [[3. 3. 3.] 62 | [3. 3. 3.]] 63 | ``` 64 | 65 | ## Project status 66 | 67 | **This project is a work in progress.** 68 | 69 | Current status: 70 | 71 | * All of the original project's regression tests pass. We have added 20 72 | additional regression tests to cover new functionality. 73 | * We have added new features to support graph rewrites, including structural 74 | pattern matching and fixed-point graph modification. 75 | * We have implemented several new graph rewrites. 76 | * The simple example script from the original project runs. We have also added 77 | new examples of new functionality; see the `examples` directory. 78 | 79 | ## Contents of root directory: 80 | 81 | * `LICENSE`: This project is released under an Apache v2 license 82 | * `env`: Not in git repo; create by running `scripts/env.sh`. Anaconda virtualenv 83 | for running tests and notebooks in this project. 84 | * `examples`: Example scripts. To run these scripts from the root directory 85 | of this project, first run `scripts/env.sh` to create an Anaconda 86 | environment, then use the command 87 | ``` 88 | PYTHONPATH=$PWD env/bin/python examples/script_name.py 89 | ``` 90 | where `script_name.py` is the name of the example script. 91 | * `notebooks`: Jupyter notebooks. 92 | * `graph_def_editor`: Source code for the Python package 93 | * `scripts`: Useful shell scripts for development. 94 | * `setup.py`: Setup script to make this project pip-installable with 95 | [`setuptools`](https://setuptools.readthedocs.io/en/latest/) 96 | * `tests`: pytest tests. To run these tests, create `env` and run 97 | `scripts/test.sh` 98 | 99 | ## Pip install instructions 100 | 101 | We have not yet posted a binary release of this library, but you can `pip 102 | install` this project directly from the source tree. We recommend using a 103 | virtualenv or an Anaconda environment for this purpose. 104 | Here is an example series of shell commands to create an Anaconda environment 105 | and `pip` install this project from source: 106 | 107 | ``` 108 | $ conda create -y --prefix ./myenv python=3.6 numpy tensorflow 109 | $ conda activate ./myenv 110 | $ git clone https://github.com/CODAIT/graph_def_editor.git 111 | $ pip install ./graph_def_editor 112 | ``` 113 | 114 | 115 | ## IDE setup instructions 116 | 117 | 1. Install IntelliJ and the community Python plugin. 118 | 2. Run the script `scripts/env.sh` to create an Anaconda enviroment under `env`. 119 | 3. Import the root directory of this repository as a new project. 120 | Use the Anaconda environment at `env/bin/python` as the Python for 121 | the project. 122 | 4. In the "Project" view of IntelliJ, right-click on `env` and select 123 | `Mark directory as ==> Excluded`. `env` shoud turn red. 124 | 5. Configure your editor to use 2 spaces for indents. Disable the PEP8 warnings 125 | in IntelliJ about indents not being a multiple of 4. 126 | 6. To run tests from within IntelliJ, open up the `Terminal` pane and type 127 | `./scripts/test.sh`. The outputs of the test run will be teed to the file 128 | `test.out` at the root of the project. 129 | 130 | 131 | ## TensorFlow versions compatibility 132 | 133 | GraphDef Editor is fully supported for TensorFlow versions 1.14.x and 1.15.x. 134 | For TensorFlow 2.x some transforms might not work. 135 | 136 | To execute tests for a specific TensorFlow version run the following command from the repository root: 137 | ```sh 138 | docker run -v ${PWD}:/v -w /v tensorflow/tensorflow:[-py3] \ 139 | bash -c "pip3 install -U pytest && pytest" 140 | ``` 141 | 142 | Pre 2.2.0 TensorFlow versions have -py3 suffix indicating that Python3 should be used. 143 | 144 | To execute a specific test: 145 | ```sh 146 | docker run -v ${PWD}:/v -w /v tensorflow/tensorflow:[-py3] python -m tests.transform_test 147 | ``` 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /graph_def_editor/match.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Simple graph matching functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from six import string_types 23 | 24 | from graph_def_editor import node, select 25 | 26 | __all__ = [ 27 | "op_type", 28 | "OpMatcher", 29 | ] 30 | 31 | 32 | def _make_graph_match(graph_match): 33 | """Convert to a OpMatcher instance.""" 34 | if graph_match is None: 35 | return None 36 | if not isinstance(graph_match, OpMatcher): 37 | graph_match = OpMatcher(graph_match) 38 | return graph_match 39 | 40 | 41 | def op_type(op_types, op=None): 42 | """Check if an op is of the given type. 43 | 44 | Args: 45 | op_types: tuple of strings containing the types to check against. 46 | For instance: ("Add", "Const") 47 | op: the operation to check (or None). 48 | Returns: 49 | if op is not None, return True if the op is of the correct type. 50 | if op is None, return a lambda function which does the type checking. 51 | """ 52 | if isinstance(op_types, string_types): 53 | op_types = (op_types,) 54 | if op is None: 55 | return lambda operator: operator.op_type in op_types 56 | else: 57 | return op.node_def.op in op_types 58 | 59 | 60 | class OpMatcher(object): 61 | """Graph match class.""" 62 | 63 | def __init__(self, positive_filter): 64 | """Graph match constructor.""" 65 | self.positive_filters = [] 66 | self.input_op_matches = None 67 | self.control_input_op_matches = None 68 | self.output_op_matches = None 69 | positive_filter = self._finalize_positive_filter(positive_filter) 70 | self.positive_filters.append(positive_filter) 71 | 72 | @staticmethod 73 | def _finalize_positive_filter(elem): 74 | """Convert to a filter function.""" 75 | if select.can_be_regex(elem): 76 | regex_ = select.make_regex(elem) 77 | return lambda op, regex=regex_: regex.search(op.name) is not None 78 | elif isinstance(elem, node.Node): 79 | return lambda op, match_op=elem: op is match_op 80 | elif callable(elem): 81 | return elem 82 | elif elem is True: 83 | return lambda op: True 84 | else: 85 | raise ValueError("Cannot finalize the positive filter: {}".format(elem)) 86 | 87 | def __call__(self, op): 88 | """Evaluate if the op matches or not.""" 89 | if not isinstance(op, node.Node): 90 | raise TypeError("Expect gde.Node, got: {}".format(type(op))) 91 | for positive_filter in self.positive_filters: 92 | if not positive_filter(op): 93 | return False 94 | if self.input_op_matches is not None: 95 | if len(op.inputs) != len(self.input_op_matches): 96 | return False 97 | for input_t, input_op_match in zip(op.inputs, self.input_op_matches): 98 | if input_op_match is None: 99 | continue 100 | if not input_op_match(input_t.node): 101 | return False 102 | if self.control_input_op_matches is not None: 103 | if len(op.control_inputs) != len(self.control_input_op_matches): 104 | return False 105 | for cinput_op, cinput_op_match in zip(op.control_inputs, 106 | self.control_input_op_matches): 107 | if cinput_op_match is None: 108 | continue 109 | if not cinput_op_match(cinput_op): 110 | return False 111 | if self.output_op_matches is not None: 112 | if len(op.outputs) != len(self.output_op_matches): 113 | return False 114 | for output_t, output_op_matches in zip(op.outputs, 115 | self.output_op_matches): 116 | if output_op_matches is None: 117 | continue 118 | if len(output_t.consumers()) != len(output_op_matches): 119 | return False 120 | for consumer_op, consumer_op_match in zip(output_t.consumers(), 121 | output_op_matches): 122 | if consumer_op_match is None: 123 | continue 124 | if not consumer_op_match(consumer_op): 125 | return False 126 | return True 127 | 128 | def input_ops(self, *args): 129 | """Add input matches.""" 130 | if self.input_op_matches is not None: 131 | raise ValueError("input_op_matches is already set.") 132 | self.input_op_matches = [] 133 | for input_match in args: 134 | self.input_op_matches.append(_make_graph_match(input_match)) 135 | return self 136 | 137 | def control_input_ops(self, *args): 138 | """Add input matches.""" 139 | if self.control_input_op_matches is not None: 140 | raise ValueError("control_input_op_matches is already set.") 141 | self.control_input_op_matches = [] 142 | for input_match in args: 143 | self.control_input_op_matches.append(_make_graph_match(input_match)) 144 | return self 145 | 146 | def output_ops(self, *args): 147 | """Add output matches.""" 148 | if self.output_op_matches is not None: 149 | raise ValueError("output_op_matches is already set.") 150 | self.output_op_matches = [] 151 | for consumer_op_matches in args: 152 | if consumer_op_matches is None: 153 | self.output_op_matches.append(None) 154 | if not isinstance(consumer_op_matches, list): 155 | consumer_op_matches = [consumer_op_matches] 156 | consumer_op_matches = [_make_graph_match(consumer_op_match) 157 | for consumer_op_match in consumer_op_matches] 158 | self.output_op_matches.append(consumer_op_matches) 159 | return self 160 | -------------------------------------------------------------------------------- /tests/function_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for function_graph.py in the GraphDef Editor.""" 17 | 18 | import unittest 19 | import tensorflow.compat.v1 as tf 20 | tf.disable_eager_execution() 21 | import shutil 22 | import tempfile 23 | import numpy as np 24 | 25 | import graph_def_editor as gde 26 | 27 | 28 | class FunctionGraphTest(unittest.TestCase): 29 | 30 | def setUp(self): 31 | # Create a temporary directory for SavedModel files. 32 | self.temp_dir = tempfile.mkdtemp() 33 | 34 | def tearDown(self): 35 | # Remove the directory after the test. 36 | # Comment out this line to prevent deleting temps. 37 | shutil.rmtree(self.temp_dir) 38 | pass # In case previous line gets commented out 39 | 40 | def build_tf_graph(self): 41 | """Builds a tf graph for function (x + y) * 10.0 .""" 42 | @tf.function 43 | def multiplier_function(x): 44 | return tf.constant(10.0, name="function_multiplier") * x 45 | 46 | tf_g = tf.Graph() 47 | with tf_g.as_default(): 48 | x = tf.placeholder(name="x", dtype=tf.float32, shape=[]) 49 | y = tf.placeholder(name="y", dtype=tf.float32, shape=[]) 50 | result_op = tf.add(x, y, name="add") 51 | _ = multiplier_function(result_op) 52 | return tf_g 53 | 54 | def run_tf_graph(self, tf_g, x, y): 55 | with tf.Session(graph=tf_g) as sess: 56 | x_tensor = tf_g.get_tensor_by_name("x:0") 57 | y_tensor = tf_g.get_tensor_by_name("y:0") 58 | output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") 59 | return sess.run(output_tensor, {x_tensor: x, y_tensor: y}) 60 | 61 | def save_tf_graph(self, tf_g, model_dir): 62 | x_tensor = tf_g.get_tensor_by_name("x:0") 63 | y_tensor = tf_g.get_tensor_by_name("y:0") 64 | output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") 65 | with tf.Session(graph=tf_g) as sess: 66 | tf.saved_model.simple_save(sess, model_dir, 67 | inputs={"x": x_tensor, "y": y_tensor}, 68 | outputs={"out": output_tensor}) 69 | 70 | def test_function_rewrite(self): 71 | tf_g = self.build_tf_graph() 72 | self.assertEqual(30.0, self.run_tf_graph(tf_g, 1.0, 2.0)) 73 | graph = gde.Graph(tf_g) 74 | add_op = graph.get_node_by_name("add") 75 | function_name = add_op.outputs[0].consumers()[0].get_attr("f").name 76 | self.assertIn(function_name, graph.function_names) 77 | 78 | function_graph = graph.get_function_graph_by_name(function_name) 79 | function_multiplier_op = \ 80 | function_graph.get_node_by_name("function_multiplier") 81 | self.assertEqual(10.0, function_multiplier_op.get_attr("value")) 82 | function_multiplier_op.replace_attr("value", 83 | np.array(1000.0, dtype=np.float32)) 84 | 85 | self.assertEqual(3000.0, self.run_tf_graph(graph.to_tf_graph(), 1.0, 2.0)) 86 | return graph 87 | 88 | def test_export_saved_model(self): 89 | g = self.test_function_rewrite() 90 | model_dir = self.temp_dir + "/saved_model" 91 | g.to_saved_model(model_dir) 92 | tf_g = tf.Graph() 93 | with tf.Session(graph=tf_g) as sess: 94 | _ = tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], 95 | model_dir) 96 | self.assertEqual(3000.0, self.run_tf_graph(tf_g, 1.0, 2.0)) 97 | 98 | def test_import_saved_model(self): 99 | g = self.test_function_rewrite() 100 | model_dir = self.temp_dir + "/saved_model" 101 | self.save_tf_graph(g.to_tf_graph(), model_dir) 102 | 103 | g = gde.saved_model_to_graph(model_dir) 104 | self.assertEqual(3000.0, self.run_tf_graph(g.to_tf_graph(), 1.0, 2.0)) 105 | 106 | def test_number_attr_support(self): 107 | model_dir = self.temp_dir + "/saved_model" 108 | 109 | @tf.function 110 | def test_function(c): 111 | cdim = tf.constant(1, tf.int32) 112 | c1 = tf.constant([2, 1, 5], tf.int32, name="FuncConst") 113 | c2 = tf.constant([2, 1, 5], tf.int32) 114 | # ConcatOffset has variable number of intputs and outputs 115 | # that is using number_attr in functions 116 | concat_offset = tf.raw_ops.ConcatOffset( 117 | concat_dim=cdim, shape=[c, c1, c2]) 118 | out = tf.math.reduce_sum(concat_offset) 119 | return out 120 | 121 | tf_g = tf.Graph() 122 | with tf_g.as_default(): 123 | with tf.Session() as sess: 124 | c = tf.placeholder(name="c", dtype=tf.int32) 125 | out_func = test_function(c) 126 | c = tf_g.get_tensor_by_name("c:0") 127 | self.assertEqual(3, sess.run(out_func, {c: [2, 1, 5]})) 128 | 129 | tf.saved_model.simple_save( 130 | sess, model_dir, inputs={"c": c}, outputs={"out_func": out_func}) 131 | 132 | g = gde.saved_model_to_graph(model_dir) 133 | 134 | tf_g = g.to_tf_graph() 135 | with tf.Session(graph=tf_g) as sess: 136 | output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") 137 | c = tf_g.get_tensor_by_name("c:0") 138 | self.assertEqual(3, sess.run(output_tensor, {c: [2, 1, 5]})) 139 | 140 | f = g.get_function_graph_by_name(g.function_names[0]) 141 | func_const_op = f.get_node_by_name("FuncConst") 142 | func_const_op.replace_attr("value", np.array([2, 2, 5], dtype=np.int32)) 143 | 144 | tf_g = g.to_tf_graph() 145 | with tf.Session(graph=tf_g) as sess: 146 | output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") 147 | c = tf_g.get_tensor_by_name("c:0") 148 | self.assertEqual(4, sess.run(output_tensor, {c: [2, 1, 5]})) 149 | 150 | def test_visialize(self): 151 | try: 152 | import graphviz 153 | except ModuleNotFoundError as error: 154 | print("WARNING: graphviz is not installed, skipping test") 155 | return 156 | tf_g = self.build_tf_graph() 157 | graph = gde.Graph(tf_g) 158 | function_graph = graph.get_function_graph_by_name(graph.function_names[0]) 159 | gv_graph = gde.util.parse_graphviz_json( 160 | function_graph.visualize(format="json").decode()) 161 | 162 | expected_gv_graph = { 163 | "x": ["mul"], 164 | "function_multiplier": ["mul"], 165 | "mul": ["Identity"], 166 | "Identity": [] 167 | } 168 | self.assertEqual(expected_gv_graph, gv_graph) 169 | 170 | 171 | if __name__ == "__main__": 172 | unittest.main() 173 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | tf.disable_eager_execution() 23 | 24 | import unittest 25 | 26 | import graph_def_editor as gde 27 | 28 | 29 | class UtilTest(unittest.TestCase): 30 | 31 | def test_list_view(self): 32 | """Test for gde.util.ListView.""" 33 | l = [0, 1, 2] 34 | lv = gde.util.ListView(l) 35 | # Should not be the same id. 36 | self.assertIsNot(l, lv) 37 | # Should behave the same way than the original list. 38 | self.assertTrue(len(lv) == 3 and lv[0] == 0 and lv[1] == 1 and lv[2] == 2) 39 | # Should be read only. 40 | with self.assertRaises(TypeError): 41 | lv[0] = 0 42 | 43 | def test_is_iterable(self): 44 | """Test for gde.util.is_iterable.""" 45 | self.assertTrue(gde.util.is_iterable([0, 1, 2])) 46 | self.assertFalse(gde.util.is_iterable(3)) 47 | 48 | def test_unique_graph(self): 49 | """Test for gde.util.check_graphs and gde.util.get_unique_graph.""" 50 | g0_graph = tf.Graph() 51 | with g0_graph.as_default(): 52 | tf.constant(1, name="a") 53 | tf.constant(2, name="b") 54 | g1_graph = tf.Graph() 55 | with g1_graph.as_default(): 56 | tf.constant(1, name="a") 57 | tf.constant(2, name="b") 58 | 59 | g0 = gde.Graph(g0_graph.as_graph_def()) 60 | g1 = gde.Graph(g1_graph.as_graph_def()) 61 | a0, b0, a1, b1 = (g0["a"], g0["b"], g1["a"], g1["b"]) 62 | 63 | print("g0['a'] returns {} (type {})".format(g0['a'], type(g0['a']))) 64 | 65 | # Same graph, should be fine. 66 | self.assertIsNone(gde.util.check_graphs(a0, b0)) 67 | # Two different graphs, should assert. 68 | with self.assertRaises(ValueError): 69 | gde.util.check_graphs(a0, b0, a1, b1) 70 | # a0 and b0 belongs to the same graph, should be fine. 71 | self.assertEqual(gde.util.get_unique_graph([a0, b0]), g0) 72 | # Different graph, should raise an error. 73 | with self.assertRaises(ValueError): 74 | gde.util.get_unique_graph([a0, b0, a1, b1]) 75 | 76 | def test_make_list_of_node(self): 77 | """Test for gde.util.make_list_of_op.""" 78 | g0_graph = tf.Graph() 79 | with g0_graph.as_default(): 80 | tf.constant(1, name="a0") 81 | tf.constant(2, name="b0") 82 | g0 = gde.Graph(g0_graph) 83 | 84 | # Should extract the ops from the graph. 85 | self.assertEqual(len(gde.util.make_list_of_op(g0)), 2) 86 | # Should extract the ops from the tuple. 87 | self.assertEqual(len(gde.util.make_list_of_op((g0["a0"], g0["b0"]))), 2) 88 | 89 | def test_make_list_of_t(self): 90 | """Test for gde.util.make_list_of_t.""" 91 | g0_graph = tf.Graph() 92 | with g0_graph.as_default(): 93 | a0_op = tf.constant(1, name="a0") 94 | b0_op = tf.constant(2, name="b0") 95 | tf.add(a0_op, b0_op) 96 | g0 = gde.Graph(g0_graph) 97 | a0 = g0["a0"].output(0) 98 | b0 = g0["b0"].output(0) 99 | 100 | # Should extract the tensors from the graph. 101 | self.assertEqual(len(gde.util.make_list_of_t(g0)), 3) 102 | # Should extract the tensors from the tuple 103 | self.assertEqual(len(gde.util.make_list_of_t((a0, b0))), 2) 104 | # Should extract the tensors and ignore the ops. 105 | self.assertEqual( 106 | len(gde.util.make_list_of_t( 107 | (a0, a0.node, b0), ignore_ops=True)), 2) 108 | 109 | def test_get_generating_consuming(self): 110 | """Test for gde.util.get_generating_ops and gde.util.get_generating_ops.""" 111 | g0_graph = tf.Graph() 112 | with g0_graph.as_default(): 113 | a0_tensor = tf.constant(1, name="a0") 114 | b0_tensor = tf.constant(2, name="b0") 115 | tf.add(a0_tensor, b0_tensor, name="c0") 116 | g0 = gde.Graph(g0_graph) 117 | a0 = g0["a0"].output(0) 118 | b0 = g0["b0"].output(0) 119 | c0 = g0["c0"].output(0) 120 | 121 | self.assertEqual(len(gde.util.get_generating_ops([a0, b0])), 2) 122 | self.assertEqual(len(gde.util.get_consuming_ops([a0, b0])), 1) 123 | self.assertEqual(len(gde.util.get_generating_ops([c0])), 1) 124 | self.assertEqual(gde.util.get_consuming_ops([c0]), []) 125 | 126 | def test_control_outputs(self): 127 | """Test for the gde.util.ControlOutputs class.""" 128 | g0_graph = tf.Graph() 129 | with g0_graph.as_default(): 130 | a0_tensor = tf.constant(1, name="a0") 131 | b0_tensor = tf.constant(2, name="b0") 132 | x0_tensor = tf.constant(3, name="x0") 133 | with tf.control_dependencies([x0_tensor.op]): 134 | tf.add(a0_tensor, b0_tensor, name="c0") 135 | 136 | g0 = gde.Graph(g0_graph) 137 | x0_node = g0["x0"] 138 | c0_node = g0["c0"] 139 | control_outputs = gde.util.ControlOutputs(g0).get_all() 140 | self.assertEqual(len(control_outputs), 1) 141 | self.assertEqual(len(control_outputs[x0_node]), 1) 142 | self.assertIs(list(control_outputs[x0_node])[0], c0_node) 143 | 144 | def test_scope(self): 145 | """Test simple path scope functionalities.""" 146 | self.assertEqual(gde.util.scope_finalize("foo/bar"), "foo/bar/") 147 | self.assertEqual(gde.util.scope_dirname("foo/bar/op"), "foo/bar/") 148 | self.assertEqual(gde.util.scope_basename("foo/bar/op"), "op") 149 | 150 | def test_placeholder(self): 151 | """Test placeholder functionalities.""" 152 | g0_graph = tf.Graph() 153 | with g0_graph.as_default(): 154 | tf.constant(1, name="foo") 155 | 156 | g0 = gde.Graph(g0_graph) 157 | a0 = g0["foo"].output(0) 158 | 159 | # Test placeholder name. 160 | self.assertEqual(gde.util.placeholder_name(a0), "geph__foo_0") 161 | self.assertEqual(gde.util.placeholder_name(None), "geph") 162 | self.assertEqual( 163 | gde.util.placeholder_name( 164 | a0, scope="foo/"), "foo/geph__foo_0") 165 | self.assertEqual( 166 | gde.util.placeholder_name( 167 | a0, scope="foo"), "foo/geph__foo_0") 168 | self.assertEqual(gde.util.placeholder_name(None, scope="foo/"), "foo/geph") 169 | self.assertEqual(gde.util.placeholder_name(None, scope="foo"), "foo/geph") 170 | 171 | # Test placeholder creation. 172 | g1_graph = tf.Graph() 173 | with g1_graph.as_default(): 174 | tf.constant(1, dtype=tf.float32, name="a1") 175 | 176 | g1 = gde.Graph(g1_graph) 177 | a1_tensor = g1["a1"].output(0) 178 | print("Type of a1_tensor is {}".format(type(a1_tensor))) 179 | 180 | ph1 = gde.util.make_placeholder_from_tensor(g1, a1_tensor) 181 | ph2 = gde.util.make_placeholder_from_dtype_and_shape(g1, dtype=tf.float32) 182 | self.assertEqual(ph1.name, "geph__a1_0") 183 | self.assertEqual(ph2.name, "geph") 184 | 185 | def test_identity(self): 186 | tf_g = tf.Graph() 187 | with tf_g.as_default(): 188 | c = tf.constant(42) 189 | i1 = tf.identity(c, name="identity_tf") 190 | 191 | g = gde.Graph(tf_g) 192 | i2_node = gde.util.make_identity(g, "identity_gde", g.get_tensor_by_name(c.name)) 193 | i2 = i2_node.outputs[0] 194 | 195 | with g.to_tf_graph().as_default(): 196 | with tf.Session() as sess: 197 | result1 = sess.run(i1.name) 198 | result2 = sess.run(i2.name) 199 | self.assertEqual(result1, result2) 200 | 201 | 202 | if __name__ == "__main__": 203 | unittest.main() 204 | -------------------------------------------------------------------------------- /graph_def_editor/edit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Various function for graph editing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from graph_def_editor import reroute, select, subgraph, util 22 | 23 | 24 | __all__ = [ 25 | "detach_control_inputs", 26 | "detach_control_outputs", 27 | "detach_inputs", 28 | "detach_outputs", 29 | "detach", 30 | "connect", 31 | "bypass", 32 | ] 33 | 34 | 35 | def detach_control_inputs(sgv): 36 | """Detach all the external control inputs of the subgraph sgv. 37 | 38 | Args: 39 | sgv: the subgraph view to be detached. This argument is converted to a 40 | subgraph using the same rules as the function subgraph.make_view. 41 | """ 42 | sgv = subgraph.make_view(sgv) 43 | for op in sgv.ops: 44 | cops = [cop for cop in op.control_inputs if cop not in sgv.ops] 45 | reroute.remove_control_inputs(op, cops) 46 | 47 | 48 | def detach_control_outputs(sgv, control_outputs): 49 | """Detach all the external control outputs of the subgraph sgv. 50 | 51 | Args: 52 | sgv: the subgraph view to be detached. This argument is converted to a 53 | subgraph using the same rules as the function subgraph.make_view. 54 | control_outputs: a util.ControlOutputs instance. 55 | """ 56 | if not isinstance(control_outputs, util.ControlOutputs): 57 | raise TypeError("Expected a util.ControlOutputs, got: {}", 58 | type(control_outputs)) 59 | control_outputs.update() 60 | sgv = subgraph.make_view(sgv) 61 | for op in sgv.ops: 62 | for cop in control_outputs.get(op): 63 | if cop not in sgv.ops: 64 | reroute.remove_control_inputs(cop, op) 65 | 66 | 67 | def detach_inputs(sgv, control_inputs=False): 68 | """Detach the inputs of a subgraph view. 69 | 70 | Args: 71 | sgv: the subgraph view to be detached. This argument is converted to a 72 | subgraph using the same rules as the function subgraph.make_view. 73 | Note that sgv is modified in place. 74 | control_inputs: if True control_inputs are also detached. 75 | Returns: 76 | A tuple `(sgv, input_placeholders)` where 77 | `sgv` is a new subgraph view of the detached subgraph; 78 | `input_placeholders` is a list of the created input placeholders. 79 | Raises: 80 | StandardError: if sgv cannot be converted to a SubGraphView using 81 | the same rules than the function subgraph.make_view. 82 | """ 83 | sgv = subgraph.make_view(sgv) 84 | 85 | # Old code: 86 | # with sgv.graph.as_default(): 87 | # input_placeholders = [ 88 | # tf.placeholder( 89 | # dtype=input_t.dtype, name=util.placeholder_name(input_t)) 90 | # for input_t in sgv.inputs 91 | # ] 92 | # New code: 93 | input_placeholders = [ 94 | util.make_placeholder(sgv.graph, util.placeholder_name(input_t), 95 | input_t.dtype, input_t.shape).output(0) 96 | for input_t in sgv.inputs 97 | ] 98 | 99 | reroute.swap_inputs(sgv, input_placeholders) 100 | if control_inputs: 101 | detach_control_inputs(sgv) 102 | return sgv, input_placeholders 103 | 104 | 105 | def detach_outputs(sgv, control_outputs=None): 106 | """Detach the output of a subgraph view. 107 | 108 | Args: 109 | sgv: the subgraph view to be detached. This argument is converted to a 110 | subgraph using the same rules as the function subgraph.make_view. 111 | Note that sgv is modified in place. 112 | control_outputs: a util.ControlOutputs instance or None. If not None the 113 | control outputs are also detached. 114 | Returns: 115 | A tuple `(sgv, output_placeholders)` where 116 | `sgv` is a new subgraph view of the detached subgraph; 117 | `output_placeholders` is a list of the created output placeholders. 118 | Raises: 119 | StandardError: if sgv cannot be converted to a SubGraphView using 120 | the same rules than the function subgraph.make_view. 121 | """ 122 | sgv = subgraph.make_view(sgv) 123 | # only select outputs with consumers 124 | sgv_ = sgv.remap_outputs([output_id 125 | for output_id, output_t in enumerate(sgv.outputs) 126 | if output_t.consumers()]) 127 | # create consumer subgraph and remap 128 | consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) 129 | consumers_sgv = consumers_sgv.remap_inputs( 130 | [input_id for input_id, input_t in enumerate(consumers_sgv.inputs) 131 | if input_t in sgv_.outputs]) 132 | 133 | # Old code: 134 | # with sgv_.graph.as_default(): 135 | # output_placeholders = [ 136 | # util.make_placeholder_from_tensor(input_t) 137 | # for input_t in consumers_sgv.inputs 138 | # ] 139 | # New code: 140 | output_placeholders = [ 141 | util.make_placeholder_from_tensor(consumers_sgv.graph, input_t).output(0) 142 | for input_t in consumers_sgv.inputs 143 | ] 144 | 145 | reroute.swap_outputs(sgv_, output_placeholders) 146 | if control_outputs is not None: 147 | detach_control_outputs(sgv_, control_outputs) 148 | return sgv_, output_placeholders 149 | 150 | 151 | def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None): 152 | """Detach both the inputs and the outputs of a subgraph view. 153 | 154 | Args: 155 | sgv: the subgraph view to be detached. This argument is converted to a 156 | subgraph using the same rules as the function subgraph.make_view. 157 | Note that sgv is modified in place. 158 | control_inputs: A boolean indicating whether control inputs are enabled. 159 | control_outputs: An instance of util.ControlOutputs or None. If not None, 160 | control outputs are enabled. 161 | control_ios: An instance of util.ControlOutputs or None. If not None, both 162 | control inputs and control outputs are enabled. This is equivalent to set 163 | control_inputs to True and control_outputs to the util.ControlOutputs 164 | instance. 165 | Returns: 166 | A tuple `(sgv, detached_inputs, detached_outputs)` where: 167 | `sgv` is a new subgraph view of the detached subgraph; 168 | `detach_inputs` is a list of the created input placeholders; 169 | `detach_outputs` is a list of the created output placeholders. 170 | Raises: 171 | StandardError: if sgv cannot be converted to a SubGraphView using 172 | the same rules than the function subgraph.make_view. 173 | """ 174 | control_inputs, control_outputs = select.check_cios(control_inputs, 175 | control_outputs, 176 | control_ios) 177 | _, detached_inputs = detach_inputs(sgv, control_inputs) 178 | _, detached_outputs = detach_outputs(sgv, control_outputs) 179 | return sgv, detached_inputs, detached_outputs 180 | 181 | 182 | def connect(sgv0, sgv1, disconnect_first=False): 183 | """Connect the outputs of sgv0 to the inputs of sgv1. 184 | 185 | Args: 186 | sgv0: the first subgraph to have its outputs swapped. This argument is 187 | converted to a subgraph using the same rules as the function 188 | subgraph.make_view. 189 | Note that sgv0 is modified in place. 190 | sgv1: the second subgraph to have its outputs swapped. This argument is 191 | converted to a subgraph using the same rules as the function 192 | subgraph.make_view. 193 | Note that sgv1 is modified in place. 194 | disconnect_first: if True the current outputs of sgv0 are disconnected. 195 | Returns: 196 | A tuple `(sgv0, sgv1)` of the now connected subgraphs. 197 | Raises: 198 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 199 | the same rules than the function subgraph.make_view. 200 | """ 201 | sgv0 = subgraph.make_view(sgv0) 202 | sgv1 = subgraph.make_view(sgv1) 203 | util.check_graphs(sgv0, sgv1) 204 | if disconnect_first: 205 | detach_outputs(sgv0) 206 | sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) 207 | reroute.reroute_inputs(sgv0_outputs, sgv1) 208 | return sgv0, sgv1 209 | 210 | 211 | def bypass(sgv): 212 | """Bypass the given subgraph by connecting its inputs to its outputs. 213 | 214 | Args: 215 | sgv: the subgraph view to be bypassed. This argument is converted to a 216 | subgraph using the same rules than the function subgraph.make_view. 217 | Note that sgv is modified in place. 218 | Returns: 219 | A tuple `(sgv, detached_inputs)` where: 220 | `sgv` is a new subgraph view of the bypassed subgraph; 221 | `detached_inputs` is a list of the created input placeholders. 222 | Raises: 223 | StandardError: if sgv cannot be converted to a SubGraphView using 224 | the same rules than the function subgraph.make_view. 225 | """ 226 | # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers 227 | sgv = subgraph.make_view(sgv) 228 | sgv_inputs = list(sgv.inputs) 229 | sgv, detached_inputs = detach_inputs(sgv) 230 | reroute.reroute_ts(sgv_inputs, sgv.outputs) 231 | return sgv, detached_inputs 232 | -------------------------------------------------------------------------------- /graph_def_editor/variable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # TODO: Move this protobuf into this project so we don't depend on 17 | # tf.core.framework 18 | from tensorflow.core.framework import variable_pb2 19 | 20 | import sys 21 | if sys.version >= '3': 22 | from graph_def_editor import base_graph 23 | from typing import AbstractSet, Union 24 | 25 | 26 | __all__ = [ 27 | "Variable", 28 | ] 29 | 30 | 31 | class Variable(object): 32 | """ 33 | Surrogate object that represents the contents of a 34 | `tensorflow.core.framework.variable_pb2.VariableDef` 35 | protocol buffer, which in turn represents a TensorFlow variable. 36 | 37 | Note that the `VariableDef` protobuf is not precisely a public API, 38 | but it's the closest thing that TensorFlow has to one. Also, you can't 39 | serialize a general graph in a meaningful way without serializing variables. 40 | 41 | TensorFlow variables are composed of multiple ops and tensors internally. 42 | TensorFlow's Python API has a class `tf.Variable` that tracks these 43 | objects. This class tracks a similar set of pointers in protobuf land. 44 | """ 45 | def __init__(self, 46 | g # type: base_graph.BaseGraph 47 | ): 48 | """ 49 | Do not call this constructor directly. 50 | 51 | This constructor should only be called from `Graph.add_variable()`. 52 | 53 | Args: 54 | g: gde.Graph object representing the containing graph 55 | """ 56 | if g.has_passthrough_saver: 57 | # The internals of a tf.Saver are opaque to us. 58 | raise ValueError("Attempted to add a variable to Graph '{}', which has " 59 | "an immutable serialized tf.Saver " 60 | "object.".format(g.name)) 61 | self._graph = g 62 | self._collection_names = set() # Set[str] 63 | 64 | # Core fields are modeled after those of VariableDef. 65 | self._variable_name = None # str 66 | self._initial_value_name = None # str 67 | self._initializer_name = None # str 68 | self._snapshot_name = None # str 69 | self._trainable = None # bool 70 | 71 | def __str__(self): 72 | return "Var[{}]".format(self.name) 73 | 74 | def __repr__(self): 75 | return "Var[name={}, init={}, val={}, " \ 76 | "snap={}, t={}]".format(self.name, self._initializer_name, 77 | self._initial_value_name, 78 | self._snapshot_name, 79 | self._trainable) 80 | 81 | def is_same_variable(self, 82 | other # type: Variable 83 | ): 84 | """ 85 | Returns true if is variable and `other` are the same, ignoring graph and 86 | collection information. 87 | """ 88 | if self.name != other.name: 89 | return False 90 | elif self.initial_value_name != other.initial_value_name: 91 | return False 92 | elif self.initializer_name != other.initializer_name: 93 | return False 94 | elif self.snapshot_name != other.snapshot_name: 95 | return False 96 | elif self.trainable != other.trainable: 97 | return False 98 | else: 99 | return True 100 | 101 | def from_proto(self, 102 | variable_def, # type: Union[variable_pb2.VariableDef, bytes] 103 | validate=True, # type: bool 104 | allow_duplicates=False # type: bool 105 | ): 106 | """ 107 | Populate the fields of this object from a serialized TensorFlow variable. 108 | 109 | variable_def: Protocol buffer representation of a TensorFlow variable. In a 110 | serialized graph, you will find these VariableDef protocol buffer 111 | messages stuffed into the `bytes_list` field of a `CollectionDef` proto 112 | inside a `MetaGraphDef` message. Otherwise you can create a 113 | `VariableDef` proto by calling `tf.Variable.to_proto()`. 114 | May be serialized as a `bytes` object. 115 | validate: True to validate any names used here. False to skip 116 | validation (e.g because you are creating the variable before creating 117 | the nodes it references). 118 | The variable name is checked for duplicates regardless of whether this 119 | flag is set to True. 120 | allow_duplicate: Don't complain if the graph contains a variable of the 121 | same name, provided that the two variables are equal. 122 | Raises: 123 | NameError if a variable with the indicated name already exists, 124 | ValueError if another validation fails 125 | """ 126 | if isinstance(variable_def, bytes): 127 | variable_def = variable_pb2.VariableDef.FromString(variable_def) 128 | self._variable_name = variable_def.variable_name 129 | self._initial_value_name = variable_def.initial_value_name 130 | self._initializer_name = variable_def.initializer_name 131 | self._snapshot_name = variable_def.snapshot_name 132 | self._trainable = variable_def.trainable 133 | # TODO(frreiss): Figure out what to do with the is_resource field 134 | # TODO(frreiss): Figure out what to do with the save_slice_info_def field 135 | if validate: 136 | self.validate(allow_duplicates) 137 | 138 | def to_proto(self): 139 | """ 140 | Inverse of `from_proto()` method. 141 | 142 | Returns a `VariableDef` protocol buffer message that represents this 143 | variable. 144 | """ 145 | ret = variable_pb2.VariableDef() 146 | ret.variable_name = self._variable_name 147 | ret.initial_value_name = self._initial_value_name 148 | ret.initializer_name = self._initializer_name 149 | ret.snapshot_name = self._snapshot_name 150 | ret.trainable = self._trainable 151 | return ret 152 | 153 | def validate(self, 154 | allow_duplicate=False # type: bool 155 | ): 156 | """ 157 | Verify that all the names this variable references are valid in the 158 | parent graph and that no conflicting variables exist. 159 | 160 | Args: 161 | allow_duplicate: Don't complain if the graph contains a variable of the 162 | same name, provided that the two variables are equal. 163 | """ 164 | if self._variable_name in self.graph.variable_names: 165 | other_var = self.graph.get_variable_by_name(self._variable_name) 166 | if other_var is not self: 167 | if not self.is_same_variable(other_var): 168 | raise ValueError("Existing '{}' in graph conflicts with this one " 169 | "({} != {})".format(self._variable_name, repr(self), 170 | repr(other_var))) 171 | elif not allow_duplicate: 172 | raise ValueError("Graph already has a variable called '{}'".format( 173 | self._variable_name)) 174 | # self._initializer_name should reference a node. Other names should 175 | # reference tensors. 176 | _initializer_name = self._initializer_name 177 | if _initializer_name and _initializer_name.rfind(":") > 0: 178 | # Adding extra check in case _initializer_name refers to a tensor. 179 | _initializer_name = _initializer_name[:_initializer_name.rfind(":")] 180 | if not self.graph.contains_node(_initializer_name): 181 | raise ValueError("Initializer name '{}' does not correspond to any " 182 | "node in graph".format(self._initializer_name)) 183 | _ = self.graph.get_tensor_by_name(self._initial_value_name, 184 | "Invalid initial value name '{}': {}") 185 | _ = self.graph.get_tensor_by_name(self._snapshot_name, 186 | "Invalid snapshot name '{}': {}") 187 | 188 | def to_proto(self): 189 | # type: () -> variable_pb2.VariableDef 190 | """ 191 | Convert this object into its equivalent TensorFlow protocol buffer 192 | message. 193 | 194 | Returns a `VariableDef` protobuf equivalent to this object. 195 | """ 196 | ret = variable_pb2.VariableDef() 197 | ret.variable_name = self.name 198 | ret.initial_value_name = self.initial_value_name 199 | ret.initializer_name = self.initializer_name 200 | ret.snapshot_name = self.snapshot_name 201 | ret.trainable = self.trainable 202 | # TODO(frreiss): Figure out what to do with the is_resource field 203 | # TODO(frreiss): Figure out what to do with the save_slice_info_def field 204 | return ret 205 | 206 | 207 | @property 208 | def graph(self): 209 | return self._graph 210 | 211 | @property 212 | def name(self): 213 | return self._variable_name 214 | 215 | @name.setter 216 | def name(self, 217 | val # type: str 218 | ): 219 | # TODO(frreiss): Should we update the graph here? 220 | self._variable_name = val 221 | 222 | @property 223 | def initial_value_name(self): 224 | return self._initial_value_name 225 | 226 | @property 227 | def initializer_name(self): 228 | return self._initializer_name 229 | 230 | @property 231 | def snapshot_name(self): 232 | return self._snapshot_name 233 | 234 | @property 235 | def trainable(self): 236 | return self._trainable 237 | 238 | @property 239 | def collection_names(self): 240 | # type: () -> AbstractSet[str] 241 | """ 242 | Returns the names of all collections this variable is a member of in the 243 | parent graph. 244 | """ 245 | return frozenset(self._collection_names) 246 | 247 | def add_to_collection(self, 248 | collection_name # type: str 249 | ): 250 | """ 251 | Add the variable to the indicated collection. 252 | """ 253 | if collection_name not in self._collection_names: 254 | self._collection_names.add(collection_name) 255 | # Invalidate any information the parent graph may have cached about 256 | # collections. 257 | self._graph.increment_version_counter() 258 | 259 | 260 | 261 | ############################################################################### 262 | # Functions below this line are private to this file. 263 | 264 | -------------------------------------------------------------------------------- /examples/mobilenet_example.py: -------------------------------------------------------------------------------- 1 | # Coypright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """ 18 | Example of using the GraphDef editor and the Graph Transform Tool to prep a 19 | copy of MobileNetV2 for inference. 20 | 21 | Requires that the "Pillow" package be installed. 22 | 23 | To run this example from the root of the project, type: 24 | PYTHONPATH=$PWD env/bin/python examples/mobilenet_example.py 25 | """ 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import os 32 | import tensorflow as tf 33 | import graph_def_editor as gde 34 | import numpy as np 35 | # noinspection PyPackageRequirements 36 | import PIL # Pillow 37 | import shutil 38 | import tarfile 39 | import textwrap 40 | import urllib.request 41 | 42 | from tensorflow.tools import graph_transforms 43 | 44 | FLAGS = tf.flags.FLAGS 45 | 46 | 47 | def _indent(s): 48 | return textwrap.indent(str(s), " ") 49 | 50 | 51 | _TMP_DIR = "/tmp/mobilenet_example" 52 | _SAVED_MODEL_DIR = _TMP_DIR + "/original_model" 53 | _FROZEN_GRAPH_FILE = "{}/frozen_graph.pbtext".format(_TMP_DIR) 54 | _TF_REWRITES_GRAPH_FILE = "{}/after_tf_rewrites_graph.pbtext".format(_TMP_DIR) 55 | _GDE_REWRITES_GRAPH_FILE = "{}/after_gde_rewrites_graph.pbtext".format(_TMP_DIR) 56 | _AFTER_MODEL_FILES = [ 57 | _FROZEN_GRAPH_FILE, _TF_REWRITES_GRAPH_FILE, _GDE_REWRITES_GRAPH_FILE 58 | ] 59 | _USE_KERAS = False 60 | 61 | # Panda pic from Wikimedia; also used in 62 | # https://github.com/tensorflow/models/blob/master/research/slim/nets ... 63 | # ... /mobilenet/mobilenet_example.ipynb 64 | _PANDA_PIC_URL = ("https://upload.wikimedia.org/wikipedia/commons/f/fe/" 65 | "Giant_Panda_in_Beijing_Zoo_1.JPG") 66 | _PANDA_PIC_FILE = _TMP_DIR + "/panda.jpg" 67 | 68 | def _clear_dir(path): 69 | # type: (str) -> None 70 | if os.path.isdir(path): 71 | shutil.rmtree(path) 72 | os.mkdir(path) 73 | 74 | 75 | def _protobuf_to_file(pb, path, human_readable_name): 76 | # type: (Any, str, str) -> None 77 | with open(path, "w") as f: 78 | f.write(str(pb)) 79 | print("{} written to {}".format(human_readable_name, path)) 80 | 81 | 82 | def get_keras_frozen_graph(): 83 | # type: () -> Tuple[tf.GraphDef, str, str] 84 | """ 85 | Generate a frozen graph for the Keras MobileNet_v2 model. 86 | 87 | This should work, but does NOT work as of TensorFlow 1.12. The 88 | save_keras_model() function in TensorFlow creates a graph that the 89 | convert_variables_to_constants() function can't consume correctly. 90 | 91 | Returns GraphDef, input node name, output node name 92 | """ 93 | # Start with the pretrained MobileNetV2 model from keras.applications, 94 | # wrapped as a tf.keras model. 95 | mobilenet = tf.keras.applications.MobileNetV2() 96 | # Because we're using a tf.keras model instead of a keras model, 97 | # the backing TensorFlow session (keras.backend.get_session()) will have a 98 | # ginormous graph with many unused nodes. The only supported API to filter 99 | # down that graph is to write the model out as a SavedModel "file". So 100 | # that's what we do here. 101 | # Note that save_keras_model() doesn't write the model to the path you told 102 | # it to use. It writes the model to a timestamped subdirectory and returns 103 | # the path of the subdirectory as a bytes object (NOT a string). 104 | actual_saved_model_directory_bytes = \ 105 | tf.contrib.saved_model.save_keras_model(mobilenet, _SAVED_MODEL_DIR, 106 | as_text=True) 107 | print("Initial SavedModel file is at {}".format(tf.compat.as_str( 108 | actual_saved_model_directory_bytes))) 109 | # Now we need to freeze the graph, i.e. convert all variables to Const nodes. 110 | # The only supported way to do this with a Keras model is to write out a 111 | # SavedModel file, read the SavedModel file back into a fresh session, 112 | # and invoke the appropriate rewrite from tf.graph_util. 113 | with tf.Session() as sess: 114 | # save_keras_model() uses the "serve" tag for inference graphs, and the 115 | # names of the output nodes are the same as those returned by Model.outputs 116 | tf.saved_model.load(sess, tags=["serve"], 117 | export_dir=actual_saved_model_directory_bytes) 118 | frozen_graph_def = tf.graph_util.convert_variables_to_constants( 119 | sess, sess.graph.as_graph_def(), 120 | output_node_names=[n.op.name for n in mobilenet.outputs]) 121 | return (frozen_graph_def, mobilenet.inputs[0].op.name, 122 | mobilenet.outputs[0].op.name) 123 | 124 | 125 | def get_slim_frozen_graph(): 126 | # type: () -> Tuple[tf.GraphDef, str, str] 127 | """ 128 | Obtains a MobileNet_v2 model from the TensorFlow model zoo 129 | 130 | Returns GraphDef, input op name, output op name 131 | """ 132 | # Download a checkpoint if we don't have a cached one in our temp dir. 133 | # See 134 | # https://github.com/tensorflow/models/tree/master/research/slim/nets ... 135 | # ... /mobilenet 136 | # for a full list of available checkpoints. 137 | _CHECKPOINT_NAME = "mobilenet_v2_1.0_224" 138 | _CHECKPOINT_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints" \ 139 | "/{}.tgz".format(_CHECKPOINT_NAME) 140 | _CHECKPOINT_TGZ = "{}/{}.tgz".format(_TMP_DIR, _CHECKPOINT_NAME) 141 | _FROZEN_GRAPH_MEMBER = "./{}_frozen.pb".format(_CHECKPOINT_NAME) 142 | 143 | if not os.path.exists(_CHECKPOINT_TGZ): 144 | urllib.request.urlretrieve(_CHECKPOINT_URL, _CHECKPOINT_TGZ) 145 | with tarfile.open(_CHECKPOINT_TGZ) as t: 146 | frozen_graph_bytes = t.extractfile(_FROZEN_GRAPH_MEMBER).read() 147 | return (tf.GraphDef.FromString(frozen_graph_bytes), 148 | "input", "MobilenetV2/Predictions/Reshape_1") 149 | 150 | 151 | def run_graph(graph_proto, img, input_node, output_node): 152 | # type: (tf.GraphDef, np.ndarray, str, str) -> None 153 | """ 154 | Run an example image through a MobileNet-like graph and print a summary of 155 | the results to STDOUT. 156 | 157 | graph_proto: GraphDef protocol buffer message holding serialized graph 158 | img: Preprocessed (centered by dividing by 128) numpy array holding image 159 | input_node: Name of input graph node 160 | output_node: Name of output graph node; should produce logits 161 | """ 162 | img_as_batch = img.reshape(tuple([1] + list(img.shape))) 163 | with tf.Graph().as_default(): 164 | with tf.Session() as sess: 165 | tf.import_graph_def(graph_proto, name="") 166 | result = sess.run(output_node + ":0", {input_node + ":0": img_as_batch}) 167 | 168 | result = result.reshape(result.shape[1:]) 169 | # print("Raw result is {}".format(result)) 170 | sorted_indices = result.argsort() 171 | # print("Top 5 indices: {}".format(sorted_indices[-5:])) 172 | 173 | print("Rank Label Weight") 174 | for i in range(5): 175 | print("{:<10}{:<10}{}".format(i + 1, sorted_indices[-(i + 1)], 176 | result[sorted_indices[-(i + 1)]])) 177 | 178 | 179 | def main(_): 180 | # Remove any detritus of previous runs of this script, but leave the temp 181 | # dir in place because the user might have a shell there. 182 | if not os.path.isdir(_TMP_DIR): 183 | os.mkdir(_TMP_DIR) 184 | _clear_dir(_SAVED_MODEL_DIR) 185 | for f in _AFTER_MODEL_FILES: 186 | if os.path.isfile(f): 187 | os.remove(f) 188 | 189 | # Obtain a frozen graph for a MobileNet model 190 | if _USE_KERAS: 191 | frozen_graph_def, input_node, output_node = get_keras_frozen_graph() 192 | else: 193 | frozen_graph_def, input_node, output_node = get_slim_frozen_graph() 194 | 195 | _protobuf_to_file(frozen_graph_def, _FROZEN_GRAPH_FILE, "Frozen graph") 196 | 197 | # Now run through some of TensorFlow's built-in graph rewrites. 198 | # For that we use the undocumented Python APIs under 199 | # tensorflow.tools.graph_transforms 200 | after_tf_rewrites_graph_def = graph_transforms.TransformGraph( 201 | frozen_graph_def, 202 | inputs=[input_node], 203 | outputs=[output_node], 204 | # Use the set of transforms recommended in the README under "Optimizing 205 | # for Deployment" 206 | transforms=['strip_unused_nodes(type=float, shape="1,299,299,3")', 207 | 'remove_nodes(op=Identity, op=CheckNumerics)', 208 | 'fold_constants(ignore_errors=true)', 209 | 'fold_batch_norms', 210 | 'fold_old_batch_norms'] 211 | ) 212 | 213 | _protobuf_to_file(after_tf_rewrites_graph_def, 214 | _TF_REWRITES_GRAPH_FILE, 215 | "Graph after built-in TensorFlow rewrites") 216 | 217 | # Now run the GraphDef editor's fold_batch_norms_up() rewrite 218 | g = gde.Graph(after_tf_rewrites_graph_def) 219 | gde.rewrite.fold_batch_norms(g) 220 | gde.rewrite.fold_old_batch_norms(g) 221 | gde.rewrite.fold_batch_norms_up(g) 222 | after_gde_graph_def = g.to_graph_def(add_shapes=True) 223 | 224 | _protobuf_to_file(after_gde_graph_def, 225 | _GDE_REWRITES_GRAPH_FILE, 226 | "Graph after fold_batch_norms_up() rewrite") 227 | 228 | # Dump some statistics about the number of each type of op 229 | print(" Number of ops in frozen graph: {}".format(len( 230 | frozen_graph_def.node))) 231 | print(" Number of ops after built-in rewrites: {}".format(len( 232 | after_tf_rewrites_graph_def.node))) 233 | print("Number of ops after GDE rewrites: {}".format(len( 234 | after_gde_graph_def.node))) 235 | 236 | # Run model before and after rewrite and compare results 237 | if not os.path.exists(_PANDA_PIC_FILE): 238 | print("Downloading {} to {}".format(_PANDA_PIC_URL, _PANDA_PIC_FILE)) 239 | urllib.request.urlretrieve(_PANDA_PIC_URL, _PANDA_PIC_FILE) 240 | img = np.array(PIL.Image.open(_PANDA_PIC_FILE).resize((224, 224))).astype( 241 | np.float) # / 128 # - 1 242 | # Normalize each channel 243 | channel_means = np.mean(img, axis=(0, 1)) 244 | 245 | print("Channel means are: {}".format(channel_means)) 246 | print("Image shape is {}".format(img.shape)) 247 | 248 | print("Frozen graph results:") 249 | run_graph(frozen_graph_def, img, input_node, output_node) 250 | print("Results after built-in rewrites:") 251 | run_graph(after_tf_rewrites_graph_def, img, input_node, output_node) 252 | print("Results after GDE rewrites:") 253 | run_graph(after_gde_graph_def, img, input_node, output_node) 254 | 255 | 256 | if __name__ == "__main__": 257 | tf.app.run() 258 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/select_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | 23 | import tensorflow.compat.v1 as tf 24 | tf.disable_eager_execution() 25 | 26 | import unittest 27 | 28 | import graph_def_editor as gde 29 | 30 | 31 | class SelectTest(unittest.TestCase): 32 | 33 | # TODO(frreiss): Merge duplicate setup code across test cases 34 | def setUp(self): 35 | tf_graph = tf.Graph() 36 | with tf_graph.as_default(): 37 | a = tf.constant([1., 1.], shape=[2], name="a") 38 | with tf.name_scope("foo"): 39 | b = tf.constant([2., 2.], shape=[2], name="b") 40 | c = tf.add(a, b, name="c") 41 | d = tf.constant([3., 3.], shape=[2], name="d") 42 | with tf.name_scope("bar"): 43 | e = tf.add(c, d, name="e") 44 | f = tf.add(c, d, name="f") 45 | g = tf.add(c, a, name="g") 46 | with tf.control_dependencies([c.op]): 47 | h = tf.add(f, g, name="h") 48 | self.graph = gde.Graph(tf_graph) 49 | self.a = self.graph.get_tensor_by_name(a.name) 50 | self.b = self.graph.get_tensor_by_name(b.name) 51 | self.c = self.graph.get_tensor_by_name(c.name) 52 | self.d = self.graph.get_tensor_by_name(d.name) 53 | self.e = self.graph.get_tensor_by_name(e.name) 54 | self.f = self.graph.get_tensor_by_name(f.name) 55 | self.g = self.graph.get_tensor_by_name(g.name) 56 | self.h = self.graph.get_tensor_by_name(h.name) 57 | 58 | def test_regex(self): 59 | """Test for ge.can_be_regex and ge.make_regex.""" 60 | self.assertTrue(gde.can_be_regex("foo")) 61 | self.assertTrue(gde.can_be_regex(re.compile("foo"))) 62 | regex = re.compile("foo") 63 | self.assertIs(gde.make_regex(regex), regex) 64 | 65 | def test_get_input_output_ts(self): 66 | """Test for ge._get_input_ts abd ge._get_output_ts.""" 67 | self.assertEqual(len(gde.select._get_input_ts(self.graph)), 6) 68 | self.assertEqual(len(gde.select._get_output_ts(self.graph)), 8) 69 | 70 | def test_get_filter(self): 71 | """Test for various filtering operations on ts ops.""" 72 | # TODO(fkp): parameterize 73 | self.assertEqual(len(gde.filter_ops(self.graph, True)), 8) 74 | self.assertEqual( 75 | len(gde.filter_ops(self.graph, 76 | lambda op: op.op_type == "Const")), 3) 77 | self.assertEqual( 78 | len(gde.filter_ops(self.graph, 79 | lambda op: op.op_type in ["Add", "AddV2"])), 5) 80 | self.assertEqual( 81 | len(gde.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3) 82 | 83 | self.assertEqual(len(gde.filter_ts(self.graph, True)), 8) 84 | self.assertEqual( 85 | len(gde.filter_ts_from_regex(self.graph, r"^.*/[fgh]:\d$")), 3) 86 | 87 | self.assertEqual(len(gde.get_name_scope_ops(self.graph, "foo/")), 7) 88 | self.assertEqual(len(gde.get_name_scope_ops(self.graph, "foo/bar")), 4) 89 | 90 | def test_get_ops_ios(self): 91 | """Test for ge.get_ops_ios.""" 92 | control_outputs = gde.util.ControlOutputs(self.graph) 93 | self.assertEqual( 94 | len(gde.get_ops_ios(self.h.op, control_ios=control_outputs)), 3) 95 | self.assertEqual(len(gde.get_ops_ios(self.h.op)), 2) 96 | self.assertEqual( 97 | len(gde.get_ops_ios(self.c.op, control_ios=control_outputs)), 6) 98 | self.assertEqual(len(gde.get_ops_ios(self.c.op)), 5) 99 | 100 | def test_compute_boundary_ts_0(self): 101 | """Test for ge.compute_boundary_ts.""" 102 | input_ts, output_ts, inside_ts = gde.compute_boundary_ts(self.g.op) 103 | self.assertEqual(list(input_ts), [self.c, self.a]) 104 | self.assertEqual(list(output_ts), [self.g]) 105 | self.assertEqual(list(inside_ts), []) 106 | 107 | def test_compute_boundary_ts_1(self): 108 | """Test for ge.compute_boundary_ts.""" 109 | input_ts, output_ts, inside_ts = gde.compute_boundary_ts( 110 | [self.g.op, self.h.op]) 111 | self.assertEqual(list(input_ts), [self.c, self.a, self.f]) 112 | self.assertEqual(list(output_ts), [self.h]) 113 | self.assertEqual(list(inside_ts), [self.g]) 114 | 115 | def test_compute_boundary_ts_2(self): 116 | """Test for ge.compute_boundary_ts.""" 117 | tf_graph = tf.Graph() 118 | with tf_graph.as_default(): 119 | a_tensor = tf.constant(1, name="a") 120 | b_tensor = tf.constant(1, name="b") 121 | c_tensor = tf.add(a_tensor, b_tensor, name="c") 122 | _ = a_tensor + c_tensor 123 | 124 | g = gde.Graph(tf_graph) 125 | input_ts, output_ts, inside_ts = gde.compute_boundary_ts([g["a"], g["c"]]) 126 | self.assertEqual(list(input_ts), [g["b"].output(0)]) 127 | self.assertEqual(list(output_ts), [g["a"].output(0), g["c"].output(0)]) 128 | self.assertEqual(list(inside_ts), [g["a"].output(0)]) 129 | 130 | def test_get_within_boundary_ops_0(self): 131 | """Test for test_get_within_boundary_ops.""" 132 | control_outputs = gde.util.ControlOutputs(self.graph) 133 | ops = gde.get_within_boundary_ops( 134 | ops=self.graph, 135 | seed_ops=self.f.op, 136 | boundary_ops=[self.c.op, self.h.op], 137 | inclusive=False, 138 | control_ios=control_outputs) 139 | self.assertEqual(len(ops), 3) 140 | 141 | def test_get_within_boundary_ops_1(self): 142 | """Test for ge.test_get_within_boundary_ops.""" 143 | ops = gde.get_within_boundary_ops( 144 | ops=self.graph, seed_ops=self.h.op, boundary_ops=[self.f.op, self.g.op]) 145 | self.assertEqual(len(ops), 3) 146 | 147 | def test_get_walks_intersection(self): 148 | """Test for ge.get_walks_intersection_ops.""" 149 | ops = gde.get_walks_intersection_ops([self.c.op], [self.g.op]) 150 | self.assertEqual(len(ops), 2) 151 | 152 | ops = gde.get_walks_intersection_ops([self.a.op], [self.f.op]) 153 | self.assertEqual(len(ops), 3) 154 | self.assertTrue(self.a.op in ops) 155 | self.assertTrue(self.c.op in ops) 156 | self.assertTrue(self.f.op in ops) 157 | 158 | within_ops = [self.a.op, self.f.op] 159 | ops = gde.get_walks_intersection_ops( 160 | [self.a.op], [self.f.op], within_ops=within_ops) 161 | self.assertEqual(len(ops), 0) 162 | 163 | def within_ops_fn(op): 164 | return op in [self.a.op, self.f.op] 165 | ops = gde.get_walks_intersection_ops( 166 | [self.a.op], [self.f.op], within_ops_fn=within_ops_fn) 167 | self.assertEqual(len(ops), 0) 168 | 169 | def test_get_walks_union(self): 170 | """Test for ge.get_walks_union_ops.""" 171 | ops = gde.get_walks_union_ops([self.f.op], [self.g.op]) 172 | self.assertEqual(len(ops), 6) 173 | 174 | ops = gde.get_walks_union_ops([self.a.op], [self.f.op]) 175 | self.assertEqual(len(ops), 8) 176 | 177 | within_ops = [self.a.op, self.c.op, self.d.op, self.f.op] 178 | ops = gde.get_walks_union_ops([self.a.op], [self.f.op], 179 | within_ops=within_ops) 180 | self.assertEqual(len(ops), 4) 181 | self.assertTrue(self.b.op not in ops) 182 | 183 | def within_ops_fn(op): 184 | return op in [self.a.op, self.c.op, self.f.op] 185 | 186 | ops = gde.get_walks_union_ops([self.a.op], [self.f.op], 187 | within_ops_fn=within_ops_fn) 188 | self.assertEqual(len(ops), 3) 189 | self.assertTrue(self.b.op not in ops) 190 | self.assertTrue(self.d.op not in ops) 191 | 192 | def test_select_ops(self): 193 | parameters = ( 194 | (("^foo/",), 7), 195 | (("^foo/bar/",), 4), 196 | (("^foo/bar/", "a"), 5), 197 | ) 198 | for param, length in parameters: 199 | ops = gde.select_ops(*param, graph=self.graph) 200 | self.assertEqual(len(ops), length) 201 | 202 | def test_select_ts(self): 203 | parameters = ( 204 | (".*:0", 8), 205 | (r".*/bar/\w+:0", 4), 206 | ) 207 | for regex, length in parameters: 208 | ts = gde.select_ts(regex, graph=self.graph) 209 | self.assertEqual(len(ts), length) 210 | 211 | def test_select_ops_and_ts(self): 212 | parameters = ( 213 | (("^foo/.*",), 7, 0), 214 | (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4), 215 | ) 216 | for param, l0, l1 in parameters: 217 | ops, ts = gde.select_ops_and_ts(*param, graph=self.graph) 218 | self.assertEqual(len(ops), l0) 219 | self.assertEqual(len(ts), l1) 220 | 221 | def test_forward_walk_ops(self): 222 | seed_ops = [self.a.op, self.d.op] 223 | # Include all ops except for self.g.op 224 | within_ops = [ 225 | x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] 226 | ] 227 | 228 | # For the fn, exclude self.e.op. 229 | def within_ops_fn(op): 230 | return op not in (self.e.op,) 231 | stop_at_ts = (self.f,) 232 | 233 | # No b.op since it's an independent source node. 234 | # No g.op from within_ops. 235 | # No e.op from within_ops_fn. 236 | # No h.op from stop_at_ts and within_ops. 237 | ops = gde.select.get_forward_walk_ops( 238 | seed_ops, 239 | inclusive=True, 240 | within_ops=within_ops, 241 | within_ops_fn=within_ops_fn, 242 | stop_at_ts=stop_at_ts) 243 | self.assertEqual( 244 | set(ops), {self.a.op, self.c.op, self.d.op, self.f.op }) 245 | 246 | # Also no a.op and d.op when inclusive=False 247 | ops = gde.select.get_forward_walk_ops( 248 | seed_ops, 249 | inclusive=False, 250 | within_ops=within_ops, 251 | within_ops_fn=within_ops_fn, 252 | stop_at_ts=stop_at_ts) 253 | self.assertEqual(set(ops), {self.c.op, self.f.op}) 254 | 255 | # Not using within_ops_fn adds e.op. 256 | ops = gde.select.get_forward_walk_ops( 257 | seed_ops, 258 | inclusive=False, 259 | within_ops=within_ops, 260 | stop_at_ts=stop_at_ts) 261 | self.assertEqual(set(ops), {self.c.op, self.e.op, self.f.op}) 262 | 263 | # Not using stop_at_ts adds back h.op. 264 | ops = gde.select.get_forward_walk_ops( 265 | seed_ops, inclusive=False, within_ops=within_ops) 266 | self.assertEqual( 267 | set(ops), {self.c.op, self.e.op, self.f.op, self.h.op}) 268 | 269 | # Starting just form a (the tensor, not op) omits a, b, d. 270 | ops = gde.select.get_forward_walk_ops([self.a], inclusive=True) 271 | self.assertEqual( 272 | set(ops), {self.c.op, self.e.op, self.f.op, self.g.op, self.h.op}) 273 | 274 | def test_backward_walk_ops(self): 275 | seed_ops = [self.h.op] 276 | # Include all ops except for self.g.op 277 | within_ops = [ 278 | x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] 279 | ] 280 | 281 | # For the fn, exclude self.c.op. 282 | def within_ops_fn(op): 283 | return op not in (self.c.op,) 284 | stop_at_ts = (self.f,) 285 | 286 | # Backward walk only includes h since we stop at f and g is not within. 287 | ops = gde.select.get_backward_walk_ops( 288 | seed_ops, 289 | inclusive=True, 290 | within_ops=within_ops, 291 | within_ops_fn=within_ops_fn, 292 | stop_at_ts=stop_at_ts) 293 | self.assertEqual(set(ops), {self.h.op}) 294 | 295 | # If we do inclusive=False, the result is empty. 296 | ops = gde.select.get_backward_walk_ops( 297 | seed_ops, 298 | inclusive=False, 299 | within_ops=within_ops, 300 | within_ops_fn=within_ops_fn, 301 | stop_at_ts=stop_at_ts) 302 | self.assertEqual(set(ops), set()) 303 | 304 | # Removing stop_at_fs adds f.op, d.op. 305 | ops = gde.select.get_backward_walk_ops( 306 | seed_ops, 307 | inclusive=True, 308 | within_ops=within_ops, 309 | within_ops_fn=within_ops_fn) 310 | self.assertEqual(set(ops), {self.d.op, self.f.op, self.h.op}) 311 | 312 | # Not using within_ops_fn adds back ops for a, b, c. 313 | ops = gde.select.get_backward_walk_ops( 314 | seed_ops, inclusive=True, within_ops=within_ops) 315 | self.assertEqual( 316 | set(ops), 317 | {self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op}) 318 | 319 | # Vanially backward search via self.h.op includes everything except e.op. 320 | ops = gde.select.get_backward_walk_ops(seed_ops, inclusive=True) 321 | self.assertEqual( 322 | set(ops), 323 | {self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op, 324 | self.h.op}) 325 | 326 | 327 | if __name__ == "__main__": 328 | unittest.main() 329 | -------------------------------------------------------------------------------- /graph_def_editor/visualization/graphviz_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Functionality to convert graph_def_editor graph to GraphViz visualization.""" 17 | 18 | import re 19 | import uuid 20 | 21 | from .graphviz_style import * 22 | import graph_def_editor.visualization.jupyter_helper as jupyter_helper 23 | 24 | 25 | FORMAT_JUPYTER_SVG = 'jupyter_svg' 26 | FORMAT_JUPYTER_INTERACTIVE = 'jupyter_interactive' 27 | 28 | _CLUSTER_INDEX = 0 # index of subgraph 29 | _ADD_DIGRAPH_FUNC = None 30 | _ADD_DIGRAPH_NODE_FUNC = None 31 | _ADD_DIGRAPH_EDGE_FUNC = None 32 | 33 | 34 | graph_pref = { 35 | 'fontcolor': '#414141', 36 | 'style': 'rounded', 37 | } 38 | 39 | name_scope_graph_pref = { 40 | 'bgcolor': '#eeeeee', 41 | 'color': '#aaaaaa', 42 | 'penwidth': '2', 43 | } 44 | 45 | non_name_scope_graph_pref = { 46 | 'fillcolor': 'white', 47 | 'color': 'white', 48 | } 49 | 50 | node_pref = { 51 | 'style': 'filled', 52 | 'fillcolor': 'white', 53 | 'color': '#aaaaaa', 54 | 'penwidth': '2', 55 | 'fontcolor': '#414141', 56 | } 57 | 58 | edge_pref = { 59 | 'color': '#aaaaaa', 60 | 'arrowsize': '1.2', 61 | 'penwidth': '2.5', 62 | 'fontcolor': '#414141', 63 | } 64 | 65 | 66 | def add_digraph(name=None, name_scope=None, style=True): 67 | """Return graphviz.dot.Digraph with TensorBoard-like style.""" 68 | try: 69 | import graphviz as gv 70 | except ModuleNotFoundError as error: 71 | raise ModuleNotFoundError( 72 | "You need to install graphviz to be able to use this functionality. " 73 | "See https://graphviz.readthedocs.io/en/stable/manual.html for details.") 74 | 75 | digraph = gv.Digraph(name=name) 76 | if name_scope: 77 | digraph.graph_attr['label'] = name_scope 78 | digraph.graph_attr['tooltip'] = name_scope 79 | 80 | if style is False: 81 | return digraph 82 | 83 | if name_scope: 84 | digraph.graph_attr.update(name_scope_graph_pref) 85 | else: 86 | digraph.graph_attr.update(non_name_scope_graph_pref) 87 | digraph.graph_attr.update(graph_pref) 88 | digraph.node_attr.update(node_pref) 89 | digraph.edge_attr.update(edge_pref) 90 | return digraph 91 | 92 | 93 | def add_digraph_node(digraph, name, op, attributes=None): 94 | """Adds a node to digraph.""" 95 | label = name.split('/')[-1] 96 | tooltip = name 97 | # For possible attribute values see: 98 | # https://graphviz.org/doc/info/attrs.html 99 | if attributes is None: 100 | attributes = [] 101 | if op is not None: 102 | tooltip += ':' + op.op_type 103 | if 'PartitionedCall' in op.op_type: 104 | try: 105 | label = '{}\n{}:{}'.format(label, 'f', op.get_attr('f').name) 106 | except ValueError: 107 | pass 108 | # For example: 109 | # attributes.append(('fillcolor', 'green')) 110 | digraph.node(name, label=label, tooltip=tooltip, _attributes=attributes) 111 | 112 | 113 | def add_digraph_edge(digraph, from_node, to_node, label=None, attributes=None): 114 | """Adds an edge to digraph.""" 115 | if attributes is None: 116 | attributes = [] 117 | digraph.edge(from_node, to_node, label=label, _attributes=attributes) 118 | 119 | 120 | def nested_dict(dict_, keys, val): 121 | """Assign value to dictionary.""" 122 | cloned = dict_.copy() 123 | if len(keys) == 1: 124 | cloned[keys[0]] = val 125 | return cloned 126 | dd = cloned[keys[0]] 127 | for k in keys[1:len(keys) - 1]: 128 | dd = dd[k] 129 | last_key = keys[len(keys) - 1] 130 | dd[last_key] = val 131 | return cloned 132 | 133 | 134 | def node_abs_paths(node): 135 | """Return absolute node path name.""" 136 | node_names = node.name.split('/') 137 | return ['/'.join(node_names[0:i + 1]) for i in range(len(node_names))] 138 | 139 | 140 | def node_table(gde_graph, depth=1, match_func=None): 141 | """Return dictionary of node.""" 142 | table = {} 143 | ops_table = {} 144 | max_depth = depth 145 | ops = gde_graph.nodes 146 | for depth_i in range(max_depth): 147 | for op in ops: 148 | abs_paths = node_abs_paths(op) 149 | if depth_i >= len(abs_paths): 150 | continue 151 | if match_func and not match_func(op.name): 152 | continue 153 | ops_table[op.name] = op 154 | ps = abs_paths[:depth_i + 1] 155 | if len(ps) == 1: 156 | key = '/'.join(abs_paths[0:depth_i + 1]) 157 | if not key in table: 158 | table[key] = {} 159 | else: 160 | table = nested_dict(table, ps, {}) 161 | return table, ops_table 162 | 163 | 164 | def tensor_shape(gde_tensor, depth=1): 165 | """Return node and the children.""" 166 | outpt_name = gde_tensor.name 167 | if len(outpt_name.split('/')) < depth: 168 | return None 169 | on = '/'.join(outpt_name.split('/')[:depth]) # output node 170 | result = re.match(r'(.*):\d*$', on) 171 | if not result: 172 | return None 173 | on = result.groups()[0] 174 | if gde_tensor.shape.ndims is None: 175 | return on, [] 176 | else: 177 | return on, gde_tensor.shape.as_list() 178 | 179 | 180 | def node_input_table(gde_graph, depth=1, match_func=None): 181 | """Return table of operations.""" 182 | table = {} 183 | inpt_op_table = {} 184 | inpt_op_shape_table = {} 185 | for op in gde_graph.nodes: 186 | if match_func and not match_func(op.name): 187 | continue 188 | op_name = op.name.split('/')[0:depth] 189 | opn = '/'.join(op_name) 190 | if not opn in inpt_op_table: 191 | inpt_op_table[opn] = [] 192 | inpt_op_list = ['/'.join(input_tensor.op.name.split('/')[0:depth]) \ 193 | for input_tensor in op.inputs if not match_func or match_func(input_tensor.op.name)] 194 | inpt_op_table[opn].append(inpt_op_list) 195 | for output in op.outputs: 196 | for i in range(depth): 197 | shape = tensor_shape(output, depth=i + 1) 198 | if shape: 199 | inpt_op_shape_table[shape[0]] = shape[1] 200 | for opn in inpt_op_table.keys(): 201 | t_l = [] 202 | for ll in inpt_op_table[opn]: 203 | list.extend(t_l, ll) 204 | table[opn] = list(set(t_l)) 205 | return table, inpt_op_shape_table 206 | 207 | 208 | def add_nodes(node_table, ops_table, name=None, name_scope=None, style=True): 209 | """Add TensorFlow graph's nodes to graphviz.dot.Digraph.""" 210 | global _CLUSTER_INDEX 211 | global _ADD_DIGRAPH_FUNC 212 | global _ADD_DIGRAPH_NODE_FUNC 213 | if name: 214 | digraph = _ADD_DIGRAPH_FUNC(name=name, name_scope=name_scope, style=style) 215 | else: 216 | digraph = _ADD_DIGRAPH_FUNC( 217 | name=str(uuid.uuid4().hex.upper()[0:6]), 218 | name_scope=name_scope, 219 | style=style) 220 | graphs = [] 221 | for key, value in node_table.items(): 222 | if len(value) > 0: 223 | sg = add_nodes( 224 | value, 225 | ops_table, 226 | name='cluster_%i' % _CLUSTER_INDEX, 227 | name_scope=key.split('/')[-1], 228 | style=style) 229 | op = ops_table.get(key, None) 230 | _ADD_DIGRAPH_NODE_FUNC(sg, key, op) 231 | _CLUSTER_INDEX += 1 232 | graphs.append(sg) 233 | else: 234 | op = ops_table.get(key, None) 235 | label = key.split('/')[-1] 236 | _ADD_DIGRAPH_NODE_FUNC(digraph, key, op) 237 | 238 | for tg in graphs: 239 | digraph.subgraph(tg) 240 | return digraph 241 | 242 | 243 | def edge_label(shape): 244 | """Returns texts of graph's edges.""" 245 | if len(shape) == 0: 246 | return '' 247 | if shape[0] is None: 248 | label = '?' 249 | else: 250 | label = '%i' % shape[0] 251 | for s in shape[1:]: 252 | if s is None: 253 | label += '×?' 254 | else: 255 | label += u'×%i' % s 256 | return label 257 | 258 | 259 | def add_edges(digraph, node_inpt_table, node_inpt_shape_table): 260 | """Add graph's edges to graphviz.dot.Digraph.""" 261 | global _ADD_DIGRAPH_EDGE_FUNC 262 | for node, node_inputs in node_inpt_table.items(): 263 | if re.match(r'\^', node): 264 | continue 265 | for ni in node_inputs: 266 | if ni == node: 267 | continue 268 | if re.match(r'\^', ni): 269 | continue 270 | if not ni in node_inpt_shape_table: 271 | _ADD_DIGRAPH_EDGE_FUNC(digraph, ni, node) 272 | else: 273 | shape = node_inpt_shape_table[ni] 274 | _ADD_DIGRAPH_EDGE_FUNC(digraph, ni, node, label=edge_label(shape)) 275 | return digraph 276 | 277 | 278 | def match_func(name_regex, negative_name_regex): 279 | name_re = None 280 | if name_regex: 281 | name_re = re.compile(name_regex) 282 | 283 | negative_name_re = None 284 | if negative_name_regex: 285 | negative_name_re = re.compile(negative_name_regex) 286 | 287 | def _matches(node_name): 288 | return bool( 289 | (not name_re or name_re.search(node_name)) and 290 | (not negative_name_re or not negative_name_re.search(node_name))) 291 | 292 | return _matches 293 | 294 | 295 | def board(gde_graph, 296 | depth=1, 297 | name='G', 298 | style=True, 299 | name_regex='', 300 | negative_name_regex='', 301 | add_digraph_func=None, 302 | add_digraph_node_func=None, 303 | add_digraph_edge_func=None): 304 | """Return GraphViz Digraph rendering of the specified graph. 305 | 306 | Args: 307 | depth: the maximum depth of the graph to display. 308 | name: graph name. 309 | style: whether to apply default styles. 310 | name_regex: only diplay nodes that have name matching this regex. 311 | negative_name_regex: only diplay nodes that have name not matching this 312 | regex. 313 | add_digraph_func: custom override for function for adding subraphs 314 | to the resulting Digraph object. 315 | add_digraph_node_func: custom override for function for adding nodes 316 | (vertices) to the resulting Digraph object. 317 | add_digraph_edge_func: custom override for function for adding edges 318 | to the resulting Digraph object. 319 | 320 | Returns: 321 | graphviz.dot.Digraph object with visual representtion for the specified 322 | graph. 323 | """ 324 | global _ADD_DIGRAPH_FUNC 325 | global _ADD_DIGRAPH_NODE_FUNC 326 | global _ADD_DIGRAPH_EDGE_FUNC 327 | global _CLUSTER_INDEX 328 | _CLUSTER_INDEX = 0 329 | _ADD_DIGRAPH_FUNC = add_digraph_func if add_digraph_func is not None else add_digraph 330 | _ADD_DIGRAPH_NODE_FUNC = add_digraph_node_func if add_digraph_node_func is not None else add_digraph_node 331 | _ADD_DIGRAPH_EDGE_FUNC = add_digraph_edge_func if add_digraph_edge_func is not None else add_digraph_edge 332 | _node_name_matches_func = match_func( 333 | name_regex, negative_name_regex) 334 | 335 | _node_table, _ops_table = node_table( 336 | gde_graph, depth=depth, match_func=_node_name_matches_func) 337 | _node_inpt_table, _node_inpt_shape_table = node_input_table( 338 | gde_graph, depth=depth, match_func=_node_name_matches_func) 339 | digraph = add_nodes(_node_table, _ops_table, name=name, style=style) 340 | digraph = add_edges(digraph, _node_inpt_table, _node_inpt_shape_table) 341 | return digraph 342 | 343 | 344 | def visualize( 345 | gde_graph, 346 | format=None, 347 | depth=1, 348 | name='G', 349 | style=True, 350 | name_regex='', 351 | negative_name_regex='', 352 | add_digraph_func=None, 353 | add_digraph_node_func=None, 354 | add_digraph_edge_func=None): 355 | """Return GraphViz Digraph rendering of the specified graph. 356 | 357 | Args: 358 | gde_graph: Graph to display. 359 | format: GraphViz display format (see https://graphviz.org/docs/outputs/). 360 | In addition to that it supports jupyter_svg, and jupyter_interactive modes 361 | depth: the maximum depth of the graph to display. 362 | name: graph name. 363 | style: whether to apply default styles. 364 | name_regex: only diplay nodes that have name matching this regex. 365 | negative_name_regex: only diplay nodes that have name not matching this 366 | regex. 367 | add_digraph_func: custom override for function for adding subraphs 368 | to the resulting Digraph object. 369 | add_digraph_node_func: custom override for function for adding nodes 370 | (vertices) to the resulting Digraph object. 371 | add_digraph_edge_func: custom override for function for adding edges 372 | to the resulting Digraph object. 373 | 374 | Returns: 375 | graphviz.dot.Digraph object with visual representtion for the specified 376 | graph. 377 | """ 378 | dg = board( 379 | gde_graph, 380 | depth=depth, 381 | name=name, 382 | style=style, 383 | name_regex=name_regex, 384 | negative_name_regex=negative_name_regex, 385 | add_digraph_func=add_digraph_func, 386 | add_digraph_node_func=add_digraph_node_func, 387 | add_digraph_edge_func=add_digraph_edge_func) 388 | 389 | if format is None: 390 | return dg 391 | elif format == FORMAT_JUPYTER_SVG: 392 | return jupyter_helper.jupyter_show_as_svg(dg) 393 | elif format == FORMAT_JUPYTER_INTERACTIVE: 394 | return jupyter_helper.jupyter_pan_and_zoom(dg) 395 | else: 396 | return dg.pipe(format=format) 397 | -------------------------------------------------------------------------------- /tests/transform_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.contrib.graph_editor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import functools 23 | import numpy as np 24 | 25 | import tensorflow.compat.v1 as tf 26 | tf.disable_eager_execution() 27 | 28 | import unittest 29 | 30 | import graph_def_editor as gde 31 | 32 | # Precision tolerance for floating-point value tests. 33 | ERROR_TOLERANCE = 1e-3 34 | 35 | 36 | class TransformTest(unittest.TestCase): 37 | 38 | # Slightly modified version of the method by the same name in tf.TestCase 39 | def assertNear(self, f1, f2, err, msg=None): 40 | """Asserts that two floats are near each other. 41 | Checks that |f1 - f2| < err and asserts a test failure 42 | if not. 43 | Args: 44 | f1: A float value. 45 | f2: A float value. 46 | err: A float value. 47 | msg: An optional string message to append to the failure message. 48 | """ 49 | # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 50 | self.assertTrue( 51 | f1 == f2 or abs(f1 - f2) <= err, 52 | "{:f} != {:f} +/- {:f}{}".format(f1, f2, err, 53 | " ({})".format(msg if msg is not None 54 | else ""))) 55 | 56 | def setUp(self): 57 | tf_graph = tf.Graph() 58 | with tf_graph.as_default(): 59 | c0 = tf.constant(1.0, shape=[10], name="Const") 60 | c0.op._set_attr("_foo", tf.AttrValue(s=b"foo")) 61 | c1 = tf.constant(1.0, shape=[10], name="Const") 62 | c2 = tf.constant(1.0, shape=[10], name="Const") 63 | i = tf.constant(1.0, shape=[10], name="Input") 64 | tf.identity(tf.add(c2, tf.add(c1, tf.add(c0, i))), name="o") 65 | self.graph = gde.Graph(tf_graph) 66 | self.o = self.graph["o"] 67 | 68 | def test_copy(self): 69 | graph = gde.Graph() 70 | _, info = gde.copy(self.graph, graph) 71 | self.assertEqual( 72 | set(op.name for op in self.graph.nodes), 73 | set(op.name for op in graph.nodes)) 74 | src_ops = self.graph.nodes 75 | dst_ops = graph.nodes 76 | for op in src_ops: 77 | op_ = info.transformed(op) 78 | self.assertTrue(op_ in dst_ops) 79 | self.assertEqual(op.name, op_.name) 80 | self.assertEqual(info.original(op_), op) 81 | src_ts = self.graph.tensors 82 | dst_ts = graph.tensors 83 | for t in src_ts: 84 | t_ = info.transformed(t) 85 | self.assertTrue(t_ in dst_ts) 86 | self.assertEqual(t.name, t_.name) 87 | self.assertEqual(info.original(t_), t) 88 | 89 | def test_copy_assert(self): 90 | tf_g = tf.Graph() 91 | with tf_g.as_default(): 92 | a = tf.constant(1, name="a") 93 | b = tf.constant(1, name="b") 94 | eq = tf.equal(a, b, name="EQ") 95 | assert_tf_op = tf.Assert(eq, [a, b]) 96 | with tf.control_dependencies([assert_tf_op]): 97 | _ = tf.add(a, b) 98 | assert_op_name = assert_tf_op.name 99 | 100 | g = gde.Graph(tf_g) 101 | assert_op = g[assert_op_name] 102 | sgv = gde.make_view([assert_op, g["EQ"], g["a"], g["b"]]) 103 | copier = gde.Transformer() 104 | _, info = copier(sgv, sgv.graph, "", "") 105 | new_assert_op = info.transformed(assert_op) 106 | self.assertIsNotNone(new_assert_op) 107 | 108 | def test_transform(self): 109 | transformer = gde.Transformer() 110 | 111 | def my_transform_op_handler(info, op, new_inputs): 112 | add_noise = op.name.startswith("Add") 113 | op_, op_outputs_ = gde.transform.copy_op_handler(info, op, new_inputs) 114 | if not add_noise: 115 | return op_, op_outputs_ 116 | 117 | # add some noise to op 118 | # Old code: 119 | # with info.graph_.as_default(): 120 | # t_ = math_ops.add( 121 | # constant_op.constant(1.0, shape=[10], name="Noise"), 122 | # op_.outputs[0], 123 | # name="AddNoise") 124 | noise_op = gde.make_const(info.graph_, "Noise", 125 | np.full([10], 1., dtype=np.float32), 126 | uniquify_name=True) 127 | add_noise_op = info.graph_.add_node("AddNoise", "Add", uniquify_name=True) 128 | add_noise_op.add_attr("T", tf.float32) 129 | add_noise_op.set_inputs([noise_op.outputs[0], op_.outputs[0]]) 130 | add_noise_op.infer_outputs() 131 | t_ = add_noise_op.outputs[0] 132 | 133 | # return the "noisy" op 134 | return op_, [t_] 135 | 136 | transformer.transform_op_handler = my_transform_op_handler 137 | 138 | graph = gde.Graph() 139 | transformer(self.graph, graph, "", "") 140 | matcher0 = gde.OpMatcher("AddNoise").input_ops( 141 | "Noise", gde.OpMatcher("Add").input_ops("Const", "Input")) 142 | matcher1 = gde.OpMatcher("AddNoise_1").input_ops( 143 | "Noise_1", gde.OpMatcher("Add_1").input_ops("Const_1", matcher0)) 144 | matcher2 = gde.OpMatcher("AddNoise_2").input_ops( 145 | "Noise_2", gde.OpMatcher("Add_2").input_ops("Const_2", matcher1)) 146 | top = gde.select_ops("^AddNoise_2$", graph=graph)[0] 147 | self.assertTrue(matcher2(top)) 148 | 149 | def test_transform_nodedef_fn(self): 150 | transformer = gde.Transformer() 151 | 152 | def nodedef_fn(node_def): 153 | if "_foo" in node_def.attr: 154 | del node_def.attr["_foo"] 155 | node_def.attr["_bar"].s = b"bar" 156 | return node_def 157 | 158 | my_copy_op_handler = functools.partial( 159 | gde.transform.copy_op_handler, nodedef_fn=nodedef_fn) 160 | transformer.transform_op_handler = my_copy_op_handler 161 | 162 | graph = gde.Graph() 163 | transformer(self.graph, graph, "", "") 164 | 165 | c0_before = self.graph["Const"] 166 | c0_after = graph["Const"] 167 | self.assertEqual(c0_before.get_attr("_foo"), "foo") 168 | with self.assertRaises(ValueError): 169 | c0_after.get_attr("_foo") 170 | 171 | all_ops = graph.nodes 172 | for op in all_ops: 173 | self.assertEqual(op.get_attr("_bar"), "bar") 174 | 175 | def test_copy_with_input_replacements(self): 176 | # Original code: 177 | # with self.graph.as_default(): 178 | # _ = tf.constant(10.0, shape=[10], name="Input") 179 | # New code adds node as a NodeDef 180 | ten_node = gde.make_const(self.graph, "Ten", np.full([10], 10., 181 | dtype=np.float32)) 182 | 183 | ten_tensor = ten_node.output(0) 184 | sgv, _ = gde.copy_with_input_replacements( 185 | # self.o is an identity on top of a tree of add ops 186 | [self.o, self.o.inputs[0].node], 187 | # Drill down to second input to outer add() 188 | {self.o.inputs[0].node.inputs[1]: ten_tensor} 189 | ) 190 | 191 | after_graph = tf.Graph() 192 | with after_graph.as_default(): 193 | tf.import_graph_def(self.graph.to_graph_def(), name="") 194 | with tf.Session() as sess: 195 | val = sess.run(sgv.outputs[0].name) 196 | 197 | print("val is {}".format(val)) 198 | self.assertNear( 199 | np.linalg.norm(val - np.array([11])), 0.0, ERROR_TOLERANCE) 200 | 201 | 202 | def test_copy_with_collection(self): 203 | """Test for issue #36""" 204 | tmp_graph = tf.Graph() 205 | with tmp_graph.as_default(): 206 | c = tf.constant(42, name="FortyTwo") 207 | tmp_graph.add_to_collection("Answers", c) 208 | 209 | g = gde.Graph(tmp_graph) 210 | g2 = gde.Graph() 211 | gde.transform.copy(g, g2) 212 | self.assertTrue("Answers" in g2.get_all_collection_keys()) 213 | 214 | 215 | @staticmethod 216 | def _create_replace_graph(): 217 | """Subroutine of the next few tests. Creates the graph that all these 218 | tests use. Since the tests modify the graph, it needs to be recreated 219 | each time. 220 | 221 | Returns: 222 | (Graph object, c, target tensor to replace, new value, output tensor)""" 223 | tmp_graph = tf.Graph() 224 | with tmp_graph.as_default(): 225 | a = tf.constant(1.0, name="a") 226 | b = tf.Variable(1.0, name="b") 227 | eps = tf.constant(0.001, name="eps") 228 | tf.identity(a + b + eps, name="c") 229 | tf.constant(2.0, name="a_new") 230 | ret = gde.Graph(tmp_graph) 231 | return ret, ret["a"].output(0), ret["a_new"].output(0), ret["c"].output(0) 232 | 233 | def test_graph_replace(self): 234 | g, a, a_new, c = self._create_replace_graph() 235 | c_new = gde.graph_replace(c, {a: a_new}) 236 | 237 | with g.to_tf_graph().as_default(): 238 | with tf.Session() as sess: 239 | sess.run(tf.global_variables_initializer()) 240 | c_val, c_new_val = sess.run([c.name, c_new.name]) 241 | 242 | self.assertNear(c_val, 2.001, ERROR_TOLERANCE) 243 | self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE) 244 | 245 | def test_graph_replace_dict(self): 246 | g, a, a_new, c = self._create_replace_graph() 247 | c_new = gde.graph_replace({"c": c}, {a: a_new}) 248 | self.assertTrue(isinstance(c_new, dict)) 249 | 250 | with g.to_tf_graph().as_default(): 251 | with tf.Session() as sess: 252 | sess.run(tf.global_variables_initializer()) 253 | c_val, c_new_val = sess.run([c.name, 254 | {k: v.name for k, v in c_new.items()}]) 255 | 256 | self.assertTrue(isinstance(c_new_val, dict)) 257 | self.assertNear(c_val, 2.001, ERROR_TOLERANCE) 258 | self.assertNear(c_new_val["c"], 3.001, ERROR_TOLERANCE) 259 | 260 | def test_graph_replace_ordered_dict(self): 261 | g, a, a_new, c = self._create_replace_graph() 262 | c_new = gde.graph_replace(collections.OrderedDict({"c": c}), {a: a_new}) 263 | self.assertTrue(isinstance(c_new, collections.OrderedDict)) 264 | 265 | def test_graph_replace_named_tuple(self): 266 | g, a, a_new, c = self._create_replace_graph() 267 | one_tensor = collections.namedtuple("OneTensor", ["t"]) 268 | c_new = gde.graph_replace(one_tensor(c), {a: a_new}) 269 | self.assertTrue(isinstance(c_new, one_tensor)) 270 | 271 | def test_graph_replace_missing(self): 272 | tmp_graph = tf.Graph() 273 | with tmp_graph.as_default(): 274 | a_tensor = tf.constant(1.0, name="a") 275 | b_tensor = tf.constant(2.0, name="b") 276 | _ = tf.add(a_tensor, 2 * b_tensor, name="c") 277 | _ = tf.constant(2.0, name="d") 278 | g = gde.Graph(tmp_graph) 279 | res = gde.graph_replace([g["b"].output(0), g["c"].output(0)], 280 | {g["a"].output(0): g["d"].output(0)}) 281 | self.assertEqual(res[0].name, "b:0") 282 | self.assertEqual(res[1].name, "c_1:0") 283 | 284 | @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") 285 | def test_graph_replace_gradients(self): 286 | tmp_graph = tf.Graph() 287 | with tmp_graph.as_default(): 288 | w_tensor = tf.Variable(0.0, name="w") 289 | y_tensor = tf.multiply(tf.multiply(w_tensor, w_tensor, name="mul1"), 290 | w_tensor, name="mul2") 291 | grad_tensor = tf.gradients(y_tensor, w_tensor, name="gradient")[0] 292 | _ = tf.identity(grad_tensor, "grad") 293 | 294 | g = gde.Graph(tmp_graph) 295 | 296 | # Extract the operations. 297 | replacement_ts = {g["w/read"].output(0): g["grad"].output(0)} 298 | 299 | # Should not raise exception. 300 | res = gde.graph_replace(g["grad"].output(0), replacement_ts, 301 | dst_scope="res") 302 | 303 | self.assertNotEqual(res.name, g["grad"].output(0).name) 304 | after_graph = tf.Graph() 305 | with after_graph.as_default(): 306 | tf.import_graph_def(g.to_graph_def(), name="") 307 | gde.util.load_variables_to_tf_graph(g) 308 | with tf.Session() as sess: 309 | sess.run(tf.global_variables_initializer()) 310 | g_val, res_val = sess.run([g["grad"].output(0).name, res.name]) 311 | self.assertNear(g_val, 0.0, ERROR_TOLERANCE) 312 | self.assertNear(res_val, 0.0, ERROR_TOLERANCE) 313 | 314 | @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") 315 | def test_graph_while_loop(self): 316 | tf_graph = tf.Graph() 317 | with tf_graph.as_default(): 318 | max_index = tf.placeholder(dtype=tf.int32, shape=tuple()) 319 | index_start = tf.constant(1) 320 | sum_start = tf.constant(0) 321 | _, result = tf.while_loop( 322 | cond=lambda i, unused_s: i <= max_index, 323 | body=lambda i, s: (i + 1, s + i), 324 | loop_vars=[index_start, sum_start]) 325 | g = gde.Graph(tf_graph) 326 | result_tensor = g[result.op.name].output(0) 327 | max_index_tensor = g[max_index.op.name].output(0) 328 | 329 | g.frozen = True 330 | copied_graph = gde.Graph() 331 | _, copy_info = gde.copy( 332 | g, dst_graph=copied_graph, dst_scope="imported") 333 | copied_result_tensor = copy_info.transformed(result_tensor) 334 | copied_max_index_tensor = copy_info.transformed(max_index_tensor) 335 | 336 | tf_copied_graph = tf.Graph() 337 | with tf_copied_graph.as_default(): 338 | tf.import_graph_def(copied_graph.to_graph_def(), name="") 339 | with tf.Session() as sess: 340 | n = 10 341 | sum_val = sess.run(copied_result_tensor.name, 342 | feed_dict={copied_max_index_tensor.name: n}) 343 | self.assertEqual(sum_val, 55) 344 | 345 | @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") 346 | def test_graph_cond(self): 347 | tf_g = tf.Graph() 348 | with tf_g.as_default(): 349 | choice_tensor = tf.placeholder(shape=(), dtype=tf.bool, name="choice") 350 | _ = tf.identity( 351 | tf.cond( 352 | choice_tensor, 353 | lambda: tf.constant(1), 354 | lambda: tf.constant(2) 355 | ), 356 | name="result" 357 | ) 358 | 359 | g = gde.Graph(tf_g) 360 | choice = g["choice"].output(0) 361 | result = g["result"].output(0) 362 | 363 | copied_g = gde.Graph() 364 | _, copy_info = gde.copy( 365 | g, dst_graph=copied_g, dst_scope="imported") 366 | copied_result = copy_info.transformed(result) 367 | copied_choice = copy_info.transformed(choice) 368 | 369 | tf_copied_graph = tf.Graph() 370 | with tf_copied_graph.as_default(): 371 | tf.import_graph_def(copied_g.to_graph_def(), name="") 372 | with tf.Session() as sess: 373 | res = sess.run(copied_result.name, feed_dict={copied_choice.name: True}) 374 | self.assertEqual(res, 1) 375 | res = sess.run(copied_result.name, 376 | feed_dict={copied_choice.name: False}) 377 | self.assertEqual(res, 2) 378 | 379 | 380 | if __name__ == "__main__": 381 | unittest.main() 382 | -------------------------------------------------------------------------------- /graph_def_editor/function_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Objects for representing function graphs undergoing rewrite operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from collections import Counter 23 | import datetime 24 | from distutils import dir_util 25 | import os 26 | from six import string_types 27 | import tensorflow.compat.v1 as tf 28 | import sys 29 | if sys.version >= "3": 30 | from typing import Tuple, Dict, List, FrozenSet, Iterable, Union, Set, Any 31 | 32 | from graph_def_editor import base_graph, node, util, tensor, variable 33 | 34 | # TODO: Move this protobuf into this project so we don't depend on 35 | # tf.core.framework 36 | from tensorflow.core.framework import function_pb2, op_def_pb2 37 | from tensorflow.python.framework import function_def_to_graph 38 | 39 | 40 | __all__ = [ 41 | "FunctionGraph", 42 | ] 43 | 44 | # Special attribute in which TensorFlow stores frame names for while loops ( 45 | # see node_to_frame_name() for more information 46 | _INPUT_DUMMY_OP_NAME = "__input__" 47 | 48 | 49 | class FunctionGraph(base_graph.BaseGraph): 50 | """Wrapper class for TensorFlow function graphs.""" 51 | 52 | def __init__( 53 | self, 54 | name=None, # type: str 55 | parent_tf_graph=None, # type: tf.Graph 56 | parent_graph=None # type: gde.Graph 57 | ): 58 | """Wrap a tf.GraphDef protocol buffer in a FunctionGraph object. 59 | 60 | Args: 61 | g: a tf.Graph or tf.GraphDef protobuf that represents a 62 | TensorFlow graph. If set to None, generate an empty 63 | tf.GraphDef 64 | name: Optional human-readable name for the graph. If not provided, 65 | the constructor will generate a name. 66 | """ 67 | super(FunctionGraph, self).__init__(name) 68 | (self._func_graph, self._func_graph_def) = \ 69 | _get_func_graph_for_name(parent_tf_graph, name) 70 | output_map = _decode_graph(name, self._func_graph) 71 | output_map_pairs = {} 72 | for op_name, tuples in output_map.items(): 73 | output_map_pairs[op_name] = \ 74 | [(dtype, shape) for (dtype, shape, _) in tuples] 75 | 76 | # Populate fields of object 77 | self._node_to_frame_names = None 78 | self._frame_name_to_nodes = None 79 | self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] 80 | self._variable_name_to_variable = {} # Dict[str, Variable] 81 | self._collection_name_to_type = None # Dict[str, str], generated on demand 82 | self._input_nodes = [] 83 | self._output_nodes = [] 84 | self._parent_graph = parent_graph 85 | 86 | for input_arg in self._func_graph_def.signature.input_arg: 87 | self._input_nodes.append( 88 | self.add_node(input_arg.name, _INPUT_DUMMY_OP_NAME)) 89 | self[input_arg.name].set_outputs_from_pairs( 90 | output_map_pairs[input_arg.name]) 91 | 92 | # Load nodes in three passes because the g may contain cycles. 93 | for node_def in self._func_graph_def.node_def: 94 | self.add_node_from_node_def(node_def, set_inputs=False) 95 | for node_def in self._func_graph_def.node_def: 96 | self[node_def.name].set_outputs_from_pairs( 97 | output_map_pairs[node_def.name]) 98 | for node_def in self._func_graph_def.node_def: 99 | try: 100 | self[node_def.name].set_inputs_from_strings( 101 | node_def.input, 102 | set_control_inputs=True, 103 | output_map=output_map) 104 | except Exception as ex: 105 | print("can't set inputs for node: {}; reason: {}".format( 106 | node_def.name, ex)) 107 | 108 | for output_tensor in self._func_graph.outputs: 109 | self._output_nodes.append(self.get_node_by_name(output_tensor.op.name)) 110 | 111 | @property 112 | def input_nodes(self): 113 | return self._input_nodes 114 | 115 | @property 116 | def output_nodes(self): 117 | return self._output_nodes 118 | 119 | @property 120 | def parent_graph(self): 121 | return self._parent_graph 122 | 123 | def get_func_graph_for_name(self, graph, func_name): 124 | """Returns the FuncGraph associated to the given func_name if possible.""" 125 | outer_graph = graph 126 | while graph is not None: 127 | # pylint: disable=protected-access 128 | func = graph._get_function(str(func_name)) 129 | if func is not None: 130 | if hasattr(func, "graph"): 131 | return func.graph 132 | # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. 133 | # in the case of nested if ops or when the gradient is being computed 134 | # from inside a Defun. We build the `func_graph` with `outer_graph` 135 | # as its outer graph. 136 | with outer_graph.as_default(): 137 | # This is a _DefinedFunction. 138 | func_graph = ( 139 | function_def_to_graph.function_def_to_graph(func.definition)) 140 | if func_graph is not None: 141 | return func_graph 142 | if hasattr(graph, "outer_graph"): 143 | graph = graph.outer_graph 144 | else: 145 | raise ValueError( 146 | "Function {} does not exist in the graph.".format(func_name)) 147 | 148 | def to_function_graph_def(self, add_shapes=True): 149 | # type: (bool) -> function_pb2.FunctionDef 150 | """ 151 | Args: 152 | add_shapes: If True, add the special "_output_shapes" attribute with 153 | output shape information from this Node's output metadata. 154 | 155 | Returns the `function_pb2.FunctionDef` serialization of this function's 156 | graph in its current form. 157 | """ 158 | ret = function_pb2.FunctionDef() 159 | ret.CopyFrom(self._func_graph_def) 160 | # Leave signature as is, but replace all node_defs 161 | del ret.node_def[:] 162 | ret.signature.CopyFrom(self._func_graph_def.signature) 163 | 164 | input_args = [input_arg.name for input_arg in ret.signature.input_arg] 165 | 166 | for op in self.nodes: 167 | if op.op_type == _INPUT_DUMMY_OP_NAME: 168 | continue 169 | 170 | node_def = ret.node_def.add() 171 | op.to_node_def(node_def, add_shapes) 172 | unique_input_counter = Counter() 173 | 174 | for i in range(len(op.inputs)): 175 | (input_tensor_name, global_input_index_str) = ( 176 | op.inputs[i].name.split(":")) 177 | 178 | global_input_index = int(global_input_index_str) 179 | if input_tensor_name in input_args: 180 | # don't add index for function args 181 | node_def.input[i] = input_tensor_name 182 | else: 183 | input_op_output_args, input_op_output_has_number_attr = ( 184 | self._get_op_def_denormalized_outputs(op.inputs[i].op)) 185 | if (len(input_op_output_args) == 1 and 186 | input_op_output_args[0].type_list_attr): 187 | node_def.input[i] = ( 188 | input_tensor_name + ":" + input_op_output_args[0].name + ":" + 189 | str(global_input_index)) 190 | else: 191 | input_name = ( 192 | input_tensor_name + ":" + 193 | input_op_output_args[global_input_index].name) 194 | node_def.input[i] = ( 195 | input_name + ":" + str(unique_input_counter[input_name])) 196 | if input_op_output_has_number_attr: 197 | # only uniquify input args with var length, 198 | # otherwise it should be 0 199 | unique_input_counter[input_name] += 1 200 | return ret 201 | 202 | def to_tf_function_graph(self): 203 | # type: () -> tf.Graph 204 | """ 205 | Converts this graph into a new TensorFlow `Graph`. Also takes care of 206 | variables. 207 | Note that function_def_to_graph.function_def_to_graph won't work if 208 | function calls into other functions. 209 | 210 | Returns a fresh `tf.Graph` containing all the nodes and variables that 211 | this object represents. 212 | """ 213 | return function_def_to_graph.function_def_to_graph( 214 | self.to_function_graph_def()) 215 | 216 | def increment_version_counter(self): 217 | """ 218 | Mark the structure of this graph as "changed" and invalidate any cached 219 | information about the edges of the graph. 220 | """ 221 | super(FunctionGraph, self).increment_version_counter() 222 | self._node_to_frame_names = None 223 | self._frame_name_to_nodes = None 224 | self._head_name_to_coloc_group = None 225 | self._collection_name_to_type = None 226 | 227 | def frame_name_to_nodes(self, frame_name): 228 | # type: (str) -> Tuple[node.Node] 229 | """ 230 | Performs the inverse mapping of node_to_frame_name(). 231 | 232 | Args: 233 | frame_name: Name of a control flow frame in the graph 234 | 235 | Returns: 236 | All nodes that are tagged with the indicated frame, either as an 237 | innermost frame or as a containing frame. 238 | """ 239 | if self._node_to_frame_names is None: 240 | self._generate_node_to_frame_name() 241 | return self._frame_name_to_nodes[frame_name] 242 | 243 | def get_frame_names(self): 244 | # type: () -> Tuple[str] 245 | """ 246 | Returns: 247 | Tuple of all the unique names of frames that occur in this graph. 248 | """ 249 | if self._node_to_frame_names is None: 250 | self._generate_node_to_frame_name() 251 | return self._frame_name_to_nodes.keys() 252 | 253 | def _get_op_def_denormalized_outputs(self, op): 254 | # type: (Node) -> (List[op_def_pb2.OpDef.ArgDef], bool) 255 | # pylint: disable=protected-access 256 | op_def = self._func_graph._get_op_def(op.op_type) 257 | output_args = [] 258 | 259 | input_op_output_has_number_attr = False 260 | for output_arg in op_def.output_arg: 261 | if output_arg.number_attr: 262 | l = op.get_attr(output_arg.number_attr) 263 | input_op_output_has_number_attr = True 264 | for _ in range(l): 265 | output_args.append(op_def_pb2.OpDef.ArgDef(name=output_arg.name, 266 | type=output_arg.type)) 267 | else: 268 | output_args.append(output_arg) 269 | 270 | return (output_args, input_op_output_has_number_attr) 271 | 272 | def _visualize_node( 273 | self, 274 | gde_node, 275 | format=None, 276 | depth=1, 277 | style=True, 278 | name_regex="", 279 | negative_name_regex="", 280 | add_digraph_func=None, 281 | add_digraph_node_func=None, 282 | add_digraph_edge_func=None, 283 | depth_before=1, 284 | depth_after=2): 285 | """Return GraphViz Digraph rendering of the current and adjacent nodes. 286 | 287 | Args: 288 | gde_node: a node to visualize. 289 | format: GraphViz display format. In addition to that it supports 290 | jupyter_svg, and jupyter_interactive modes. 291 | depth: the maximum depth of the graph to display. 292 | style: whether to apply default styles. 293 | name_regex: only diplay nodes that have name matching this regex. 294 | negative_name_regex: only diplay nodes that have name not matching this 295 | regex. 296 | add_digraph_func: custom override for function for adding subraphs 297 | to the resulting Digraph object. 298 | add_digraph_node_func: custom override for function for adding nodes 299 | (vertices) to the resulting Digraph object. 300 | add_digraph_edge_func: custom override for function for adding edges 301 | to the resulting Digraph object. 302 | depth_before: number of adjacent nodes to show before the current one. 303 | depth_after: number of adjacent nodes to show after the current one. 304 | 305 | Returns: 306 | graphviz.dot.Digraph object with visual representtion for the current 307 | graph. 308 | """ 309 | 310 | # pylint: disable=protected-access 311 | return self._parent_graph._visualize_node( 312 | gde_node=gde_node, 313 | format=format, 314 | depth=depth, 315 | style=style, 316 | name_regex=name_regex, 317 | negative_name_regex=negative_name_regex, 318 | add_digraph_func=add_digraph_func, 319 | add_digraph_node_func=add_digraph_node_func, 320 | add_digraph_edge_func=add_digraph_edge_func, 321 | depth_before=depth_before, 322 | depth_after=depth_after) 323 | 324 | 325 | ################################################################################ 326 | # Stuff below this line is private to this file. 327 | 328 | 329 | def _get_func_graph_for_name(graph, func_name): 330 | """Returns the FuncGraph and FuncDef associated to the given func_name.""" 331 | outer_graph = graph 332 | while graph is not None: 333 | # pylint: disable=protected-access 334 | func = graph._get_function(str(func_name)) 335 | if func is not None: 336 | if hasattr(func, "graph"): 337 | return (func.graph, func.definition) 338 | # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. 339 | # in the case of nested if ops or when the gradient is being computed 340 | # from inside a Defun. We build the `func_graph` with `outer_graph` as its 341 | # outer graph. 342 | with outer_graph.as_default(): 343 | # This is a _DefinedFunction. 344 | func_graph = ( 345 | function_def_to_graph.function_def_to_graph(func.definition)) 346 | if func_graph is not None: 347 | return (func_graph, func.definition) 348 | if hasattr(graph, "outer_graph"): 349 | graph = graph.outer_graph 350 | else: 351 | raise ValueError( 352 | "Function {} does not exist in the graph.".format(func_name)) 353 | 354 | 355 | def _decode_graph(name, func_graph): 356 | # type: (str, tf.Graph) -> Dict[str, List[Tuple[tf.DType, tf.TensorShape, str]]] 357 | """ 358 | Use public TensorFlow APIs to decode the important information that is not 359 | explicitly stored in the GraphDef proto, but which must be inferred from the 360 | GraphDef in conjunction with additional data structures that TensorFlow 361 | generally keeps to itself. 362 | 363 | Args: 364 | name: function name. 365 | func_graph: tf.GraphDef protobuf that represents a function graph. 366 | 367 | Returns: 368 | A map from node name to a list of (type, shape, output_arg_name) tuples 369 | that describes in turn each of the outputs of said node. 370 | """ 371 | # The information in a NodeDef is not sufficient to determine output type 372 | # information. For that kind of type inference, you need access to the 373 | # corresponding OpDef protos. Unfortunately there is not a public API that 374 | # allows for OpDef lookup. So instead we instantiate the graph that 375 | # graph_def describes. This approach makes things easier, but there will be 376 | # a reduction in forwards compatibility, because import_graph_def() does a 377 | # lot of sanity checks that aren't necessary when rewriting a graph_def. 378 | output_map = {} 379 | for op in func_graph.get_operations(): 380 | # pylint: disable=protected-access 381 | op_def = func_graph._get_op_def(op.type) 382 | output_idx = 0 383 | output_map[op.name] = [] 384 | for output_arg_idx in range(len(op_def.output_arg)): 385 | output_arg = op_def.output_arg[output_arg_idx] 386 | output = op.outputs[output_idx] 387 | if output_arg.type_list_attr: 388 | output_map[op.name] = [( 389 | output.dtype, output.shape, op_def.output_arg[0].name) 390 | for output in op.outputs] 391 | break 392 | elif output_arg.number_attr: 393 | output_len = op.node_def.attr[output_arg.number_attr].i 394 | for _ in range(output_len): 395 | output = op.outputs[output_idx] 396 | output_map[op.name].append( 397 | (output.dtype, output.shape, output_arg.name)) 398 | output_idx += 1 399 | else: 400 | output_map[op.name].append( 401 | (output.dtype, output.shape, output_arg.name)) 402 | output_idx += 1 403 | return output_map 404 | 405 | -------------------------------------------------------------------------------- /examples/coco_example.py: -------------------------------------------------------------------------------- 1 | # Coypright 2019 IBM. All Rights Reserved. 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """ 18 | Example of using the GraphDef editor and the Graph Transform Tool to prep an 19 | object detection model for easy deployment. 20 | 21 | This script starts with the pre-trained object detection model from the 22 | TensorFlow Models repository; see https://github.com/tensorflow/models/ 23 | blob/master/research/object_detection/g3doc/detection_model_zoo.md. 24 | 25 | Specifically, we use the object detector trained on the COCO dataset with a 26 | MobileNetV1 architecture. 27 | 28 | The original model takes as input batches of equal-sized images, represented 29 | as a single dense numpy array of binary pixel data. The output of the 30 | original model represents the object type as an integer. This script grafts on 31 | pre- and post-processing ops to make the input and output format more amenable 32 | to use in applications. After these ops are added, the resulting graph takes a 33 | single image file as an input and produces string-valued object labels. 34 | 35 | To run this example from the root of the project, type: 36 | PYTHONPATH=$PWD env/bin/python examples/coco_example.py 37 | """ 38 | 39 | from __future__ import absolute_import 40 | from __future__ import division 41 | from __future__ import print_function 42 | 43 | import os 44 | import re 45 | import tensorflow as tf 46 | import graph_def_editor as gde 47 | import numpy as np 48 | # noinspection PyPackageRequirements 49 | import PIL # Pillow 50 | import shutil 51 | import tarfile 52 | from typing import List 53 | import textwrap 54 | import urllib.request 55 | 56 | from tensorflow.tools import graph_transforms 57 | 58 | FLAGS = tf.flags.FLAGS 59 | 60 | 61 | def _indent(s): 62 | return textwrap.indent(str(s), " ") 63 | 64 | 65 | # Parameters of input graph 66 | _LONG_MODEL_NAME = "ssd_mobilenet_v1_coco_2018_01_28" 67 | _MODEL_TARBALL_URL = ("http://download.tensorflow.org/models/object_detection/" 68 | + _LONG_MODEL_NAME + ".tar.gz") 69 | # Path to frozen graph within tarball 70 | _FROZEN_GRAPH_MEMBER = _LONG_MODEL_NAME + "/frozen_inference_graph.pb" 71 | _INPUT_NODE_NAMES = ["image_tensor"] 72 | _OUTPUT_NODE_NAMES = ["detection_boxes", "detection_classes", 73 | "detection_scores", "num_detections"] 74 | 75 | _HASH_TABLE_INIT_OP_NAME = "hash_table_init" 76 | 77 | # Label map for decoding label IDs in the output of the graph 78 | _LABEL_MAP_URL = ("https://raw.githubusercontent.com/tensorflow/models/" 79 | "f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/" 80 | "object_detection/data/mscoco_label_map.pbtxt") 81 | 82 | 83 | # Locations of intermediate files 84 | _TMP_DIR = "/tmp/coco_example" 85 | _SAVED_MODEL_DIR = _TMP_DIR + "/original_model" 86 | _FROZEN_GRAPH_FILE = "{}/frozen_graph.pbtext".format(_TMP_DIR) 87 | _PRE_POST_GRAPH_FILE = "{}/pre_and_post.pbtext".format(_TMP_DIR) 88 | _TF_REWRITES_GRAPH_FILE = "{}/after_tf_rewrites_graph.pbtext".format(_TMP_DIR) 89 | _GDE_REWRITES_GRAPH_FILE = "{}/after_gde_rewrites_graph.pbtext".format(_TMP_DIR) 90 | _AFTER_MODEL_FILES = [ 91 | _FROZEN_GRAPH_FILE, _TF_REWRITES_GRAPH_FILE, _GDE_REWRITES_GRAPH_FILE 92 | ] 93 | # Panda pic from Wikimedia; also used in 94 | # https://github.com/tensorflow/models/blob/master/research/slim/nets ... 95 | # ... /mobilenet/mobilenet_example.ipynb 96 | _PANDA_PIC_URL = ("https://upload.wikimedia.org/wikipedia/commons/f/fe/" 97 | "Giant_Panda_in_Beijing_Zoo_1.JPG") 98 | 99 | 100 | def _clear_dir(path): 101 | # type: (str) -> None 102 | if os.path.isdir(path): 103 | shutil.rmtree(path) 104 | os.mkdir(path) 105 | 106 | 107 | def _protobuf_to_file(pb, path, human_readable_name): 108 | # type: (Any, str, str) -> None 109 | with open(path, "w") as f: 110 | f.write(str(pb)) 111 | print("{} written to {}".format(human_readable_name, path)) 112 | 113 | 114 | def _fetch_or_use_cached(file_name, url): 115 | # type: (str, str) -> str 116 | """ 117 | Check for a cached copy of the indicated file in our temp directory. 118 | 119 | If a copy doesn't exist, download the file. 120 | 121 | Arg: 122 | file_name: Name of the file within the temp dir, not including the temp 123 | dir path 124 | url: Full URL from which to download the file, including remote file 125 | name, which can be different from file_name 126 | 127 | Returns the path of the cached file. 128 | """ 129 | cached_filename = "{}/{}".format(_TMP_DIR, file_name) 130 | if not os.path.exists(cached_filename): 131 | print("Downloading {} to {}".format(url, cached_filename)) 132 | urllib.request.urlretrieve(url, cached_filename) 133 | return cached_filename 134 | 135 | 136 | def _get_frozen_graph(): 137 | # type: () -> tf.GraphDef 138 | """ 139 | Obtains the starting version of the model from the TensorFlow model zoo 140 | 141 | Returns GraphDef 142 | """ 143 | tarball = _fetch_or_use_cached("{}.tar.gz".format(_LONG_MODEL_NAME), 144 | _MODEL_TARBALL_URL) 145 | 146 | print("Original model files at {}".format(tarball)) 147 | with tarfile.open(tarball) as t: 148 | frozen_graph_bytes = t.extractfile(_FROZEN_GRAPH_MEMBER).read() 149 | return tf.GraphDef.FromString(frozen_graph_bytes) 150 | 151 | 152 | def _build_preprocessing_graph_def(): 153 | # type: () -> tf.GraphDef 154 | """ 155 | Build a TensorFlow graph that performs the preprocessing operations that 156 | need to happen before the main graph, then convert to a GraphDef. 157 | 158 | Returns: 159 | Python object representation of the GraphDef for the preprocessing graph. 160 | Input node of the graph is the placeholder "raw_image", and the output is 161 | the node with the name "preprocessed_image". 162 | """ 163 | # At the moment, the only preprocessing we need to perform is converting 164 | # JPEG/PNG/GIF files to numpy arrays. 165 | img_decode_g = tf.Graph() 166 | with img_decode_g.as_default(): 167 | raw_image = tf.placeholder(tf.string, name="raw_image") 168 | 169 | # Downstream code hardcodes RGB 170 | _NUM_CHANNELS = 3 171 | 172 | # The TensorFlow authors, in their infinite wisdom, created two generic 173 | # image-decoder ops. tf.image.decode_imaage() returns a 4D tensor when it 174 | # receives a GIF and a 3D tensor for every other file type. This means 175 | # that you need complicated shape-checking and reshaping logic downstream 176 | # for it to be of any use in an inference context. 177 | # The other op is tf.image.decode_png(). In spite of its name, this op 178 | # actually handles PNG, JPEG, and non-animated GIF files. For now, we use 179 | # this op for simplicity. 180 | decoded_image = tf.image.decode_png(raw_image, _NUM_CHANNELS) 181 | 182 | # Downstream code expects a batch of equal-sized images. For now, we 183 | # generate a single-image batch. 184 | decoded_image_batch = tf.expand_dims(decoded_image, 0, 185 | name="preprocessed_image") 186 | 187 | return img_decode_g.as_graph_def() 188 | 189 | 190 | def _build_postprocessing_graph_def(): 191 | # type: () -> tf.GraphDef 192 | """ 193 | Build the TensorFlow graph that performs postprocessing operations that 194 | should happen after the main graph. 195 | 196 | Returns: 197 | Python object representation of the GraphDef for the postprocessing graph. 198 | The graph has one input placeholder called "detection_classes" and 199 | an output op called "decoded_detection_classes". 200 | The graph will also have an op called "hash_table_init" that initializes 201 | the mapping table. This op MUST be run exactly once before the 202 | "decoded_detection_classes" op will work. 203 | """ 204 | label_file = _fetch_or_use_cached("labels.pbtext", _LABEL_MAP_URL) 205 | 206 | # Category mapping comes in pbtext format. Translate to the format that 207 | # TensorFlow's hash table initializers expect (key and value tensors). 208 | with open(label_file, "r") as f: 209 | raw_data = f.read() 210 | # Parse directly instead of going through the protobuf API dance. 211 | records = raw_data.split("}") 212 | records.pop(-1) # Remove empty record at end 213 | records = [r.replace("\n", "") for r in records] # Strip newlines 214 | regex = re.compile(r"item { name: \".+\" id: (.+) display_name: \"(.+)\"") 215 | keys = [] 216 | values = [] 217 | for r in records: 218 | match = regex.match(r) 219 | keys.append(int(match.group(1))) 220 | values.append(match.group(2)) 221 | 222 | result_decode_g = tf.Graph() 223 | with result_decode_g.as_default(): 224 | # The original graph produces floating-point output for detection class, 225 | # even though the output is always an integer. 226 | float_class = tf.placeholder(tf.float32, shape=[None], 227 | name="detection_classes") 228 | int_class = tf.cast(float_class, tf.int32) 229 | key_tensor = tf.constant(keys, dtype=tf.int32) 230 | value_tensor = tf.constant(values) 231 | table_init = tf.contrib.lookup.KeyValueTensorInitializer( 232 | key_tensor, 233 | value_tensor, 234 | name=_HASH_TABLE_INIT_OP_NAME) 235 | hash_table = tf.contrib.lookup.HashTable( 236 | table_init, 237 | default_value="Unknown" 238 | ) 239 | _ = hash_table.lookup(int_class, name="decoded_detection_classes") 240 | 241 | return result_decode_g.as_graph_def() 242 | 243 | 244 | def _graft_pre_and_post_processing_to_main_graph(g): 245 | # type: (gde.Graph) -> None 246 | """ 247 | Attach pre- and post-processing subgraphs to the main graph. 248 | 249 | Args: 250 | g: GDE representation of the core graph. Modified in place. 251 | """ 252 | # Build the pre- and post-processing subgraphs and import into GDE 253 | pre_g = gde.Graph(_build_preprocessing_graph_def()) 254 | post_g = gde.Graph(_build_postprocessing_graph_def()) 255 | 256 | # Replace the graph's input placeholder with the contents of our 257 | # pre-processing graph. 258 | name_of_input_node = _INPUT_NODE_NAMES[0] 259 | gde.copy(pre_g, g) 260 | gde.reroute_ts(g.get_node_by_name("preprocessed_image").output(0), 261 | g.get_node_by_name(name_of_input_node).output(0)) 262 | g.remove_node_by_name(name_of_input_node) 263 | g.rename_node("raw_image", name_of_input_node) 264 | 265 | # Tack on the postprocessing graph at the original output and rename 266 | # the postprocessed output to the original output's name 267 | # The original graph produces an output called "detection_classes". 268 | # The postprocessing graph goes from "detection_classes" to 269 | # "decoded_detection_classes". 270 | # The graph after modification produces decoded classes under the original 271 | # "detection_classes" name. The original output is renamed to 272 | # "raw_detection_classes". 273 | g.rename_node("detection_classes", "raw_detection_classes") 274 | gde.copy(post_g, g) 275 | gde.reroute_ts(g.get_node_by_name("raw_detection_classes").output(0), 276 | g.get_node_by_name("detection_classes").output(0)) 277 | g.remove_node_by_name("detection_classes") 278 | g.rename_node("decoded_detection_classes", "detection_classes") 279 | 280 | 281 | def _apply_graph_transform_tool_rewrites(g, input_node_names, 282 | output_node_names): 283 | # type: (gde.Graph, List[str], List[str]) -> tf.GraphDef 284 | """ 285 | Use the [Graph Transform Tool]( 286 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ 287 | graph_transforms/README.md) 288 | to perform a series of pre-deployment rewrites. 289 | 290 | Args: 291 | g: GDE representation of the core graph. 292 | input_node_names: Names of placeholder nodes that are used as inputs to 293 | the graph for inference. Placeholders NOT on this list will be 294 | considered dead code. 295 | output_node_names: Names of nodes that produce tensors that are outputs 296 | of the graph for inference purposes. Nodes not necessary to produce 297 | these tensors will be considered dead code. 298 | 299 | Returns: GraphDef representation of rewritten graph. 300 | """ 301 | # Invoke the Graph Transform Tool using the undocumented Python APIs under 302 | # tensorflow.tools.graph_transforms 303 | after_tf_rewrites_graph_def = graph_transforms.TransformGraph( 304 | g.to_graph_def(), 305 | inputs=input_node_names, 306 | outputs=output_node_names, 307 | # Use the set of transforms recommended in the README under "Optimizing 308 | # for Deployment" 309 | transforms=['strip_unused_nodes(type=float, shape="1,299,299,3")', 310 | 'remove_nodes(op=Identity, op=CheckNumerics)', 311 | 'fold_constants(ignore_errors=true)', 312 | 'fold_batch_norms', 313 | 'fold_old_batch_norms'] 314 | ) 315 | return after_tf_rewrites_graph_def 316 | 317 | 318 | def _graph_has_op(g, op_name): 319 | # type: (tf.Graph, str) -> bool 320 | """ 321 | A method that really ought to be part of `tf.Graph`. Returns true of the 322 | indicated graph has an op by the indicated name. 323 | """ 324 | all_ops_in_graph = g.get_operations() 325 | return any(op_name == o.name for o in all_ops_in_graph) 326 | 327 | 328 | def _run_coco_graph(graph_proto, img): 329 | # type: (tf.GraphDef, np.ndarray) -> None 330 | """ 331 | Run an example image through a TensorFlow graph and print a summary of 332 | the results to STDOUT. 333 | 334 | Only works for the graphs used in this example. 335 | 336 | graph_proto: GraphDef protocol buffer message holding serialized graph 337 | img: input image, either as a numpy array or a JPEG binary string 338 | """ 339 | image_tensor_name = _INPUT_NODE_NAMES[0] + ":0" 340 | output_tensor_names = [n + ":0" for n in _OUTPUT_NODE_NAMES] 341 | with tf.Graph().as_default(): 342 | with tf.Session() as sess: 343 | tf.import_graph_def(graph_proto, name="") 344 | 345 | # Initialize hash tables if present. Assumes that the init op is called 346 | # "hash_table_init" 347 | if _graph_has_op(tf.get_default_graph(), _HASH_TABLE_INIT_OP_NAME): 348 | sess.run(_HASH_TABLE_INIT_OP_NAME) 349 | 350 | results = sess.run(output_tensor_names, {image_tensor_name: img}) 351 | 352 | bboxes, classes, scores, num_detections = results 353 | if len(classes.shape) > 1: 354 | # Results are a batch of length 1; unnest. 355 | bboxes = bboxes[0] 356 | classes = classes[0] 357 | scores = scores[0] 358 | num_detections = num_detections[0] 359 | 360 | # The num_detections output tells how much of the other output tensors is 361 | # used. The remaining rows of the tensors contain garbage. Print out the 362 | # non-garbage rows. 363 | print("Rank Label Weight Bounding Box") 364 | for i in range(int(num_detections)): 365 | clazz = classes[i] # "class" is a reserved word in Python 366 | if isinstance(clazz, bytes): 367 | clazz = clazz.decode("UTF-8") 368 | 369 | print("{:<10}{:<20}{:<10f}{}".format( 370 | i + 1, clazz, scores[i], bboxes[i])) 371 | 372 | 373 | def main(_): 374 | # Remove any detritus of previous runs of this script, but leave the temp 375 | # dir in place because the user might have a shell there. 376 | if not os.path.isdir(_TMP_DIR): 377 | os.mkdir(_TMP_DIR) 378 | _clear_dir(_SAVED_MODEL_DIR) 379 | for f in _AFTER_MODEL_FILES: 380 | if os.path.isfile(f): 381 | os.remove(f) 382 | 383 | # We start with a frozen graph for the model. "Frozen" means that all 384 | # variables have been converted to constants. 385 | frozen_graph_def = _get_frozen_graph() 386 | 387 | # Wrap the initial GraphDef in a gde.Graph so we can examine it. 388 | frozen_graph = gde.Graph(frozen_graph_def) 389 | input_node_names = [n.name for n in 390 | gde.filter_ops_by_optype(frozen_graph, "Placeholder")] 391 | # TODO: Devise an automatic way to find the outputs 392 | output_node_names = _OUTPUT_NODE_NAMES + [_HASH_TABLE_INIT_OP_NAME] 393 | print("Input names: {}".format(input_node_names)) 394 | print("Output names: {}".format(output_node_names)) 395 | 396 | _protobuf_to_file(frozen_graph_def, _FROZEN_GRAPH_FILE, "Frozen graph") 397 | 398 | # Graft the preprocessing and postprocessing graphs onto the beginning and 399 | # end of the inference graph. 400 | g = gde.Graph(frozen_graph_def) 401 | _graft_pre_and_post_processing_to_main_graph(g) 402 | after_add_pre_post_graph_def = g.to_graph_def() 403 | _protobuf_to_file(after_add_pre_post_graph_def, _PRE_POST_GRAPH_FILE, 404 | "Graph with pre- and post-processing") 405 | 406 | # Now run through some of TensorFlow's built-in graph rewrites. 407 | after_tf_rewrites_graph_def = _apply_graph_transform_tool_rewrites( 408 | g, input_node_names, output_node_names) 409 | _protobuf_to_file(after_tf_rewrites_graph_def, 410 | _TF_REWRITES_GRAPH_FILE, 411 | "Graph after built-in TensorFlow rewrites") 412 | 413 | # Now run the GraphDef editor's graph prep rewrites 414 | g = gde.Graph(after_tf_rewrites_graph_def) 415 | gde.rewrite.fold_batch_norms(g) 416 | gde.rewrite.fold_old_batch_norms(g) 417 | gde.rewrite.fold_batch_norms_up(g) 418 | after_gde_graph_def = g.to_graph_def(add_shapes=True) 419 | _protobuf_to_file(after_gde_graph_def, 420 | _GDE_REWRITES_GRAPH_FILE, 421 | "Graph after GraphDef Editor rewrites") 422 | 423 | # Dump some statistics about the number of each type of op 424 | print(" Number of ops in frozen graph: {}".format(len( 425 | frozen_graph_def.node))) 426 | print(" Num. ops after adding pre- and post-proc: {}".format(len( 427 | after_add_pre_post_graph_def.node))) 428 | print(" Number of ops after built-in rewrites: {}".format(len( 429 | after_tf_rewrites_graph_def.node))) 430 | print(" Number of ops after GDE rewrites: {}".format(len( 431 | after_gde_graph_def.node))) 432 | 433 | # Run model before and after rewrite and compare results 434 | img_path = _fetch_or_use_cached("panda.jpg", _PANDA_PIC_URL) 435 | 436 | with open(img_path, "rb") as f: 437 | jpg_img = f.read() 438 | np_img_batch = np.expand_dims(np.array(PIL.Image.open(img_path)), axis=0) 439 | 440 | print("Frozen graph results:") 441 | _run_coco_graph(frozen_graph_def, np_img_batch) 442 | print("Results after adding pre/post-processing:") 443 | _run_coco_graph(after_add_pre_post_graph_def, jpg_img) 444 | print("Results after built-in rewrites:") 445 | _run_coco_graph(after_tf_rewrites_graph_def, jpg_img) 446 | print("Results after GDE rewrites:") 447 | _run_coco_graph(after_gde_graph_def, jpg_img) 448 | 449 | 450 | if __name__ == "__main__": 451 | tf.app.run() 452 | -------------------------------------------------------------------------------- /graph_def_editor/base_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google. All Rights Reserved. 2 | # Copyright 2019 IBM. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Base class for Graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import datetime 23 | from distutils import dir_util 24 | import os 25 | from six import string_types 26 | import tensorflow.compat.v1 as tf 27 | import sys 28 | if sys.version >= '3': 29 | from typing import Tuple, Dict, FrozenSet, Iterable, Union, Set, Any 30 | 31 | from graph_def_editor import node, util, tensor, variable 32 | import graph_def_editor.visualization.graphviz_wrapper as gvw 33 | 34 | 35 | __all__ = [ 36 | "BaseGraph", 37 | ] 38 | 39 | class BaseGraph(object): 40 | """ 41 | Base class for Graph and FunctionGraph classes. 42 | 43 | Mutable surrogate for a `tf.GraphDef` protocol buffer message. 44 | 45 | Summary of internal data structures: 46 | * _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name. 47 | * _version: Counter that increments every time the graph is modified 48 | * _collections: Map from collection name to collection contents for all 49 | collections 50 | """ 51 | 52 | def __init__( 53 | self, 54 | name = None, # type: str 55 | ): 56 | """ 57 | Constructor to be called by subclasses only. 58 | 59 | Initializes attributes of this base class. 60 | 61 | Args: 62 | name: Optional human-readable name for the graph. If not provided, 63 | the constructor will generate a name. 64 | """ 65 | # Populate fields of object 66 | self._name = name # str 67 | self._version = 0 # Must happen first; other init code needs self._version 68 | self._frozen = False # bool 69 | self._next_id = 1 # int 70 | self._node_name_to_node = {} # Dict[str, node.Node]; key is node name 71 | self._variable_name_to_variable = {} # Dict[str, Variable] 72 | 73 | @property 74 | def name(self): 75 | """ 76 | Returns human-readable name for this graph. This name may not be unique 77 | across graphs. 78 | """ 79 | return self._name 80 | 81 | def __getitem__(self, name): 82 | # type: (str) -> Union[tensor.Tensor, 'node.Node'] 83 | """ 84 | Convenience method to retrieve a node or tensor of the graph by name 85 | 86 | Args: 87 | name: Name of the node or tensor to return. Case-sensitive. 88 | 89 | Returns the named item as a `gde.Node` or `gde.Tensor` object. If there 90 | is a conflict between node and tensor names, node names win. 91 | """ 92 | if not isinstance(name, string_types): 93 | raise TypeError("name must be a string; got type {}".format(type(name))) 94 | 95 | if self.contains_node(name): 96 | return self._node_name_to_node[name] 97 | elif self.contains_tensor(name): 98 | return self.get_tensor_by_name(name) 99 | else: 100 | raise ValueError("No node or tensor '{}' found in graph".format(name)) 101 | 102 | def get_node_by_name(self, name): 103 | # type: (str) -> node.Node 104 | """ 105 | Retrieve a node in the graph by name. 106 | 107 | Args: 108 | name: Name of the node. Case-sensitive. 109 | 110 | Returns the indicated node as a `gde.Node` object. 111 | """ 112 | if self.contains_node(name): 113 | return self._node_name_to_node[name] 114 | else: 115 | raise ValueError("No node '{}' found in graph".format(name)) 116 | 117 | def contains_node(self, name): 118 | # type: (str) -> bool 119 | """ 120 | Returns true if the graph has a node by the indicated name. Exact string 121 | match. 122 | """ 123 | if not isinstance(name, string_types): 124 | raise ValueError("Node name argument is not a string, but is of type " 125 | "{}".format(type(name))) 126 | return name in self._node_name_to_node.keys() 127 | 128 | def add_node(self, 129 | name, # type: str 130 | op_name, # type: str 131 | uniquify_name = False, # type: bool 132 | debug_info = None # type: tf.compat.v1.NodeDef.ExperimentalDebugInfo 133 | ): 134 | # type: (...) -> node.Node 135 | """ 136 | Add a new, empty node to the graph. 137 | Args: 138 | name: Name for the new op 139 | op_name: Name of the type of operation for the node 140 | uniquify_name: Generate a unique name from this name if the graph 141 | already has a node with the indicated name. If False, raise an 142 | exception if the name is in use. 143 | debug_info: Some internal TensorFlow debug information. 144 | We just pass it through for safety. 145 | 146 | Returns: 147 | `MutableNode` wrapper for the new node. 148 | 149 | Raises: 150 | ValueError if the name is already in use and `uniquify_name` is False 151 | """ 152 | if uniquify_name: 153 | name = self.unique_name(name) 154 | elif self._name_in_use(name): # and not uniquify_name 155 | raise ValueError("Graph already contains a node with name '{}' " 156 | "(Note that this check is case-insensitive)." 157 | .format(name)) 158 | ret = node.Node(self, 159 | self._get_next_id(), 160 | name=name, 161 | op_name=op_name, 162 | debug_info=debug_info) 163 | self._node_name_to_node[name] = ret 164 | self.increment_version_counter() 165 | return ret 166 | 167 | def add_node_from_node_def(self, 168 | node_def, # type: tf.NodeDef 169 | set_inputs = False, # type: bool 170 | set_control_inputs = False # type: bool 171 | ): 172 | # type: (...) -> node.Node 173 | """ 174 | Adds a new node to the graph, populating fields of the node from a 175 | `tf.NodeDef` protocol buffer. 176 | 177 | Equivalent to calling `add_node()`, then populating the relevant fields 178 | of the returned MutableNode object. 179 | 180 | Args: 181 | node_def: Protocol buffer describing parameters of the new node. 182 | set_inputs: If True, populate the node's inputs list from the list of 183 | inputs in the `NodeDef` 184 | set_control_inputs: Also set control inputs. Must be False if 185 | `set_inputs` is False. 186 | 187 | Returns: 188 | `MutableNode` wrapper for the new node 189 | """ 190 | if set_control_inputs and not set_inputs: 191 | raise ValueError("set_inputs must be True if set_control_inputs is True") 192 | ret = self.add_node(name=node_def.name, 193 | op_name=node_def.op, 194 | debug_info=node_def.experimental_debug_info) 195 | if set_inputs: 196 | ret.set_inputs_from_strings(node_def.input, 197 | set_control_inputs=set_control_inputs) 198 | ret.device = node_def.device 199 | ret.clear_attrs() 200 | for key in node_def.attr: 201 | ret.add_attr(key, node_def.attr[key]) 202 | 203 | # Don't need to increment version counter; add_node() already did that. 204 | return ret 205 | 206 | def remove_node_by_name(self, name, check_for_refs = True): 207 | # type: (str, str) -> None 208 | """ 209 | Removes the indicated node from this graph and from any collections in 210 | this graph. 211 | 212 | The caller is responsible for removing all links to the indicated node 213 | prior to making this call. 214 | 215 | Args: 216 | name: name of the node to remove 217 | check_for_refs: Optional. If True, raise an exception if there are any 218 | other nodes in the graph that reference this node. If False, allow 219 | removal of nodes with outstanding references to them. In the latter 220 | case, the caller is responsible for cleaning up the graph afterwards. 221 | """ 222 | n = self.get_node_by_name(name) 223 | if check_for_refs: 224 | for t in n.outputs: 225 | if len(t.consumers()) > 0: 226 | raise ValueError("Removing node '{}' would leave dangling " 227 | "references from nodes {} to tensor '{}'" 228 | "".format(name, [c.name for c in t.consumers()], 229 | t.name)) 230 | # noinspection PyProtectedMember 231 | n._remove_from_graph() 232 | del self._node_name_to_node[name] 233 | self.increment_version_counter() 234 | # Don't need to update collection info because collection membership is 235 | # stored in the node. 236 | # Don't need to update consumers of tensors because that information is 237 | # calculated dynamically by iterating over nodes. 238 | 239 | def rename_node(self, old_name, new_name): 240 | # type: (str, str) -> None 241 | """ 242 | Change the name of a node in the graph. 243 | 244 | Args: 245 | old_name: Name of an existing node 246 | new_name: New name for the node in question. Must not currently be in use. 247 | """ 248 | if self.contains_node(new_name): 249 | raise ValueError("Graph already has a node under name '{}'".format( 250 | new_name)) 251 | n = self.get_node_by_name(old_name) 252 | # noinspection PyProtectedMember 253 | n._change_name(new_name) 254 | del self._node_name_to_node[old_name] 255 | self._node_name_to_node[new_name] = n 256 | self.increment_version_counter() 257 | 258 | def add_variable(self, name): 259 | # type: (str) -> variable.Variable 260 | """ 261 | Adds a new variable to the graph. 262 | 263 | Args: 264 | name: Name of the variable. Must not already be in use. 265 | 266 | Returns the `gde.Variable` object corresponding to the added variable. 267 | """ 268 | if name in self._variable_name_to_variable: 269 | raise ValueError("Variable name '{}' already in use".format(name)) 270 | v = variable.Variable(self) 271 | v.name = name 272 | self._variable_name_to_variable[name] = v 273 | self.increment_version_counter() 274 | return v 275 | 276 | def add_variable_from_variable_def(self, variable_def, 277 | skip_if_present = False): 278 | # type: (Any, bool) -> None 279 | """ 280 | Adds a new variable to the graph and populates the fields of the 281 | corresponding Variable object according to a protocol buffer message. 282 | 283 | Args: 284 | variable_def: `tensorflow.core.framework.variable_pb2.VariableDef` 285 | protobuf object. May be serialized as a `bytes` object. 286 | skip_if_present: If True, silently skips inserting duplicate variables, 287 | as long as they don't conflict with existing variables. 288 | 289 | Returns the `gde.Variable` object corresponding to the added variable. 290 | """ 291 | v = variable.Variable(self) 292 | v.from_proto(variable_def, allow_duplicates=skip_if_present) 293 | if v.name not in self._variable_name_to_variable: 294 | self._variable_name_to_variable[v.name] = v 295 | return self._variable_name_to_variable[v.name] 296 | 297 | @property 298 | def variable_names(self): 299 | return self._variable_name_to_variable.keys() 300 | 301 | def get_variable_by_name(self, name): 302 | # type: (str) -> variable.Variable 303 | """ 304 | Fetch a variable by its variable name. 305 | 306 | Args: 307 | name: Name of a variable in this graph. 308 | 309 | Returns the variable associated with the name. Raises an exception if 310 | there is no variable with the indicated name. 311 | """ 312 | return self._variable_name_to_variable[name] 313 | 314 | def _name_in_use(self, name): 315 | # type: (str) -> bool 316 | """Check whether a name is in use, using the same collision semantics as 317 | TensorFlow: Exact lowercase string match. 318 | 319 | Args: 320 | name: Name of a potential node in the graph. 321 | 322 | Returns True if the indicated name is currently in use, ignoring case. 323 | """ 324 | return name.lower() in [k.lower() for k in self._node_name_to_node.keys()] 325 | 326 | def unique_name(self, name): 327 | # type: (str) -> str 328 | """Emulate the behavior of the method by the same name in `tf.Graph`. 329 | 330 | Does *not* emulate the `name_stack` field of `tf.Graph`. 331 | 332 | Unlike the original method, this version does *not* keep a separate table 333 | of names currently "in use for the purposes of `unique_name()`", but instead 334 | refers directly to internal data structures to find names that are truly 335 | in use. 336 | 337 | Args: 338 | name: The name for an operation. 339 | 340 | Returns: 341 | A variant of `name` that has been made unique by appending a key to it 342 | in the same way that `tf.Graph.unique_name()` would. 343 | """ 344 | # For the sake of checking for names in use, we treat names as case 345 | # insensitive (e.g. foo = Foo). 346 | if not self._name_in_use(name): 347 | return name 348 | 349 | # Generate a unique version by appending "_1", "_2", etc. until we find 350 | # an unused name. Note that this approach will behave slightly 351 | # differently from the original if nodes are deleted. 352 | i = 1 353 | new_name = "{}_{}".format(name, i) 354 | while self._name_in_use(new_name): 355 | i = i + 1 356 | new_name = "{}_{}".format(name, i) 357 | return new_name 358 | 359 | @property 360 | def node_names(self): 361 | # type: () -> Iterable[str] 362 | return self._node_name_to_node.keys() 363 | 364 | @property 365 | def nodes(self): 366 | # type: () -> Tuple[node.Node] 367 | """ 368 | Returns: 369 | A list of all nodes, both immutable and mutable, present in the graph 370 | after the edits that this object is buffering. 371 | """ 372 | return tuple(self._node_name_to_node.values()) 373 | 374 | @property 375 | def tensors(self): 376 | # type: () -> List[tensor.Tensor] 377 | """ 378 | Return a list of all the tensors which are input or output of an op in 379 | the graph. 380 | """ 381 | ts = [] 382 | for op in self.nodes: 383 | ts += op.outputs 384 | return ts 385 | 386 | def contains_tensor(self, tensor_name): 387 | # type: (str) -> bool 388 | """ 389 | Returns true if the graph has a tensor by the indicated name. Exact string 390 | match. 391 | 392 | Args: 393 | tensor_name: TensorFlow-format name ('node name:input num', or 'node 394 | name' as shorthand for 'node name:0') 395 | 396 | Raises ValueError if the tensor name is not properly formatted. 397 | """ 398 | error_msg = "Invalid tensor name '{}': {}" 399 | node_name, output_ix = self._decode_tensor_name(tensor_name, error_msg) 400 | if node_name not in self._node_name_to_node: 401 | return False 402 | else: 403 | n = self[node_name] 404 | if output_ix >= len(n.outputs): 405 | return False 406 | else: 407 | return True 408 | 409 | def get_tensor_by_name(self, tensor_name, error_msg = None): 410 | # type: (str, str) -> tensor.Tensor 411 | """ 412 | Retrieve a tensor by human-readable name. 413 | 414 | Args: 415 | tensor_name: TensorFlow-format name ('node name:input num', or 'node 416 | name' as shorthand for 'node name:0') 417 | error_msg: Optional format string for raising errors. Must be able to 418 | serve as an input to `str.format()` with two arguments: tensor name 419 | string and reason for failure. 420 | 421 | Returns: gde.Tensor object corresponding to the indicated tensor. 422 | 423 | Raises ValueError if the name is invalid or references a tensor that does 424 | not exist. 425 | """ 426 | if error_msg is None: 427 | error_msg = "Invalid tensor name '{}': {}" 428 | node_name, output_ix = self._decode_tensor_name(tensor_name, error_msg) 429 | if node_name not in self._node_name_to_node: 430 | raise ValueError(error_msg.format( 431 | tensor_name, "Node name '{}' not found in graph.".format(node_name) 432 | )) 433 | n = self[node_name] 434 | if output_ix >= len(n.outputs): 435 | raise ValueError(error_msg.format( 436 | tensor_name, "Requested output {}, but node '{}' has {} " 437 | "outputs.".format(output_ix, node_name, len(n.outputs)) 438 | )) 439 | return n.output(output_ix) 440 | 441 | @property 442 | def version(self): 443 | # type: () -> int 444 | """ 445 | Returns a counter that goes up every time this graph is changed. 446 | """ 447 | return self._version 448 | 449 | @property 450 | def frozen(self): 451 | # type: () -> bool 452 | """ 453 | True if the graph is configured to raise an exception on any structural 454 | modification. 455 | """ 456 | return self._frozen 457 | 458 | @frozen.setter 459 | def frozen(self, value): 460 | # type: (bool) -> None 461 | self._frozen = value 462 | 463 | def increment_version_counter(self): 464 | """ 465 | Mark the structure of this graph as "changed" and invalidate any cached 466 | information about the edges of the graph. 467 | """ 468 | if self.frozen: 469 | raise RuntimeError("Detected a change to a frozen graph") 470 | self._version += 1 471 | 472 | def visualize( 473 | self, 474 | format=None, 475 | depth=1, 476 | style=True, 477 | name_regex="", 478 | negative_name_regex="", 479 | add_digraph_func=None, 480 | add_digraph_node_func=None, 481 | add_digraph_edge_func=None): 482 | """Return GraphViz Digraph rendering of the current graph. 483 | 484 | Args: 485 | format: GraphViz display format. In addition to that it supports 486 | jupyter_svg, and jupyter_interactive modes. 487 | depth: the maximum depth of the graph to display. 488 | style: whether to apply default styles. 489 | name_regex: only diplay nodes that have name matching this regex. 490 | negative_name_regex: only diplay nodes that have name not matching this 491 | regex. 492 | add_digraph_func: custom override for function for adding subraphs 493 | to the resulting Digraph object. 494 | add_digraph_node_func: custom override for function for adding nodes 495 | (vertices) to the resulting Digraph object. 496 | add_digraph_edge_func: custom override for function for adding edges 497 | to the resulting Digraph object. 498 | 499 | Returns: 500 | graphviz.dot.Digraph object with visual representtion for the current 501 | graph. 502 | """ 503 | return gvw.visualize( 504 | self, 505 | format=format, 506 | depth=depth, 507 | name=self.name, 508 | style=style, 509 | name_regex=name_regex, 510 | negative_name_regex=negative_name_regex, 511 | add_digraph_func=add_digraph_func, 512 | add_digraph_node_func=add_digraph_node_func, 513 | add_digraph_edge_func=add_digraph_edge_func) 514 | 515 | def _get_next_id(self): 516 | # type: () -> int 517 | """Generates and returns a unique integer ID *within this graph*.""" 518 | ret = self._next_id 519 | self._next_id = ret + 1 520 | return ret 521 | 522 | def _decode_tensor_name(self, tensor_name, error_msg): 523 | # type: (str, str) -> Tuple[str, int] 524 | """ 525 | Args: 526 | tensor_name: TensorFlow-format name ('node name:input num', or 'node 527 | name' as shorthand for 'node name:0') 528 | error_msg: Format string for raising errors. Must be able to 529 | serve as an input to `str.format()` with two arguments: tensor name 530 | string and reason for failure. 531 | 532 | Returns: (node name, output index) tuple identifying the tensor 533 | 534 | Raises ValueError if the name is not properly formatted 535 | """ 536 | if ":" in tensor_name: 537 | node_name, output_ix_str = tensor_name.split(":") 538 | if not output_ix_str.isdigit(): 539 | raise ValueError(error_msg.format( 540 | tensor_name, "Invalid output index string '{}'.".format(output_ix_str) 541 | )) 542 | output_ix = int(output_ix_str) 543 | else: 544 | node_name = tensor_name 545 | output_ix = 0 546 | 547 | return node_name, output_ix 548 | 549 | -------------------------------------------------------------------------------- /graph_def_editor/reroute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Various function for graph rerouting.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | from graph_def_editor import node, subgraph, util 24 | 25 | _allowed_symbols = [ 26 | "swap_ts", 27 | "reroute_ts", 28 | "swap_inputs", 29 | "reroute_inputs", 30 | "swap_outputs", 31 | "reroute_outputs", 32 | "swap_ios", 33 | "reroute_ios", 34 | "remove_control_inputs", 35 | "add_control_inputs", 36 | ] 37 | 38 | 39 | def _check_ts_compatibility(ts0, ts1): 40 | """Make sure the shape and dtype of two lists of tensors are compatible. 41 | 42 | Args: 43 | ts0: an object convertible to a list of `gde.Tensor`. 44 | ts1: an object convertible to a list of `gde.Tensor`. 45 | Raises: 46 | ValueError: if any pair of tensors (same index in ts0 and ts1) have 47 | a dtype or a shape which is not compatible. 48 | """ 49 | ts0 = util.make_list_of_t(ts0) 50 | ts1 = util.make_list_of_t(ts1) 51 | if len(ts0) != len(ts1): 52 | raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( 53 | len(ts0), len(ts1))) 54 | for t0, t1 in zip(ts0, ts1): 55 | # check dtype 56 | dtype0, dtype1 = t0.dtype, t1.dtype 57 | if not dtype0.is_compatible_with(dtype1): 58 | raise ValueError("Dtypes {} and {} of tensors {} and {} are not " 59 | "compatible.".format(dtype0, dtype1, t0.name, t1.name)) 60 | # check shape 61 | shape0, shape1 = t0.shape, t1.shape 62 | if not shape0.is_compatible_with(shape1): 63 | raise ValueError("Shapes {} and {} of tensors {} and {} are not " 64 | "compatible.".format(shape0, shape1, t0.name, t1.name)) 65 | 66 | 67 | class _RerouteMode(object): 68 | """Enums for reroute's mode. 69 | 70 | swap: the end of tensors a and b are swapped. 71 | a2b: the end of the tensor a are also rerouted to the end of the tensor b 72 | (the end of b is left dangling). 73 | b2a: the end of the tensor b are also rerouted to the end of the tensor a 74 | (the end of a is left dangling). 75 | """ 76 | swap, a2b, b2a = range(3) 77 | 78 | @classmethod 79 | def check(cls, mode): 80 | """Check swap mode. 81 | 82 | Args: 83 | mode: an integer representing one of the modes. 84 | Returns: 85 | A tuple `(a2b, b2a)` boolean indicating what rerouting needs doing. 86 | Raises: 87 | ValueError: if mode is outside the enum range. 88 | """ 89 | if mode == cls.swap: 90 | return True, True 91 | elif mode == cls.b2a: 92 | return False, True 93 | elif mode == cls.a2b: 94 | return True, False 95 | else: 96 | raise ValueError("Unknown _RerouteMode: {}".format(mode)) 97 | 98 | 99 | def _reroute_t(t0, t1, consumers1, can_modify=None, cannot_modify=None): 100 | """Reroute the end of the tensors (t0,t1). 101 | 102 | Warning: this function is directly manipulating the internals of the 103 | `gde.Graph`. 104 | 105 | Args: 106 | t0: a `gde.Tensor`. 107 | t1: a `gde.Tensor`. 108 | consumers1: The consumers of t1 which needs to be rerouted. 109 | can_modify: iterable of operations which can be modified. Any operation 110 | outside within_ops will be left untouched by this function. 111 | cannot_modify: iterable of operations which cannot be modified. 112 | Any operation within cannot_modify will be left untouched by this 113 | function. 114 | Returns: 115 | The number of individual modifications made by the function. 116 | """ 117 | nb_update_inputs = 0 118 | if can_modify is not None: 119 | consumers1 &= can_modify 120 | if cannot_modify is not None: 121 | consumers1 -= cannot_modify 122 | consumers1_indices = {} 123 | for consumer1 in consumers1: 124 | consumers1_indices[consumer1] = [i for i, t in enumerate(consumer1.inputs) 125 | if t is t1] 126 | for consumer1 in consumers1: 127 | for i in consumers1_indices[consumer1]: 128 | consumer1.replace_input(i, t0) 129 | nb_update_inputs += 1 130 | return nb_update_inputs 131 | 132 | 133 | def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): 134 | """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1. 135 | 136 | This function is the back-bone of the Graph-Editor. It is essentially a thin 137 | wrapper on top of `gde.Node.replace_input`. 138 | 139 | Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end 140 | of t0 and t1 in three possible ways: 141 | 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After 142 | this operation, the previous consumers of t0 are now consumers of t1 and 143 | vice-versa. 144 | 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the 145 | tensors's end of t1 (which are left dangling). After this operation, the 146 | previous consumers of t0 are still consuming t0 but the previous consumers of 147 | t1 are not also consuming t0. The tensor t1 has no consumer. 148 | 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode. 149 | 150 | Note that this function is re-routing the end of two tensors, not the start. 151 | Re-routing the start of two tensors is not supported by this library. The 152 | reason for that is the following: TensorFlow, by design, creates a strong bond 153 | between an op and its output tensor. This Graph editor follows this design and 154 | treats an operation A and its generating tensors {t_i} as an entity which 155 | cannot be broken. In other words, an op cannot be detached from any of its 156 | output tensors, ever. But it is possible to detach an op from its input 157 | tensors, which is what this function concerns itself with. 158 | 159 | Warning: this function is directly manipulating the internals of the `gde.Graph`. 160 | 161 | Args: 162 | ts0: an object convertible to a list of `gde.Tensor`. 163 | ts1: an object convertible to a list of `gde.Tensor`. 164 | mode: what to do with those tensors: "a<->b" or "b<->a" for swapping and 165 | "a->b" or "b->a" for one direction re-routing. 166 | can_modify: iterable of operations which can be modified. Any operation 167 | outside within_ops will be left untouched by this function. 168 | cannot_modify: iterable of operations which cannot be modified. 169 | Any operation within cannot_modify will be left untouched by this 170 | function. 171 | Returns: 172 | The number of individual modifications made by the function. 173 | Raises: 174 | TypeError: if `ts0` or `ts1` cannot be converted to a list of `gde.Tensor`. 175 | TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be 176 | converted to a list of `gde.Node`. 177 | """ 178 | a2b, b2a = _RerouteMode.check(mode) 179 | ts0 = util.make_list_of_t(ts0) 180 | ts1 = util.make_list_of_t(ts1) 181 | _check_ts_compatibility(ts0, ts1) 182 | if cannot_modify is not None: 183 | cannot_modify = frozenset(util.make_list_of_op(cannot_modify)) 184 | if can_modify is not None: 185 | can_modify = frozenset(util.make_list_of_op(can_modify)) 186 | nb_update_inputs = 0 187 | precomputed_consumers = [] 188 | # precompute consumers to avoid issue with repeated tensors: 189 | for t0, t1 in zip(ts0, ts1): 190 | consumers0 = set(t0.consumers()) 191 | consumers1 = set(t1.consumers()) 192 | precomputed_consumers.append((consumers0, consumers1)) 193 | for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers): 194 | if t0 is t1: 195 | continue # Silently ignore identical tensors. 196 | consumers0, consumers1 = consumers 197 | if a2b: 198 | nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify, 199 | cannot_modify) 200 | if b2a: 201 | nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify, 202 | cannot_modify) 203 | return nb_update_inputs 204 | 205 | 206 | def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): 207 | """For each tensor's pair, swap the end of (t0,t1). 208 | 209 | B0 B1 B0 B1 210 | | | => X 211 | A0 A1 A0 A1 212 | 213 | Args: 214 | ts0: an object convertible to a list of `gde.Tensor`. 215 | ts1: an object convertible to a list of `gde.Tensor`. 216 | can_modify: iterable of operations which can be modified. Any operation 217 | outside within_ops will be left untouched by this function. 218 | cannot_modify: iterable of operations which cannot be modified. 219 | Any operation within cannot_modify will be left untouched by this 220 | function. 221 | Returns: 222 | The number of individual modifications made by the function. 223 | Raises: 224 | TypeError: if ts0 or ts1 cannot be converted to a list of `gde.Tensor`. 225 | TypeError: if can_modify or cannot_modify is not None and cannot be 226 | converted to a list of `gde.Node`. 227 | """ 228 | return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify) 229 | 230 | 231 | def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): 232 | """For each tensor's pair, replace the end of t1 by the end of t0. 233 | 234 | B0 B1 B0 B1 235 | | | => |/ 236 | A0 A1 A0 A1 237 | 238 | The end of the tensors in ts1 are left dangling. 239 | 240 | Args: 241 | ts0: an object convertible to a list of `gde.Tensor`. 242 | ts1: an object convertible to a list of `gde.Tensor`. 243 | can_modify: iterable of operations which can be modified. Any operation 244 | outside within_ops will be left untouched by this function. 245 | cannot_modify: iterable of operations which cannot be modified. Any 246 | operation within cannot_modify will be left untouched by this function. 247 | Returns: 248 | The number of individual modifications made by the function. 249 | Raises: 250 | TypeError: if ts0 or ts1 cannot be converted to a list of `gde.Tensor`. 251 | TypeError: if can_modify or cannot_modify is not None and cannot be 252 | converted to a list of `gde.Node`. 253 | """ 254 | return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify) 255 | 256 | 257 | def _reroute_sgv_remap(sgv0, sgv1, mode): 258 | """Remap in place the inputs of two subgraph views to mimic the reroute. 259 | 260 | This function is meant to used by reroute_inputs only. 261 | 262 | Args: 263 | sgv0: the first subgraph to have its inputs remapped. 264 | sgv1: the second subgraph to have its inputs remapped. 265 | mode: reroute mode, see _reroute_ts(...). 266 | Raises: 267 | TypeError: if svg0 or svg1 are not SubGraphView. 268 | ValueError: if sgv0 and sgv1 do not belong to the same graph. 269 | """ 270 | a2b, b2a = _RerouteMode.check(mode) 271 | if not isinstance(sgv0, subgraph.SubGraphView): 272 | raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0))) 273 | if not isinstance(sgv1, subgraph.SubGraphView): 274 | raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1))) 275 | util.check_graphs(sgv0, sgv1) 276 | sgv0_ = sgv0.copy() 277 | sgv1_ = sgv1.copy() 278 | # pylint: disable=protected-access 279 | if a2b and b2a: 280 | (sgv0_._input_ts, sgv1_._input_ts) = (sgv1_._input_ts, sgv0_._input_ts) 281 | (sgv0_._passthrough_ts, sgv1_._passthrough_ts) = (sgv1_._passthrough_ts, 282 | sgv0_._passthrough_ts) 283 | elif a2b: 284 | sgv1_._input_ts = sgv0_._input_ts[:] 285 | sgv1_._passthrough_ts = sgv0_._passthrough_ts[:] 286 | elif b2a: 287 | sgv0_._input_ts = sgv1_._input_ts[:] 288 | sgv0_._passthrough_ts = sgv1_._passthrough_ts[:] 289 | # pylint: enable=protected-access 290 | 291 | # Update the passthrough outputs as well. 292 | def update_passthrough_outputs(a, b): 293 | # pylint: disable=protected-access 294 | for i, t in enumerate(b._output_ts): 295 | if t in a._passthrough_ts: 296 | ii = a._input_ts.index(t) 297 | b._output_ts[i] = b._input_ts[ii] 298 | # pylint: enable=protected-access 299 | 300 | if a2b: 301 | update_passthrough_outputs(sgv0_, sgv1_) 302 | if b2a: 303 | update_passthrough_outputs(sgv1_, sgv0_) 304 | 305 | # in-place 306 | # pylint: disable=protected-access 307 | sgv0._assign_from(sgv0_) 308 | sgv1._assign_from(sgv1_) 309 | # pylint: enable=protected-access 310 | 311 | 312 | def _reroute_sgv_inputs(sgv0, sgv1, mode): 313 | """Re-route all the inputs of two subgraphs. 314 | 315 | Args: 316 | sgv0: the first subgraph to have its inputs swapped. This argument is 317 | converted to a subgraph using the same rules than the function 318 | subgraph.make_view. 319 | sgv1: the second subgraph to have its inputs swapped. This argument is 320 | converted to a subgraph using the same rules than the function 321 | subgraph.make_view. 322 | mode: reroute mode, see _reroute_ts(...). 323 | Returns: 324 | A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. 325 | Note that the function argument sgv0 and sgv1 are also modified in place. 326 | Raises: 327 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 328 | the same rules than the function subgraph.make_view. 329 | """ 330 | sgv0 = subgraph.make_view(sgv0) 331 | sgv1 = subgraph.make_view(sgv1) 332 | util.check_graphs(sgv0, sgv1) 333 | can_modify = sgv0.ops + sgv1.ops 334 | # also allow consumers of passthrough to be modified: 335 | can_modify += util.get_consuming_ops(sgv0.passthroughs) 336 | can_modify += util.get_consuming_ops(sgv1.passthroughs) 337 | _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) 338 | _reroute_sgv_remap(sgv0, sgv1, mode) 339 | return sgv0, sgv1 340 | 341 | 342 | def _reroute_sgv_outputs(sgv0, sgv1, mode): 343 | """Re-route all the outputs of two operations. 344 | 345 | Args: 346 | sgv0: the first subgraph to have its outputs swapped. This argument is 347 | converted to a subgraph using the same rules than the function 348 | subgraph.make_view. 349 | sgv1: the second subgraph to have its outputs swapped. This argument is 350 | converted to a subgraph using the same rules than the function 351 | subgraph.make_view. 352 | mode: reroute mode, see _reroute_ts(...). 353 | Returns: 354 | A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. 355 | Note that the function argument sgv0 and sgv1 are also modified in place. 356 | Raises: 357 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 358 | the same rules than the function subgraph.make_view. 359 | """ 360 | sgv0 = subgraph.make_view(sgv0) 361 | sgv1 = subgraph.make_view(sgv1) 362 | util.check_graphs(sgv0, sgv1) 363 | cannot_modify = sgv0.ops + sgv1.ops 364 | _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) 365 | return sgv0, sgv1 366 | 367 | 368 | def _reroute_sgv(sgv0, sgv1, mode): 369 | """Re-route both the inputs and the outputs of the two subgraph views. 370 | 371 | This involves swapping all the inputs/outputs of the two subgraph views. 372 | 373 | Args: 374 | sgv0: the first subgraph to be swapped. This argument is converted to a 375 | subgraph using the same rules than the function subgraph.make_view. 376 | sgv1: the second subgraph to be swapped. This argument is converted to a 377 | subgraph using the same rules than the function subgraph.make_view. 378 | mode: reroute mode, see _reroute_ts(...). 379 | Returns: 380 | A tuple `(sgv0, sgv1)` of subgraph views with their outputs and inputs 381 | swapped. 382 | Note that the function argument sgv0 and sgv1 are also modified in place. 383 | Raises: 384 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 385 | the same rules than the function subgraph.make_view. 386 | """ 387 | _reroute_sgv_outputs(sgv0, sgv1, mode) 388 | _reroute_sgv_inputs(sgv0, sgv1, mode) 389 | return sgv0, sgv1 390 | 391 | 392 | def swap_inputs(sgv0, sgv1): 393 | """Swap all the inputs of sgv0 and sgv1 (see reroute_inputs).""" 394 | return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap) 395 | 396 | 397 | def reroute_inputs(sgv0, sgv1): 398 | """Re-route all the inputs of two subgraphs. 399 | 400 | Args: 401 | sgv0: the first subgraph to have its inputs swapped. This argument is 402 | converted to a subgraph using the same rules than the function 403 | subgraph.make_view. 404 | sgv1: the second subgraph to have its inputs swapped. This argument is 405 | converted to a subgraph using the same rules than the function 406 | subgraph.make_view. 407 | Returns: 408 | A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. 409 | Note that the function argument sgv0 and sgv1 are also modified in place. 410 | Raises: 411 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 412 | the same rules than the function subgraph.make_view. 413 | """ 414 | return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) 415 | 416 | 417 | def swap_outputs(sgv0, sgv1): 418 | """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" 419 | return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) 420 | 421 | 422 | def reroute_outputs(sgv0, sgv1): 423 | """Re-route all the outputs of two operations. 424 | 425 | Args: 426 | sgv0: the first subgraph to have its outputs swapped. This argument is 427 | converted to a subgraph using the same rules than the function 428 | subgraph.make_view. 429 | sgv1: the second subgraph to have its outputs swapped. This argument is 430 | converted to a subgraph using the same rules than the function 431 | subgraph.make_view. 432 | Returns: 433 | A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. 434 | Note that the function argument sgv0 and sgv1 are also modified in place. 435 | Raises: 436 | StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 437 | the same rules than the function subgraph.make_view. 438 | """ 439 | return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) 440 | 441 | 442 | def swap_ios(sgv0, sgv1): 443 | """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" 444 | return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) 445 | 446 | 447 | def reroute_ios(sgv0, sgv1): 448 | """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" 449 | return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) 450 | 451 | 452 | def remove_control_inputs(op, cops): 453 | """Remove the control inputs cops from co. 454 | 455 | Warning: this function is directly manipulating the internals of the 456 | `gde.Graph`. 457 | 458 | Args: 459 | op: a `gde.Node` from which to remove the control inputs. 460 | cops: an object convertible to a list of `gde.Node`. 461 | Raises: 462 | TypeError: if op is not a `gde.Node`. 463 | ValueError: if any cop in cops is not a control input of op. 464 | """ 465 | if not isinstance(op, node.Node): 466 | raise TypeError("Expected a gde.Node, got: {}", type(op)) 467 | cops = util.make_list_of_op(cops, allow_graph=False) 468 | for cop in cops: 469 | if cop not in op.control_inputs: 470 | raise ValueError("{} is not a control_input of {}".format(op.name, 471 | cop.name)) 472 | control_inputs = [cop for cop in op.control_inputs if cop not in cops] 473 | op.set_control_inputs(control_inputs) 474 | 475 | 476 | def add_control_inputs(op, cops): 477 | """Add the control inputs cops to op. 478 | 479 | Warning: this function is directly manipulating the internals of the `gde.Graph`. 480 | 481 | Args: 482 | op: a `gde.Node` to which the control inputs are added. 483 | cops: an object convertible to a list of `gde.Node`. 484 | Raises: 485 | TypeError: if op is not a `gde.Node` 486 | ValueError: if any cop in cops is already a control input of op. 487 | """ 488 | if not isinstance(op, node.Node): 489 | raise TypeError("Expected a gde.Node, got: {}", type(op)) 490 | cops = util.make_list_of_op(cops, allow_graph=False) 491 | for cop in cops: 492 | if cop in op.control_inputs: 493 | raise ValueError("{} is already a control_input of {}".format(cop.name, 494 | op.name)) 495 | op.set_control_inputs(itertools.chain(op.control_inputs, cops)) 496 | op.graph.increment_version_counter() 497 | --------------------------------------------------------------------------------