├── .gitignore ├── LICENSE ├── NetPLOT_demo.gif ├── Netplot.png ├── README.md ├── build └── lib │ └── netplot.py ├── dist ├── netplot-0.1.2-py3-none-any.whl └── netplot-0.1.2.tar.gz ├── requirements.txt ├── screenshot demo ├── ModelPlot 3D with grid.png └── model_summary.png ├── setup.py └── src ├── __init__.py ├── __pycache__ └── netplot.cpython-39.pyc ├── netplot.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt └── netplot.py /.gitignore: -------------------------------------------------------------------------------- 1 | /venv/ 2 | /nn_architecture_plot.iml 3 | /.idea/ 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Souvik Pratiher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NetPLOT_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/NetPLOT_demo.gif -------------------------------------------------------------------------------- /Netplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/Netplot.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![NETPLOT](https://github.com/Spratiher9/Netplot/blob/5a7b0807114bd858deeb99e17c893b749ab95b93/Netplot.png) 2 | ## Netplot🚀 [![Downloads](https://static.pepy.tech/personalized-badge/netplot?period=total&units=international_system&left_color=black&right_color=orange&left_text=PYPI%20Downloads)](https://pepy.tech/project/netplot) 3 | A ultra-lightweight 3D renderer of the Tensorflow/Keras neural network architectures. 4 | This Library is working on Matplotlib visualization for now. In future the visualization can be moved to plotly 5 | for a more interactive visual of the neural network architecture. 6 | 7 | **Note:** *For now the rendering is working in Jupyter only Google Colab support is in works.* 8 | 9 | For more details visit [NetPlot](https://pypi.org/project/netplot/0.1.2/) 10 | 11 | ### How to use it 12 | 13 | ![NetPLOT DEMO Notebook](https://github.com/Spratiher9/Netplot/blob/1e16251651d4c947c7a33fd7bac2f7701d7d162b/NetPLOT_demo.gif) 14 | 15 | ### Install with Pip 16 | 17 | ```python 18 | pip install netplot 19 | ``` 20 | 21 | ### Notebook Codelets 22 | 23 | ```python 24 | from netplot import ModelPlot 25 | import tensorflow as tf 26 | import numpy as np 27 | ``` 28 | 29 | ```python 30 | %matplotlib notebook 31 | ``` 32 | 33 | ```python 34 | X_input = tf.keras.layers.Input(shape=(32,32,3)) 35 | X = tf.keras.layers.Conv2D(4, 3, activation='relu')(X_input) 36 | X = tf.keras.layers.MaxPool2D(2,2)(X) 37 | X = tf.keras.layers.Conv2D(16, 3, activation='relu')(X) 38 | X = tf.keras.layers.MaxPool2D(2,2)(X) 39 | X = tf.keras.layers.Conv2D(8, 3, activation='relu')(X) 40 | X = tf.keras.layers.MaxPool2D(2,2)(X) 41 | X = tf.keras.layers.Flatten()(X) 42 | X = tf.keras.layers.Dense(10, activation='relu')(X) 43 | X = tf.keras.layers.Dense(2, activation='softmax')(X) 44 | 45 | model = tf.keras.models.Model(inputs=X_input, outputs=X) 46 | ``` 47 | ```python 48 | modelplot = ModelPlot(model=model, grid=True, connection=True, linewidth=0.1) 49 | modelplot.show() 50 | ``` 51 | ![Keras Model Summarized](https://github.com/Spratiher9/Netplot/blob/master/screenshot%20demo/model_summary.png) 52 | ![Keras Model Visualized](https://github.com/Spratiher9/Netplot/blob/master/screenshot%20demo/ModelPlot%203D%20with%20grid.png) 53 | -------------------------------------------------------------------------------- /build/lib/netplot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | class ModelPlot(object): 6 | """ModelPlot""" 7 | 8 | def __init__(self, model, grid=True, connection=False, linewidth=0.1): 9 | super(ModelPlot, self).__init__() 10 | 11 | self.model = model 12 | self.grid = grid 13 | self.connection = connection 14 | self.linewidth = linewidth 15 | 16 | def _layer(self, shape, name): 17 | """add more feature on layers""" 18 | lay_shape = None 19 | lay_name = None 20 | lay_color = None 21 | lay_marker = None 22 | 23 | if len(shape) == 1: 24 | lay_shape = (shape[0], 1, 1) 25 | elif len(shape) == 2: 26 | lay_shape = (shape[0], shape[1], 1) 27 | else: 28 | if name == 'MaxPooling2D' or name == 'AveragePooling2D': 29 | lay_shape = (shape[0], shape[1], 1) 30 | else: 31 | lay_shape = shape 32 | 33 | lay_name = name 34 | 35 | if len(lay_shape) == 3 and lay_shape[-1] == 3: 36 | lay_color = 'rgb' 37 | lay_marker = 'o' 38 | else: 39 | if lay_name == 'InputLayer': 40 | lay_color = 'r' 41 | lay_marker = 'o' 42 | elif lay_name == 'Conv2D': 43 | lay_color = 'y' 44 | lay_marker = '^' 45 | elif lay_name == 'MaxPooling2D' or lay_name == 'AveragePooling2D': 46 | lay_color = 'c' 47 | lay_marker = '.' 48 | else: 49 | lay_color = 'g' 50 | lay_marker = '.' 51 | 52 | return {'shape': lay_shape, 'name': lay_name, 'color': lay_color, 'marker': lay_marker} 53 | 54 | def _model2layer(self): 55 | """fatch layers name and shape from model""" 56 | layers = [] 57 | 58 | for i in self.model.layers: 59 | name = str(i.with_name_scope).split('.')[-1][:-3] 60 | if name == 'InputLayer': 61 | shape = i.input_shape[0][1:] 62 | elif name == 'MaxPooling2D': 63 | shape = i.input_shape[1:] 64 | else: 65 | shape = i.output_shape[1:] 66 | layers.append((tuple(shape), name)) 67 | 68 | return layers 69 | 70 | def _shape2array(self, shape, layers_len, xy_max): 71 | """create shape to array/matrix""" 72 | x = shape[0] 73 | y = shape[1] 74 | z = shape[2] 75 | 76 | single_layer = [] 77 | 78 | if xy_max[0] < x: 79 | xy_max[0] = x 80 | if xy_max[1] < y: 81 | xy_max[1] = y 82 | 83 | for k in range(z): 84 | arr_x, arr_y, arr_z = [], [], [] 85 | 86 | for i in range(y): 87 | ox = [j for j in range(x)] 88 | arr_x.append(ox) 89 | 90 | for i in range(y): 91 | oy = [j for j in (np.ones(x, dtype=int) * i)] 92 | arr_y.append(oy) 93 | 94 | for i in range(y): 95 | oz = [j for j in (np.ones(y, dtype=int) * layers_len)] 96 | arr_z.append(oz) 97 | 98 | layers_len += 2 99 | single_layer.append([arr_x, arr_y, arr_z]) 100 | 101 | layers_len += 4 102 | 103 | return single_layer, layers_len, xy_max 104 | 105 | def _dense(self, ax, x1=1, x2=1, y1=1, y2=1, x11=1, x21=1, y11=1, y21=1, z1=1, z2=1, c='r'): 106 | """plot connection between dense units""" 107 | for i in np.arange(x1, x2 + 1, 1): 108 | for j in np.arange(x11, x21 + 1, 1): 109 | for k in np.arange(y1, y2 + 1, 1): 110 | for l in np.arange(y11, y21 + 1, 1): 111 | ax.plot([i, j], [z1, z2], [k, l], c=c, linewidth=self.linewidth) 112 | 113 | def _plot_dots(self, layers_array, layers_name, layers_color, layers_marker, ax, xy_max): 114 | """plot layers units as dots""" 115 | temp = True 116 | last_a, last_b, last_c = [0, 0], [0, 0], [0, 0] 117 | 118 | for layer, name, color_in, marker in zip(layers_array, layers_name, layers_color, layers_marker): 119 | line_x, line_y, line_z = [], [], [] 120 | color_count = 0 121 | 122 | for j in layer: 123 | my_x, my_y, my_z = [], [], [] 124 | temp_list_l = [] 125 | 126 | for k in j[0]: 127 | k = [a + ((xy_max[0] - len(k)) / 2) for a in k] 128 | my_x += k 129 | 130 | line_x.append([k[0], k[-1]]) 131 | 132 | for l in j[1]: 133 | l = [b + ((xy_max[1] - (j[1][-1][-1] + 1)) / 2) for b in l] 134 | my_y += l 135 | temp_list_l.append(l[0]) 136 | 137 | line_y.append([temp_list_l[0], temp_list_l[-1]]) 138 | 139 | for k in j[2]: 140 | my_z += k 141 | 142 | line_z.append([k[0], k[-1]]) 143 | 144 | if color_in == 'rgb': 145 | color = color_in[color_count] 146 | color_count += 1 147 | else: 148 | color = color_in 149 | 150 | ax.scatter(my_x, my_z, my_y, c=color, marker=marker, s=20) 151 | 152 | if self.connection: 153 | if name == 'Dense' or name == 'Flatten': 154 | for c in line_z: 155 | a, b, c = line_x[0], line_y[0], c 156 | if temp: 157 | temp = False 158 | last_a, last_b, last_c = a, b, c 159 | continue 160 | 161 | if color_in == 'rgb': 162 | color = color_in[color_count] 163 | color_count += 1 164 | 165 | else: 166 | color = color_in 167 | 168 | self._dense(ax, a[0], a[1], b[0], b[1], last_a[0], last_a[1], last_b[0], last_b[1], c[0], 169 | last_c[0], c=color) 170 | last_a, last_b, last_c = a, b, c 171 | 172 | def show(self): 173 | fig = plt.figure(figsize=(10, 9)) 174 | ax = fig.add_subplot(111, projection='3d') 175 | 176 | layers_len = 0 177 | layers_array = [] 178 | layers_name = [] 179 | layers_marker = [] 180 | layers_color = [] 181 | xy_max = [0, 0] 182 | 183 | # convert model to layers 184 | layers = self._model2layer() 185 | 186 | # create layers array 187 | for lay in layers: 188 | layer_dict = self._layer(lay[0], lay[1]) 189 | single_layer, layers_len, xy_max = self._shape2array(layer_dict['shape'], layers_len, xy_max) 190 | 191 | layers_array.append(single_layer) 192 | layers_name.append(layer_dict['name']) 193 | layers_color.append(layer_dict['color']) 194 | layers_marker.append(layer_dict['marker']) 195 | 196 | # plot dots and lines 197 | self._plot_dots(layers_array, layers_name, layers_color, layers_marker, ax, xy_max) 198 | 199 | # Hide axes ticks 200 | if self.grid == False: 201 | ax.grid(False) 202 | plt.axis('off') 203 | print(self.model.summary()) 204 | plt.show() 205 | -------------------------------------------------------------------------------- /dist/netplot-0.1.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/dist/netplot-0.1.2-py3-none-any.whl -------------------------------------------------------------------------------- /dist/netplot-0.1.2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/dist/netplot-0.1.2.tar.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/requirements.txt -------------------------------------------------------------------------------- /screenshot demo/ModelPlot 3D with grid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/screenshot demo/ModelPlot 3D with grid.png -------------------------------------------------------------------------------- /screenshot demo/model_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/screenshot demo/model_summary.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="netplot", # This is the name of the package 9 | version="0.1.2", # The initial release version 10 | author="Souvik Pratiher", # Full name of the author 11 | description="Ultralight 3D renderer of neural network architecture for TF/Keras Models", 12 | long_description=long_description, # Long description read from the the readme file 13 | long_description_content_type="text/markdown", 14 | url = "https://github.com/Spratiher9/Netplot", 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], # Information to filter the project on PyPi website 20 | python_requires='>=3.6', # Minimum version requirement of the package 21 | py_modules=["netplot"], # Name of the python package 22 | package_dir={'':'src'}, # Directory of the source code of the package 23 | install_requires=[ 24 | "cycler==0.10.0", 25 | "kiwisolver==1.3.2", 26 | "matplotlib==3.4.3", 27 | "numpy==1.21.2", 28 | "Pillow==8.3.2", 29 | "pyparsing==2.4.7", 30 | "python-dateutil==2.8.2", 31 | "six==1.16.0" 32 | ] # Install other dependencies if any 33 | ) 34 | 35 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from src.netplot import * 2 | -------------------------------------------------------------------------------- /src/__pycache__/netplot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spratiher9/Netplot/4c07b6b86450d0fc7dd36f08e0c24ca67dd5330e/src/__pycache__/netplot.cpython-39.pyc -------------------------------------------------------------------------------- /src/netplot.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: netplot 3 | Version: 0.1.2 4 | Summary: Ultralight 3D renderer of neural network architecture for TF/Keras Models 5 | Home-page: https://github.com/Spratiher9/Netplot 6 | Author: Souvik Pratiher 7 | License: UNKNOWN 8 | Platform: UNKNOWN 9 | Classifier: Programming Language :: Python :: 3 10 | Classifier: License :: OSI Approved :: MIT License 11 | Classifier: Operating System :: OS Independent 12 | Requires-Python: >=3.6 13 | Description-Content-Type: text/markdown 14 | License-File: LICENSE 15 | 16 | ## Netplot 17 | A ultra-lightweight 3D renderer of the Tensorflow/Keras neural network architectures. 18 | This Library is working on Matplotlib visualization for now. In future the visualization can be moved to plotly 19 | for a more interactive visual of the neural network architecture. 20 | 21 | **Note:** *For now the rendering is working in Jupyter only Google Colab support is in works.* 22 | 23 | For more details visit [NetPlot](https://github.com/Spratiher9/Netplot.git) 24 | 25 | ### Install with Pip 26 | 27 | ```python 28 | pip install netplot 29 | ``` 30 | 31 | ### Usage guide 32 | 33 | ```python 34 | from netplot import ModelPlot 35 | import tensorflow as tf 36 | import numpy as np 37 | ``` 38 | 39 | ```python 40 | %matplotlib notebook 41 | ``` 42 | 43 | ```python 44 | X_input = tf.keras.layers.Input(shape=(32,32,3)) 45 | X = tf.keras.layers.Conv2D(4, 3, activation='relu')(X_input) 46 | X = tf.keras.layers.MaxPool2D(2,2)(X) 47 | X = tf.keras.layers.Conv2D(16, 3, activation='relu')(X) 48 | X = tf.keras.layers.MaxPool2D(2,2)(X) 49 | X = tf.keras.layers.Conv2D(8, 3, activation='relu')(X) 50 | X = tf.keras.layers.MaxPool2D(2,2)(X) 51 | X = tf.keras.layers.Flatten()(X) 52 | X = tf.keras.layers.Dense(10, activation='relu')(X) 53 | X = tf.keras.layers.Dense(2, activation='softmax')(X) 54 | 55 | model = tf.keras.models.Model(inputs=X_input, outputs=X) 56 | ``` 57 | ```python 58 | modelplot = ModelPlot(model=model, grid=True, connection=True, linewidth=0.1) 59 | modelplot.show() 60 | ``` 61 | ![Keras Model Summarized](https://github.com/Spratiher9/Netplot/blob/master/screenshot%20demo/model_summary.png) 62 | ![Keras Model Visualized](https://github.com/Spratiher9/Netplot/blob/master/screenshot%20demo/ModelPlot%203D%20with%20grid.png) 63 | 64 | 65 | -------------------------------------------------------------------------------- /src/netplot.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | src/netplot.py 5 | src/netplot.egg-info/PKG-INFO 6 | src/netplot.egg-info/SOURCES.txt 7 | src/netplot.egg-info/dependency_links.txt 8 | src/netplot.egg-info/requires.txt 9 | src/netplot.egg-info/top_level.txt -------------------------------------------------------------------------------- /src/netplot.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/netplot.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.3.2 3 | matplotlib==3.4.3 4 | numpy==1.21.2 5 | Pillow==8.3.2 6 | pyparsing==2.4.7 7 | python-dateutil==2.8.2 8 | six==1.16.0 9 | -------------------------------------------------------------------------------- /src/netplot.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | netplot 2 | -------------------------------------------------------------------------------- /src/netplot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | class ModelPlot(object): 6 | """ModelPlot""" 7 | 8 | def __init__(self, model, grid=True, connection=False, linewidth=0.1): 9 | super(ModelPlot, self).__init__() 10 | 11 | self.model = model 12 | self.grid = grid 13 | self.connection = connection 14 | self.linewidth = linewidth 15 | 16 | def _layer(self, shape, name): 17 | """add more feature on layers""" 18 | lay_shape = None 19 | lay_name = None 20 | lay_color = None 21 | lay_marker = None 22 | 23 | if len(shape) == 1: 24 | lay_shape = (shape[0], 1, 1) 25 | elif len(shape) == 2: 26 | lay_shape = (shape[0], shape[1], 1) 27 | else: 28 | if name == 'MaxPooling2D' or name == 'AveragePooling2D': 29 | lay_shape = (shape[0], shape[1], 1) 30 | else: 31 | lay_shape = shape 32 | 33 | lay_name = name 34 | 35 | if len(lay_shape) == 3 and lay_shape[-1] == 3: 36 | lay_color = 'rgb' 37 | lay_marker = 'o' 38 | else: 39 | if lay_name == 'InputLayer': 40 | lay_color = 'r' 41 | lay_marker = 'o' 42 | elif lay_name == 'Conv2D': 43 | lay_color = 'y' 44 | lay_marker = '^' 45 | elif lay_name == 'MaxPooling2D' or lay_name == 'AveragePooling2D': 46 | lay_color = 'c' 47 | lay_marker = '.' 48 | else: 49 | lay_color = 'g' 50 | lay_marker = '.' 51 | 52 | return {'shape': lay_shape, 'name': lay_name, 'color': lay_color, 'marker': lay_marker} 53 | 54 | def _model2layer(self): 55 | """fatch layers name and shape from model""" 56 | layers = [] 57 | 58 | for i in self.model.layers: 59 | name = str(i.with_name_scope).split('.')[-1][:-3] 60 | if name == 'InputLayer': 61 | shape = i.input_shape[0][1:] 62 | elif name == 'MaxPooling2D': 63 | shape = i.input_shape[1:] 64 | else: 65 | shape = i.output_shape[1:] 66 | layers.append((tuple(shape), name)) 67 | 68 | return layers 69 | 70 | def _shape2array(self, shape, layers_len, xy_max): 71 | """create shape to array/matrix""" 72 | x = shape[0] 73 | y = shape[1] 74 | z = shape[2] 75 | 76 | single_layer = [] 77 | 78 | if xy_max[0] < x: 79 | xy_max[0] = x 80 | if xy_max[1] < y: 81 | xy_max[1] = y 82 | 83 | for k in range(z): 84 | arr_x, arr_y, arr_z = [], [], [] 85 | 86 | for i in range(y): 87 | ox = [j for j in range(x)] 88 | arr_x.append(ox) 89 | 90 | for i in range(y): 91 | oy = [j for j in (np.ones(x, dtype=int) * i)] 92 | arr_y.append(oy) 93 | 94 | for i in range(y): 95 | oz = [j for j in (np.ones(y, dtype=int) * layers_len)] 96 | arr_z.append(oz) 97 | 98 | layers_len += 2 99 | single_layer.append([arr_x, arr_y, arr_z]) 100 | 101 | layers_len += 4 102 | 103 | return single_layer, layers_len, xy_max 104 | 105 | def _dense(self, ax, x1=1, x2=1, y1=1, y2=1, x11=1, x21=1, y11=1, y21=1, z1=1, z2=1, c='r'): 106 | """plot connection between dense units""" 107 | for i in np.arange(x1, x2 + 1, 1): 108 | for j in np.arange(x11, x21 + 1, 1): 109 | for k in np.arange(y1, y2 + 1, 1): 110 | for l in np.arange(y11, y21 + 1, 1): 111 | ax.plot([i, j], [z1, z2], [k, l], c=c, linewidth=self.linewidth) 112 | 113 | def _plot_dots(self, layers_array, layers_name, layers_color, layers_marker, ax, xy_max): 114 | """plot layers units as dots""" 115 | temp = True 116 | last_a, last_b, last_c = [0, 0], [0, 0], [0, 0] 117 | 118 | for layer, name, color_in, marker in zip(layers_array, layers_name, layers_color, layers_marker): 119 | line_x, line_y, line_z = [], [], [] 120 | color_count = 0 121 | 122 | for j in layer: 123 | my_x, my_y, my_z = [], [], [] 124 | temp_list_l = [] 125 | 126 | for k in j[0]: 127 | k = [a + ((xy_max[0] - len(k)) / 2) for a in k] 128 | my_x += k 129 | 130 | line_x.append([k[0], k[-1]]) 131 | 132 | for l in j[1]: 133 | l = [b + ((xy_max[1] - (j[1][-1][-1] + 1)) / 2) for b in l] 134 | my_y += l 135 | temp_list_l.append(l[0]) 136 | 137 | line_y.append([temp_list_l[0], temp_list_l[-1]]) 138 | 139 | for k in j[2]: 140 | my_z += k 141 | 142 | line_z.append([k[0], k[-1]]) 143 | 144 | if color_in == 'rgb': 145 | color = color_in[color_count] 146 | color_count += 1 147 | else: 148 | color = color_in 149 | 150 | ax.scatter(my_x, my_z, my_y, c=color, marker=marker, s=20) 151 | 152 | if self.connection: 153 | if name == 'Dense' or name == 'Flatten': 154 | for c in line_z: 155 | a, b, c = line_x[0], line_y[0], c 156 | if temp: 157 | temp = False 158 | last_a, last_b, last_c = a, b, c 159 | continue 160 | 161 | if color_in == 'rgb': 162 | color = color_in[color_count] 163 | color_count += 1 164 | 165 | else: 166 | color = color_in 167 | 168 | self._dense(ax, a[0], a[1], b[0], b[1], last_a[0], last_a[1], last_b[0], last_b[1], c[0], 169 | last_c[0], c=color) 170 | last_a, last_b, last_c = a, b, c 171 | 172 | def show(self): 173 | fig = plt.figure(figsize=(10, 9)) 174 | ax = fig.add_subplot(111, projection='3d') 175 | 176 | layers_len = 0 177 | layers_array = [] 178 | layers_name = [] 179 | layers_marker = [] 180 | layers_color = [] 181 | xy_max = [0, 0] 182 | 183 | # convert model to layers 184 | layers = self._model2layer() 185 | 186 | # create layers array 187 | for lay in layers: 188 | layer_dict = self._layer(lay[0], lay[1]) 189 | single_layer, layers_len, xy_max = self._shape2array(layer_dict['shape'], layers_len, xy_max) 190 | 191 | layers_array.append(single_layer) 192 | layers_name.append(layer_dict['name']) 193 | layers_color.append(layer_dict['color']) 194 | layers_marker.append(layer_dict['marker']) 195 | 196 | # plot dots and lines 197 | self._plot_dots(layers_array, layers_name, layers_color, layers_marker, ax, xy_max) 198 | 199 | # Hide axes ticks 200 | if self.grid == False: 201 | ax.grid(False) 202 | plt.axis('off') 203 | print(self.model.summary()) 204 | plt.show() 205 | --------------------------------------------------------------------------------