├── README.md ├── arch.png ├── net.cfg └── nnplot.py /README.md: -------------------------------------------------------------------------------- 1 | # nnplot 2 | python script for plotting various neural network architectures 3 | 4 | for now, edit the "io" and "layers" variables in nnplot.py 5 | 6 | Some example output ... 7 | ![Alt text](arch.png?raw=true "Title") 8 | -------------------------------------------------------------------------------- /arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osh/nnplot/cadf6a182f169d4c9bd403c300fd7282fa513972/arch.png -------------------------------------------------------------------------------- /net.cfg: -------------------------------------------------------------------------------- 1 | { 2 | "io": [ 3 | [ "Input", [ 2, 128 ] ], 4 | [ "Output", [ 2, 128 ] ] 5 | ], 6 | "layers": [ 7 | [ "Conv2D(1,1,40)\nLinear", 0.5, 40, 5 ], 8 | [ "Dense(44)\nRelu", null, 10, 2 ], 9 | [ "Dense(2*88)\nRelu", null, 15, 3 ], 10 | [ "Conv2D(1,1,81)\nLinear", 0.5, 40, 5 ] 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /nnplot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # NN Plot 4 | # Tim O'Shea (c) 2016 5 | # 6 | # 7 | # 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | import matplotlib.lines as lines 11 | 12 | # load the config def 13 | import sys,json 14 | fn = sys.argv[1] 15 | print fn 16 | cfg = json.loads(open(fn).read()) 17 | io = cfg["io"] 18 | layers = cfg["layers"] 19 | 20 | # ... 21 | maxy = max(map(lambda x: x[3], layers)) 22 | print maxy 23 | fig = plt.figure() 24 | ax = fig.add_subplot(111) 25 | for i,l in enumerate(layers): 26 | ll=0.25 + i 27 | w = 0.5 28 | h = layers[i][3] 29 | lr=(maxy-h)/2.0 30 | 31 | print (ll,lr,w,h) 32 | ax.add_patch( 33 | patches.Rectangle( (ll, lr), w, h, fill=False) 34 | ) 35 | ax.annotate(l[0], xy=(ll+w/2.0, lr+h/2.0), 36 | ha='center', va='center', 37 | rotation=90) 38 | 39 | if(i < len(layers)-1): 40 | 41 | h_n = layers[i+1][3] 42 | lr_n=(maxy-h_n)/2.0 43 | i_x = ll+w 44 | o_x = ll+1 45 | 46 | for j in range(0,layers[i][2]): 47 | for k in range(0,layers[i+1][2]): 48 | j_rel = (j*1.0/(layers[i][2]-1)) 49 | k_rel = (k*1.0/(layers[i+1][2]-1)) 50 | if(l[1]==None or abs(j_rel - k_rel)