├── README.md ├── example.png └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # RelayViz: Visualize TVM Relay program graph 2 | 3 | I wrote this script to address my pain point: I found it difficult to make sense of what's going on 4 | in a [Relay](https://docs.tvm.ai/dev/relay_intro.html) program. 5 | 6 | ![Relay visualization example](example.png "Relay visualization example") 7 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hcho3/relayviz/208a25715d1b94d8d1d5fd6a704e8d163375afb1/example.png -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | import tvm.relay as relay 3 | import tvm.relay.testing as testing 4 | from graphviz import Digraph 5 | 6 | batch_size = 1 7 | num_class = 1000 8 | image_shape = (3, 28, 28) 9 | 10 | mod, params = testing.resnet.get_workload( 11 | num_layers=8, batch_size=batch_size, image_shape=image_shape) 12 | 13 | def _traverse_expr(node, node_dict): 14 | if node in node_dict: 15 | return 16 | if isinstance(node, relay.op.op.Op): 17 | return 18 | node_dict[node] = len(node_dict) 19 | 20 | dot = Digraph(format='svg') 21 | dot.attr(rankdir='BT') 22 | dot.attr('node', shape='box') 23 | 24 | node_dict = {} 25 | relay.analysis.post_order_visit(mod['main'], lambda node: _traverse_expr(node, node_dict)) 26 | for node, node_idx in node_dict.items(): 27 | if isinstance(node, relay.expr.Var): 28 | print(f'node_idx: {node_idx}, Var(name={node.name_hint}, type=Tensor[{tuple(node.type_annotation.shape)}, {node.type_annotation.dtype}])') 29 | dot.node(str(node_idx), f'{node.name_hint}:\nTensor[{tuple(node.type_annotation.shape)}, {node.type_annotation.dtype}]') 30 | elif isinstance(node, relay.expr.Call): 31 | args = [node_dict[arg] for arg in node.args] 32 | print(f'node_idx: {node_idx}, Call(op_name={node.op.name}, args={args})') 33 | dot.node(str(node_idx), f'Call(op={node.op.name})') 34 | for arg in args: 35 | dot.edge(str(arg), str(node_idx)) 36 | elif isinstance(node, relay.expr.Function): 37 | print(f'node_idx: {node_idx}, Function(body={node_dict[node.body]})') 38 | dot.node(str(node_idx), f'Function') 39 | dot.edge(str(node_dict[node.body]), str(node_idx)) 40 | elif isinstance(node, relay.expr.TupleGetItem): 41 | print(f'node_idx: {node_idx}, TupleGetItem(tuple={node_dict[node.tuple_value]}, idx={node.index})') 42 | dot.node(str(node_idx), f'TupleGetItem(idx={node.index})') 43 | dot.edge(str(node_dict[node.tuple_value]), str(node_idx)) 44 | else: 45 | raise RuntimeError(f'Unknown node type. node_idx: {node_idx}, node: {type(node)}') 46 | 47 | print(dot.render()) 48 | --------------------------------------------------------------------------------