├── .gitignore
├── LICENSE
├── README.md
├── docker
├── Dockerfile
├── build.bash
└── run.bash
├── onnxgraphqt
├── __init__.py
├── __main__.py
├── data
│ ├── icon
│ │ ├── icon.png
│ │ ├── sam4onnx.png
│ │ ├── sbi4onnx.png
│ │ ├── scc4onnx.png
│ │ ├── scs4onnx.png
│ │ ├── sio4onnx.png
│ │ ├── sna4onnx.png
│ │ ├── snc4onnx.png
│ │ ├── snd4onnx.png
│ │ ├── sne4onnx.png
│ │ ├── soc4onnx.png
│ │ ├── sog4onnx.png
│ │ └── sor4onnx.png
│ ├── mobilenetv2-12-int8.onnx
│ ├── mobilenetv2-7.onnx
│ ├── onnx_opsets.json
│ └── splash.png
├── graph
│ ├── __init__.py
│ ├── autolayout
│ │ ├── __init__.py
│ │ └── sugiyama_layout.py
│ ├── onnx_node.py
│ └── onnx_node_graph.py
├── main_window.py
├── utils
│ ├── __init__.py
│ ├── color.py
│ ├── dtype.py
│ ├── operators.py
│ ├── opset.py
│ ├── style.py
│ └── widgets.py
└── widgets
│ ├── __init__.py
│ ├── custom_node_item.py
│ ├── custom_properties.py
│ ├── custom_properties_bin.py
│ ├── splash_screen.py
│ ├── widgets_add_node.py
│ ├── widgets_change_channel.py
│ ├── widgets_change_input_ouput_shape.py
│ ├── widgets_change_opset.py
│ ├── widgets_combine_network.py
│ ├── widgets_constant_shrink.py
│ ├── widgets_delete_node.py
│ ├── widgets_extract_network.py
│ ├── widgets_generate_operator.py
│ ├── widgets_inference_test.py
│ ├── widgets_initialize_batchsize.py
│ ├── widgets_menubar.py
│ ├── widgets_message_box.py
│ ├── widgets_modify_attrs.py
│ ├── widgets_node_search.py
│ └── widgets_rename_op.py
├── pyproject.toml
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .vscode/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 fateshelled
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OnnxGraphQt
2 |
3 | ONNX model visualizer. You can edit model structure with GUI!
4 |
5 | 
6 | 
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Requirements
14 | - [NodeGraphQt](https://github.com/jchanvfx/NodeGraphQt)
15 | - PySide2
16 | - Qt.py
17 | - Numpy
18 | - Pillow
19 | - onnx
20 | - onnx-simplifier
21 | - onnx_graphsurgeon
22 | - [simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools)
23 | - networkx
24 | - grandalf
25 |
26 | ## Install
27 | ```bash
28 | sudo apt install python3-pyside2*
29 |
30 | git clone https://github.com/fateshelled/OnnxGraphQt
31 | cd OnnxGraphQt
32 | python3 -m pip install -U nvidia-pyindex
33 | python3 -m pip install -U Qt.py
34 | # If you want to use InferenceTest, install onnxruntime or onnxruntime-gpu
35 | # python3 -m pip install -U onnxruntime
36 | # python3 -m pip install -U onnxruntime-gpu
37 | python3 -m pip install -U -r requirements.txt
38 |
39 | # Install OnnxGraphQt
40 | python3 -m pip install .
41 | ```
42 |
43 | ## Run with Docker
44 | ```bash
45 | git clone https://github.com/fateshelled/OnnxGraphQt
46 | cd OnnxGraphQt
47 | # build docker image
48 | ./docker/build.bash
49 | # run
50 | ./docker/run.bash
51 | ```
52 |
53 | ## Usage
54 | ```bash
55 | # Open empty graph
56 | onnxgraphqt
57 |
58 | # Open with onnx model
59 | onnxgraphqt onnxgraphqt/data/mobilenetv2-7.onnx
60 |
61 | ```
62 |
63 | 
64 |
65 |
66 | ### Open Onnx Model
67 | Open file dialog from menubar(File - Open) or drag and drop from file manager to main window.
68 |
69 | Sample model is available at `ONNXGraphQt/onnxgraphqt/data/mobilenetv2-7.onnx`
70 |
71 | 
72 |
73 |
74 | ### Export
75 | Export to ONNX file or Json file.
76 |
77 | ### Node detail
78 | Double click on Node for more information.
79 |
80 | 
81 |
82 | ### Node Search
83 | Node search window can be open from menubar(View - Search).
84 | You can search node by name, type, input or output name.
85 |
86 | 
87 |
88 | ### [simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools)
89 |
90 | Please refer to each tool's Github repository for detailed usage.
91 |
92 | - Generate Operator [[sog4onnx](https://github.com/PINTO0309/sog4onnx)]
93 | - Add Node [[sna4onnx](https://github.com/PINTO0309/sna4onnx)]
94 | - Combine Network [[snc4onnx](https://github.com/PINTO0309/snc4onnx)]
95 | - Extract Network [[sne4onnx](https://github.com/PINTO0309/sne4onnx)]
96 | - Rename Operator [[sor4onnx](https://github.com/PINTO0309/sor4onnx)]
97 | - Modify Attributes and Constant [[sam4onnx](https://github.com/PINTO0309/sam4onnx)]
98 | - Input Channel Conversion [[scc4onnx](https://github.com/PINTO0309/scc4onnx)]
99 | - Initialize Batchsize [[sbi4onnx](https://github.com/PINTO0309/sbi4onnx)]
100 | - Change Opset [[soc4onnx](https://github.com/PINTO0309/soc4onnx)]
101 | - Constant Value Shrink [[scs4onnx](https://github.com/PINTO0309/scs4onnx)]
102 | - Delete Node [[snd4onnx](https://github.com/PINTO0309/snd4onnx)])]
103 | - Inference Test [[sit4onnx](https://github.com/PINTO0309/sit4onnx)]
104 | - Change the INPUT and OUTPUT shape [[sio4onnx](https://github.com/PINTO0309/sio4onnx)]
105 |
106 |
107 | ## ToDo
108 | - [ ] Add Simple Structure Checker[[ssc4onnx](https://github.com/PINTO0309/ssc4onnx)]
109 |
110 |
111 | ## References
112 | - https://github.com/jchanvfx/NodeGraphQt
113 | - https://github.com/lutzroeder/netron
114 | - https://github.com/PINTO0309/simple-onnx-processing-tools
115 | - https://fdwr.github.io/LostOnnxDocs/OperatorFormulas.html
116 | - https://github.com/onnx/onnx/blob/main/docs/Operators.md
117 |
118 |
119 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:20.04
2 |
3 | ARG USERNAME=onnxgraphqt
4 | ARG HOME=/home/${USERNAME}
5 | ARG GID=1000
6 | ARG UID=1000
7 |
8 | RUN groupadd -f -g ${GID} ${USERNAME} && \
9 | useradd -m -s /bin/bash -u ${UID} -g ${GID} -G sudo ${USERNAME} && \
10 | echo ${USERNAME}:${USERNAME} | chpasswd && \
11 | echo "${USERNAME} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
12 |
13 | ENV DEBIAN_FRONTEND noninteractive
14 | RUN apt update && \
15 | apt install -y \
16 | git \
17 | bash-completion \
18 | python3-dev \
19 | python3-pip \
20 | python3-pyside2* && \
21 | apt clean && \
22 | rm -rf /var/lib/apt/lists/*
23 |
24 | USER ${USERNAME}
25 | WORKDIR ${HOME}
26 | RUN git clone https://github.com/fateshelled/OnnxGraphQt && \
27 | cd OnnxGraphQt && \
28 | python3 -m pip install -U pip && \
29 | python3 -m pip install -U nvidia-pyindex && \
30 | python3 -m pip install -U Qt.py && \
31 | python3 -m pip install -U onnxruntime && \
32 | python3 -m pip install -U -r requirements.txt && \
33 | python3 -m pip install . && \
34 | rm -rf ~/.cache/pip
35 |
36 | WORKDIR ${HOME}/OnnxGraphQt
37 | CMD [ "python3", "-m", "onnxgraphqt" ]
38 |
--------------------------------------------------------------------------------
/docker/build.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_DIR=$(cd $(dirname $0); pwd)
4 | docker build -t onnxgraphqt $SCRIPT_DIR
--------------------------------------------------------------------------------
/docker/run.bash:
--------------------------------------------------------------------------------
1 | xhost +
2 | XAUTH=/tmp/.docker.xauth
3 | if [ ! -f $XAUTH ]
4 | then
5 | xauth_list=$(xauth nlist :0 | sed -e 's/^..../ffff/')
6 | if [ ! -z "$xauth_list" ]
7 | then
8 | echo $xauth_list | xauth -f $XAUTH nmerge -
9 | else
10 | touch $XAUTH
11 | fi
12 | chmod a+r $XAUTH
13 | fi
14 |
15 | docker run --net=host --rm -it \
16 | -v=/tmp/.X11-unix:/tmp/.X11-unix:rw \
17 | -v=${XAUTH}:${XAUTH}:rw \
18 | -e="XAUTHORITY=${XAUTH}" \
19 | -e="DISPLAY=${DISPLAY}" \
20 | -e=TERM=xterm-256color \
21 | -e=QT_X11_NO_MITSHM=1 \
22 | onnxgraphqt:latest
23 |
--------------------------------------------------------------------------------
/onnxgraphqt/__init__.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import PySide2
3 |
4 | plugin_path = os.path.join(os.path.dirname(PySide2.__file__),
5 | "Qt", "plugins", "platforms")
6 | os.environ["QT_PLUGIN_PATH"] = plugin_path
7 | os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = plugin_path
8 |
9 | from .main_window import main
10 |
--------------------------------------------------------------------------------
/onnxgraphqt/__main__.py:
--------------------------------------------------------------------------------
1 | from .main_window import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/icon.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sam4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sam4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sbi4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sbi4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/scc4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/scc4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/scs4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/scs4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sio4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sio4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sna4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sna4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/snc4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/snc4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/snd4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/snd4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sne4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sne4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/soc4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/soc4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sog4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sog4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/icon/sor4onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sor4onnx.png
--------------------------------------------------------------------------------
/onnxgraphqt/data/mobilenetv2-12-int8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/mobilenetv2-12-int8.onnx
--------------------------------------------------------------------------------
/onnxgraphqt/data/mobilenetv2-7.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/mobilenetv2-7.onnx
--------------------------------------------------------------------------------
/onnxgraphqt/data/splash.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/splash.png
--------------------------------------------------------------------------------
/onnxgraphqt/graph/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/graph/__init__.py
--------------------------------------------------------------------------------
/onnxgraphqt/graph/autolayout/__init__.py:
--------------------------------------------------------------------------------
1 | from .sugiyama_layout import sugiyama_layout
--------------------------------------------------------------------------------
/onnxgraphqt/graph/autolayout/sugiyama_layout.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import grandalf as grand
3 | from grandalf.layouts import SugiyamaLayout
4 |
5 | __all__ = ["sugiyama_layout"]
6 |
7 |
8 | def sugiyama_layout(g: nx.Graph):
9 | gg = grand.utils.convert_nextworkx_graph_to_grandalf(g) # undocumented function
10 |
11 | class defaultview(object):
12 | w, h = 400, 100
13 | for v in gg.V(): v.view = defaultview()
14 | if len(gg.C) == 0:
15 | return {}
16 | sug = SugiyamaLayout(gg.C[0])
17 | sug.init_all(optimize=False)
18 | sug.draw()
19 |
20 | pos = {v.data: (v.view.xy[0], v.view.xy[1]) for v in gg.C[0].sV} # Extracts the positions
21 | return pos
22 |
--------------------------------------------------------------------------------
/onnxgraphqt/graph/onnx_node.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import (
3 | Dict, List, Any, Optional, Union
4 | )
5 | from collections import OrderedDict
6 |
7 | from NodeGraphQt import BaseNode
8 | from NodeGraphQt.constants import NodeEnum, LayoutDirectionEnum, NodePropWidgetEnum
9 |
10 | from onnxgraphqt.utils.color import (
11 | COLOR_BG,
12 | COLOR_FONT,
13 | COLOR_GRID,
14 | INPUT_NODE_COLOR,
15 | OUTPUT_NODE_COLOR,
16 | NODE_BORDER_COLOR,
17 | get_node_color,
18 | )
19 | from onnxgraphqt.utils.widgets import set_font, GRAPH_FONT_SIZE
20 | from onnxgraphqt.widgets.custom_node_item import CustomNodeItem
21 |
22 |
23 | @dataclass
24 | class OnnxNodeIO:
25 | name: str
26 | dtype: str
27 | shape: List[int]
28 | values: List
29 |
30 |
31 | class ONNXNode(BaseNode):
32 | # unique node identifier.
33 | __identifier__ = 'nodes.node'
34 | # initial default node name.
35 | NODE_NAME = 'onnxnode'
36 |
37 | def __init__(self):
38 | super(ONNXNode, self).__init__(qgraphics_item=CustomNodeItem)
39 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value)
40 | self.attrs = OrderedDict()
41 | self.node_name = ""
42 | self.op = ""
43 | self.onnx_inputs = []
44 | self.onnx_outputs = []
45 |
46 | # create node inputs.
47 | self.input_port = self.add_input('multi in', multi_input=True, display_name=False)
48 | # create node outputs.
49 | self.output_port = self.add_output('multi out', multi_output=True, display_name=False)
50 | self.set_font()
51 |
52 | def set_attrs(self, attrs:OrderedDict, push_undo=False):
53 | self.attrs = attrs
54 | for key, val in self.attrs.items():
55 | if self.has_property(key + "_"):
56 | self.set_property(key + "_", val, push_undo=push_undo)
57 | else:
58 | if key == "dtype":
59 | self.create_property(key + "_", val, widget_type=NodePropWidgetEnum.QLABEL)
60 | else:
61 | self.create_property(key + "_", val, widget_type=NodePropWidgetEnum.QLINE_EDIT)
62 |
63 | def get_attrs(self)->OrderedDict:
64 | d = [(key, self.get_property(key + "_")) for key in self.attrs.keys()]
65 | return OrderedDict(d)
66 |
67 | def set_node_name(self, node_name:str, push_undo=False):
68 | self.node_name = node_name
69 | if not self.has_property("node_name"):
70 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT)
71 | else:
72 | self.set_property("node_name", self.node_name, push_undo=push_undo)
73 |
74 | def get_node_name(self):
75 | self.node_name = self.get_property("node_name")
76 | return self.node_name
77 |
78 | def set_op(self, op:str, push_undo=False):
79 | self.op = op
80 | self._view.name = op
81 | if not self.has_property("op"):
82 | self.create_property("op", self.op, widget_type=NodePropWidgetEnum.QLABEL)
83 | else:
84 | self.set_property("op", self.op, push_undo=push_undo)
85 |
86 | def set_onnx_inputs(self, onnx_inputs:List[OnnxNodeIO], push_undo=False):
87 | self.onnx_inputs = onnx_inputs
88 | value = [[inp.name, inp.dtype, inp.shape, inp.values] for inp in self.onnx_inputs]
89 | if not self.has_property("inputs_"):
90 | self.create_property("inputs_", value, widget_type=NodePropWidgetEnum.QLINE_EDIT)
91 | else:
92 | self.set_property("inputs_", value, push_undo=push_undo)
93 |
94 | def set_onnx_outputs(self, onnx_outputs:List[OnnxNodeIO], push_undo=False):
95 | self.onnx_outputs = onnx_outputs
96 | value = [[out.name, out.dtype, out.shape, out.values] for out in self.onnx_outputs]
97 | if not self.has_property("outputs_"):
98 | self.create_property("outputs_", value, widget_type=NodePropWidgetEnum.QLINE_EDIT)
99 | else:
100 | self.set_property("outputs_", value, push_undo=push_undo)
101 |
102 | def set_color(self, push_undo=False):
103 | self.view.text_color = COLOR_FONT + [255]
104 | color = get_node_color(self.op)
105 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo)
106 | self.set_property('color', color + [255], push_undo)
107 |
108 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False):
109 | set_font(self.view.text_item, font_size=font_size, bold=bold)
110 |
111 |
112 | class ONNXInput(BaseNode):
113 | # unique node identifier.
114 | __identifier__ = 'nodes.node'
115 | # initial default node name.
116 | NODE_NAME = 'input'
117 | def __init__(self):
118 | super(ONNXInput, self).__init__(qgraphics_item=CustomNodeItem)
119 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value)
120 | self.node_name = ""
121 | self.shape = []
122 | self.dtype = ""
123 | self.output_names = []
124 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT)
125 | self.create_property("shape", self.shape, widget_type=NodePropWidgetEnum.QLINE_EDIT)
126 | self.create_property("dtype", self.dtype, widget_type=NodePropWidgetEnum.QLINE_EDIT)
127 | self.create_property("output_names", self.output_names, widget_type=NodePropWidgetEnum.QTEXT_EDIT)
128 | # create node outputs.
129 | self.output_port = self.add_output('multi out', multi_output=True, display_name=False)
130 | self.set_color()
131 | self.set_font()
132 | self._view.name = "input"
133 |
134 | def get_node_name(self):
135 | self.node_name = self.get_property("node_name")
136 | return self.node_name
137 |
138 | def set_node_name(self, node_name:str, push_undo=False):
139 | self.node_name = node_name
140 | self.set_property("node_name", self.node_name, push_undo=push_undo)
141 |
142 | def get_shape(self):
143 | self.shape = self.get_property("shape")
144 | return self.shape
145 |
146 | def set_shape(self, shape, push_undo=False):
147 | self.shape = shape
148 | self.set_property("shape", self.shape, push_undo=push_undo)
149 |
150 | def get_dtype(self):
151 | self.dtype = self.get_property("dtype")
152 | return self.dtype
153 |
154 | def set_dtype(self, dtype, push_undo=False):
155 | self.dtype = str(dtype)
156 | self.set_property("dtype", self.dtype, push_undo=push_undo)
157 |
158 | def get_output_names(self)->List[str]:
159 | self.output_names = self.get_property("output_names")
160 | return self.output_names
161 |
162 | def set_output_names(self, output_names, push_undo=False):
163 | self.output_names = output_names
164 | self.set_property("output_names", self.output_names, push_undo=push_undo)
165 |
166 | def set_color(self, push_undo=False):
167 | self.view.text_color = COLOR_FONT + [255]
168 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo)
169 | self.set_property('color', INPUT_NODE_COLOR + [255], push_undo)
170 |
171 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False):
172 | set_font(self.view.text_item, font_size=font_size, bold=bold)
173 |
174 |
175 | class ONNXOutput(BaseNode):
176 | # unique node identifier.
177 | __identifier__ = 'nodes.node'
178 | # initial default node name.
179 | NODE_NAME = 'output'
180 | def __init__(self):
181 | super(ONNXOutput, self).__init__(qgraphics_item=CustomNodeItem)
182 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value)
183 | self.node_name = ""
184 | self.shape = []
185 | self.dtype:str = ""
186 | self.input_names = []
187 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT)
188 | self.create_property("shape", self.shape, widget_type=NodePropWidgetEnum.QLINE_EDIT)
189 | self.create_property("dtype", self.dtype, widget_type=NodePropWidgetEnum.QLINE_EDIT)
190 | self.create_property("input_names", self.input_names, widget_type=NodePropWidgetEnum.QTEXT_EDIT)
191 | # create node inputs.
192 | self.input_port = self.add_input('multi in', multi_input=True, display_name=False)
193 | self.set_color()
194 | self.set_font()
195 | self._view.name = "output"
196 |
197 | def get_node_name(self):
198 | self.node_name = self.get_property("node_name")
199 | return self.node_name
200 |
201 | def set_node_name(self, node_name:str, push_undo=False):
202 | self.node_name = node_name
203 | self.set_property("node_name", self.node_name, push_undo=push_undo)
204 |
205 | def get_shape(self):
206 | self.shape = self.get_property("shape")
207 | return self.shape
208 |
209 | def set_shape(self, shape, push_undo=False):
210 | self.shape = shape
211 | self.set_property("shape", self.shape, push_undo)
212 |
213 | def get_dtype(self):
214 | self.dtype = self.get_property("dtype")
215 | return self.dtype
216 |
217 | def set_dtype(self, dtype, push_undo=False):
218 | self.dtype = str(dtype)
219 | self.set_property("dtype", self.dtype, push_undo)
220 |
221 | def get_input_names(self):
222 | self.input_names = self.get_property("input_names")
223 | return self.input_names
224 |
225 | def set_input_names(self, input_names, push_undo=False):
226 | self.input_names = input_names
227 | self.set_property("input_names", self.input_names, push_undo)
228 |
229 | def set_color(self, push_undo=False):
230 | self.view.text_color = COLOR_FONT + [255]
231 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo)
232 | self.set_property('color', INPUT_NODE_COLOR + [255], push_undo)
233 |
234 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False):
235 | set_font(self.view.text_item, font_size=font_size, bold=bold)
236 |
--------------------------------------------------------------------------------
/onnxgraphqt/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/utils/__init__.py
--------------------------------------------------------------------------------
/onnxgraphqt/utils/color.py:
--------------------------------------------------------------------------------
1 | COLOR_WHITE = [235, 235, 235]
2 | COLOR_PALEGRAY = [230, 230, 230]
3 | COLOR_LIGHTGRAY = [200, 200, 203]
4 | COLOR_GRAY = [127, 135, 143]
5 | COLOR_DARKGRAY = [50, 55, 60]
6 | COLOR_BLACK = [20, 20, 20]
7 |
8 | COLOR_RED = [112, 41, 33]
9 | COLOR_GREEN = [51, 85, 51]
10 | COLOR_BLUE = [51, 85, 136]
11 | COLOR_BROWN = [89, 66, 59]
12 |
13 | COLOR_BG = COLOR_PALEGRAY
14 | COLOR_FONT = COLOR_BLACK
15 | COLOR_GRID = COLOR_GRAY
16 | INPUT_NODE_COLOR = COLOR_LIGHTGRAY
17 | OUTPUT_NODE_COLOR = COLOR_LIGHTGRAY
18 | DEFAULT_COLOR = COLOR_GRAY
19 |
20 | NODE_BG_COLOR = COLOR_WHITE
21 | NODE_BORDER_COLOR = COLOR_DARKGRAY
22 | NODE_SELECTED_BORDER_COLOR = [240, 50, 0]
23 |
24 | NODE_COLORS = {
25 | # Generic
26 | 'Identity': COLOR_DARKGRAY,
27 | # Constant
28 | 'Constant': COLOR_DARKGRAY,
29 | 'ConstantOfShape': COLOR_DARKGRAY,
30 | # Math
31 | 'Add': COLOR_DARKGRAY,
32 | 'Sub': COLOR_DARKGRAY,
33 | 'Mul': COLOR_DARKGRAY,
34 | 'Div': COLOR_DARKGRAY,
35 | 'Sqrt': COLOR_DARKGRAY,
36 | 'Reciprocal': COLOR_DARKGRAY,
37 | 'Pow': COLOR_DARKGRAY,
38 | 'Exp': COLOR_DARKGRAY,
39 | 'Log': COLOR_DARKGRAY,
40 | 'Abs': COLOR_DARKGRAY,
41 | 'Neg': COLOR_DARKGRAY,
42 | 'Ceil': COLOR_DARKGRAY,
43 | 'Floor': COLOR_DARKGRAY,
44 | 'Clip': COLOR_RED,
45 | 'Erf': COLOR_DARKGRAY,
46 | 'IsNan': COLOR_DARKGRAY,
47 | 'IsInf': COLOR_DARKGRAY,
48 | 'Sign': COLOR_DARKGRAY,
49 | # Logical
50 | 'Greater': COLOR_DARKGRAY,
51 | 'Less': COLOR_DARKGRAY,
52 | 'Equal': COLOR_DARKGRAY,
53 | 'Not': COLOR_DARKGRAY,
54 | 'And': COLOR_DARKGRAY,
55 | 'Or': COLOR_DARKGRAY,
56 | 'Xor': COLOR_DARKGRAY,
57 | # Trigonometric
58 | 'Sin': COLOR_DARKGRAY,
59 | 'Cos': COLOR_DARKGRAY,
60 | 'Tan': COLOR_DARKGRAY,
61 | 'Asin': COLOR_DARKGRAY,
62 | 'Acos': COLOR_DARKGRAY,
63 | 'Atan': COLOR_DARKGRAY,
64 | 'Sinh': COLOR_DARKGRAY,
65 | 'Cosh': COLOR_DARKGRAY,
66 | 'Tanh': COLOR_DARKGRAY,
67 | 'Acosh': COLOR_DARKGRAY,
68 | 'Asinh': COLOR_DARKGRAY,
69 | 'Atanh': COLOR_DARKGRAY,
70 | 'Max': COLOR_DARKGRAY,
71 | # Reduction
72 | 'Sum': COLOR_DARKGRAY,
73 | 'Mean': COLOR_DARKGRAY,
74 | 'Max': COLOR_DARKGRAY,
75 | 'Min': COLOR_DARKGRAY,
76 | # Activation
77 | 'Sigmoid': COLOR_RED,
78 | 'HardSigmoid': COLOR_RED,
79 | 'Tanh': COLOR_RED,
80 | 'ScaledTanh': COLOR_RED,
81 | 'Relu': COLOR_RED,
82 | 'LeakyRelu': COLOR_RED,
83 | 'PRelu': COLOR_RED,
84 | 'ThresholdedRelu': COLOR_RED,
85 | 'Elu': COLOR_RED,
86 | 'Selu': COLOR_RED,
87 | 'Softmax': COLOR_RED,
88 | 'LogSoftmax': COLOR_RED,
89 | 'Hardmax': COLOR_RED,
90 | 'Softsign': COLOR_RED,
91 | 'Softplus': COLOR_RED,
92 | 'Affine': COLOR_RED,
93 | 'Shrink': COLOR_RED,
94 | # Random
95 | 'RandomNormal': COLOR_DARKGRAY,
96 | 'RandomNormalLike': COLOR_DARKGRAY,
97 | 'RandomUniform': COLOR_DARKGRAY,
98 | 'RandomUniformLike': COLOR_DARKGRAY,
99 | 'Multinomial': COLOR_DARKGRAY,
100 | # Multiplication
101 | 'EyeLike': COLOR_BLUE,
102 | 'Gemm': COLOR_BLUE,
103 | 'MatMul': COLOR_BLUE,
104 | 'Conv': COLOR_BLUE,
105 | 'ConvTranspose': COLOR_BLUE,
106 | # Conversion
107 | 'Cast': COLOR_DARKGRAY,
108 | # Reorganization
109 | 'Transpose': COLOR_GREEN,
110 | 'Expand': COLOR_DARKGRAY,
111 | 'Tile': COLOR_DARKGRAY,
112 | 'Split': COLOR_DARKGRAY,
113 | 'Slice': COLOR_BROWN,
114 | 'DynamicSlice': COLOR_BROWN,
115 | 'Concat': COLOR_BROWN,
116 | 'Gather': COLOR_GREEN,
117 | 'GatherElements': COLOR_GREEN,
118 | 'ScatterElements': COLOR_DARKGRAY,
119 | 'Pad': COLOR_DARKGRAY,
120 | 'SpaceToDepth': COLOR_DARKGRAY,
121 | 'DepthToSpace': COLOR_DARKGRAY,
122 | 'Shape': COLOR_DARKGRAY,
123 | 'Size': COLOR_DARKGRAY,
124 | 'Reshape': COLOR_BROWN,
125 | 'Flatten': COLOR_DARKGRAY,
126 | 'Squeeze': COLOR_DARKGRAY,
127 | 'Unsqueeze': COLOR_GREEN,
128 | 'OneHot': COLOR_DARKGRAY,
129 | 'TopK': COLOR_DARKGRAY,
130 | 'Where': COLOR_DARKGRAY,
131 | 'Compress': COLOR_DARKGRAY,
132 | 'Reverse': COLOR_DARKGRAY,
133 | # Pooling
134 | 'GlobalAveragePool': COLOR_GREEN,
135 | 'AveragePool': COLOR_GREEN,
136 | 'GlobalMaxPool': COLOR_GREEN,
137 | 'MaxPool': COLOR_GREEN,
138 | 'MaxUnpool': COLOR_GREEN,
139 | 'LpPool': COLOR_GREEN,
140 | 'GlobalLpPool': COLOR_GREEN,
141 | 'MaxRoiPool': COLOR_GREEN,
142 | # Reduce
143 | 'ReduceSum': COLOR_DARKGRAY,
144 | 'ReduceMean': COLOR_DARKGRAY,
145 | 'ReduceProd': COLOR_DARKGRAY,
146 | 'ReduceLogSum': COLOR_DARKGRAY,
147 | 'ReduceLogSumExp': COLOR_DARKGRAY,
148 | 'ReduceSumSquare': COLOR_DARKGRAY,
149 | 'ReduceL1': COLOR_DARKGRAY,
150 | 'ReduceL2': COLOR_DARKGRAY,
151 | 'ReduceMax': COLOR_DARKGRAY,
152 | 'ReduceMin': COLOR_DARKGRAY,
153 | 'ArgMax': COLOR_DARKGRAY,
154 | 'ArgMin': COLOR_DARKGRAY,
155 | # Imaging
156 | 'Upsample': COLOR_DARKGRAY,
157 | # Flow
158 | 'If': COLOR_DARKGRAY,
159 | 'Loop': COLOR_DARKGRAY,
160 | 'Scan': COLOR_DARKGRAY,
161 | # Normalization
162 | 'InstanceNormalization': COLOR_DARKGRAY,
163 | 'BatchNormalization': COLOR_DARKGRAY,
164 | 'LRN': COLOR_DARKGRAY,
165 | 'MeanVarianceNormalization': COLOR_DARKGRAY,
166 | 'LpNormalization': COLOR_DARKGRAY,
167 | # collation
168 | 'Nonzero': COLOR_DARKGRAY,
169 | # NGram
170 | 'TfldfVectorizer': COLOR_DARKGRAY,
171 | # Aggregate
172 | 'RNN': COLOR_DARKGRAY,
173 | 'GRU': COLOR_DARKGRAY,
174 | 'LSTM': COLOR_DARKGRAY,
175 | # Training
176 | 'Dropout': COLOR_DARKGRAY,
177 | # Quantize
178 | 'QuantizeLinear': COLOR_DARKGRAY,
179 | 'QLinearConv': COLOR_BLUE,
180 | 'DequantizeLinear': COLOR_DARKGRAY,
181 | 'QLinearGlobalAveragePool': COLOR_GREEN,
182 | 'QLinearAdd': COLOR_DARKGRAY,
183 | 'QLinearMatMul': COLOR_BLUE,
184 | }
185 |
186 | def get_node_color(op_name):
187 | return NODE_COLORS.get(op_name, DEFAULT_COLOR)
188 |
189 | class PrintColor:
190 | BLACK = ['\033[30m', ""]
191 | RED = ['\033[31m', ""]
192 | GREEN = ['\033[32m', ""]
193 | YELLOW = ['\033[33m', ""]
194 | BLUE = ['\033[34m', ""]
195 | MAGENTA = ['\033[35m', ""]
196 | CYAN = ['\033[36m', ""]
197 | WHITE = ['\033[37m', ""]
198 | COLOR_DEFAULT = ['\033[39m', ""]
199 | BOLD = ['\033[1m', ""]
200 | UNDERLINE = ['\033[4m', ""]
201 | INVISIBLE = ['\033[08m', ""]
202 | REVERCE = ['\033[07m', ""]
203 | BG_BLACK = ['\033[40m', ""]
204 | BG_RED = ['\033[41m', ""]
205 | BG_GREEN = ['\033[42m', ""]
206 | BG_YELLOW = ['\033[43m', ""]
207 | BG_BLUE = ['\033[44m', ""]
208 | BG_MAGENTA = ['\033[45m', ""]
209 | BG_CYAN = ['\033[46m', ""]
210 | BG_WHITE = ['\033[47m', ""]
211 | BG_DEFAULT = ['\033[49m', ""]
212 | RESET = ['\033[0m', ""]
213 |
214 |
215 | def remove_PrintColor(message:str)->str:
216 | ret = message
217 | for key, v in vars(PrintColor).items():
218 | if key[:2] == "__":
219 | continue
220 | ret = ret.replace(v[0], '')
221 | for i in range(32):
222 | v = f'\033[38;5;{i}m'
223 | ret = ret.replace(v, '')
224 | return ret
225 |
226 | def replace_PrintColor(message: str)->str:
227 | ret = message
228 | for key, v in vars(PrintColor).items():
229 | if key[:2] == "__":
230 | continue
231 | ret = ret.replace(v[0], v[1])
232 | for i in range(32):
233 | v = f'\033[38;5;{i}m'
234 | ret = ret.replace(v, '')
235 | return ret
236 |
237 | if __name__ == "__main__":
238 | import inspect
239 | text = "\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\nAttributes: {'perm': [3, 2, 1, 0]}]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\nNote: Producer node(s) of first tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\nAttributes: {'perm': [3, 2, 1, 0]}]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[33mWARNING:\x1b[0m The input shape of the next OP does not match the output shape. Be sure to open the .onnx file to verify the certainty of the geometry.\n\x1b[33mWARNING:\x1b[0m onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Transpose, node name: input_order_convert_transpose_0): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 1: (224) vs (3)\n\x1b[32mINFO:\x1b[0m Finish!\n"
240 | print(remove_PrintColor(text))
241 | print("------------------------------")
242 | for key, v in vars(PrintColor).items():
243 | if key[:2] == "__":
244 | continue
245 | print(f"{v[0]}{key}{PrintColor.RESET[0]}")
246 | print("------------------------------")
247 | print()
248 | print(replace_PrintColor(text))
--------------------------------------------------------------------------------
/onnxgraphqt/utils/dtype.py:
--------------------------------------------------------------------------------
1 | import onnx
2 | import numpy as np
3 |
4 | AVAILABLE_DTYPES = [
5 | 'float32',
6 | 'float64',
7 | 'int32',
8 | 'int64',
9 | 'str',
10 | ]
11 |
12 | DTYPES_TO_ONNX_DTYPES = {
13 | float: onnx.TensorProto.FLOAT,
14 | int: onnx.TensorProto.INT64,
15 | str: onnx.TensorProto.STRING,
16 | }
17 |
18 | DTYPES_TO_NUMPY_TYPES = {
19 | 'float32': np.float32,
20 | 'float64': np.float64,
21 | 'int32': np.int32,
22 | 'int64': np.int64,
23 | }
24 |
25 | NUMPY_TYPES_TO_ONNX_DTYPES = {
26 | np.dtype('float32'): onnx.TensorProto.FLOAT,
27 | np.dtype('float64'): onnx.TensorProto.DOUBLE,
28 | np.dtype('int32'): onnx.TensorProto.INT32,
29 | np.dtype('int64'): onnx.TensorProto.INT64,
30 | np.float32: onnx.TensorProto.FLOAT,
31 | np.float64: onnx.TensorProto.DOUBLE,
32 | np.int32: onnx.TensorProto.INT32,
33 | np.int64: onnx.TensorProto.INT64,
34 | }
35 |
36 | NUMPY_TYPES_TO_CLASSES = {
37 | np.dtype('float32'): np.float32,
38 | np.dtype('float64'): np.float64,
39 | np.dtype('int32'): np.int32,
40 | np.dtype('int64'): np.int64,
41 | }
42 |
--------------------------------------------------------------------------------
/onnxgraphqt/utils/operators.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from dataclasses import dataclass
4 | from typing import List, Dict
5 | from json import JSONDecoder
6 |
7 | @dataclass
8 | class OperatorAttribute:
9 | name: str
10 | value_type: type
11 | default_value: str
12 |
13 | @dataclass
14 | class OperatorVersion:
15 | since_opset: int
16 | inputs: int
17 | outputs: int
18 | attrs: List[OperatorAttribute]
19 |
20 | @dataclass
21 | class Operator:
22 | name: str
23 | versions: List[OperatorVersion]
24 |
25 | _DEFAULT_ONNX_OPSETS_JSON_PATH = os.path.join(os.path.dirname(__file__),
26 | '..',
27 | 'data', 'onnx_opsets.json')
28 | def _load_json(json_path=_DEFAULT_ONNX_OPSETS_JSON_PATH)->List[Operator]:
29 | json_str = ''
30 | with open(json_path, mode='r') as f:
31 | json_str = f.read()
32 | dec = JSONDecoder()
33 | json_dict:Dict = dec.decode(json_str)
34 |
35 | ret = []
36 | for op_name, v1 in json_dict.items():
37 | versions = []
38 | for since_opset, v2 in v1.items():
39 | since_opset = int(since_opset)
40 | inputs = v2["inputs"]
41 | outputs = v2["outputs"]
42 | attrs = []
43 | if inputs == 'inf':
44 | inputs = math.inf
45 | else:
46 | inputs = int(inputs)
47 | if outputs == 'inf':
48 | outputs = math.inf
49 | else:
50 | outputs = int(outputs)
51 |
52 | for attr_name, (attr_value_type, defalut_value) in v2["attrs"].items():
53 | attrs.append(
54 | OperatorAttribute(name=attr_name,
55 | value_type=attr_value_type,
56 | default_value=defalut_value)
57 | )
58 | versions.append(
59 | OperatorVersion(
60 | since_opset=since_opset,
61 | inputs=inputs, outputs=outputs, attrs=attrs)
62 | )
63 |
64 | ret.append(
65 | Operator(
66 | name=op_name,
67 | versions=versions)
68 | )
69 | return ret
70 |
71 | def _get_latest_opset_version(opsets:List[Operator])->int:
72 | opset = 1
73 | for op in opsets:
74 | for v in op.versions:
75 | if opset < v.since_opset:
76 | opset = v.since_opset
77 | return opset
78 |
79 | onnx_opsets = _load_json()
80 | opnames = [op.name for op in onnx_opsets]
81 | latest_opset = _get_latest_opset_version(onnx_opsets)
82 |
83 | if __name__ == "__main__":
84 | print(_DEFAULT_ONNX_OPSETS_JSON_PATH)
85 | for op in onnx_opsets:
86 | print(op.name)
87 | for v in op.versions:
88 | print(v)
89 | print()
--------------------------------------------------------------------------------
/onnxgraphqt/utils/opset.py:
--------------------------------------------------------------------------------
1 | DEFAULT_OPSET = 16
--------------------------------------------------------------------------------
/onnxgraphqt/utils/style.py:
--------------------------------------------------------------------------------
1 | from NodeGraphQt import NodeGraph
2 | from NodeGraphQt.base.menu import NodeGraphMenu
3 |
4 | def set_context_menu_style(graph:NodeGraph, text_color, bg_color, selected_color, disabled_text_color=None):
5 | context_menu: NodeGraphMenu = graph.get_context_menu("graph").qmenu
6 | style = get_context_menu_stylesheet(text_color, bg_color, selected_color, disabled_text_color)
7 | context_menu.setStyleSheet(style)
8 |
9 | def get_context_menu_stylesheet(text_color, bg_color, selected_color, disabled_text_color=None):
10 | if disabled_text_color is None:
11 | disabled_text_color =[int(0.5 * abs(text_color[i] + bg_color[i])) for i in range(3)]
12 | style_dict = {
13 | 'QMenu': {
14 | 'color': 'rgb({0},{1},{2})'.format(*text_color),
15 | 'background-color': 'rgb({0},{1},{2})'.format(*bg_color),
16 | 'border': '1px solid rgba({0},{1},{2},30)'.format(*text_color),
17 | 'border-radius': '3px',
18 | },
19 | 'QMenu::item': {
20 | 'padding': '5px 18px 2px',
21 | 'background-color': 'transparent',
22 | },
23 | 'QMenu::item:selected': {
24 | 'color': 'rgb({0},{1},{2})'.format(*text_color),
25 | 'background-color': 'rgba({0},{1},{2},200)'
26 | .format(*selected_color),
27 | },
28 | 'QMenu::item:disabled': {
29 | 'color': 'rgb({0},{1},{2})'.format(*disabled_text_color),
30 | },
31 | 'QMenu::separator': {
32 | 'height': '1px',
33 | 'background': 'rgba({0},{1},{2}, 50)'.format(*text_color),
34 | 'margin': '4px 8px',
35 | }
36 | }
37 | stylesheet = ''
38 | for css_class, css in style_dict.items():
39 | style = '{} {{\n'.format(css_class)
40 | for elm_name, elm_val in css.items():
41 | style += ' {}:{};\n'.format(elm_name, elm_val)
42 | style += '}\n'
43 | stylesheet += style
44 | return stylesheet
--------------------------------------------------------------------------------
/onnxgraphqt/utils/widgets.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import math
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 | from NodeGraphQt.constants import PipeEnum
5 | from NodeGraphQt.qgraphics.pipe import PIPE_STYLES
6 |
7 |
8 | BASE_FONT_SIZE = 16
9 | LARGE_FONT_SIZE = 18
10 | GRAPH_FONT_SIZE = 32
11 | PIPE_WIDTH = 3.0
12 |
13 | def set_font(widget: QtWidgets.QWidget, font_size:int=None, bold=False):
14 | f = widget.font()
15 | if font_size:
16 | f.setPixelSize(font_size)
17 | f.setBold(bold)
18 | widget.setFont(f)
19 |
20 | def iconButton_paintEvent(button: QtWidgets.QPushButton, pixmap: QtGui.QPixmap, event: QtGui.QPaintEvent):
21 | QtWidgets.QPushButton.paintEvent(button, event)
22 | pos_x = 5 + int((30 - pixmap.width())*0.5 + 0.5)
23 | pos_y = (button.height() - pixmap.height()) / 2
24 | painter = QtGui.QPainter(button)
25 | painter.setRenderHint(QtGui.QPainter.Antialiasing, True)
26 | painter.setRenderHint(QtGui.QPainter.SmoothPixmapTransform, True)
27 | painter.drawPixmap(pos_x, pos_y, pixmap)
28 |
29 | def createIconButton(text:str, icon_path: str, icon_size:List[int]=[25, 25], font_size:int=None) -> QtWidgets.QPushButton:
30 | button = QtWidgets.QPushButton()
31 | button.setContentsMargins(QtCore.QMargins(5, 5, 5, 5))
32 | button.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
33 | pixmap = QtGui.QPixmap(icon_path).scaled(icon_size[0], icon_size[1], QtCore.Qt.AspectRatioMode.KeepAspectRatio, QtCore.Qt.SmoothTransformation)
34 | def paintEvent(button, pixmap):
35 | def func(event):
36 | return iconButton_paintEvent(button, pixmap, event)
37 | return func
38 | button.paintEvent = paintEvent(button, pixmap)
39 | button_layout = QtWidgets.QVBoxLayout()
40 | button_layout.setMargin(0)
41 | button.setLayout(button_layout)
42 | label = QtWidgets.QLabel(text=text)
43 | label.setMargin(0)
44 | label.setAlignment(QtCore.Qt.AlignRight)
45 | set_font(label, font_size=font_size)
46 | button_layout.addWidget(label)
47 | return button
48 |
49 |
50 | def pipe_paint(pipe, painter, option, widget, text=""):
51 | """
52 | Draws the connection line between nodes.
53 |
54 | Args:
55 | painter (QtGui.QPainter): painter used for drawing the item.
56 | option (QtGui.QStyleOptionGraphicsItem):
57 | used to describe the parameters needed to draw.
58 | widget (QtWidgets.QWidget): not used.
59 | """
60 | color = QtGui.QColor(*pipe._color)
61 | pen_style = PIPE_STYLES.get(pipe.style)
62 | pen_width = PIPE_WIDTH #PipeEnum.WIDTH.value
63 | if pipe._active:
64 | color = QtGui.QColor(*PipeEnum.ACTIVE_COLOR.value)
65 | if pen_style == QtCore.Qt.DashDotDotLine:
66 | pen_width += 1
67 | else:
68 | pen_width += 0.35
69 | elif pipe._highlight:
70 | color = QtGui.QColor(*PipeEnum.HIGHLIGHT_COLOR.value)
71 | pen_style = PIPE_STYLES.get(PipeEnum.DRAW_TYPE_DEFAULT.value)
72 |
73 | if pipe.disabled():
74 | if not pipe._active:
75 | color = QtGui.QColor(*PipeEnum.DISABLED_COLOR.value)
76 | pen_width += 0.2
77 | pen_style = PIPE_STYLES.get(PipeEnum.DRAW_TYPE_DOTTED.value)
78 |
79 | pen = QtGui.QPen(color, pen_width, pen_style)
80 | pen.setCapStyle(QtCore.Qt.RoundCap)
81 | pen.setJoinStyle(QtCore.Qt.MiterJoin)
82 |
83 | painter.save()
84 | painter.setPen(pen)
85 | painter.setRenderHint(painter.Antialiasing, True)
86 | painter.drawPath(pipe.path())
87 |
88 | # draw arrow
89 | if pipe.input_port and pipe.output_port:
90 | cen_x = pipe.path().pointAtPercent(0.5).x()
91 | cen_y = pipe.path().pointAtPercent(0.5).y()
92 | loc_pt = pipe.path().pointAtPercent(0.49)
93 | tgt_pt = pipe.path().pointAtPercent(0.51)
94 | dist = math.hypot(tgt_pt.x() - cen_x, tgt_pt.y() - cen_y)
95 | if dist < 0.5:
96 | painter.restore()
97 | return
98 |
99 | color.setAlpha(255)
100 | if pipe._highlight:
101 | painter.setBrush(QtGui.QBrush(color.lighter(150)))
102 | elif pipe._active or pipe.disabled():
103 | painter.setBrush(QtGui.QBrush(color.darker(200)))
104 | else:
105 | painter.setBrush(QtGui.QBrush(color.darker(130)))
106 |
107 | pen_width = 0.6
108 | if dist < 1.0:
109 | pen_width *= (1.0 + dist)
110 |
111 | pen = QtGui.QPen(color, pen_width)
112 | pen.setCapStyle(QtCore.Qt.RoundCap)
113 | pen.setJoinStyle(QtCore.Qt.MiterJoin)
114 | painter.setPen(pen)
115 |
116 | transform = QtGui.QTransform()
117 | transform.translate(cen_x, cen_y)
118 | radians = math.atan2(tgt_pt.y() - loc_pt.y(),
119 | tgt_pt.x() - loc_pt.x())
120 | degrees = math.degrees(radians) - 90
121 | transform.rotate(degrees)
122 | if dist < 1.0:
123 | transform.scale(dist, dist)
124 | painter.drawPolygon(transform.map(pipe._arrow))
125 | if text:
126 | painter.setPen(QtCore.Qt.black)
127 | set_font(painter, font_size=20)
128 | painter.drawText(QtCore.QRectF(cen_x, cen_y, 200, 100), text)
129 |
130 |
131 | # QPaintDevice: Cannot destroy paint device that is being painted.
132 | painter.restore()
133 |
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/widgets/__init__.py
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/custom_node_item.py:
--------------------------------------------------------------------------------
1 |
2 | from PySide2 import QtCore, QtWidgets, QtGui
3 | from NodeGraphQt.constants import NodeEnum
4 | from NodeGraphQt.qgraphics.node_base import NodeItem
5 |
6 | from onnxgraphqt.utils.color import NODE_BG_COLOR, NODE_SELECTED_BORDER_COLOR
7 | from onnxgraphqt.utils.widgets import set_font
8 |
9 |
10 | class CustomNodeItem(NodeItem):
11 | def __init__(self, name='node', parent=None):
12 | super().__init__(name, parent)
13 | self._display_name = ""
14 |
15 | def set_display_name(self, name: str):
16 | self._display_name = name
17 |
18 | def _paint_vertical(self, painter, option, widget):
19 | painter.save()
20 | painter.setPen(QtCore.Qt.NoPen)
21 | painter.setBrush(QtCore.Qt.NoBrush)
22 |
23 | # base background.
24 | margin = 1.0
25 | radius = 4.0
26 | header_height = 20.0
27 | rect = self.boundingRect()
28 | rect = QtCore.QRectF(rect.left() + margin,
29 | rect.top() + margin,
30 | rect.width() - (margin * 2),
31 | rect.height() - (margin * 2))
32 |
33 | painter.setBrush(QtGui.QColor(*NODE_BG_COLOR + [255]))
34 | painter.drawRoundedRect(rect, radius, radius)
35 |
36 | # header
37 | header_rect = QtCore.QRectF(rect.x() + margin, rect.y() + margin,
38 | rect.width() - margin * 2, header_height)
39 | painter.setBrush(QtGui.QColor(*self.color))
40 | painter.drawRoundedRect(header_rect, radius, radius)
41 |
42 | # header text
43 | r, g, b = self.color[:3]
44 | if 0.2126*r + 0.715*g + 0.0722*b > 128:
45 | painter.setPen(QtCore.Qt.black)
46 | else:
47 | painter.setPen(QtCore.Qt.white)
48 | if self.selected:
49 | set_font(painter, bold=True)
50 | painter.drawText(QtCore.QRectF(5, 5, rect.width(), rect.height()), self._display_name)
51 | set_font(painter, bold=False)
52 |
53 | # light overlay on background when selected.
54 | if self.selected:
55 | painter.setBrush(
56 | QtGui.QColor(*NodeEnum.SELECTED_COLOR.value)
57 | )
58 | painter.drawRoundedRect(rect, radius, radius)
59 |
60 | # # top & bottom edge background.
61 | # padding = 2.0
62 | # height = 10
63 | # if self.selected:
64 | # painter.setBrush(QtGui.QColor(*NodeEnum.SELECTED_COLOR.value))
65 | # else:
66 | # painter.setBrush(QtGui.QColor(0, 0, 0, 80))
67 | # for y in [rect.y() + padding, rect.height() - height - 1]:
68 | # edge_rect = QtCore.QRectF(rect.x() + padding, y,
69 | # rect.width() - (padding * 2), height)
70 | # painter.drawRoundedRect(edge_rect, 3.0, 3.0)
71 |
72 | # node border
73 | border_width = 0.8
74 | border_color = QtGui.QColor(*self.border_color)
75 | if self.selected:
76 | border_width = 1.2
77 | border_color = QtGui.QColor(
78 | *list(NODE_SELECTED_BORDER_COLOR + [255])
79 | )
80 | border_rect = QtCore.QRectF(rect.left(), rect.top(),
81 | rect.width(), rect.height())
82 |
83 | pen = QtGui.QPen(border_color, border_width)
84 | pen.setCosmetic(self.viewer().get_zoom() < 0.0)
85 | painter.setBrush(QtCore.Qt.NoBrush)
86 | painter.setPen(pen)
87 | painter.drawRoundedRect(border_rect, radius, radius)
88 |
89 | painter.restore()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/custom_properties.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from PySide2 import QtWidgets, QtCore, QtGui
3 | from NodeGraphQt.constants import NodePropWidgetEnum
4 | from NodeGraphQt.widgets.dialogs import FileDialog
5 | from NodeGraphQt.custom_widgets.properties_bin.prop_widgets_base import (
6 | PropLineEdit,
7 | PropTextEdit
8 | )
9 | from NodeGraphQt.custom_widgets.properties_bin.node_property_factory import NodePropertyWidgetFactory
10 | from NodeGraphQt.custom_widgets.properties_bin.node_property_widgets import _PropertiesContainer
11 |
12 |
13 | class CustomPropWindow(QtWidgets.QWidget):
14 |
15 | def __init__(self, parent=None):
16 | super(CustomPropWindow, self).__init__(parent)
17 | self.__layout = QtWidgets.QGridLayout()
18 | self.__layout.setColumnStretch(1, 1)
19 | # self.__layout.setSpacing(6)
20 |
21 | layout = QtWidgets.QVBoxLayout(self)
22 | layout.setAlignment(QtCore.Qt.AlignTop)
23 | layout.addLayout(self.__layout)
24 |
25 | def __repr__(self):
26 | return ''.format(hex(id(self)))
27 |
28 | def add_widget(self, name, widget, value, label=None):
29 | """
30 | Add a property widget to the window.
31 |
32 | Args:
33 | name (str): property name to be displayed.
34 | widget (BaseProperty): property widget.
35 | value (object): property value.
36 | label (str): custom label to display.
37 | """
38 | widget.setToolTip(name)
39 | if value is not None:
40 | widget.set_value(value)
41 | if label is None:
42 | label = name
43 | row = self.__layout.rowCount()
44 | if row > 0:
45 | row += 1
46 |
47 | label_flags = QtCore.Qt.AlignCenter | QtCore.Qt.AlignRight
48 | # label_flags = QtCore.Qt.AlignTop
49 | if widget.__class__.__name__ == 'PropTextEdit':
50 | label_flags = label_flags | QtCore.Qt.AlignTop
51 | elif widget.__class__.__name__ == 'PropList':
52 | label_flags = QtCore.Qt.AlignTop | QtCore.Qt.AlignRight
53 |
54 | self.__layout.addWidget(QtWidgets.QLabel(label), row, 0, label_flags)
55 | self.__layout.addWidget(widget, row, 1)
56 |
57 | def get_widget(self, name):
58 | """
59 | Returns the property widget from the name.
60 |
61 | Args:
62 | name (str): property name.
63 |
64 | Returns:
65 | QtWidgets.QWidget: property widget.
66 | """
67 | for row in range(self.__layout.rowCount()):
68 | item = self.__layout.itemAtPosition(row, 1)
69 | if item and name == item.widget().toolTip():
70 | return item.widget()
71 |
72 |
73 | class PropList(QtWidgets.QWidget):
74 |
75 | def __init__(self, parent=None):
76 | super(PropList, self).__init__(parent)
77 | self.__layout = QtWidgets.QVBoxLayout()
78 | self.__layout.setAlignment(QtCore.Qt.AlignTop)
79 |
80 | layout = QtWidgets.QVBoxLayout(self)
81 | layout.setAlignment(QtCore.Qt.AlignTop)
82 | layout.addLayout(self.__layout)
83 | layout.setContentsMargins(0, 0 ,0 ,0)
84 |
85 | def add_item(self, name, dtype, shape, value):
86 | if dtype is None:
87 | label = QtWidgets.QLabel(f"{name}")
88 | self.__layout.addWidget(label)
89 | elif value is None or value == -1:
90 | label = QtWidgets.QLabel(f"{name}")
91 | type_shape = QtWidgets.QLabel(f"{dtype} {list(shape)}")
92 | self.__layout.addWidget(label)
93 | self.__layout.addWidget(type_shape)
94 | else:
95 | # has value
96 | content_layout = QtWidgets.QVBoxLayout()
97 | label = QtWidgets.QLabel(f"{name}")
98 | type_shape = QtWidgets.QLabel(f"{dtype} {list(shape)}")
99 | txt = QtWidgets.QLineEdit()
100 | txt.setText(str(value))
101 | txt.setReadOnly(True)
102 | content_layout.addWidget(label)
103 | content_layout.addWidget(type_shape)
104 | content_layout.addWidget(txt)
105 | self.__layout.addLayout(content_layout)
106 |
107 |
108 | class CustomNodePropWidget(QtWidgets.QWidget):
109 | """
110 | Node properties widget for display a Node object.
111 |
112 | Args:
113 | parent (QtWidgets.QWidget): parent object.
114 | node (NodeGraphQt.BaseNode): node.
115 | """
116 |
117 | #: signal (node_id, prop_name, prop_value)
118 | property_changed = QtCore.Signal(str, str, object)
119 | property_closed = QtCore.Signal(str)
120 |
121 | def __init__(self, parent=None, node=None):
122 | super(CustomNodePropWidget, self).__init__(parent)
123 | self.__node_id = node.id
124 |
125 | self.prop_window = CustomPropWindow(self)
126 |
127 | self._layout = QtWidgets.QVBoxLayout(self)
128 | self._layout .addWidget(self.prop_window)
129 | # layout.setSpacing(4)
130 | # layout.addWidget(self.prop_window)
131 |
132 | self._read_node(node)
133 |
134 | def __repr__(self):
135 | return ''.format(hex(id(self)))
136 |
137 | def _on_close(self):
138 | """
139 | called by the close button.
140 | """
141 | self.property_closed.emit(self.__node_id)
142 |
143 | def _on_property_changed(self, name, value):
144 | """
145 | slot function called when a property widget has changed.
146 | Args:
147 | name (str): property name.
148 | value (object): new value.
149 | """
150 | self.property_changed.emit(self.__node_id, name, value)
151 |
152 | def _read_node(self, node):
153 | """
154 | Populate widget from a node.
155 |
156 | Args:
157 | node (NodeGraphQt.BaseNode): node class.
158 | """
159 | model = node.model
160 | graph_model = node.graph.model
161 |
162 | common_props = graph_model.get_node_common_properties(node.type_)
163 |
164 | properties = []
165 |
166 | for prop_name, prop_val in model.custom_properties.items():
167 | tab_name = model.get_tab_name(prop_name)
168 | if tab_name == 'Properties':
169 | properties.append((prop_name, prop_val))
170 |
171 | # property widget factory.
172 | widget_factory = NodePropertyWidgetFactory()
173 |
174 | for prop_name, value in properties:
175 | wid_type = model.get_widget_type(prop_name).value
176 | if wid_type == 0:
177 | continue
178 | if prop_name in ["inputs_", "outputs_"]:
179 | io_len = len(value)
180 | for i, (name, dtype, shape, v) in enumerate(value):
181 | widget = PropList()
182 | widget.add_item(name, dtype, shape, v)
183 | if io_len == 1:
184 | self.prop_window.add_widget(prop_name, widget, None,
185 | prop_name.replace('_', ' '))
186 | else:
187 | self.prop_window.add_widget(prop_name, widget, None,
188 | prop_name.replace('_', ' ') + f"[{i+1}]")
189 | else:
190 | widget = widget_factory.get_widget(wid_type)
191 | if isinstance(widget, PropLineEdit) or isinstance(widget, PropTextEdit):
192 | widget.setReadOnly(True)
193 |
194 | self.prop_window.add_widget(prop_name, widget, value,
195 | prop_name.replace('_', ' '))
196 |
197 |
198 | def node_id(self):
199 | """
200 | Returns the node id linked to the widget.
201 |
202 | Returns:
203 | str: node id
204 | """
205 | return self.__node_id
206 |
207 | def add_widget(self, name, widget, tab='Properties'):
208 | """
209 | add new node property widget.
210 | Args:
211 | name (str): property name.
212 | widget (BaseProperty): property widget.
213 | tab (str): tab name.
214 | """
215 | if tab not in self._widgets.keys():
216 | tab = 'Properties'
217 | window = self.__tab_windows[tab]
218 | window.add_widget(name, widget)
219 | widget.value_changed.connect(self._on_property_changed)
220 |
221 | def get_widget(self, name):
222 | """
223 | get property widget.
224 | Args:
225 | name (str): property name.
226 | Returns:
227 | QtWidgets.QWidget: property widget.
228 | """
229 | if name == 'name':
230 | return self.name_wgt
231 | for tab_name, prop_win in self.__tab_windows.items():
232 | widget = prop_win.get_widget(name)
233 | if widget:
234 | return
235 |
236 | if __name__ == '__main__':
237 | import sys
238 | from NodeGraphQt import BaseNode, NodeGraph
239 | from NodeGraphQt.constants import NodePropWidgetEnum
240 |
241 |
242 | class TestNode(BaseNode):
243 | NODE_NAME = 'test node'
244 |
245 | def __init__(self):
246 | super(TestNode, self).__init__()
247 | self.create_property('label_test', 'foo bar',
248 | widget_type=NodePropWidgetEnum.QLABEL)
249 | self.create_property('line_edit', 'hello',
250 | widget_type=NodePropWidgetEnum.QLINE_EDIT)
251 | self.create_property('color_picker', (0, 0, 255),
252 | widget_type=NodePropWidgetEnum.COLOR_PICKER)
253 | self.create_property('integer', 10,
254 | widget_type=NodePropWidgetEnum.QSPIN_BOX)
255 | self.create_property('list', 'foo',
256 | items=['foo', 'bar'],
257 | widget_type=NodePropWidgetEnum.QCOMBO_BOX)
258 | self.create_property('range', 50,
259 | range=(45, 55),
260 | widget_type=NodePropWidgetEnum.SLIDER)
261 | self.create_property('text_edit', 'test text',
262 | widget_type=NodePropWidgetEnum.QTEXT_EDIT,
263 | tab='text')
264 |
265 |
266 | def prop_changed(node_id, prop_name, prop_value):
267 | print('-' * 100)
268 | print(node_id, prop_name, prop_value)
269 |
270 |
271 | def prop_close(node_id):
272 | print('=' * 100)
273 | print(node_id)
274 |
275 |
276 | app = QtWidgets.QApplication(sys.argv)
277 |
278 | graph = NodeGraph()
279 | graph.register_node(TestNode)
280 |
281 | test_node = graph.create_node('nodeGraphQt.nodes.TestNode')
282 |
283 | node_prop = CustomNodePropWidget(node=test_node)
284 | node_prop.property_changed.connect(prop_changed)
285 | node_prop.property_closed.connect(prop_close)
286 | node_prop.show()
287 |
288 | app.exec_()
289 |
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/custom_properties_bin.py:
--------------------------------------------------------------------------------
1 | from PySide2 import QtWidgets, QtCore, QtGui
2 | from NodeGraphQt.custom_widgets.properties_bin.node_property_widgets import _PropertiesList
3 |
4 | from onnxgraphqt.widgets.custom_properties import CustomNodePropWidget
5 |
6 |
7 | class CustomPropertiesBinWidget(QtWidgets.QWidget):
8 | """
9 | The :class:`NodeGraphQt.PropertiesBinWidget` is a list widget for displaying
10 | and editing a nodes properties.
11 |
12 | .. image:: _images/prop_bin.png
13 | :width: 950px
14 |
15 | .. code-block:: python
16 | :linenos:
17 |
18 | from NodeGraphQt import NodeGraph, PropertiesBinWidget
19 |
20 | # create node graph.
21 | graph = NodeGraph()
22 |
23 | # create properties bin widget.
24 | properties_bin = PropertiesBinWidget(parent=None, node_graph=graph)
25 | properties_bin.show()
26 |
27 | Args:
28 | parent (QtWidgets.QWidget): parent of the new widget.
29 | node_graph (NodeGraphQt.NodeGraph): node graph.
30 | """
31 |
32 | #: Signal emitted (node_id, prop_name, prop_value)
33 | property_changed = QtCore.Signal(str, str, object)
34 |
35 | def __init__(self, parent=None, node_graph=None):
36 | super(CustomPropertiesBinWidget, self).__init__(parent)
37 | self.setWindowTitle('Properties Bin')
38 | self._prop_list = _PropertiesList()
39 | self.resize(450, 400)
40 |
41 | self._block_signal = False
42 | self._lock = False
43 |
44 | layout = QtWidgets.QVBoxLayout(self)
45 | layout.addWidget(self._prop_list, 1)
46 |
47 | # wire up node graph.
48 | node_graph.add_properties_bin(self)
49 | node_graph.node_double_clicked.connect(self.add_node)
50 |
51 | def __repr__(self):
52 | return '<{} object at {}>'.format(self.__class__.__name__, hex(id(self)))
53 |
54 | def __on_prop_close(self, node_id):
55 | items = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly)
56 | [self._prop_list.removeRow(i.row()) for i in items]
57 |
58 | def add_node(self, node):
59 | """
60 | Add node to the properties bin.
61 |
62 | Args:
63 | node (NodeGraphQt.NodeObject): node object.
64 | """
65 |
66 | rows = self._prop_list.rowCount()
67 | if rows >= 1:
68 | self._prop_list.removeRow(rows - 1)
69 |
70 | itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly)
71 | if itm_find:
72 | self._prop_list.removeRow(itm_find[0].row())
73 |
74 | self._prop_list.insertRow(0)
75 | prop_widget = CustomNodePropWidget(node=node)
76 | prop_widget.property_closed.connect(self.__on_prop_close)
77 | self._prop_list.setCellWidget(0, 0, prop_widget)
78 |
79 | item = QtWidgets.QTableWidgetItem(node.id)
80 | self._prop_list.setItem(0, 0, item)
81 | self._prop_list.selectRow(0)
82 |
83 | def remove_node(self, node):
84 | """
85 | Remove node from the properties bin.
86 |
87 | Args:
88 | node (str or NodeGraphQt.BaseNode): node id or node object.
89 | """
90 | node_id = node if isinstance(node, str) else node.id
91 | self.__on_prop_close(node_id)
92 |
93 | def prop_widget(self, node):
94 | """
95 | Returns the node property widget.
96 |
97 | Args:
98 | node (str or NodeGraphQt.NodeObject): node id or node object.
99 |
100 | Returns:
101 | CustomNodePropWidget: node property widget.
102 | """
103 | node_id = node if isinstance(node, str) else node.id
104 | itm_find = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly)
105 | if itm_find:
106 | item = itm_find[0]
107 | return self._prop_list.cellWidget(item.row(), 0)
108 |
109 |
110 | if __name__ == '__main__':
111 | import sys
112 | from NodeGraphQt import BaseNode, NodeGraph
113 | from NodeGraphQt.constants import NodePropWidgetEnum
114 |
115 |
116 | class TestNode(BaseNode):
117 | NODE_NAME = 'test node'
118 |
119 | def __init__(self):
120 | super(TestNode, self).__init__()
121 | self.create_property('label_test', 'foo bar',
122 | widget_type=NodePropWidgetEnum.QLABEL)
123 | self.create_property('text_edit', 'hello',
124 | widget_type=NodePropWidgetEnum.QLINE_EDIT)
125 | self.create_property('color_picker', (0, 0, 255),
126 | widget_type=NodePropWidgetEnum.COLOR_PICKER)
127 | self.create_property('integer', 10,
128 | widget_type=NodePropWidgetEnum.QSPIN_BOX)
129 | self.create_property('list', 'foo',
130 | items=['foo', 'bar'],
131 | widget_type=NodePropWidgetEnum.QCOMBO_BOX)
132 | self.create_property('range', 50,
133 | range=(45, 55),
134 | widget_type=NodePropWidgetEnum.SLIDER)
135 |
136 | def prop_changed(node_id, prop_name, prop_value):
137 | print('-'*100)
138 | print(node_id, prop_name, prop_value)
139 |
140 |
141 | app = QtWidgets.QApplication(sys.argv)
142 |
143 | graph = NodeGraph()
144 | graph.register_node(TestNode)
145 |
146 | prop_bin = CustomPropertiesBinWidget(node_graph=graph)
147 | prop_bin.property_changed.connect(prop_changed)
148 |
149 | node = graph.create_node('nodeGraphQt.nodes.TestNode')
150 |
151 | prop_bin.add_node(node)
152 | prop_bin.show()
153 |
154 | app.exec_()
155 |
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/splash_screen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 | DEFAULT_SPLASH_IMAGE = os.path.join(os.path.dirname(__file__), "../data/splash.png")
6 |
7 | def create_screen(image_path=DEFAULT_SPLASH_IMAGE):
8 | pixmap = QtGui.QPixmap(image_path)
9 | splash = QtWidgets.QSplashScreen(pixmap, QtCore.Qt.WindowStaysOnTopHint)
10 | splash.setEnabled(False)
11 | return splash
12 |
13 | def create_screen_progressbar(image_path=DEFAULT_SPLASH_IMAGE):
14 | pixmap = QtGui.QPixmap(image_path)
15 | splash = QtWidgets.QSplashScreen(pixmap, QtCore.Qt.WindowStaysOnTopHint)
16 | splash.setEnabled(False)
17 | progressBar = QtWidgets.QProgressBar(splash)
18 | progressBar.setMaximum(10)
19 | progressBar.setGeometry(0, pixmap.height() - 50, pixmap.width(), 20)
20 | return splash, progressBar
21 |
22 | if __name__ == "__main__":
23 | import signal
24 | import os
25 | import sys
26 | import time
27 | # handle SIGINT to make the app terminate on CTRL+C
28 | signal.signal(signal.SIGINT, signal.SIG_DFL)
29 |
30 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
31 |
32 | app = QtWidgets.QApplication(sys.argv)
33 | splash, progressBar = create_screen_progressbar()
34 | # splash = create_screen()
35 | splash.show()
36 | time.sleep(0.01)
37 | splash.showMessage("loading...", alignment=QtCore.Qt.AlignBottom, color=QtGui.QColor.fromRgb(255, 255, 255))
38 | for i in range(1, 11):
39 | progressBar.setValue(i)
40 | t = time.time()
41 | while time.time() < t + 0.1:
42 | app.processEvents()
43 | time.sleep(1)
44 |
45 | window = QtWidgets.QMainWindow()
46 | window.show()
47 |
48 | splash.finish(window)
49 |
50 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_add_node.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 | from ast import literal_eval
5 |
6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
7 | from onnxgraphqt.utils.operators import onnx_opsets, opnames, OperatorVersion
8 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
10 |
11 |
12 | AVAILABLE_DTYPES = [
13 | 'float32',
14 | 'float64',
15 | 'int32',
16 | 'int64',
17 | 'str',
18 | ]
19 |
20 |
21 | AddNodeProperties = namedtuple("AddNodeProperties",
22 | [
23 | "connection_src_op_output_names",
24 | "connection_dest_op_input_names",
25 | "add_op_type",
26 | "add_op_name",
27 | "add_op_input_variables",
28 | "add_op_output_variables",
29 | "add_op_attributes",
30 | ])
31 |
32 |
33 | class AddNodeWidgets(QtWidgets.QDialog):
34 | _DEFAULT_WINDOW_WIDTH = 500
35 | _MAX_VARIABLES_COUNT = 5
36 | _MAX_ATTRIBUTES_COUNT = 10
37 | # _LABEL_FONT_SIZE = 16
38 |
39 | def __init__(self, current_opset:int, graph: OnnxGraph=None, parent=None) -> None:
40 | super().__init__(parent)
41 | set_font(self, font_size=BASE_FONT_SIZE)
42 | self.setModal(False)
43 | self.setWindowTitle("add node")
44 | self.current_opset = current_opset
45 | self.graph = graph
46 | self.initUI()
47 | self.updateUI(self.graph)
48 |
49 | def initUI(self):
50 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
51 |
52 | base_layout = QtWidgets.QVBoxLayout()
53 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
54 |
55 | # src_output
56 | labels_src_output = ["src_op", "src_op output", "add_op", "add_op input"]
57 | layout_src_output = QtWidgets.QVBoxLayout()
58 | lbl_src_output = QtWidgets.QLabel("connection_src_op_output_names")
59 | set_font(lbl_src_output, font_size=LARGE_FONT_SIZE, bold=True)
60 | layout_src_output.addWidget(lbl_src_output)
61 | layout_src_output_form = QtWidgets.QFormLayout()
62 | self.src_output_names = {}
63 | for i in range(len(labels_src_output)):
64 | if i in [0, 1]:
65 | self.src_output_names[i] = QtWidgets.QComboBox()
66 | self.src_output_names[i].setEditable(True)
67 | else:
68 | self.src_output_names[i] = QtWidgets.QLineEdit()
69 | layout_src_output_form.addRow(labels_src_output[i], self.src_output_names[i])
70 | layout_src_output.addLayout(layout_src_output_form)
71 |
72 | # dest_input
73 | labels_dest_input = ["add_op", "add_op output", "dst_op", "dst_op input"]
74 | layout_dest_input = QtWidgets.QVBoxLayout()
75 | lbl_dest_input = QtWidgets.QLabel("connection_dest_op_input_names")
76 | set_font(lbl_dest_input, font_size=LARGE_FONT_SIZE, bold=True)
77 | layout_dest_input.addWidget(lbl_dest_input)
78 | layout_dest_input_form = QtWidgets.QFormLayout()
79 | self.dest_input_names = {}
80 | for i in range(len(labels_dest_input)):
81 | if i in [0, 1]:
82 | self.dest_input_names[i] = QtWidgets.QLineEdit()
83 | else:
84 | self.dest_input_names[i] = QtWidgets.QComboBox()
85 | self.dest_input_names[i].setEditable(True)
86 | layout_dest_input_form.addRow(labels_dest_input[i], self.dest_input_names[i])
87 | layout_dest_input.addLayout(layout_dest_input_form)
88 |
89 | # Form layout
90 | layout_op = QtWidgets.QFormLayout()
91 | layout_op.setLabelAlignment(QtCore.Qt.AlignRight)
92 | self.add_op_type = QtWidgets.QComboBox()
93 | for op in onnx_opsets:
94 | for version in op.versions:
95 | if version.since_opset <= self.current_opset:
96 | self.add_op_type.addItem(op.name, op.versions[0])
97 | break
98 | self.add_op_type.setEditable(True)
99 | self.add_op_type.currentIndexChanged.connect(self.add_op_type_currentIndexChanged)
100 | self.add_op_name = QtWidgets.QLineEdit()
101 | self.add_op_name.setPlaceholderText("Name of op to be added")
102 | lbl_add_op_name = QtWidgets.QLabel("add_op_name")
103 | set_font(lbl_add_op_name, font_size=LARGE_FONT_SIZE, bold=True)
104 | lbl_add_op_type = QtWidgets.QLabel("add_op_type")
105 | set_font(lbl_add_op_type, font_size=LARGE_FONT_SIZE, bold=True)
106 | layout_op.addRow(lbl_add_op_name, self.add_op_name)
107 | layout_op.addRow(lbl_add_op_type, self.add_op_type)
108 |
109 | # variables
110 | layout_valiables = QtWidgets.QVBoxLayout()
111 | self.visible_input_valiables_count = 1
112 | self.visible_output_valiables_count = 1
113 |
114 | self.add_input_valiables = {}
115 | self.add_output_valiables = {}
116 | for i in range(self._MAX_VARIABLES_COUNT):
117 | self.create_variables_widget(i, is_input=True)
118 | for i in range(self._MAX_VARIABLES_COUNT):
119 | self.create_variables_widget(i, is_input=False)
120 |
121 | self.btn_add_input_valiables = QtWidgets.QPushButton("+")
122 | self.btn_del_input_valiables = QtWidgets.QPushButton("-")
123 | self.btn_add_input_valiables.clicked.connect(self.btn_add_input_valiables_clicked)
124 | self.btn_del_input_valiables.clicked.connect(self.btn_del_input_valiables_clicked)
125 | self.btn_add_output_valiables = QtWidgets.QPushButton("+")
126 | self.btn_del_output_valiables = QtWidgets.QPushButton("-")
127 | self.btn_add_output_valiables.clicked.connect(self.btn_add_output_valiables_clicked)
128 | self.btn_del_output_valiables.clicked.connect(self.btn_del_output_valiables_clicked)
129 |
130 | layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
131 | lbl_add_input_valiables = QtWidgets.QLabel("add op input valiables [optional]")
132 | set_font(lbl_add_input_valiables, font_size=LARGE_FONT_SIZE, bold=True)
133 | layout_valiables.addWidget(lbl_add_input_valiables)
134 | for key, widgets in self.add_input_valiables.items():
135 | layout_valiables.addWidget(widgets["base"])
136 | self.set_visible_input_valiables()
137 | layout_btn_input = QtWidgets.QHBoxLayout()
138 | layout_btn_input.addWidget(self.btn_add_input_valiables)
139 | layout_btn_input.addWidget(self.btn_del_input_valiables)
140 | layout_valiables.addLayout(layout_btn_input)
141 |
142 | layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
143 | lbl_add_output_valiables = QtWidgets.QLabel("add op output valiables [optional]")
144 | set_font(lbl_add_output_valiables, font_size=LARGE_FONT_SIZE, bold=True)
145 | layout_valiables.addWidget(lbl_add_output_valiables)
146 | for key, widgets in self.add_output_valiables.items():
147 | layout_valiables.addWidget(widgets["base"])
148 | self.set_visible_output_valiables()
149 | layout_btn_output = QtWidgets.QHBoxLayout()
150 | layout_btn_output.addWidget(self.btn_add_output_valiables)
151 | layout_btn_output.addWidget(self.btn_del_output_valiables)
152 | layout_valiables.addLayout(layout_btn_output)
153 |
154 | # add_attributes
155 | layout_attributes = QtWidgets.QVBoxLayout()
156 | layout_attributes.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
157 | lbl_add_atrributes = QtWidgets.QLabel("add op atrributes [optional]")
158 | set_font(lbl_add_atrributes, font_size=LARGE_FONT_SIZE, bold=True)
159 | layout_attributes.addWidget(lbl_add_atrributes)
160 | self.visible_add_attributes_count = 3
161 | self.add_op_attributes = {}
162 | for index in range(self._MAX_ATTRIBUTES_COUNT):
163 | self.add_op_attributes[index] = {}
164 | self.add_op_attributes[index]["base"] = QtWidgets.QWidget()
165 | self.add_op_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.add_op_attributes[index]["base"])
166 | self.add_op_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0)
167 | self.add_op_attributes[index]["name"] = QtWidgets.QLineEdit()
168 | self.add_op_attributes[index]["name"].setPlaceholderText("name")
169 | self.add_op_attributes[index]["value"] = QtWidgets.QLineEdit()
170 | self.add_op_attributes[index]["value"].setPlaceholderText("value")
171 | self.add_op_attributes[index]["layout"].addWidget(self.add_op_attributes[index]["name"])
172 | self.add_op_attributes[index]["layout"].addWidget(self.add_op_attributes[index]["value"])
173 | layout_attributes.addWidget(self.add_op_attributes[index]["base"])
174 | self.btn_add_op_attributes = QtWidgets.QPushButton("+")
175 | self.btn_del_op_attributes = QtWidgets.QPushButton("-")
176 | self.btn_add_op_attributes.clicked.connect(self.btn_add_op_attributes_clicked)
177 | self.btn_del_op_attributes.clicked.connect(self.btn_del_op_attributes_clicked)
178 | self.set_visible_add_op_attributes()
179 | layout_btn_attributes = QtWidgets.QHBoxLayout()
180 | layout_btn_attributes.addWidget(self.btn_add_op_attributes)
181 | layout_btn_attributes.addWidget(self.btn_del_op_attributes)
182 | layout_attributes.addLayout(layout_btn_attributes)
183 |
184 | # add layout
185 | base_layout.addLayout(layout_src_output)
186 | base_layout.addSpacerItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 10))
187 | base_layout.addLayout(layout_dest_input)
188 | base_layout.addSpacerItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 10))
189 | base_layout.addLayout(layout_op)
190 | base_layout.addLayout(layout_valiables)
191 | base_layout.addLayout(layout_attributes)
192 |
193 | # Dialog button
194 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
195 | QtWidgets.QDialogButtonBox.Cancel)
196 | btn.accepted.connect(self.accept)
197 | btn.rejected.connect(self.reject)
198 | # layout.addWidget(btn)
199 | base_layout.addWidget(btn)
200 |
201 | self.setLayout(base_layout)
202 | self.add_op_type_currentIndexChanged(self.add_op_type.currentIndex())
203 |
204 | def updateUI(self, graph: OnnxGraph):
205 | if graph:
206 | for name in graph.node_inputs.keys():
207 | self.src_output_names[0].addItem(name)
208 | self.src_output_names[1].addItem(name)
209 | for name in graph.nodes.keys():
210 | self.src_output_names[0].addItem(name)
211 | self.src_output_names[1].addItem(name)
212 |
213 | for name in graph.node_inputs.keys():
214 | self.dest_input_names[2].addItem(name)
215 | self.dest_input_names[3].addItem(name)
216 | for name in graph.nodes.keys():
217 | self.dest_input_names[2].addItem(name)
218 | self.dest_input_names[3].addItem(name)
219 |
220 | def create_variables_widget(self, index:int, is_input=True)->QtWidgets.QBoxLayout:
221 | if is_input:
222 | self.add_input_valiables[index] = {}
223 | self.add_input_valiables[index]["base"] = QtWidgets.QWidget()
224 | self.add_input_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_input_valiables[index]["base"])
225 | self.add_input_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0)
226 | self.add_input_valiables[index]["name"] = QtWidgets.QLineEdit()
227 | self.add_input_valiables[index]["name"].setPlaceholderText("name")
228 | self.add_input_valiables[index]["dtype"] = QtWidgets.QComboBox()
229 | for dtype in AVAILABLE_DTYPES:
230 | self.add_input_valiables[index]["dtype"].addItem(dtype)
231 | self.add_input_valiables[index]["dtype"].setEditable(True)
232 | self.add_input_valiables[index]["dtype"].setFixedSize(100, 20)
233 | self.add_input_valiables[index]["shape"] = QtWidgets.QLineEdit()
234 | self.add_input_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`")
235 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["name"])
236 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["dtype"])
237 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["shape"])
238 | else:
239 | self.add_output_valiables[index] = {}
240 | self.add_output_valiables[index]["base"] = QtWidgets.QWidget()
241 | self.add_output_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_output_valiables[index]["base"])
242 | self.add_output_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0)
243 | self.add_output_valiables[index]["name"] = QtWidgets.QLineEdit()
244 | self.add_output_valiables[index]["name"].setPlaceholderText("name")
245 | self.add_output_valiables[index]["dtype"] = QtWidgets.QComboBox()
246 | for dtype in AVAILABLE_DTYPES:
247 | self.add_output_valiables[index]["dtype"].addItem(dtype)
248 | self.add_output_valiables[index]["dtype"].setEditable(True)
249 | self.add_output_valiables[index]["dtype"].setFixedSize(100, 20)
250 | self.add_output_valiables[index]["dtype"].setPlaceholderText("dtype. e.g. `float32`")
251 | self.add_output_valiables[index]["shape"] = QtWidgets.QLineEdit()
252 | self.add_output_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`")
253 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["name"])
254 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["dtype"])
255 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["shape"])
256 |
257 | def set_visible_input_valiables(self):
258 | for key, widgets in self.add_input_valiables.items():
259 | widgets["base"].setVisible(key < self.visible_input_valiables_count)
260 | if self.visible_input_valiables_count == 0:
261 | self.btn_add_input_valiables.setEnabled(True)
262 | self.btn_del_input_valiables.setEnabled(False)
263 | elif self.visible_input_valiables_count >= self._MAX_VARIABLES_COUNT:
264 | self.btn_add_input_valiables.setEnabled(False)
265 | self.btn_del_input_valiables.setEnabled(True)
266 | else:
267 | self.btn_add_input_valiables.setEnabled(True)
268 | self.btn_del_input_valiables.setEnabled(True)
269 |
270 | def set_visible_output_valiables(self):
271 | for key, widgets in self.add_output_valiables.items():
272 | widgets["base"].setVisible(key < self.visible_output_valiables_count)
273 | if self.visible_output_valiables_count == 0:
274 | self.btn_add_output_valiables.setEnabled(True)
275 | self.btn_del_output_valiables.setEnabled(False)
276 | elif self.visible_output_valiables_count >= self._MAX_VARIABLES_COUNT:
277 | self.btn_add_output_valiables.setEnabled(False)
278 | self.btn_del_output_valiables.setEnabled(True)
279 | else:
280 | self.btn_add_output_valiables.setEnabled(True)
281 | self.btn_del_output_valiables.setEnabled(True)
282 |
283 | def set_visible_add_op_attributes(self):
284 | for key, widgets in self.add_op_attributes.items():
285 | widgets["base"].setVisible(key < self.visible_add_attributes_count)
286 | if self.visible_add_attributes_count == 0:
287 | self.btn_add_op_attributes.setEnabled(True)
288 | self.btn_del_op_attributes.setEnabled(False)
289 | elif self.visible_add_attributes_count >= self._MAX_ATTRIBUTES_COUNT:
290 | self.btn_add_op_attributes.setEnabled(False)
291 | self.btn_del_op_attributes.setEnabled(True)
292 | else:
293 | self.btn_add_op_attributes.setEnabled(True)
294 | self.btn_del_op_attributes.setEnabled(True)
295 |
296 | def btn_add_input_valiables_clicked(self, e):
297 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count + 1), self._MAX_VARIABLES_COUNT)
298 | self.set_visible_input_valiables()
299 |
300 | def btn_del_input_valiables_clicked(self, e):
301 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count - 1), self._MAX_VARIABLES_COUNT)
302 | self.set_visible_input_valiables()
303 |
304 | def btn_add_output_valiables_clicked(self, e):
305 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count + 1), self._MAX_VARIABLES_COUNT)
306 | self.set_visible_output_valiables()
307 |
308 | def btn_del_output_valiables_clicked(self, e):
309 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count - 1), self._MAX_VARIABLES_COUNT)
310 | self.set_visible_output_valiables()
311 |
312 | def btn_add_op_attributes_clicked(self, e):
313 | self.visible_add_attributes_count = min(max(0, self.visible_add_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT)
314 | self.set_visible_add_op_attributes()
315 |
316 | def btn_del_op_attributes_clicked(self, e):
317 | self.visible_add_attributes_count = min(max(0, self.visible_add_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT)
318 | self.set_visible_add_op_attributes()
319 |
320 | def add_op_type_currentIndexChanged(self, selected_index:int):
321 | selected_operator: OperatorVersion = self.add_op_type.currentData()
322 | self.visible_input_valiables_count = selected_operator.inputs
323 | self.visible_output_valiables_count = selected_operator.outputs
324 | self.visible_add_attributes_count = min(max(0, len(selected_operator.attrs)), self._MAX_ATTRIBUTES_COUNT)
325 |
326 | for i, att in enumerate(selected_operator.attrs):
327 | self.add_op_attributes[i]["name"].setText(att.name)
328 | self.add_op_attributes[i]["value"].setText(att.default_value)
329 | for j in range(len(selected_operator.attrs), self._MAX_ATTRIBUTES_COUNT):
330 | self.add_op_attributes[j]["name"].setText("")
331 | self.add_op_attributes[j]["value"].setText("")
332 | self.set_visible_input_valiables()
333 | self.set_visible_output_valiables()
334 | self.set_visible_add_op_attributes()
335 |
336 | def get_properties(self)->AddNodeProperties:
337 | add_op_input_variables = {}
338 | add_op_output_variables = {}
339 | for i in range(self.visible_input_valiables_count):
340 | name = self.add_input_valiables[i]["name"].text()
341 | dtype = self.add_input_valiables[i]["dtype"].currentText()
342 | shape = self.add_input_valiables[i]["shape"].text()
343 | if name and dtype and shape:
344 | add_op_input_variables[name] = [dtype, literal_eval(shape)]
345 | for i in range(self.visible_output_valiables_count):
346 | name = self.add_output_valiables[i]["name"].text()
347 | dtype = self.add_output_valiables[i]["dtype"].currentText()
348 | shape = self.add_output_valiables[i]["shape"].text()
349 | if name and dtype and shape:
350 | add_op_output_variables[name] = [dtype, literal_eval(shape)]
351 |
352 | if len(add_op_input_variables) == 0:
353 | add_op_input_variables = None
354 | if len(add_op_output_variables) == 0:
355 | add_op_output_variables = None
356 |
357 | add_op_attributes = {}
358 | for i in range(self.visible_add_attributes_count):
359 | name = self.add_op_attributes[i]["name"].text()
360 | value = self.add_op_attributes[i]["value"].text()
361 | if name and value:
362 | try:
363 | # For literal
364 | add_op_attributes[name] = literal_eval(value)
365 | except BaseException as e:
366 | # For str
367 | add_op_attributes[name] = value
368 | if len(add_op_attributes) == 0:
369 | add_op_attributes = None
370 |
371 | src_output_names = []
372 | for i in [0, 1]:
373 | name: str = self.src_output_names[i].currentText().strip()
374 | src_output_names.append(name)
375 | for i in [2, 3]:
376 | name: str = self.src_output_names[i].text().strip()
377 | src_output_names.append(name)
378 |
379 | dest_input_names = []
380 | for i in [0, 1]:
381 | name: str = self.dest_input_names[i].text().strip()
382 | dest_input_names.append(name)
383 | for i in [2, 3]:
384 | name: str = self.dest_input_names[i].currentText().strip()
385 | dest_input_names.append(name)
386 |
387 | return AddNodeProperties(
388 | connection_src_op_output_names=[src_output_names],
389 | connection_dest_op_input_names=[dest_input_names],
390 | add_op_type=self.add_op_type.currentText(),
391 | add_op_name=self.add_op_name.text(),
392 | add_op_input_variables=add_op_input_variables,
393 | add_op_output_variables=add_op_output_variables,
394 | add_op_attributes=add_op_attributes,
395 | )
396 |
397 | def accept(self) -> None:
398 | # value check
399 | invalid = False
400 | props = self.get_properties()
401 | print(props)
402 | err_msgs = []
403 | for src_op_output_name in props.connection_src_op_output_names:
404 | src_op, src_op_output, add_op, add_op_input = src_op_output_name
405 | if not src_op:
406 | err_msgs.append("- [connection_src_op_output_names] src_op is not set.")
407 | invalid = True
408 | if not src_op_output:
409 | err_msgs.append("- [connection_src_op_output_names] src_op output is not set.")
410 | invalid = True
411 | if not add_op:
412 | err_msgs.append("- [connection_src_op_output_names] add_op is not set.")
413 | invalid = True
414 | if not add_op_input:
415 | err_msgs.append("- [connection_src_op_output_names] add_op input is not set.")
416 | invalid = True
417 | for dest_op_input_name in props.connection_dest_op_input_names:
418 | add_op, add_op_output, dst_op, dst_op_input = dest_op_input_name
419 | if not add_op:
420 | err_msgs.append("- [connection_dest_op_input_names] add_op is not set")
421 | invalid = True
422 | if not add_op_output:
423 | err_msgs.append("- [connection_dest_op_input_names] add_op_output is not set")
424 | invalid = True
425 | if not dst_op:
426 | err_msgs.append("- [connection_dest_op_input_names] dst_op is not set")
427 | invalid = True
428 | if not dst_op_input:
429 | err_msgs.append("- [connection_dest_op_input_names] dst_op input is not set")
430 | invalid = True
431 | if not props.add_op_name:
432 | err_msgs.append("- [add_op_name] not set")
433 | invalid = True
434 | if not props.add_op_type in opnames:
435 | err_msgs.append("- [add_op_type] not support")
436 | invalid = True
437 | if invalid:
438 | for m in err_msgs:
439 | print(m)
440 | MessageBox.error(err_msgs, "add node", parent=self)
441 | return
442 | return super().accept()
443 |
444 |
445 |
446 | if __name__ == "__main__":
447 | import signal
448 | import os
449 | # handle SIGINT to make the app terminate on CTRL+C
450 | signal.signal(signal.SIGINT, signal.SIG_DFL)
451 |
452 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
453 |
454 | app = QtWidgets.QApplication([])
455 | window = AddNodeWidgets(current_opset=16)
456 | window.show()
457 |
458 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_change_channel.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 | from ast import literal_eval
5 |
6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
7 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
8 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
9 |
10 |
11 | ChangeChannelProperties = namedtuple("ChangeChannelProperties",
12 | [
13 | "input_op_names_and_order_dims",
14 | "channel_change_inputs",
15 | ])
16 |
17 |
18 | class ChangeChannelWidgets(QtWidgets.QDialog):
19 | _DEFAULT_WINDOW_WIDTH = 300
20 | _QLINE_EDIT_WIDTH = 170
21 |
22 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None:
23 | super().__init__(parent)
24 | self.setModal(False)
25 | self.setWindowTitle("change channel")
26 | self.graph = graph
27 |
28 | if graph:
29 | self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT = len(self.graph.inputs)
30 | self._MAX_CHANNEL_CHANGE_INPUTS_COUNT = len(self.graph.inputs)
31 | else:
32 | self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT = 4
33 | self._MAX_CHANNEL_CHANGE_INPUTS_COUNT = 4
34 | self.initUI()
35 | self.updateUI(self.graph)
36 |
37 | def initUI(self):
38 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
39 | set_font(self, font_size=BASE_FONT_SIZE)
40 |
41 | base_layout = QtWidgets.QVBoxLayout()
42 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
43 |
44 | self.layout_order_dims = QtWidgets.QVBoxLayout()
45 | self.layout_change_channel = QtWidgets.QVBoxLayout()
46 | self.visible_change_order_dim_inputs_count = self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT
47 | self.visible_channel_change_inputs_count = self._MAX_CHANNEL_CHANGE_INPUTS_COUNT
48 |
49 | self.input_op_names_and_order_dims = {} # transpose NCHW <-> NHWC
50 | self.channel_change_inputs = {} # RGB <-> BGR
51 | self.create_order_dims_widgets()
52 | self.create_channel_change_widgets()
53 |
54 | lbl_channel_change = QtWidgets.QLabel("channel_change_inputs")
55 | lbl_channel_change.setToolTip('Change the channel order of RGB and BGR.' +
56 | 'If the original model is RGB, it is transposed to BGR.' +
57 | 'If the original model is BGR, it is transposed to RGB.' +
58 | 'It can be selectively specified from among the OP names' +
59 | 'specified in input_op_names_and_order_dims.' +
60 | 'OP names not specified in input_op_names_and_order_dims are ignored.' +
61 | 'Multiple times can be specified as many times as the number' +
62 | 'of OP names specified in input_op_names_and_order_dims.' +
63 | 'channel_change_inputs = {"op_name": dimension_number_representing_the_channel}' +
64 | 'dimension_number_representing_the_channel must specify' +
65 | 'the dimension position after the change in input_op_names_and_order_dims.' +
66 | 'For example, dimension_number_representing_the_channel is 1 for NCHW and 3 for NHWC.')
67 | set_font(lbl_channel_change, font_size=LARGE_FONT_SIZE, bold=True)
68 | self.layout_change_channel.addWidget(lbl_channel_change)
69 | for i in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT):
70 | self.layout_change_channel.addWidget(self.channel_change_inputs[i]["base"])
71 |
72 | lbl_order_dims = QtWidgets.QLabel("input_op_names_and_order_dims")
73 | lbl_order_dims.setToolTip("Specify the name of the input_op to be dimensionally changed and\n" +
74 | "the order of the dimensions after the change.\n" +
75 | "The name of the input_op to be dimensionally changed\n" +
76 | "can be specified multiple times.")
77 | set_font(lbl_order_dims, font_size=LARGE_FONT_SIZE, bold=True)
78 | self.layout_order_dims.addWidget(lbl_order_dims)
79 | for i in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT):
80 | self.layout_order_dims.addWidget(self.input_op_names_and_order_dims[i]["base"])
81 |
82 | # add button
83 | self.btn_add_order_dims = QtWidgets.QPushButton("+")
84 | self.btn_del_order_dims = QtWidgets.QPushButton("-")
85 | self.btn_add_order_dims.clicked.connect(self.btn_add_order_dims_clicked)
86 | self.btn_del_order_dims.clicked.connect(self.btn_del_order_dims_clicked)
87 | self.layout_btn_order_dims = QtWidgets.QHBoxLayout()
88 | self.layout_btn_order_dims.addWidget(self.btn_add_order_dims)
89 | self.layout_btn_order_dims.addWidget(self.btn_del_order_dims)
90 |
91 | self.btn_add_channel_change = QtWidgets.QPushButton("+")
92 | self.btn_del_channel_change = QtWidgets.QPushButton("-")
93 | self.btn_add_channel_change.clicked.connect(self.btn_add_channel_change_clicked)
94 | self.btn_del_channel_change.clicked.connect(self.btn_del_channel_change_clicked)
95 | self.layout_btn_channel_change = QtWidgets.QHBoxLayout()
96 | self.layout_btn_channel_change.addWidget(self.btn_add_channel_change)
97 | self.layout_btn_channel_change.addWidget(self.btn_del_channel_change)
98 |
99 | self.layout_order_dims.addLayout(self.layout_btn_order_dims)
100 | self.layout_change_channel.addLayout(self.layout_btn_channel_change)
101 |
102 | # add layout
103 | base_layout.addLayout(self.layout_order_dims)
104 | base_layout.addSpacing(15)
105 | base_layout.addLayout(self.layout_change_channel)
106 |
107 | # Dialog button
108 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
109 | QtWidgets.QDialogButtonBox.Cancel)
110 | btn.accepted.connect(self.accept)
111 | btn.rejected.connect(self.reject)
112 | # layout.addWidget(btn)
113 | base_layout.addWidget(btn)
114 |
115 | self.setLayout(base_layout)
116 | self.set_visible_order_dims()
117 | self.set_visible_channel_change()
118 |
119 | def updateUI(self, graph: OnnxGraph):
120 | if graph:
121 | for index in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT):
122 | self.input_op_names_and_order_dims[index]["name"].clear()
123 | for name, input_node in graph.inputs.items():
124 | self.input_op_names_and_order_dims[index]["name"].addItem(name)
125 | self.input_op_names_and_order_dims[index]["name"].setCurrentIndex(index)
126 |
127 | for index in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT):
128 | self.channel_change_inputs[index]["name"].clear()
129 | for name, input_node in graph.inputs.items():
130 | self.channel_change_inputs[index]["name"].addItem(name)
131 | self.channel_change_inputs[index]["name"].setCurrentIndex(index)
132 | self.set_visible_order_dims()
133 | self.set_visible_channel_change()
134 |
135 | def btn_add_order_dims_clicked(self, e):
136 | self.visible_change_order_dim_inputs_count = min(max(0, self.visible_change_order_dim_inputs_count + 1), self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT)
137 | self.set_visible_order_dims()
138 |
139 | def btn_del_order_dims_clicked(self, e):
140 | self.visible_change_order_dim_inputs_count = min(max(0, self.visible_change_order_dim_inputs_count - 1), self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT)
141 | self.set_visible_order_dims()
142 |
143 | def btn_add_channel_change_clicked(self, e):
144 | self.visible_channel_change_inputs_count = min(max(0, self.visible_channel_change_inputs_count + 1), self._MAX_CHANNEL_CHANGE_INPUTS_COUNT)
145 | self.set_visible_channel_change()
146 |
147 | def btn_del_channel_change_clicked(self, e):
148 | self.visible_channel_change_inputs_count = min(max(0, self.visible_channel_change_inputs_count - 1), self._MAX_CHANNEL_CHANGE_INPUTS_COUNT)
149 | self.set_visible_channel_change()
150 |
151 | def create_order_dims_widgets(self):
152 | for index in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT):
153 | self.input_op_names_and_order_dims[index] = {}
154 | self.input_op_names_and_order_dims[index]["base"] = QtWidgets.QWidget()
155 | self.input_op_names_and_order_dims[index]["layout"] = QtWidgets.QHBoxLayout(self.input_op_names_and_order_dims[index]["base"])
156 | self.input_op_names_and_order_dims[index]["layout"].setContentsMargins(0, 0, 0, 0)
157 | self.input_op_names_and_order_dims[index]["name"] = QtWidgets.QComboBox()
158 | self.input_op_names_and_order_dims[index]["name"].setEditable(True)
159 | self.input_op_names_and_order_dims[index]["name"].setFixedWidth(self._QLINE_EDIT_WIDTH)
160 | self.input_op_names_and_order_dims[index]["value"] = QtWidgets.QLineEdit()
161 | self.input_op_names_and_order_dims[index]["value"].setPlaceholderText("List of dims.")
162 | self.input_op_names_and_order_dims[index]["layout"].addWidget(self.input_op_names_and_order_dims[index]["name"])
163 | self.input_op_names_and_order_dims[index]["layout"].addWidget(self.input_op_names_and_order_dims[index]["value"])
164 |
165 | def create_channel_change_widgets(self):
166 | for index in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT):
167 | self.channel_change_inputs[index] = {}
168 | self.channel_change_inputs[index]["base"] = QtWidgets.QWidget()
169 | self.channel_change_inputs[index]["layout"] = QtWidgets.QHBoxLayout(self.channel_change_inputs[index]["base"])
170 | self.channel_change_inputs[index]["layout"].setContentsMargins(0, 0, 0, 0)
171 | self.channel_change_inputs[index]["name"] = QtWidgets.QComboBox()
172 | self.channel_change_inputs[index]["name"].setEditable(True)
173 | self.channel_change_inputs[index]["name"].setFixedWidth(self._QLINE_EDIT_WIDTH)
174 | self.channel_change_inputs[index]["value"] = QtWidgets.QLineEdit()
175 | self.channel_change_inputs[index]["value"].setPlaceholderText("dim")
176 | self.channel_change_inputs[index]["layout"].addWidget(self.channel_change_inputs[index]["name"])
177 | self.channel_change_inputs[index]["layout"].addWidget(self.channel_change_inputs[index]["value"])
178 |
179 | def set_visible_order_dims(self):
180 | for key, widgets in self.input_op_names_and_order_dims.items():
181 | widgets["base"].setVisible(key < self.visible_change_order_dim_inputs_count)
182 | if self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT == 1:
183 | self.btn_add_order_dims.setEnabled(False)
184 | self.btn_del_order_dims.setEnabled(False)
185 | elif self.visible_change_order_dim_inputs_count == 1:
186 | self.btn_add_order_dims.setEnabled(True)
187 | self.btn_del_order_dims.setEnabled(False)
188 | elif self.visible_change_order_dim_inputs_count >= self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT:
189 | self.btn_add_order_dims.setEnabled(False)
190 | self.btn_del_order_dims.setEnabled(True)
191 | else:
192 | self.btn_add_order_dims.setEnabled(True)
193 | self.btn_del_order_dims.setEnabled(True)
194 |
195 | def set_visible_channel_change(self):
196 | for key, widgets in self.channel_change_inputs.items():
197 | widgets["base"].setVisible(key < self.visible_channel_change_inputs_count)
198 | if self._MAX_CHANNEL_CHANGE_INPUTS_COUNT == 1:
199 | self.btn_add_channel_change.setEnabled(False)
200 | self.btn_del_channel_change.setEnabled(False)
201 | elif self.visible_channel_change_inputs_count == 1:
202 | self.btn_add_channel_change.setEnabled(True)
203 | self.btn_del_channel_change.setEnabled(False)
204 | elif self.visible_channel_change_inputs_count >= self._MAX_CHANNEL_CHANGE_INPUTS_COUNT:
205 | self.btn_add_channel_change.setEnabled(False)
206 | self.btn_del_channel_change.setEnabled(True)
207 | else:
208 | self.btn_add_channel_change.setEnabled(True)
209 | self.btn_del_channel_change.setEnabled(True)
210 |
211 |
212 | def get_properties(self)->ChangeChannelProperties:
213 | input_op_names_and_order_dims = {}
214 | for i in range(self.visible_change_order_dim_inputs_count):
215 | name = self.input_op_names_and_order_dims[i]["name"].currentText()
216 | value = self.input_op_names_and_order_dims[i]["value"].text()
217 | if name and value:
218 | input_op_names_and_order_dims[name] = literal_eval(value)
219 | channel_change_inputs = {}
220 | for i in range(self.visible_channel_change_inputs_count):
221 | name = self.channel_change_inputs[i]["name"].currentText()
222 | value = self.channel_change_inputs[i]["value"].text()
223 | if name and value:
224 | channel_change_inputs[name] = literal_eval(value)
225 | return ChangeChannelProperties(
226 | input_op_names_and_order_dims=input_op_names_and_order_dims,
227 | channel_change_inputs=channel_change_inputs,
228 | )
229 |
230 | def accept(self) -> None:
231 | # value check
232 | invalid = False
233 | props = self.get_properties()
234 | print(props)
235 | err_msgs = []
236 | if len(props.channel_change_inputs) < 1 and len(props.input_op_names_and_order_dims) < 1:
237 | err_msgs.append("At least one of input_op_names_and_order_dims or channel_change_inputs must be specified.")
238 | invalid = True
239 |
240 | for key, val in props.channel_change_inputs.items():
241 | if type(val) is not int:
242 | err_msgs.append(f'channel_change_inputs value must be integer. {key}: {val}')
243 | invalid = True
244 |
245 | if invalid:
246 | for m in err_msgs:
247 | print(m)
248 | MessageBox.error(err_msgs, "change channel", parent=self)
249 | return
250 | return super().accept()
251 |
252 |
253 |
254 | if __name__ == "__main__":
255 | import signal
256 | import os
257 | # handle SIGINT to make the app terminate on CTRL+C
258 | signal.signal(signal.SIGINT, signal.SIG_DFL)
259 |
260 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
261 |
262 | app = QtWidgets.QApplication([])
263 | window = ChangeChannelWidgets()
264 | window.show()
265 |
266 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_change_input_ouput_shape.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 | import os
5 | from ast import literal_eval
6 |
7 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
8 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
10 |
11 |
12 | ChangeInputOutputShapeProperties = namedtuple("ChangeInputOutputShapeProperties",
13 | [
14 | "input_names",
15 | "input_shapes",
16 | "output_names",
17 | "output_shapes",
18 | ])
19 |
20 |
21 | class ChangeInputOutputShapeWidget(QtWidgets.QDialog):
22 | _DEFAULT_WINDOW_WIDTH = 500
23 |
24 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None:
25 | super().__init__(parent)
26 | self.setModal(False)
27 | self.setWindowTitle("change input output shape")
28 | self.graph = graph
29 | self.initUI(graph)
30 |
31 | def initUI(self, graph: OnnxGraph=None):
32 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
33 | set_font(self, font_size=BASE_FONT_SIZE)
34 |
35 | base_layout = QtWidgets.QVBoxLayout()
36 |
37 | # layout
38 | layout = QtWidgets.QVBoxLayout()
39 |
40 | label_input_header = QtWidgets.QLabel("Inputs")
41 | set_font(label_input_header, font_size=LARGE_FONT_SIZE, bold=True)
42 | self.cb_change_input_shape = QtWidgets.QCheckBox("change input shape")
43 | self.cb_change_input_shape.setChecked(True)
44 | self.cb_change_input_shape.stateChanged.connect(self.cb_change_input_shape_changed)
45 | layout.addWidget(label_input_header)
46 | layout.addWidget(self.cb_change_input_shape)
47 |
48 | self.input_widgets = {}
49 | for name in graph.inputs.keys():
50 | input = graph.inputs[name]
51 | shape = input.get_shape()
52 |
53 | self.input_widgets[name] = {}
54 | self.input_widgets[name]["label"] = QtWidgets.QLabel(name)
55 | self.input_widgets[name]["shape"] = QtWidgets.QLineEdit()
56 | self.input_widgets[name]["shape"].setText(str(shape))
57 | child_layout = QtWidgets.QHBoxLayout()
58 | child_layout.addWidget(self.input_widgets[name]["label"])
59 | child_layout.addWidget(QtWidgets.QLabel(f": "))
60 | child_layout.addWidget(self.input_widgets[name]["shape"])
61 | layout.addLayout(child_layout)
62 |
63 | layout.addSpacing(15)
64 |
65 | label_output_header = QtWidgets.QLabel("Outputs")
66 | set_font(label_output_header, font_size=LARGE_FONT_SIZE, bold=True)
67 | self.cb_change_output_shape = QtWidgets.QCheckBox("change output shape")
68 | self.cb_change_output_shape.setChecked(True)
69 | self.cb_change_output_shape.stateChanged.connect(self.cb_change_output_shape_changed)
70 | layout.addWidget(label_output_header)
71 | layout.addWidget(self.cb_change_output_shape)
72 |
73 | self.output_widgets = {}
74 | for name in graph.outputs.keys():
75 | output = graph.outputs[name]
76 | shape = output.get_shape()
77 |
78 | self.output_widgets[name] = {}
79 | self.output_widgets[name]["label"] = QtWidgets.QLabel(name)
80 | self.output_widgets[name]["shape"] = QtWidgets.QLineEdit()
81 | self.output_widgets[name]["shape"].setText(str(shape))
82 | child_layout = QtWidgets.QHBoxLayout()
83 | child_layout.addWidget(self.output_widgets[name]["label"])
84 | child_layout.addWidget(QtWidgets.QLabel(f": "))
85 | child_layout.addWidget(self.output_widgets[name]["shape"])
86 | layout.addLayout(child_layout)
87 |
88 | # add layout
89 | base_layout.addLayout(layout)
90 |
91 | # Dialog button
92 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
93 | QtWidgets.QDialogButtonBox.Cancel)
94 | btn.accepted.connect(self.accept)
95 | btn.rejected.connect(self.reject)
96 | # layout.addWidget(btn)
97 | base_layout.addWidget(btn)
98 |
99 | self.setLayout(base_layout)
100 |
101 | def cb_change_input_shape_changed(self, e):
102 | change = self.cb_change_input_shape.isChecked()
103 | for widgets in self.input_widgets.values():
104 | widgets["shape"].setEnabled(change)
105 |
106 | def cb_change_output_shape_changed(self, e):
107 | change = self.cb_change_output_shape.isChecked()
108 | for widgets in self.output_widgets.values():
109 | widgets["shape"].setEnabled(change)
110 |
111 | def get_properties(self)->ChangeInputOutputShapeProperties:
112 | input_names = None
113 | input_shapes = None
114 | output_names = None
115 | output_shapes = None
116 | errors = []
117 |
118 | if self.cb_change_input_shape.isChecked():
119 | input_names = []
120 | input_shapes = []
121 | for widgets in self.input_widgets.values():
122 | name = widgets["label"].text()
123 | str_shape = widgets["shape"].text().strip()
124 | if str_shape == "":
125 | errors.append(f"{name}: not entered")
126 | continue
127 | try:
128 | shape = literal_eval(str_shape)
129 | except BaseException as e:
130 | print(e)
131 | errors.append(f"{name}: {str(e)}")
132 | continue
133 | input_names.append(name)
134 | input_shapes.append(shape)
135 |
136 | if self.cb_change_output_shape.isChecked():
137 | output_names = []
138 | output_shapes = []
139 | for widgets in self.output_widgets.values():
140 | name = widgets["label"].text()
141 | str_shape = widgets["shape"].text().strip()
142 | if str_shape == "":
143 | errors.append(f"{name}: not entered")
144 | continue
145 | try:
146 | shape = literal_eval(str_shape)
147 | except BaseException as e:
148 | print(e)
149 | errors.append(f"{name}: {e}")
150 | continue
151 | output_names.append(name)
152 | output_shapes.append(shape)
153 |
154 | return ChangeInputOutputShapeProperties(
155 | input_names=input_names,
156 | input_shapes=input_shapes,
157 | output_names=output_names,
158 | output_shapes=output_shapes,
159 | ), errors
160 |
161 | def accept(self) -> None:
162 | # value check
163 | invalid = False
164 | props, errors = self.get_properties()
165 | print(props)
166 | if len(errors) > 0:
167 | err_msgs = ["shape convert error"]
168 | err_msgs += errors
169 | invalid = True
170 | else:
171 | err_msgs = []
172 | if props.input_names is None and props.output_names is None:
173 | err_msgs.append("input shape or output shape must be change.")
174 | invalid = True
175 | if props.input_names is not None and len(props.input_names) != len(self.graph.inputs):
176 | err_msgs.append("input shape.")
177 | invalid = True
178 | if props.output_names is not None and len(props.output_names) != len(self.graph.outputs):
179 | err_msgs.append("input shape or output shape must be change.")
180 | invalid = True
181 |
182 | if invalid:
183 | for m in err_msgs:
184 | print(m)
185 | MessageBox.error(err_msgs, "change input output shape", parent=self)
186 | return
187 | return super().accept()
188 |
189 |
190 | if __name__ == "__main__":
191 | import signal
192 | import os
193 | import onnx
194 | import onnx_graphsurgeon as gs
195 | from onnxgraphqt.graph.onnx_node_graph import ONNXNodeGraph
196 | # handle SIGINT to make the app terminate on CTRL+C
197 | signal.signal(signal.SIGINT, signal.SIG_DFL)
198 |
199 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
200 |
201 | app = QtWidgets.QApplication([])
202 | model_path = os.path.join(os.path.dirname(__file__), "../data/mobilenetv2-7.onnx")
203 | onnx_model = onnx.load(model_path)
204 | onnx_graph = gs.import_onnx(onnx_model)
205 | graph = ONNXNodeGraph(name=onnx_graph.name,
206 | opset=onnx_graph.opset,
207 | doc_string=onnx_graph.doc_string,
208 | import_domains=onnx_graph.import_domains,
209 | producer_name=onnx_model.producer_name,
210 | producer_version=onnx_model.producer_version,
211 | ir_version=onnx_model.ir_version,
212 | model_version=onnx_model.model_version)
213 | graph.load_onnx_graph(onnx_graph,)
214 | window = ChangeInputOutputShapeWidget(graph.to_data())
215 | window.show()
216 |
217 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_change_opset.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 | from onnxgraphqt.utils.opset import DEFAULT_OPSET
6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
8 |
9 |
10 | ChangeOpsetProperties = namedtuple("ChangeOpsetProperties",
11 | [
12 | "opset",
13 | ])
14 |
15 |
16 | class ChangeOpsetWidget(QtWidgets.QDialog):
17 | _DEFAULT_WINDOW_WIDTH = 300
18 |
19 | def __init__(self, current_opset, parent=None) -> None:
20 | super().__init__(parent)
21 | self.setModal(False)
22 | self.setWindowTitle("change opset")
23 | self.current_opset = current_opset
24 | self.initUI()
25 |
26 | def initUI(self):
27 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
28 | set_font(self, font_size=BASE_FONT_SIZE)
29 |
30 | base_layout = QtWidgets.QVBoxLayout()
31 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
32 |
33 | # Form layout
34 | layout = QtWidgets.QFormLayout()
35 | layout.setLabelAlignment(QtCore.Qt.AlignRight)
36 | self.ledit_opset = QtWidgets.QLineEdit()
37 | self.ledit_opset.setText(str(self.current_opset))
38 | self.ledit_opset.setPlaceholderText("opset")
39 |
40 | label = QtWidgets.QLabel("opset number to be changed")
41 | set_font(label, font_size=LARGE_FONT_SIZE, bold=True)
42 | layout.addRow(label, self.ledit_opset)
43 |
44 | # add layout
45 | base_layout.addLayout(layout)
46 |
47 | # Dialog button
48 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
49 | QtWidgets.QDialogButtonBox.Cancel)
50 | btn.accepted.connect(self.accept)
51 | btn.rejected.connect(self.reject)
52 | # layout.addWidget(btn)
53 | base_layout.addWidget(btn)
54 |
55 | self.setLayout(base_layout)
56 |
57 | def get_properties(self)->ChangeOpsetProperties:
58 | return ChangeOpsetProperties(
59 | opset=self.ledit_opset.text()
60 | )
61 |
62 | def accept(self) -> None:
63 | # value check
64 | invalid = False
65 | props = self.get_properties()
66 | print(props)
67 | err_msgs = []
68 | if not str(props.opset).isdecimal():
69 | err_msgs.append("opset must be unsigned integer")
70 | invalid = True
71 | if invalid:
72 | for m in err_msgs:
73 | print(m)
74 | MessageBox.error(err_msgs, "change opset", parent=self)
75 | return
76 | return super().accept()
77 |
78 |
79 | if __name__ == "__main__":
80 | import signal
81 | import os
82 | # handle SIGINT to make the app terminate on CTRL+C
83 | signal.signal(signal.SIGINT, signal.SIG_DFL)
84 |
85 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
86 |
87 | app = QtWidgets.QApplication([])
88 | window = ChangeOpsetWidget(current_opset=DEFAULT_OPSET)
89 | window.show()
90 |
91 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_combine_network.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import List, Dict
3 | import signal
4 | from PySide2 import QtCore, QtWidgets, QtGui
5 | from ast import literal_eval
6 |
7 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
8 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
10 |
11 |
12 | CombineNetworkProperties = namedtuple("CombineNetworkProperties",
13 | [
14 | "combine_with_current_graph",
15 | "srcop_destop",
16 | "op_prefixes_after_merging",
17 | "input_onnx_file_paths",
18 | "output_of_onnx_file_in_the_process_of_fusion",
19 | ])
20 |
21 |
22 | class CombineNetworkWidgets(QtWidgets.QDialog):
23 | _DEFAULT_WINDOW_WIDTH = 560
24 | _MAX_INPUTS_FILE = 3
25 | _MAX_OP_PREFIXES = _MAX_INPUTS_FILE + 1
26 |
27 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None:
28 | super().__init__(parent)
29 | self.setModal(False)
30 | self.setWindowTitle("combine graphs")
31 | self.graph = graph
32 | self.initUI()
33 | # self.updateUI(self.graph)
34 |
35 | def initUI(self):
36 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
37 | set_font(self, font_size=BASE_FONT_SIZE)
38 |
39 | base_layout = QtWidgets.QVBoxLayout()
40 |
41 | # src dst
42 | layout_src_dst = QtWidgets.QFormLayout()
43 | lbl_src_dst = QtWidgets.QLabel("srcop_destop")
44 | set_font(lbl_src_dst, font_size=LARGE_FONT_SIZE, bold=True)
45 | self.src_op_dst_op = QtWidgets.QTextEdit()
46 | self.src_op_dst_op.setMinimumHeight(20*7)
47 | self.src_op_dst_op.setPlaceholderText(
48 | 'e.g. '+ '\n' + \
49 | ' [' + '\n' + \
50 | ' [' + '\n' + \
51 | ' "model1_out1_opname","model2_in1_opname",' + '\n' + \
52 | ' "model1_out2_opname","model2_in2_opname",' + '\n' + \
53 | ' ],' + '\n' + \
54 | ' ]'
55 | )
56 | layout_src_dst.addRow(lbl_src_dst, self.src_op_dst_op)
57 |
58 | # input_onnx_file_paths
59 | layout_inputs_file_paths = QtWidgets.QVBoxLayout()
60 | lbl_inputs = QtWidgets.QLabel("input_onnx_file_paths")
61 | set_font(lbl_inputs, font_size=LARGE_FONT_SIZE, bold=True)
62 | layout_inputs_file_paths.addWidget(lbl_inputs)
63 |
64 | self.inputs_file_paths = {}
65 | for index in range(self._MAX_INPUTS_FILE):
66 | self.inputs_file_paths[index] = {}
67 | self.inputs_file_paths[index]["base"] = QtWidgets.QWidget()
68 | self.inputs_file_paths[index]["layout"] = QtWidgets.QHBoxLayout(self.inputs_file_paths[index]["base"])
69 | self.inputs_file_paths[index]["layout"].setContentsMargins(0, 0, 0, 0)
70 | self.inputs_file_paths[index]["label"] = QtWidgets.QLabel(f"file{index+1}")
71 | self.inputs_file_paths[index]["value"] = QtWidgets.QLineEdit()
72 | self.inputs_file_paths[index]["button"] = QtWidgets.QPushButton("select")
73 | self.inputs_file_paths[index]["button"].setObjectName(f"button{index}")
74 |
75 | def clicked(index):
76 | def func(e:bool):
77 | return self.btn_inputs_file_path_clicked(index, e)
78 | return func
79 |
80 | self.inputs_file_paths[index]["button"].clicked.connect(clicked(index))
81 |
82 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["label"])
83 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["value"])
84 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["button"])
85 | layout_inputs_file_paths.addWidget(self.inputs_file_paths[index]["base"])
86 |
87 | self.combine_with_current_graph = QtWidgets.QCheckBox("combine_with_current_graph")
88 | if self.graph is not None and len(self.graph.nodes) > 0:
89 | self.combine_with_current_graph.setChecked(True)
90 | else:
91 | self.combine_with_current_graph.setChecked(False)
92 | self.combine_with_current_graph.setEnabled(False)
93 | self.combine_with_current_graph.clicked.connect(self.combine_with_current_graph_clicked)
94 | layout_inputs_file_paths.addWidget(self.combine_with_current_graph)
95 | layout_inputs_file_paths.setAlignment(self.combine_with_current_graph, QtCore.Qt.AlignRight)
96 |
97 | #
98 | layout_op_prefixes_base = QtWidgets.QVBoxLayout()
99 | lbl_op_prefixes = QtWidgets.QLabel("op_prefixes_after_merging")
100 | set_font(lbl_op_prefixes, font_size=LARGE_FONT_SIZE, bold=True)
101 | layout_op_prefixes_base.addWidget(lbl_op_prefixes)
102 |
103 | layout_op_prefixes = QtWidgets.QFormLayout()
104 | self.op_prefixes = {}
105 | for index in range(self._MAX_OP_PREFIXES):
106 | self.op_prefixes[index] = {}
107 | label = f"file{index}" if index > 0 else "current graph"
108 | self.op_prefixes[index]["label"] = QtWidgets.QLabel(label)
109 | self.op_prefixes[index]["value"] = QtWidgets.QLineEdit()
110 | layout_op_prefixes.addRow(self.op_prefixes[index]["label"], self.op_prefixes[index]["value"])
111 | if self.graph is not None and len(self.graph.nodes) == 0:
112 | self.op_prefixes[0]["label"].setVisible(False)
113 | self.op_prefixes[0]["value"].setVisible(False)
114 | layout_op_prefixes_base.addLayout(layout_op_prefixes)
115 |
116 |
117 | #
118 | layout_output_in_process = QtWidgets.QVBoxLayout()
119 | self.output_in_process = QtWidgets.QCheckBox("output_of_onnx_file_in_the_process_of_fusion")
120 | self.output_in_process.setChecked(False)
121 | layout_output_in_process.addWidget(self.output_in_process)
122 | layout_output_in_process.setAlignment(self.output_in_process, QtCore.Qt.AlignRight)
123 |
124 | # add layout
125 | base_layout.addLayout(layout_src_dst)
126 | base_layout.addLayout(layout_inputs_file_paths)
127 | base_layout.addLayout(layout_op_prefixes_base)
128 | base_layout.addLayout(layout_output_in_process)
129 |
130 | # Dialog button
131 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
132 | QtWidgets.QDialogButtonBox.Cancel)
133 | btn.accepted.connect(self.accept)
134 | btn.rejected.connect(self.reject)
135 | # layout.addWidget(btn)
136 | base_layout.addWidget(btn)
137 |
138 | self.setLayout(base_layout)
139 |
140 | def btn_inputs_file_path_clicked(self, index:int, e:bool):
141 |
142 | file_name, filter = QtWidgets.QFileDialog.getOpenFileName(
143 | self,
144 | caption=f"Open ONNX Model File({index+1})",
145 | # directory=os.path.abspath(os.curdir),
146 | filter="*.onnx *.json")
147 | if not file_name:
148 | return
149 | self.inputs_file_paths[index]["value"].setText(file_name)
150 | print(file_name)
151 |
152 | def combine_with_current_graph_clicked(self, e:bool):
153 | checked = self.combine_with_current_graph.isChecked()
154 | self.op_prefixes[0]["label"].setVisible(checked)
155 | self.op_prefixes[0]["value"].setVisible(checked)
156 |
157 | def get_properties(self)->CombineNetworkProperties:
158 | combine_with_current_graph=self.combine_with_current_graph.isChecked()
159 | output_of_onnx_file_in_the_process_of_fusion=self.output_in_process.isChecked()
160 |
161 | srcop_destop = []
162 | srt_op = self.src_op_dst_op.toPlainText()
163 | if srt_op:
164 | srcop_destop = literal_eval(srt_op)
165 |
166 | input_files = []
167 | for index, widgets in self.inputs_file_paths.items():
168 | file = widgets["value"].text().strip()
169 | if file:
170 | input_files.append(file)
171 |
172 | op_prefix = []
173 | for index, widgets in self.op_prefixes.items():
174 | if combine_with_current_graph is False and index == 0:
175 | continue
176 | prefix = widgets["value"].text().strip()
177 | if prefix:
178 | op_prefix.append(prefix)
179 | if len(op_prefix) == 0:
180 | op_prefix = None
181 |
182 | return CombineNetworkProperties(
183 | combine_with_current_graph=combine_with_current_graph,
184 | srcop_destop=srcop_destop,
185 | op_prefixes_after_merging=op_prefix,
186 | input_onnx_file_paths=input_files,
187 | output_of_onnx_file_in_the_process_of_fusion=output_of_onnx_file_in_the_process_of_fusion,
188 | )
189 |
190 | def accept(self) -> None:
191 | # value check
192 | invalid = False
193 | props = self.get_properties()
194 | print(props)
195 | err_msgs = []
196 | if len(props.input_onnx_file_paths) == 0:
197 | err_msgs.append("- input_onnx_file_paths must be specified")
198 | invalid = True
199 | if props.combine_with_current_graph:
200 | if len(props.srcop_destop) != len(props.input_onnx_file_paths):
201 | err_msgs.append("- The number of srcop_destops must be (number of input_onnx_file_paths + 1).")
202 | invalid = True
203 | if props.op_prefixes_after_merging:
204 | if len(props.input_onnx_file_paths) + 1 != len(props.op_prefixes_after_merging):
205 | err_msgs.append("- The number of op_prefixes_after_merging must match (number of input_onnx_file_paths + 1).")
206 | invalid = True
207 | else:
208 | if len(props.srcop_destop) != len(props.input_onnx_file_paths)-1:
209 | err_msgs.append("- The number of srcop_destop must match the number of input_onnx_file_paths.")
210 | invalid = True
211 | if props.op_prefixes_after_merging:
212 | if len(props.input_onnx_file_paths) != len(props.op_prefixes_after_merging):
213 | err_msgs.append("- The number of op_prefixes_after_merging must match number of input_onnx_file_paths.")
214 | invalid = True
215 |
216 | if invalid:
217 | for m in err_msgs:
218 | print(m)
219 | MessageBox.error(err_msgs, "combine network", parent=self)
220 | return
221 | return super().accept()
222 |
223 |
224 |
225 | if __name__ == "__main__":
226 | import signal
227 | import os
228 | # handle SIGINT to make the app terminate on CTRL+C
229 | signal.signal(signal.SIGINT, signal.SIG_DFL)
230 |
231 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
232 |
233 | app = QtWidgets.QApplication([])
234 | window = CombineNetworkWidgets()
235 | window.show()
236 |
237 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_constant_shrink.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 | from ast import literal_eval
5 |
6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
8 |
9 |
10 | ConstantShrinkProperties = namedtuple("ConstantShrinkProperties",
11 | [
12 | "mode",
13 | "forced_extraction_op_names",
14 | "forced_extraction_constant_names",
15 | "disable_auto_downcast",
16 | ])
17 |
18 | MODE = [
19 | "shrink",
20 | "npy"
21 | ]
22 |
23 | class ConstantShrinkWidgets(QtWidgets.QDialog):
24 | _DEFAULT_WINDOW_WIDTH = 500
25 |
26 | def __init__(self, parent=None) -> None:
27 | super().__init__(parent)
28 | self.setModal(False)
29 | self.setWindowTitle("constant shrink")
30 | self.initUI()
31 |
32 | def initUI(self):
33 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
34 | set_font(self, font_size=BASE_FONT_SIZE)
35 |
36 | base_layout = QtWidgets.QVBoxLayout()
37 |
38 | # Form layout
39 | layout = QtWidgets.QFormLayout()
40 | layout.setLabelAlignment(QtCore.Qt.AlignRight)
41 | lbl_mode = QtWidgets.QLabel("mode")
42 | set_font(lbl_mode, font_size=LARGE_FONT_SIZE, bold=True)
43 | self.cmb_mode = QtWidgets.QComboBox()
44 | for m in MODE:
45 | self.cmb_mode.addItem(m)
46 |
47 | lbl_forced_extraction_op_names = QtWidgets.QLabel("forced_extraction_op_names")
48 | set_font(lbl_forced_extraction_op_names, font_size=LARGE_FONT_SIZE, bold=True)
49 | self.tb_forced_extraction_op_names = QtWidgets.QLineEdit()
50 | self.tb_forced_extraction_op_names.setPlaceholderText("e.g. ['aaa','bbb','ccc']")
51 |
52 | lbl_forced_extraction_constant_names = QtWidgets.QLabel("forced_extraction_constant_names")
53 | set_font(lbl_forced_extraction_constant_names, font_size=LARGE_FONT_SIZE, bold=True)
54 | self.tb_forced_extraction_constant_names = QtWidgets.QLineEdit()
55 | self.tb_forced_extraction_constant_names.setPlaceholderText("e.g. ['aaa','bbb','ccc']")
56 |
57 | layout.addRow(lbl_mode, self.cmb_mode)
58 | layout.addRow(lbl_forced_extraction_op_names, self.tb_forced_extraction_op_names)
59 | layout.addRow(lbl_forced_extraction_constant_names, self.tb_forced_extraction_constant_names)
60 |
61 | layout2 = QtWidgets.QVBoxLayout()
62 | self.check_auto_downcast = QtWidgets.QCheckBox("auto_downcast")
63 | self.check_auto_downcast.setChecked(True)
64 | layout2.addWidget(self.check_auto_downcast)
65 | layout2.setAlignment(self.check_auto_downcast, QtCore.Qt.AlignRight)
66 |
67 | # add layout
68 | base_layout.addLayout(layout)
69 | base_layout.addLayout(layout2)
70 |
71 | # Dialog button
72 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
73 | QtWidgets.QDialogButtonBox.Cancel)
74 | btn.accepted.connect(self.accept)
75 | btn.rejected.connect(self.reject)
76 | # layout.addWidget(btn)
77 | base_layout.addWidget(btn)
78 |
79 | self.setLayout(base_layout)
80 |
81 | def get_properties(self)->ConstantShrinkProperties:
82 | mode = self.cmb_mode.currentText()
83 | forced_extraction_op_names = []
84 | forced_extraction_constant_names = []
85 | disable_auto_downcast = not self.check_auto_downcast.isChecked()
86 |
87 | op_names = self.tb_forced_extraction_op_names.text()
88 | if op_names:
89 | try:
90 | forced_extraction_op_names = literal_eval(op_names)
91 | except Exception as e:
92 | raise e
93 |
94 | constant_names = self.tb_forced_extraction_constant_names.text()
95 | if constant_names:
96 | try:
97 | forced_extraction_constant_names = literal_eval(constant_names)
98 | except Exception as e:
99 | raise e
100 |
101 | return ConstantShrinkProperties(
102 | mode=mode,
103 | forced_extraction_op_names=forced_extraction_op_names,
104 | forced_extraction_constant_names=forced_extraction_constant_names,
105 | disable_auto_downcast=disable_auto_downcast
106 | )
107 |
108 | def accept(self) -> None:
109 | # value check
110 | invalid = False
111 | try:
112 | props = self.get_properties()
113 | print(props)
114 | err_msgs = []
115 | except Exception as e:
116 | print(e)
117 | return
118 | if not props.mode in MODE:
119 | err_msgs.append(f"- mode is select from {'or'.join(MODE)}")
120 | invalid = True
121 | if invalid:
122 | for m in err_msgs:
123 | print(m)
124 | MessageBox.error(err_msgs, "constant shrink", parent=self)
125 | return
126 | return super().accept()
127 |
128 |
129 | if __name__ == "__main__":
130 | import signal
131 | import os
132 | # handle SIGINT to make the app terminate on CTRL+C
133 | signal.signal(signal.SIGINT, signal.SIG_DFL)
134 |
135 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
136 |
137 | app = QtWidgets.QApplication([])
138 | window = ConstantShrinkWidgets()
139 | window.show()
140 |
141 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_delete_node.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import List
3 | import signal
4 | from PySide2 import QtCore, QtWidgets, QtGui
5 |
6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
7 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
8 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
9 |
10 |
11 | DeleteNodeProperties = namedtuple("DeleteNodeProperties",
12 | [
13 | "remove_node_names",
14 | ])
15 |
16 | class DeleteNodeWidgets(QtWidgets.QDialog):
17 | # _DEFAULT_WINDOW_WIDTH = 400
18 | _MAX_REMOVE_NODE_NAMES_COUNT = 5
19 |
20 | def __init__(self, graph: OnnxGraph=None, selected_nodes:List[str]=[], parent=None) -> None:
21 | super().__init__(parent)
22 | self.setModal(False)
23 | self.setWindowTitle("delete node")
24 | self.graph = graph
25 | self.selected_nodes = selected_nodes
26 | self.initUI()
27 | self.updateUI(self.graph, selected_nodes)
28 |
29 | def initUI(self):
30 | set_font(self, font_size=BASE_FONT_SIZE)
31 |
32 | base_layout = QtWidgets.QVBoxLayout()
33 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
34 |
35 | # attributes
36 | self.layout = QtWidgets.QVBoxLayout()
37 | lbl = QtWidgets.QLabel("remove_node_names ")
38 | set_font(lbl, font_size=LARGE_FONT_SIZE, bold=True)
39 | self.layout.addWidget(lbl)
40 | self.visible_remove_node_names_count = 1
41 | self.remove_node_names = {}
42 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT):
43 | self.remove_node_names[index] = {}
44 | self.remove_node_names[index]["base"] = QtWidgets.QWidget()
45 | self.remove_node_names[index]["layout"] = QtWidgets.QHBoxLayout(self.remove_node_names[index]["base"])
46 | self.remove_node_names[index]["layout"].setContentsMargins(0, 0, 0, 0)
47 | self.remove_node_names[index]["name"] = QtWidgets.QComboBox()
48 | self.remove_node_names[index]["name"].setEditable(True)
49 | self.remove_node_names[index]["layout"].addWidget(self.remove_node_names[index]["name"])
50 | self.layout.addWidget(self.remove_node_names[index]["base"])
51 | self.btn_add = QtWidgets.QPushButton("+")
52 | self.btn_del = QtWidgets.QPushButton("-")
53 | self.btn_add.clicked.connect(self.btn_add_clicked)
54 | self.btn_del.clicked.connect(self.btn_del_clicked)
55 | self.set_visible()
56 | layout_btn = QtWidgets.QHBoxLayout()
57 | layout_btn.addWidget(self.btn_add)
58 | layout_btn.addWidget(self.btn_del)
59 | self.layout.addLayout(layout_btn)
60 |
61 | # add layout
62 | base_layout.addLayout(self.layout)
63 |
64 | # Dialog button
65 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
66 | QtWidgets.QDialogButtonBox.Cancel)
67 | btn.accepted.connect(self.accept)
68 | btn.rejected.connect(self.reject)
69 | # layout.addWidget(btn)
70 | base_layout.addWidget(btn)
71 |
72 | self.setLayout(base_layout)
73 |
74 | def updateUI(self, graph: OnnxGraph=None, selected_nodes:List[str]=[]):
75 | if graph:
76 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT):
77 | self.remove_node_names[index]["name"].clear()
78 | for name, node in graph.nodes.items():
79 | self.remove_node_names[index]["name"].addItem(name)
80 | self.remove_node_names[index]["name"].setCurrentIndex(-1)
81 |
82 | index = 0
83 | visible_count = 0
84 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT):
85 | if len(selected_nodes) < visible_count + 1:
86 | break
87 | node = selected_nodes[index]
88 | if node in graph.nodes.keys():
89 | self.remove_node_names[visible_count]["name"].setCurrentText(node)
90 | visible_count += 1
91 | self.visible_remove_node_names_count = min(visible_count + 1, self._MAX_REMOVE_NODE_NAMES_COUNT)
92 | self.set_visible()
93 |
94 | def set_visible(self):
95 | for key, widgets in self.remove_node_names.items():
96 | widgets["base"].setVisible(key < self.visible_remove_node_names_count)
97 | if self.visible_remove_node_names_count == 1:
98 | self.btn_add.setEnabled(True)
99 | self.btn_del.setEnabled(False)
100 | elif self.visible_remove_node_names_count >= self._MAX_REMOVE_NODE_NAMES_COUNT:
101 | self.btn_add.setEnabled(False)
102 | self.btn_del.setEnabled(True)
103 | else:
104 | self.btn_add.setEnabled(True)
105 | self.btn_del.setEnabled(True)
106 |
107 | def btn_add_clicked(self, e):
108 | self.visible_remove_node_names_count = min(max(0, self.visible_remove_node_names_count + 1), self._MAX_REMOVE_NODE_NAMES_COUNT)
109 | self.set_visible()
110 |
111 | def btn_del_clicked(self, e):
112 | self.visible_remove_node_names_count = min(max(0, self.visible_remove_node_names_count - 1), self._MAX_REMOVE_NODE_NAMES_COUNT)
113 | self.set_visible()
114 |
115 |
116 | def get_properties(self)->DeleteNodeProperties:
117 |
118 | remove_node_names = []
119 | for i in range(self.visible_remove_node_names_count):
120 | name = self.remove_node_names[i]["name"].currentText()
121 | if str.strip(name):
122 | remove_node_names.append(name)
123 |
124 | return DeleteNodeProperties(
125 | remove_node_names=remove_node_names
126 | )
127 |
128 | def accept(self) -> None:
129 | # value check
130 | invalid = False
131 | props = self.get_properties()
132 | print(props)
133 | err_msgs = []
134 | if len(props.remove_node_names) == 0:
135 | err_msgs.append("- remove_node_names is not set.")
136 | invalid = True
137 |
138 | if invalid:
139 | for m in err_msgs:
140 | print(m)
141 | MessageBox.error(err_msgs, "delete node", parent=self)
142 | return
143 | return super().accept()
144 |
145 |
146 |
147 | if __name__ == "__main__":
148 | import signal
149 | import os
150 | # handle SIGINT to make the app terminate on CTRL+C
151 | signal.signal(signal.SIGINT, signal.SIG_DFL)
152 |
153 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
154 |
155 | app = QtWidgets.QApplication([])
156 | window = DeleteNodeWidgets()
157 | window.show()
158 |
159 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_extract_network.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
8 |
9 |
10 | ExtractNetworkProperties = namedtuple("ExtractNetworkProperties",
11 | [
12 | "input_op_names",
13 | "output_op_names",
14 | ])
15 |
16 | class ExtractNetworkWidgets(QtWidgets.QDialog):
17 | _DEFAULT_WINDOW_WIDTH = 400
18 | _MAX_INPUT_OP_NAMES_COUNT = 5
19 | _MAX_OUTPUT_OP_NAMES_COUNT = 5
20 |
21 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None:
22 | super().__init__(parent)
23 | self.setModal(False)
24 | self.setWindowTitle("extract network")
25 | self.graph = graph
26 | self.initUI()
27 | self.updateUI(self.graph)
28 |
29 | def initUI(self):
30 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
31 | set_font(self, font_size=BASE_FONT_SIZE)
32 |
33 | base_layout = QtWidgets.QVBoxLayout()
34 |
35 | # inputs
36 | self.layout_inputs = QtWidgets.QVBoxLayout()
37 |
38 | lbl_input_op_names = QtWidgets.QLabel("input_op_names")
39 | set_font(lbl_input_op_names, font_size=LARGE_FONT_SIZE, bold=True)
40 |
41 | self.layout_inputs.addWidget(lbl_input_op_names)
42 |
43 | self.visible_input_op_names_count = 1
44 | self.widgets_inputs = {}
45 | for index in range(self._MAX_INPUT_OP_NAMES_COUNT):
46 | self.widgets_inputs[index] = {}
47 | self.widgets_inputs[index]["base"] = QtWidgets.QWidget()
48 | self.widgets_inputs[index]["layout"] = QtWidgets.QHBoxLayout(self.widgets_inputs[index]["base"])
49 | self.widgets_inputs[index]["layout"].setContentsMargins(0, 0, 0, 0)
50 | self.widgets_inputs[index]["name"] = QtWidgets.QComboBox()
51 | self.widgets_inputs[index]["name"].setEditable(True)
52 | self.widgets_inputs[index]["layout"].addWidget(self.widgets_inputs[index]["name"])
53 | self.layout_inputs.addWidget(self.widgets_inputs[index]["base"])
54 | self.btn_add_inputs = QtWidgets.QPushButton("+")
55 | self.btn_del_inputs = QtWidgets.QPushButton("-")
56 | self.btn_add_inputs.clicked.connect(self.btn_add_inputs_clicked)
57 | self.btn_del_inputs.clicked.connect(self.btn_del_inputs_clicked)
58 |
59 | layout_btn_inputs = QtWidgets.QHBoxLayout()
60 | layout_btn_inputs.addWidget(self.btn_add_inputs)
61 | layout_btn_inputs.addWidget(self.btn_del_inputs)
62 | self.layout_inputs.addLayout(layout_btn_inputs)
63 |
64 | # outputs
65 | self.layout_outputs = QtWidgets.QVBoxLayout()
66 |
67 | lbl_output_op_names = QtWidgets.QLabel("output_op_names")
68 | set_font(lbl_output_op_names, font_size=LARGE_FONT_SIZE, bold=True)
69 | self.layout_outputs.addWidget(lbl_output_op_names)
70 |
71 | self.visible_output_op_names_count = 1
72 | self.widgets_outputs = {}
73 | for index in range(self._MAX_OUTPUT_OP_NAMES_COUNT):
74 | self.widgets_outputs[index] = {}
75 | self.widgets_outputs[index]["base"] = QtWidgets.QWidget()
76 | self.widgets_outputs[index]["layout"] = QtWidgets.QHBoxLayout(self.widgets_outputs[index]["base"])
77 | self.widgets_outputs[index]["layout"].setContentsMargins(0, 0, 0, 0)
78 | self.widgets_outputs[index]["name"] = QtWidgets.QComboBox()
79 | self.widgets_outputs[index]["name"].setEditable(True)
80 | self.widgets_outputs[index]["layout"].addWidget(self.widgets_outputs[index]["name"])
81 | self.layout_outputs.addWidget(self.widgets_outputs[index]["base"])
82 | self.btn_add_outputs = QtWidgets.QPushButton("+")
83 | self.btn_del_outputs = QtWidgets.QPushButton("-")
84 | self.btn_add_outputs.clicked.connect(self.btn_add_outputs_clicked)
85 | self.btn_del_outputs.clicked.connect(self.btn_del_outputs_clicked)
86 |
87 | layout_btn_outputs = QtWidgets.QHBoxLayout()
88 | layout_btn_outputs.addWidget(self.btn_add_outputs)
89 | layout_btn_outputs.addWidget(self.btn_del_outputs)
90 | self.layout_outputs.addLayout(layout_btn_outputs)
91 |
92 | # Dialog button
93 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
94 | QtWidgets.QDialogButtonBox.Cancel)
95 | btn.accepted.connect(self.accept)
96 | btn.rejected.connect(self.reject)
97 |
98 | # add layout
99 | base_layout.addLayout(self.layout_inputs)
100 | base_layout.addSpacing(10)
101 | base_layout.addLayout(self.layout_outputs)
102 | base_layout.addSpacing(10)
103 | base_layout.addWidget(btn)
104 |
105 | self.setLayout(base_layout)
106 |
107 | self.set_btn_visible()
108 |
109 | def updateUI(self, graph: OnnxGraph=None):
110 | if graph:
111 | for index in range(self._MAX_INPUT_OP_NAMES_COUNT):
112 | self.widgets_inputs[index]["name"].clear()
113 | for name in graph.node_inputs.keys():
114 | self.widgets_inputs[index]["name"].addItem(name)
115 | self.widgets_inputs[index]["name"].setCurrentIndex(-1)
116 |
117 | for index in range(self._MAX_OUTPUT_OP_NAMES_COUNT):
118 | self.widgets_outputs[index]["name"].clear()
119 | for name in graph.node_inputs.keys():
120 | self.widgets_outputs[index]["name"].addItem(name)
121 | self.widgets_outputs[index]["name"].setCurrentIndex(-1)
122 |
123 | def set_btn_visible(self):
124 | for key, widgets in self.widgets_inputs.items():
125 | widgets["base"].setVisible(key < self.visible_input_op_names_count)
126 | for key, widgets in self.widgets_outputs.items():
127 | widgets["base"].setVisible(key < self.visible_output_op_names_count)
128 |
129 | if self.visible_input_op_names_count == 1:
130 | self.btn_add_inputs.setEnabled(True)
131 | self.btn_del_inputs.setEnabled(False)
132 | elif self.visible_input_op_names_count >= self._MAX_INPUT_OP_NAMES_COUNT:
133 | self.btn_add_inputs.setEnabled(False)
134 | self.btn_del_inputs.setEnabled(True)
135 | else:
136 | self.btn_add_inputs.setEnabled(True)
137 | self.btn_del_inputs.setEnabled(True)
138 |
139 | if self.visible_output_op_names_count == 1:
140 | self.btn_add_outputs.setEnabled(True)
141 | self.btn_del_outputs.setEnabled(False)
142 | elif self.visible_output_op_names_count >= self._MAX_OUTPUT_OP_NAMES_COUNT:
143 | self.btn_add_outputs.setEnabled(False)
144 | self.btn_del_outputs.setEnabled(True)
145 | else:
146 | self.btn_add_outputs.setEnabled(True)
147 | self.btn_del_outputs.setEnabled(True)
148 |
149 | def btn_add_inputs_clicked(self, e):
150 | self.visible_input_op_names_count = min(max(0, self.visible_input_op_names_count + 1), self._MAX_INPUT_OP_NAMES_COUNT)
151 | self.set_btn_visible()
152 |
153 | def btn_del_inputs_clicked(self, e):
154 | self.visible_input_op_names_count = min(max(0, self.visible_input_op_names_count - 1), self._MAX_INPUT_OP_NAMES_COUNT)
155 | self.set_btn_visible()
156 |
157 | def btn_add_outputs_clicked(self, e):
158 | self.visible_output_op_names_count = min(max(0, self.visible_output_op_names_count + 1), self._MAX_OUTPUT_OP_NAMES_COUNT)
159 | self.set_btn_visible()
160 |
161 | def btn_del_outputs_clicked(self, e):
162 | self.visible_output_op_names_count = min(max(0, self.visible_output_op_names_count - 1), self._MAX_OUTPUT_OP_NAMES_COUNT)
163 | self.set_btn_visible()
164 |
165 |
166 | def get_properties(self)->ExtractNetworkProperties:
167 |
168 | inputs_op_names = []
169 | for i in range(self.visible_input_op_names_count):
170 | name = self.widgets_inputs[i]["name"].currentText()
171 | if str.strip(name):
172 | inputs_op_names.append(name)
173 |
174 | outputs_op_names = []
175 | for i in range(self.visible_output_op_names_count):
176 | name = self.widgets_outputs[i]["name"].currentText()
177 | if str.strip(name):
178 | outputs_op_names.append(name)
179 |
180 | return ExtractNetworkProperties(
181 | input_op_names=inputs_op_names,
182 | output_op_names=outputs_op_names
183 | )
184 |
185 | def accept(self) -> None:
186 | # value check
187 | invalid = False
188 | props = self.get_properties()
189 | print(props)
190 | err_msgs = []
191 | if len(props.input_op_names) == 0:
192 | err_msgs.append("- input_op_names is not set")
193 | invalid = True
194 | if len(props.output_op_names) == 0:
195 | err_msgs.append("- output_op_names is not set")
196 | invalid = True
197 |
198 | if invalid:
199 | for m in err_msgs:
200 | print(m)
201 | MessageBox.error(err_msgs, "extract network", parent=self)
202 | return
203 | return super().accept()
204 |
205 |
206 |
207 | if __name__ == "__main__":
208 | import signal
209 | import os
210 | # handle SIGINT to make the app terminate on CTRL+C
211 | signal.signal(signal.SIGINT, signal.SIG_DFL)
212 |
213 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
214 |
215 | app = QtWidgets.QApplication([])
216 | window = ExtractNetworkWidgets()
217 | window.show()
218 |
219 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_generate_operator.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import List, Dict
3 | import signal
4 | from PySide2 import QtCore, QtWidgets, QtGui
5 | from ast import literal_eval
6 | import numpy as np
7 |
8 | from onnxgraphqt.utils.opset import DEFAULT_OPSET
9 | from onnxgraphqt.utils.operators import onnx_opsets, opnames, OperatorVersion, latest_opset
10 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
11 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
12 |
13 |
14 | AVAILABLE_DTYPES = [
15 | 'float32',
16 | 'float64',
17 | 'int32',
18 | 'int64',
19 | 'str',
20 | ]
21 |
22 |
23 | GenerateOperatorProperties = namedtuple("GenerateOperatorProperties",
24 | [
25 | "op_type",
26 | "opset",
27 | "op_name",
28 | "input_variables",
29 | "output_variables",
30 | "attributes",
31 | ])
32 |
33 |
34 | class GenerateOperatorWidgets(QtWidgets.QDialog):
35 | _DEFAULT_WINDOW_WIDTH = 500
36 | _MAX_INPUT_VARIABLES_COUNT = 5
37 | _MAX_OUTPUT_VARIABLES_COUNT = 5
38 | _MAX_ATTRIBUTES_COUNT = 10
39 |
40 | def __init__(self, opset=DEFAULT_OPSET, parent=None) -> None:
41 | super().__init__(parent)
42 | self.setModal(False)
43 | self.setWindowTitle("generate operator")
44 | self.initUI(opset)
45 |
46 | def initUI(self, opset:int):
47 | # self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
48 | set_font(self, font_size=BASE_FONT_SIZE)
49 |
50 | base_layout = QtWidgets.QVBoxLayout()
51 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
52 |
53 | # Form layout
54 | layout = QtWidgets.QFormLayout()
55 | layout.setLabelAlignment(QtCore.Qt.AlignRight)
56 | self.cmb_optype = QtWidgets.QComboBox()
57 | self.cmb_optype.setEditable(True)
58 | lbl_op_type = QtWidgets.QLabel("op_type")
59 | set_font(lbl_op_type, font_size=LARGE_FONT_SIZE, bold=True)
60 | layout.addRow(lbl_op_type, self.cmb_optype)
61 |
62 | self.cmb_opset = QtWidgets.QComboBox()
63 | self.cmb_opset.setEditable(True)
64 | for i in range(1,latest_opset + 1):
65 | self.cmb_opset.addItem(str(i), i)
66 | lbl_opset = QtWidgets.QLabel("opset")
67 | set_font(lbl_opset, font_size=LARGE_FONT_SIZE, bold=True)
68 | layout.addRow(lbl_opset, self.cmb_opset)
69 |
70 | self.tb_opname = QtWidgets.QLineEdit()
71 | self.tb_opname.setText("")
72 | lbl_op_name = QtWidgets.QLabel("op_name")
73 | set_font(lbl_op_name, font_size=LARGE_FONT_SIZE, bold=True)
74 | layout.addRow(lbl_op_name, self.tb_opname)
75 |
76 | # variables
77 | self.layout_valiables = QtWidgets.QVBoxLayout()
78 | self.visible_input_valiables_count = 1
79 | self.visible_output_valiables_count = 1
80 |
81 | self.add_input_valiables = {}
82 | self.add_output_valiables = {}
83 | for i in range(self._MAX_INPUT_VARIABLES_COUNT):
84 | self.create_variables_widget(i, is_input=True)
85 | for i in range(self._MAX_OUTPUT_VARIABLES_COUNT):
86 | self.create_variables_widget(i, is_input=False)
87 |
88 | self.btn_add_input_valiables = QtWidgets.QPushButton("+")
89 | self.btn_del_input_valiables = QtWidgets.QPushButton("-")
90 | self.btn_add_input_valiables.clicked.connect(self.btn_add_input_valiables_clicked)
91 | self.btn_del_input_valiables.clicked.connect(self.btn_del_input_valiables_clicked)
92 | self.btn_add_output_valiables = QtWidgets.QPushButton("+")
93 | self.btn_del_output_valiables = QtWidgets.QPushButton("-")
94 | self.btn_add_output_valiables.clicked.connect(self.btn_add_output_valiables_clicked)
95 | self.btn_del_output_valiables.clicked.connect(self.btn_del_output_valiables_clicked)
96 |
97 | self.layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
98 | lbl_inp_val = QtWidgets.QLabel("input valiables [optional]")
99 | set_font(lbl_inp_val, font_size=LARGE_FONT_SIZE, bold=True)
100 | self.layout_valiables.addWidget(lbl_inp_val)
101 | for key, widgets in self.add_input_valiables.items():
102 | self.layout_valiables.addWidget(widgets["base"])
103 | layout_btn_input = QtWidgets.QHBoxLayout()
104 | layout_btn_input.addWidget(self.btn_add_input_valiables)
105 | layout_btn_input.addWidget(self.btn_del_input_valiables)
106 | self.layout_valiables.addLayout(layout_btn_input)
107 |
108 | self.layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
109 | lbl_out_val = QtWidgets.QLabel("output valiables [optional]")
110 | set_font(lbl_out_val, font_size=LARGE_FONT_SIZE, bold=True)
111 | self.layout_valiables.addWidget(lbl_out_val)
112 | for key, widgets in self.add_output_valiables.items():
113 | self.layout_valiables.addWidget(widgets["base"])
114 | layout_btn_output = QtWidgets.QHBoxLayout()
115 | layout_btn_output.addWidget(self.btn_add_output_valiables)
116 | layout_btn_output.addWidget(self.btn_del_output_valiables)
117 | self.layout_valiables.addLayout(layout_btn_output)
118 |
119 | # add_attributes
120 | self.layout_attributes = QtWidgets.QVBoxLayout()
121 | self.layout_attributes.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20))
122 | lbl_atrributes = QtWidgets.QLabel("atrributes [optional]")
123 | set_font(lbl_atrributes, font_size=LARGE_FONT_SIZE, bold=True)
124 | self.layout_attributes.addWidget(lbl_atrributes)
125 | self.visible_attributes_count = 3
126 | self.attributes = {}
127 | for index in range(self._MAX_ATTRIBUTES_COUNT):
128 | self.attributes[index] = {}
129 | self.attributes[index]["base"] = QtWidgets.QWidget()
130 | self.attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.attributes[index]["base"])
131 | self.attributes[index]["layout"].setContentsMargins(0, 0, 0, 0)
132 | self.attributes[index]["name"] = QtWidgets.QLineEdit()
133 | self.attributes[index]["name"].setPlaceholderText("name")
134 | self.attributes[index]["value"] = QtWidgets.QLineEdit()
135 | self.attributes[index]["value"].setPlaceholderText("value")
136 | self.attributes[index]["layout"].addWidget(self.attributes[index]["name"])
137 | self.attributes[index]["layout"].addWidget(self.attributes[index]["value"])
138 | self.layout_attributes.addWidget(self.attributes[index]["base"])
139 | self.btn_add_attributes = QtWidgets.QPushButton("+")
140 | self.btn_del_attributes = QtWidgets.QPushButton("-")
141 | self.btn_add_attributes.clicked.connect(self.btn_add_attributes_clicked)
142 | self.btn_del_attributes.clicked.connect(self.btn_del_attributes_clicked)
143 | layout_btn_attributes = QtWidgets.QHBoxLayout()
144 | layout_btn_attributes.addWidget(self.btn_add_attributes)
145 | layout_btn_attributes.addWidget(self.btn_del_attributes)
146 | self.layout_attributes.addLayout(layout_btn_attributes)
147 |
148 | # add layout
149 | base_layout.addLayout(layout)
150 | base_layout.addLayout(self.layout_valiables)
151 | base_layout.addLayout(self.layout_attributes)
152 |
153 | # Dialog button
154 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
155 | QtWidgets.QDialogButtonBox.Cancel)
156 | btn.accepted.connect(self.accept)
157 | btn.rejected.connect(self.reject)
158 | base_layout.addWidget(btn)
159 |
160 | self.setLayout(base_layout)
161 |
162 | self.cmb_optype.currentIndexChanged.connect(self.cmb_optype_currentIndexChanged)
163 | self.cmb_opset.currentIndexChanged.connect(self.cmb_opset_currentIndexChanged)
164 | self.cmb_opset.setCurrentIndex(opset-1)
165 |
166 | def create_variables_widget(self, index:int, is_input=True)->QtWidgets.QBoxLayout:
167 | if is_input:
168 | self.add_input_valiables[index] = {}
169 | self.add_input_valiables[index]["base"] = QtWidgets.QWidget()
170 | self.add_input_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_input_valiables[index]["base"])
171 | self.add_input_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0)
172 | self.add_input_valiables[index]["name"] = QtWidgets.QLineEdit()
173 | self.add_input_valiables[index]["name"].setPlaceholderText("name")
174 | self.add_input_valiables[index]["dtype"] = QtWidgets.QComboBox()
175 | for dtype in AVAILABLE_DTYPES:
176 | self.add_input_valiables[index]["dtype"].addItem(dtype)
177 | self.add_input_valiables[index]["dtype"].setEditable(True)
178 | self.add_input_valiables[index]["dtype"].setFixedSize(100, 20)
179 | self.add_input_valiables[index]["shape"] = QtWidgets.QLineEdit()
180 | self.add_input_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`")
181 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["name"])
182 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["dtype"])
183 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["shape"])
184 | else:
185 | self.add_output_valiables[index] = {}
186 | self.add_output_valiables[index]["base"] = QtWidgets.QWidget()
187 | self.add_output_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_output_valiables[index]["base"])
188 | self.add_output_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0)
189 | self.add_output_valiables[index]["name"] = QtWidgets.QLineEdit()
190 | self.add_output_valiables[index]["name"].setPlaceholderText("name")
191 | self.add_output_valiables[index]["dtype"] = QtWidgets.QComboBox()
192 | for dtype in AVAILABLE_DTYPES:
193 | self.add_output_valiables[index]["dtype"].addItem(dtype)
194 | self.add_output_valiables[index]["dtype"].setEditable(True)
195 | self.add_output_valiables[index]["dtype"].setFixedSize(100, 20)
196 | self.add_output_valiables[index]["dtype"].setPlaceholderText("dtype. e.g. `float32`")
197 | self.add_output_valiables[index]["shape"] = QtWidgets.QLineEdit()
198 | self.add_output_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`")
199 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["name"])
200 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["dtype"])
201 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["shape"])
202 |
203 | def set_visible_input_valiables(self):
204 | for key, widgets in self.add_input_valiables.items():
205 | widgets["base"].setVisible(key < self.visible_input_valiables_count)
206 | if self.visible_input_valiables_count == 0:
207 | self.btn_add_input_valiables.setEnabled(True)
208 | self.btn_del_input_valiables.setEnabled(False)
209 | elif self.visible_input_valiables_count >= self._MAX_INPUT_VARIABLES_COUNT:
210 | self.btn_add_input_valiables.setEnabled(False)
211 | self.btn_del_input_valiables.setEnabled(True)
212 | else:
213 | self.btn_add_input_valiables.setEnabled(True)
214 | self.btn_del_input_valiables.setEnabled(True)
215 |
216 | def set_visible_output_valiables(self):
217 | for key, widgets in self.add_output_valiables.items():
218 | widgets["base"].setVisible(key < self.visible_output_valiables_count)
219 | if self.visible_output_valiables_count == 0:
220 | self.btn_add_output_valiables.setEnabled(True)
221 | self.btn_del_output_valiables.setEnabled(False)
222 | elif self.visible_output_valiables_count >= self._MAX_OUTPUT_VARIABLES_COUNT:
223 | self.btn_add_output_valiables.setEnabled(False)
224 | self.btn_del_output_valiables.setEnabled(True)
225 | else:
226 | self.btn_add_output_valiables.setEnabled(True)
227 | self.btn_del_output_valiables.setEnabled(True)
228 |
229 | def set_visible_add_op_attributes(self):
230 | for key, widgets in self.attributes.items():
231 | widgets["base"].setVisible(key < self.visible_attributes_count)
232 | if self.visible_attributes_count == 0:
233 | self.btn_add_attributes.setEnabled(True)
234 | self.btn_del_attributes.setEnabled(False)
235 | elif self.visible_attributes_count >= self._MAX_ATTRIBUTES_COUNT:
236 | self.btn_add_attributes.setEnabled(False)
237 | self.btn_del_attributes.setEnabled(True)
238 | else:
239 | self.btn_add_attributes.setEnabled(True)
240 | self.btn_del_attributes.setEnabled(True)
241 |
242 | def btn_add_input_valiables_clicked(self, e):
243 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count + 1), self._MAX_INPUT_VARIABLES_COUNT)
244 | self.set_visible_input_valiables()
245 |
246 | def btn_del_input_valiables_clicked(self, e):
247 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count - 1), self._MAX_INPUT_VARIABLES_COUNT)
248 | self.set_visible_input_valiables()
249 |
250 | def btn_add_output_valiables_clicked(self, e):
251 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count + 1), self._MAX_OUTPUT_VARIABLES_COUNT)
252 | self.set_visible_output_valiables()
253 |
254 | def btn_del_output_valiables_clicked(self, e):
255 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count - 1), self._MAX_OUTPUT_VARIABLES_COUNT)
256 | self.set_visible_output_valiables()
257 |
258 | def btn_add_attributes_clicked(self, e):
259 | self.visible_attributes_count = min(max(0, self.visible_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT)
260 | self.set_visible_add_op_attributes()
261 |
262 | def btn_del_attributes_clicked(self, e):
263 | self.visible_attributes_count = min(max(0, self.visible_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT)
264 | self.set_visible_add_op_attributes()
265 |
266 | def cmb_optype_currentIndexChanged(self, selected_index:int):
267 | selected_operator: OperatorVersion = self.cmb_optype.currentData()
268 | if selected_operator:
269 | self.visible_input_valiables_count = selected_operator.inputs
270 | self.visible_output_valiables_count = selected_operator.outputs
271 | self.visible_attributes_count = min(max(0, len(selected_operator.attrs)), self._MAX_ATTRIBUTES_COUNT)
272 |
273 | for i, att in enumerate(selected_operator.attrs):
274 | self.attributes[i]["name"].setText(att.name)
275 | self.attributes[i]["value"].setText(att.default_value)
276 | for j in range(len(selected_operator.attrs), self._MAX_ATTRIBUTES_COUNT):
277 | self.attributes[j]["name"].setText("")
278 | self.attributes[j]["value"].setText("")
279 | self.set_visible_input_valiables()
280 | self.set_visible_output_valiables()
281 | self.set_visible_add_op_attributes()
282 |
283 | def cmb_opset_currentIndexChanged(self, selected_index:int):
284 | current_opset:int = self.cmb_opset.currentData()
285 | current_optype = self.cmb_optype.currentText()
286 | current_optype_index = 0
287 | self.cmb_optype.clear()
288 | for i, op in enumerate(onnx_opsets):
289 | for v in op.versions:
290 | if v.since_opset <= current_opset:
291 | if op.name == current_optype:
292 | current_optype_index = self.cmb_optype.count()
293 | self.cmb_optype.addItem(op.name, v)
294 | break
295 | self.cmb_optype.setCurrentIndex(current_optype_index)
296 |
297 | def get_properties(self)->GenerateOperatorProperties:
298 |
299 | op_type = self.cmb_optype.currentText()
300 | opset = self.cmb_opset.currentText()
301 | if opset:
302 | opset = literal_eval(opset)
303 | if not isinstance(opset, int):
304 | opset = ""
305 | op_name = self.tb_opname.text()
306 |
307 | input_variables = {}
308 | output_variables = {}
309 | for i in range(self.visible_input_valiables_count):
310 | name = self.add_input_valiables[i]["name"].text()
311 | dtype = self.add_input_valiables[i]["dtype"].currentText()
312 | shape = self.add_input_valiables[i]["shape"].text()
313 | if name and dtype and shape:
314 | input_variables[name] = [dtype, literal_eval(shape)]
315 | for i in range(self.visible_output_valiables_count):
316 | name = self.add_output_valiables[i]["name"].text()
317 | dtype = self.add_output_valiables[i]["dtype"].currentText()
318 | shape = self.add_output_valiables[i]["shape"].text()
319 | if name and dtype and shape:
320 | output_variables[name] = [dtype, literal_eval(shape)]
321 |
322 | if len(input_variables) == 0:
323 | input_variables = None
324 | if len(output_variables) == 0:
325 | output_variables = None
326 |
327 | attributes = {}
328 | for i in range(self.visible_attributes_count):
329 | name = self.attributes[i]["name"].text()
330 | value = self.attributes[i]["value"].text()
331 | if name and value:
332 | attributes[name] = literal_eval(value)
333 | if len(attributes) == 0:
334 | attributes = None
335 |
336 | return GenerateOperatorProperties(
337 | op_type=op_type,
338 | opset=opset,
339 | op_name=op_name,
340 | input_variables=input_variables,
341 | output_variables=output_variables,
342 | attributes=attributes,
343 | )
344 |
345 | def accept(self) -> None:
346 | # value check
347 | invalid = False
348 | props = self.get_properties()
349 | print(props)
350 | err_msgs = []
351 | if not props.op_type in opnames:
352 | err_msgs.append("- op_type is invalid.")
353 | invalid = True
354 | if not isinstance(props.opset, int):
355 | err_msgs.append("- opset must be unsigned integer.")
356 | invalid = True
357 | if not props.op_name:
358 | err_msgs.append("- op_name is not set.")
359 | invalid = True
360 | if invalid:
361 | for m in err_msgs:
362 | print(m)
363 | MessageBox.error(err_msgs, "generate operator", parent=self)
364 | return
365 | return super().accept()
366 |
367 |
368 |
369 | if __name__ == "__main__":
370 | import signal
371 | import os
372 | # handle SIGINT to make the app terminate on CTRL+C
373 | signal.signal(signal.SIGINT, signal.SIG_DFL)
374 |
375 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
376 |
377 | app = QtWidgets.QApplication([])
378 | window = GenerateOperatorWidgets()
379 | window.show()
380 |
381 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_inference_test.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from typing import List, Union
3 | import tempfile
4 | import signal
5 | from PySide2 import QtCore, QtWidgets, QtGui
6 | from ast import literal_eval
7 | import onnx
8 | import onnxruntime as ort
9 |
10 | import os
11 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
12 | from onnxgraphqt.utils.color import replace_PrintColor
13 |
14 | ONNX_PROVIDER_TABLE = {
15 | "TensorrtExecutionProvider": ["tensorrt"],
16 | "CUDAExecutionProvider": ["cuda"],
17 | "OpenVINOExecutionProvider": ["openvino_cpu", "openvino_gpu"],
18 | "CPUExecutionProvider": ["cpu"]
19 | }
20 |
21 | class InferenceProcess(QtCore.QThread):
22 | signal = QtCore.Signal(str)
23 | btn_signal = QtCore.Signal(bool)
24 | def __init__(self, parent=None) -> None:
25 | super().__init__(parent)
26 | self.onnx_file_path: str = ""
27 | self.batch_size: int = 1
28 | self.fixes_shapes: List[int] = None
29 | self.test_loop_count: int = 1
30 | self.onnx_execution_provider: str = ""
31 |
32 | def set_properties(self,
33 | onnx_file_path: str,
34 | batch_size: int, fixes_shapes: Union[List[int], None],
35 | test_loop_count: int, onnx_execution_provider: str):
36 | self.onnx_file_path = onnx_file_path
37 | self.batch_size = batch_size
38 | self.fixes_shapes = fixes_shapes
39 | self.test_loop_count = test_loop_count
40 | self.onnx_execution_provider = onnx_execution_provider
41 |
42 | def run(self):
43 | self.btn_signal.emit(False)
44 | cmd = f"python3 -m sit4onnx " # fix for docker
45 | cmd += f" --input_onnx_file_path {self.onnx_file_path} "
46 | cmd += f" --batch_size {self.batch_size} "
47 | cmd += f" --test_loop_count {self.test_loop_count} "
48 | cmd += f" --onnx_execution_provider {self.onnx_execution_provider} "
49 | if self.fixes_shapes is not None:
50 | cmd += f" --fixes_shapes {self.fixes_shapes} "
51 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
52 | while True:
53 | line: bytes = proc.stdout.readline()
54 | if line:
55 | txt = line.decode("utf-8").replace("\n", "")
56 | print(txt)
57 | self.signal.emit(replace_PrintColor(txt))
58 | if not line and proc.poll() is not None:
59 | txt = "\n"
60 | print(txt)
61 | self.signal.emit(txt)
62 | break
63 | self.btn_signal.emit(True)
64 |
65 | class InferenceTestWidgets(QtWidgets.QDialog):
66 | _DEFAULT_WINDOW_WIDTH = 600
67 |
68 | def __init__(self, onnx_model: onnx.ModelProto=None, parent=None) -> None:
69 | super().__init__(parent)
70 | self.setModal(False)
71 | self.setWindowTitle("inference test")
72 | self.fp = tempfile.NamedTemporaryFile()
73 | self.tmp_filename = self.fp.name + ".onnx"
74 | try:
75 | onnx.save(onnx_model, self.tmp_filename)
76 | except BaseException as e:
77 | raise e
78 | self.initUI()
79 |
80 | def __del__(self) -> None:
81 | print(f"delete {self.tmp_filename}")
82 | try:
83 | os.remove(self.tmp_filename)
84 | except BaseException as e:
85 | raise e
86 |
87 | def initUI(self):
88 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
89 | set_font(self, font_size=BASE_FONT_SIZE)
90 |
91 | base_layout = QtWidgets.QVBoxLayout()
92 | # base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
93 |
94 | # add layout
95 | footer_layout = QtWidgets.QVBoxLayout()
96 | main_layout = QtWidgets.QFormLayout()
97 | base_layout.addLayout(main_layout)
98 | base_layout.addLayout(footer_layout)
99 |
100 | # Form layout
101 | main_layout.setLabelAlignment(QtCore.Qt.AlignRight)
102 |
103 | self.tb_batch_size = QtWidgets.QLineEdit()
104 | self.tb_batch_size.setText("1")
105 | self.tb_batch_size.setAlignment(QtCore.Qt.AlignRight)
106 | lbl_batch_size = QtWidgets.QLabel("Batch Size: ")
107 | set_font(lbl_batch_size, font_size=LARGE_FONT_SIZE, bold=True)
108 | main_layout.addRow(lbl_batch_size, self.tb_batch_size)
109 |
110 | self.tb_fixed_shapes = QtWidgets.QLineEdit()
111 | self.tb_fixed_shapes.setAlignment(QtCore.Qt.AlignRight)
112 | lbl_fixed_shapes = QtWidgets.QLabel("Fixed Shapes: ")
113 | set_font(lbl_fixed_shapes, font_size=LARGE_FONT_SIZE, bold=True)
114 | main_layout.addRow(lbl_fixed_shapes, self.tb_fixed_shapes)
115 |
116 | self.tb_test_loop_count = QtWidgets.QLineEdit()
117 | self.tb_test_loop_count.setText("10")
118 | self.tb_test_loop_count.setAlignment(QtCore.Qt.AlignRight)
119 | lbl_test_loop_count = QtWidgets.QLabel("Test Loop Count: ")
120 | set_font(lbl_test_loop_count, font_size=LARGE_FONT_SIZE, bold=True)
121 | main_layout.addRow(lbl_test_loop_count, self.tb_test_loop_count)
122 |
123 | self.cmb_onnx_execution_provider = QtWidgets.QComboBox()
124 | self.cmb_onnx_execution_provider.setEditable(False)
125 | for provider in ort.get_available_providers():
126 | if provider in ONNX_PROVIDER_TABLE.keys():
127 | providers = ONNX_PROVIDER_TABLE[provider]
128 | for p in providers:
129 | self.cmb_onnx_execution_provider.addItem(f"{p} ", p)
130 | lbl_onnx_execution_provider = QtWidgets.QLabel("Execution Provider: ")
131 | set_font(lbl_onnx_execution_provider, font_size=LARGE_FONT_SIZE, bold=True)
132 | main_layout.addRow(lbl_onnx_execution_provider, self.cmb_onnx_execution_provider)
133 |
134 | # textbox
135 | self.tb_console = QtWidgets.QTextBrowser()
136 | self.tb_console.setReadOnly(True)
137 | self.tb_console.setStyleSheet(f"font-size: {BASE_FONT_SIZE}px; color: #FFFFFF; background-color: #505050;")
138 | footer_layout.addWidget(self.tb_console)
139 |
140 | # Dialog button
141 | self.btn_infer = QtWidgets.QPushButton("inference")
142 | self.btn_infer.clicked.connect(self.btn_infer_clicked)
143 | footer_layout.addWidget(self.btn_infer)
144 |
145 | # inferenceProcess
146 | self.inference_process = InferenceProcess()
147 | self.inference_process.signal.connect(self.update_text)
148 | self.inference_process.btn_signal.connect(self.btn_infer.setEnabled)
149 |
150 | self.setLayout(base_layout)
151 |
152 | def update_text(self, txt: str):
153 | # self.tb_console.appendPlainText(txt)
154 | self.tb_console.append(txt)
155 |
156 | def btn_infer_clicked(self) -> None:
157 | batch_size = literal_eval(self.tb_batch_size.text())
158 | try:
159 | fixes_shapes = literal_eval(self.tb_fixed_shapes.text())
160 | except:
161 | fixes_shapes = None
162 | test_loop_count = literal_eval(self.tb_test_loop_count.text())
163 | onnx_execution_provider = self.cmb_onnx_execution_provider.currentData()
164 | self.inference_process.set_properties(self.tmp_filename,
165 | batch_size,
166 | fixes_shapes,
167 | test_loop_count,
168 | onnx_execution_provider)
169 | self.inference_process.start()
170 |
171 |
172 | if __name__ == "__main__":
173 | import signal
174 | import os
175 | from onnxgraphqt.utils.color import PrintColor
176 | # handle SIGINT to make the app terminate on CTRL+C
177 | signal.signal(signal.SIGINT, signal.SIG_DFL)
178 |
179 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
180 |
181 | app = QtWidgets.QApplication([])
182 |
183 | model_path = os.path.join(os.path.dirname(__file__), "../data/mobilenetv2-7.onnx")
184 | model = onnx.load_model(model_path)
185 | window = InferenceTestWidgets(model)
186 | window.show()
187 |
188 | window.update_text(f"{PrintColor.BLACK[1]}BLACK{PrintColor.RESET[1]}")
189 | window.update_text(f"{PrintColor.RED[1]}RED{PrintColor.RESET[1]}")
190 | window.update_text(f"{PrintColor.GREEN[1]}GREEN{PrintColor.RESET[1]}")
191 | window.update_text(f"{PrintColor.YELLOW[1]}YELLOW{PrintColor.RESET[1]}")
192 | window.update_text(f"{PrintColor.BLUE[1]}BLUE{PrintColor.RESET[1]}")
193 | window.update_text(f"{PrintColor.MAGENTA[1]}MAGENTA{PrintColor.RESET[1]}")
194 | window.update_text(f"{PrintColor.CYAN[1]}CYAN{PrintColor.RESET[1]}")
195 | window.update_text(f"{PrintColor.WHITE[1]}WHITE{PrintColor.RESET[1]}")
196 |
197 | app.exec_()
198 | del window
199 |
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_initialize_batchsize.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 | from onnxgraphqt.utils.opset import DEFAULT_OPSET
6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
8 |
9 |
10 | InitializeBatchsizeProperties = namedtuple("InitializeBatchsizeProperties",
11 | [
12 | "initialization_character_string",
13 | ])
14 |
15 |
16 | class InitializeBatchsizeWidget(QtWidgets.QDialog):
17 | _DEFAULT_WINDOW_WIDTH = 350
18 |
19 | def __init__(self, current_batchsize="-1", parent=None) -> None:
20 | super().__init__(parent)
21 | self.setModal(False)
22 | self.setWindowTitle("initialize batchsize")
23 | self.initUI()
24 | self.updateUI(current_batchsize)
25 |
26 | def initUI(self):
27 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
28 | set_font(self, font_size=BASE_FONT_SIZE)
29 |
30 | base_layout = QtWidgets.QVBoxLayout()
31 |
32 | # layout
33 | layout = QtWidgets.QVBoxLayout()
34 | lbl_name = QtWidgets.QLabel("Input string to initialize batch size.")
35 | set_font(lbl_name, font_size=LARGE_FONT_SIZE, bold=True)
36 | self.ledit_character = QtWidgets.QLineEdit()
37 | self.ledit_character.setText("-1")
38 | self.ledit_character.setPlaceholderText("initialization_character_string")
39 | layout.addWidget(lbl_name)
40 | layout.addWidget(self.ledit_character)
41 |
42 | # add layout
43 | base_layout.addLayout(layout)
44 |
45 | # Dialog button
46 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
47 | QtWidgets.QDialogButtonBox.Cancel)
48 | btn.accepted.connect(self.accept)
49 | btn.rejected.connect(self.reject)
50 | # layout.addWidget(btn)
51 | base_layout.addWidget(btn)
52 |
53 | self.setLayout(base_layout)
54 |
55 | def updateUI(self, current_batchsize):
56 | self.ledit_character.setText(str(current_batchsize))
57 |
58 | def get_properties(self)->InitializeBatchsizeProperties:
59 | character = self.ledit_character.text().strip()
60 | return InitializeBatchsizeProperties(
61 | initialization_character_string=character
62 | )
63 |
64 | def accept(self) -> None:
65 | # value check
66 | invalid = False
67 | props = self.get_properties()
68 | print(props)
69 | err_msgs = []
70 | if props.initialization_character_string == "":
71 | err_msgs.append("- initialization_character_string is not set.")
72 | invalid = True
73 | if invalid:
74 | for m in err_msgs:
75 | print(m)
76 | MessageBox.error(err_msgs, "initialize batchsize", parent=self)
77 | return
78 | return super().accept()
79 |
80 |
81 | if __name__ == "__main__":
82 | import signal
83 | import os
84 | # handle SIGINT to make the app terminate on CTRL+C
85 | signal.signal(signal.SIGINT, signal.SIG_DFL)
86 |
87 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
88 |
89 | app = QtWidgets.QApplication([])
90 | window = InitializeBatchsizeWidget()
91 | window.show()
92 |
93 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_menubar.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, List
2 | from dataclasses import dataclass
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 |
6 | @dataclass
7 | class SubMenu:
8 | name: str
9 | func: Callable
10 | icon: str
11 |
12 | class Separator:
13 | pass
14 |
15 | @dataclass
16 | class Menu:
17 | name: str
18 | contents: List[Union[SubMenu, Separator]]
19 |
20 | class MenuBarWidget(QtWidgets.QMenuBar):
21 | def __init__(self, menu_list: List[Menu], parent=None) -> None:
22 | super().__init__(parent)
23 |
24 | self.menu_actions = {}
25 | for menu in menu_list:
26 | m = self.addMenu(menu.name)
27 | for content in menu.contents:
28 | if isinstance(content, Separator):
29 | m.addSeparator()
30 | elif isinstance(content, SubMenu):
31 | self.menu_actions[content.name] = m.addAction(content.name, content.func)
32 | if content.icon:
33 | self.menu_actions[content.name].setIcon(QtGui.QIcon(content.icon))
34 |
35 |
36 | if __name__ == "__main__":
37 | import sys
38 | app = QtWidgets.QApplication(sys.argv)
39 |
40 | m = MenuBarWidget()
41 | m.show()
42 |
43 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_message_box.py:
--------------------------------------------------------------------------------
1 | from PySide2 import QtCore, QtWidgets, QtGui
2 | from typing import Union, List
3 |
4 |
5 | class MessageBox(QtWidgets.QMessageBox):
6 | def __init__(self,
7 | text:Union[str, List[str]],
8 | title:str,
9 | default_button=QtWidgets.QMessageBox.Ok,
10 | icon=QtWidgets.QMessageBox.Icon.Information,
11 | parent=None) -> int:
12 | super().__init__(parent)
13 | if isinstance(text, list):
14 | self.setText('\n'.join(text))
15 | else:
16 | self.setText(text)
17 | self.setWindowTitle(title)
18 | self.setStandardButtons(default_button)
19 | self.setIcon(icon)
20 | return
21 |
22 | @classmethod
23 | def info(cls,
24 | text:Union[str, List[str]],
25 | title:str,
26 | default_button=QtWidgets.QMessageBox.Ok,
27 | parent=None):
28 | return MessageBox(text, "[INFO] " + title, default_button, icon=MessageBox.Icon.Information, parent=parent).exec_()
29 |
30 | @classmethod
31 | def question(cls,
32 | text:Union[str, List[str]],
33 | title:str,
34 | default_button=QtWidgets.QMessageBox.Yes|QtWidgets.QMessageBox.No,
35 | parent=None):
36 | return MessageBox(text, "[QUESTION] " + title, default_button, icon=MessageBox.Icon.Question, parent=parent).exec_()
37 |
38 | @classmethod
39 | def warn(cls,
40 | text:Union[str, List[str]],
41 | title:str,
42 | default_button=QtWidgets.QMessageBox.Ok,
43 | parent=None):
44 | return MessageBox(text, "[WARN] " + title, default_button, icon=MessageBox.Icon.Warning, parent=parent).exec_()
45 |
46 | @classmethod
47 | def error(cls,
48 | text:Union[str, List[str]],
49 | title:str,
50 | default_button=QtWidgets.QMessageBox.Ok,
51 | parent=None):
52 | return MessageBox(text, "[ERROR] " + title, default_button, icon=MessageBox.Icon.Critical, parent=parent).exec_()
53 |
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_modify_attrs.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from typing import List, Dict
3 | import signal
4 | from PySide2 import QtCore, QtWidgets, QtGui
5 | from ast import literal_eval
6 | import numpy as np
7 | from sam4onnx.onnx_attr_const_modify import (
8 | ATTRIBUTE_DTYPES_TO_NUMPY_TYPES,
9 | CONSTANT_DTYPES_TO_NUMPY_TYPES
10 | )
11 |
12 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph
13 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
14 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
15 |
16 |
17 | ModifyAttrsProperties = namedtuple("ModifyAttrsProperties",
18 | [
19 | "op_name",
20 | "attributes",
21 | "input_constants",
22 | "delete_attributes",
23 | ])
24 |
25 |
26 | def get_dtype_str(list_or_scalar)->str:
27 | v0 = None
28 | if isinstance(list_or_scalar, list):
29 | v0 = np.ravel(list_or_scalar)[0].tolist()
30 | else:
31 | v0 = list_or_scalar
32 |
33 | if isinstance(v0, int):
34 | dtype = "int32"
35 | elif isinstance(v0, float):
36 | dtype = "float32"
37 | elif isinstance(v0, complex):
38 | dtype = "complex64"
39 | elif isinstance(v0, bool):
40 | dtype = "bool"
41 | elif v0 is None:
42 | dtype = ""
43 | else:
44 | dtype = "str"
45 | return dtype
46 |
47 |
48 | class ModifyAttrsWidgets(QtWidgets.QDialog):
49 | _DEFAULT_WINDOW_WIDTH = 500
50 | _MAX_ATTRIBUTES_COUNT = 5
51 | _MAX_DELETE_ATTRIBUTES_COUNT = 5
52 | _MAX_CONST_COUNT = 4
53 |
54 | def __init__(self, graph: OnnxGraph=None, selected_node: str="", parent=None) -> None:
55 | super().__init__(parent)
56 | self.setModal(False)
57 | self.setWindowTitle("modify attributes")
58 | self.graph = graph
59 | self.initUI()
60 | self.updateUI(self.graph, selected_node)
61 |
62 | def initUI(self):
63 | # self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
64 | set_font(self, font_size=BASE_FONT_SIZE)
65 |
66 | base_layout = QtWidgets.QVBoxLayout()
67 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize)
68 |
69 | # Form layout
70 | layout = QtWidgets.QFormLayout()
71 | layout.setLabelAlignment(QtCore.Qt.AlignRight)
72 | self.cmb_opname = QtWidgets.QComboBox()
73 | self.cmb_opname.setEditable(True)
74 | lbl_opname = QtWidgets.QLabel("opname")
75 | set_font(lbl_opname, font_size=LARGE_FONT_SIZE, bold=True)
76 | layout.addRow(lbl_opname, self.cmb_opname)
77 |
78 | # attributes
79 | self.layout_attributes = QtWidgets.QVBoxLayout()
80 | lbl_attributes = QtWidgets.QLabel("attributes")
81 | set_font(lbl_attributes, font_size=LARGE_FONT_SIZE, bold=True)
82 | self.layout_attributes.addWidget(lbl_attributes)
83 | self.visible_attributes_count = 3
84 | self.edit_attributes = {}
85 | for index in range(self._MAX_ATTRIBUTES_COUNT):
86 | self.edit_attributes[index] = {}
87 | self.edit_attributes[index]["base"] = QtWidgets.QWidget()
88 | self.edit_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.edit_attributes[index]["base"])
89 | self.edit_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0)
90 | self.edit_attributes[index]["name"] = QtWidgets.QComboBox()
91 | self.edit_attributes[index]["name"].setEditable(True)
92 | self.edit_attributes[index]["name"].setFixedWidth(200)
93 | self.edit_attributes[index]["value"] = QtWidgets.QLineEdit()
94 | self.edit_attributes[index]["value"].setPlaceholderText("value")
95 | self.edit_attributes[index]["dtype"] = QtWidgets.QComboBox()
96 | for key, dtype in ATTRIBUTE_DTYPES_TO_NUMPY_TYPES.items():
97 | self.edit_attributes[index]["dtype"].addItem(key, dtype)
98 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["name"])
99 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["value"])
100 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["dtype"])
101 | self.layout_attributes.addWidget(self.edit_attributes[index]["base"])
102 | self.btn_add_attributes = QtWidgets.QPushButton("+")
103 | self.btn_del_attributes = QtWidgets.QPushButton("-")
104 | self.btn_add_attributes.clicked.connect(self.btn_add_attributes_clicked)
105 | self.btn_del_attributes.clicked.connect(self.btn_del_attributes_clicked)
106 | layout_btn_attributes = QtWidgets.QHBoxLayout()
107 | layout_btn_attributes.addWidget(self.btn_add_attributes)
108 | layout_btn_attributes.addWidget(self.btn_del_attributes)
109 | self.layout_attributes.addLayout(layout_btn_attributes)
110 |
111 | # input_const
112 | self.layout_const = QtWidgets.QVBoxLayout()
113 | lbl_input_constants = QtWidgets.QLabel("input_constants")
114 | set_font(lbl_input_constants, font_size=LARGE_FONT_SIZE, bold=True)
115 | self.layout_const.addWidget(lbl_input_constants)
116 | self.visible_const_count = 3
117 | self.edit_const = {}
118 | for index in range(self._MAX_CONST_COUNT):
119 | self.edit_const[index] = {}
120 | self.edit_const[index]["base"] = QtWidgets.QWidget()
121 | self.edit_const[index]["layout"] = QtWidgets.QHBoxLayout(self.edit_const[index]["base"])
122 | self.edit_const[index]["layout"].setContentsMargins(0, 0, 0, 0)
123 | self.edit_const[index]["name"] = QtWidgets.QComboBox()
124 | self.edit_const[index]["name"].setEditable(True)
125 | self.edit_const[index]["name"].setFixedWidth(200)
126 | self.edit_const[index]["value"] = QtWidgets.QLineEdit()
127 | self.edit_const[index]["value"].setPlaceholderText("value")
128 | self.edit_const[index]["dtype"] = QtWidgets.QComboBox()
129 | for key, dtype in CONSTANT_DTYPES_TO_NUMPY_TYPES.items():
130 | self.edit_const[index]["dtype"].addItem(key, dtype)
131 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["name"])
132 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["value"])
133 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["dtype"])
134 | self.layout_const.addWidget(self.edit_const[index]["base"])
135 | self.btn_add_const = QtWidgets.QPushButton("+")
136 | self.btn_del_const = QtWidgets.QPushButton("-")
137 | self.btn_add_const.clicked.connect(self.btn_add_const_clicked)
138 | self.btn_del_const.clicked.connect(self.btn_del_const_clicked)
139 | layout_btn_const = QtWidgets.QHBoxLayout()
140 | layout_btn_const.addWidget(self.btn_add_const)
141 | layout_btn_const.addWidget(self.btn_del_const)
142 | self.layout_const.addLayout(layout_btn_const)
143 |
144 | # delete_attributes
145 | self.layout_delete_attributes = QtWidgets.QVBoxLayout()
146 | lbl_delete_attributes = QtWidgets.QLabel("delete_attributes")
147 | set_font(lbl_delete_attributes, font_size=LARGE_FONT_SIZE, bold=True)
148 | self.layout_delete_attributes.addWidget(lbl_delete_attributes)
149 | self.visible_delete_attributes_count = 3
150 | self.delete_attributes = {}
151 | for index in range(self._MAX_DELETE_ATTRIBUTES_COUNT):
152 | self.delete_attributes[index] = {}
153 | self.delete_attributes[index]["base"] = QtWidgets.QWidget()
154 | self.delete_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.delete_attributes[index]["base"])
155 | self.delete_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0)
156 | self.delete_attributes[index]["name"] = QtWidgets.QComboBox()
157 | self.delete_attributes[index]["name"].setPlaceholderText("name")
158 | self.delete_attributes[index]["name"].setEditable(True)
159 | self.delete_attributes[index]["layout"].addWidget(self.delete_attributes[index]["name"])
160 | self.layout_delete_attributes.addWidget(self.delete_attributes[index]["base"])
161 | self.btn_add_delete_attributes = QtWidgets.QPushButton("+")
162 | self.btn_del_delete_attributes = QtWidgets.QPushButton("-")
163 | self.btn_add_delete_attributes.clicked.connect(self.btn_add_delete_attributes_clicked)
164 | self.btn_del_delete_attributes.clicked.connect(self.btn_del_delete_attributes_clicked)
165 | layout_btn_delete_attributes = QtWidgets.QHBoxLayout()
166 | layout_btn_delete_attributes.addWidget(self.btn_add_delete_attributes)
167 | layout_btn_delete_attributes.addWidget(self.btn_del_delete_attributes)
168 | self.layout_delete_attributes.addLayout(layout_btn_delete_attributes)
169 |
170 | # Dialog button
171 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
172 | QtWidgets.QDialogButtonBox.Cancel)
173 | btn.accepted.connect(self.accept)
174 | btn.rejected.connect(self.reject)
175 |
176 | # add layout
177 | base_layout.addLayout(layout)
178 | base_layout.addSpacing(10)
179 | base_layout.addLayout(self.layout_attributes)
180 | base_layout.addSpacing(10)
181 | base_layout.addLayout(self.layout_const)
182 | base_layout.addSpacing(10)
183 | base_layout.addLayout(self.layout_delete_attributes)
184 | base_layout.addSpacing(10)
185 | base_layout.addWidget(btn)
186 |
187 | self.set_visible_attributes()
188 | self.set_visible_const()
189 | self.set_visible_delete_attributes()
190 | self.setLayout(base_layout)
191 |
192 | def updateUI(self, graph: OnnxGraph, selected_node: str=""):
193 |
194 | self.cmb_opname.clear()
195 | if self.graph:
196 | for op_name in self.graph.nodes.keys():
197 | self.cmb_opname.addItem(op_name)
198 | self.cmb_opname.setEditable(True)
199 | self.cmb_opname.setCurrentIndex(-1)
200 |
201 | if self.graph:
202 | def edit_attributes_name_currentIndexChanged(attr_index, current_index):
203 | op_name = self.cmb_opname.currentText()
204 | attr_name = self.edit_attributes[attr_index]["name"].currentText()
205 | attrs = self.graph.nodes[op_name].attrs
206 | if attr_name:
207 | value = attrs[attr_name]
208 | self.edit_attributes[attr_index]["value"].setText(str(value))
209 | dtype = get_dtype_str(value)
210 | self.edit_attributes[attr_index]["dtype"].setCurrentText(dtype)
211 |
212 | def edit_const_name_currentIndexChanged(attr_index, current_index):
213 | # op_name = self.cmb_opname.currentText()
214 | input_name = self.edit_const[attr_index]["name"].currentText()
215 | node_input = self.graph.node_inputs.get(input_name)
216 | if node_input:
217 | self.edit_const[attr_index]["value"].setText(str(node_input.values))
218 | dtype = get_dtype_str(node_input.values)
219 | self.edit_const[attr_index]["dtype"].setCurrentText(dtype)
220 |
221 | def cmb_opname_currentIndexChanged(current_index):
222 | op_name = self.cmb_opname.currentText()
223 |
224 | for index in range(self._MAX_ATTRIBUTES_COUNT):
225 | self.edit_attributes[index]["name"].clear()
226 | for attr_name in self.graph.nodes[op_name].attrs.keys():
227 | self.edit_attributes[index]["name"].addItem(attr_name)
228 | self.edit_attributes[index]["name"].setCurrentIndex(-1)
229 | def on_change(edit_attr_index):
230 | def func(selected_index):
231 | return edit_attributes_name_currentIndexChanged(edit_attr_index, selected_index)
232 | return func
233 | self.edit_attributes[index]["name"].currentIndexChanged.connect(on_change(index))
234 | self.edit_attributes[index]["value"].setText("")
235 |
236 | for index in range(self._MAX_DELETE_ATTRIBUTES_COUNT):
237 | self.delete_attributes[index]["name"].clear()
238 | for attr_name in self.graph.nodes[op_name].attrs.keys():
239 | self.delete_attributes[index]["name"].addItem(attr_name)
240 | self.delete_attributes[index]["name"].setCurrentIndex(-1)
241 |
242 | for index in range(self._MAX_CONST_COUNT):
243 | self.edit_const[index]["name"].clear()
244 | for name, val in self.graph.node_inputs.items():
245 | self.edit_const[index]["name"].addItem(name)
246 | self.edit_const[index]["name"].setCurrentIndex(-1)
247 | def on_change(edit_const_index):
248 | def func(selected_index):
249 | return edit_const_name_currentIndexChanged(edit_const_index, selected_index)
250 | return func
251 | self.edit_const[index]["name"].currentIndexChanged.connect(on_change(index))
252 | self.edit_const[index]["value"].setText("")
253 | self.cmb_opname.currentIndexChanged.connect(cmb_opname_currentIndexChanged)
254 |
255 | if selected_node:
256 | if selected_node in self.graph.nodes.keys():
257 | self.cmb_opname.setCurrentText(selected_node)
258 |
259 | def set_visible_attributes(self):
260 | for key, widgets in self.edit_attributes.items():
261 | widgets["base"].setVisible(key < self.visible_attributes_count)
262 | if self.visible_attributes_count == 1:
263 | self.btn_add_attributes.setEnabled(True)
264 | self.btn_del_attributes.setEnabled(False)
265 | elif self.visible_attributes_count >= self._MAX_ATTRIBUTES_COUNT:
266 | self.btn_add_attributes.setEnabled(False)
267 | self.btn_del_attributes.setEnabled(True)
268 | else:
269 | self.btn_add_attributes.setEnabled(True)
270 | self.btn_del_attributes.setEnabled(True)
271 |
272 | def set_visible_delete_attributes(self):
273 | for key, widgets in self.delete_attributes.items():
274 | widgets["base"].setVisible(key < self.visible_delete_attributes_count)
275 | if self.visible_delete_attributes_count == 1:
276 | self.btn_add_delete_attributes.setEnabled(True)
277 | self.btn_del_delete_attributes.setEnabled(False)
278 | elif self.visible_delete_attributes_count >= self._MAX_DELETE_ATTRIBUTES_COUNT:
279 | self.btn_add_delete_attributes.setEnabled(False)
280 | self.btn_del_delete_attributes.setEnabled(True)
281 | else:
282 | self.btn_add_delete_attributes.setEnabled(True)
283 | self.btn_del_delete_attributes.setEnabled(True)
284 |
285 | def set_visible_const(self):
286 | for key, widgets in self.edit_const.items():
287 | widgets["base"].setVisible(key < self.visible_const_count)
288 | if self.visible_const_count == 1:
289 | self.btn_add_const.setEnabled(True)
290 | self.btn_del_const.setEnabled(False)
291 | elif self.visible_const_count >= self._MAX_CONST_COUNT:
292 | self.btn_add_const.setEnabled(False)
293 | self.btn_del_const.setEnabled(True)
294 | else:
295 | self.btn_add_const.setEnabled(True)
296 | self.btn_del_const.setEnabled(True)
297 |
298 | def btn_add_attributes_clicked(self, e):
299 | self.visible_attributes_count = min(max(0, self.visible_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT)
300 | self.set_visible_attributes()
301 |
302 | def btn_del_attributes_clicked(self, e):
303 | self.visible_attributes_count = min(max(0, self.visible_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT)
304 | self.set_visible_attributes()
305 |
306 | def btn_add_const_clicked(self, e):
307 | self.visible_const_count = min(max(0, self.visible_const_count + 1), self._MAX_CONST_COUNT)
308 | self.set_visible_const()
309 |
310 | def btn_del_const_clicked(self, e):
311 | self.visible_const_count = min(max(0, self.visible_const_count - 1), self._MAX_CONST_COUNT)
312 | self.set_visible_const()
313 |
314 | def btn_add_delete_attributes_clicked(self, e):
315 | self.visible_delete_attributes_count = min(max(0, self.visible_delete_attributes_count + 1), self._MAX_DELETE_ATTRIBUTES_COUNT)
316 | self.set_visible_delete_attributes()
317 |
318 | def btn_del_delete_attributes_clicked(self, e):
319 | self.visible_delete_attributes_count = min(max(0, self.visible_delete_attributes_count - 1), self._MAX_DELETE_ATTRIBUTES_COUNT)
320 | self.set_visible_delete_attributes()
321 |
322 | def get_properties(self)->ModifyAttrsProperties:
323 | opname = self.cmb_opname.currentText()
324 |
325 | attributes = {}
326 | for i in range(self.visible_attributes_count):
327 | name = self.edit_attributes[i]["name"].currentText()
328 | value = self.edit_attributes[i]["value"].text()
329 | dtype = self.edit_attributes[i]["dtype"].currentData()
330 | if name and value:
331 | if dtype == np.str_:
332 | attributes[name] = value
333 | else:
334 | value = literal_eval(value)
335 | if isinstance(value, list):
336 | attributes[name] = np.asarray(value, dtype=dtype)
337 | else:
338 | attributes[name] = value
339 |
340 | delete_attributes = []
341 | for i in range(self.visible_delete_attributes_count):
342 | name = self.delete_attributes[i]["name"].currentText()
343 | if name:
344 | delete_attributes.append(name)
345 |
346 | input_constants = {}
347 | for i in range(self.visible_const_count):
348 | name = self.edit_const[i]["name"].currentText()
349 | value = self.edit_const[i]["value"].text()
350 | dtype = self.edit_const[i]["dtype"].currentData()
351 | if name and value:
352 | value = literal_eval(value)
353 | input_constants[name] = np.asarray(value, dtype=dtype)
354 |
355 | return ModifyAttrsProperties(
356 | op_name=opname,
357 | attributes=attributes,
358 | input_constants=input_constants,
359 | delete_attributes=delete_attributes,
360 | )
361 |
362 | def accept(self) -> None:
363 | # value check
364 | invalid = False
365 | props = self.get_properties()
366 | print(props)
367 | err_msgs = []
368 | edit_attr = len(props.attributes) > 0
369 | edit_const = len(props.input_constants) > 0
370 | delete_attr = len(props.delete_attributes) > 0
371 | if (not props.op_name and edit_attr) or (not props.op_name and delete_attr):
372 | err_msgs.append("- op_name and attributes must always be specified at the same time.")
373 | invalid = True
374 | if invalid:
375 | for m in err_msgs:
376 | print(m)
377 | MessageBox.error(err_msgs, "modify attrs", parent=self)
378 | return
379 | return super().accept()
380 |
381 |
382 |
383 | if __name__ == "__main__":
384 | import signal
385 | import os
386 | # handle SIGINT to make the app terminate on CTRL+C
387 | signal.signal(signal.SIGINT, signal.SIG_DFL)
388 |
389 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
390 |
391 | app = QtWidgets.QApplication([])
392 | window = ModifyAttrsWidgets()
393 | window.show()
394 |
395 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_node_search.py:
--------------------------------------------------------------------------------
1 | import signal
2 | from PySide2 import QtCore, QtWidgets, QtGui
3 |
4 | from onnxgraphqt.graph.onnx_node_graph import ONNXNodeGraph
5 | from onnxgraphqt.graph.onnx_node import ONNXInput, ONNXOutput, ONNXNode, OnnxNodeIO
6 |
7 |
8 | class NodeSearchWidget(QtWidgets.QDialog):
9 | _DEFAULT_WINDOW_WIDTH = 500
10 | _DEFAULT_WINDOW_HEIGHT = 600
11 |
12 | def __init__(self, graph:ONNXNodeGraph=None, parent=None) -> None:
13 | super().__init__(parent)
14 | self.setModal(False)
15 | self.setWindowTitle("node search")
16 | # self.nodes = nodes
17 | self.graph = graph
18 | self.initUI()
19 |
20 | def initUI(self):
21 | if self.parentWidget():
22 | x = self.parentWidget().x()
23 | y = self.parentWidget().y()
24 | parent_w = self.parentWidget().width()
25 | parent_h = self.parentWidget().height()
26 | self.setGeometry(x + parent_w, y, self._DEFAULT_WINDOW_WIDTH, self._DEFAULT_WINDOW_HEIGHT)
27 | else:
28 |
29 | self.setGeometry(0, 0, self._DEFAULT_WINDOW_WIDTH, self._DEFAULT_WINDOW_HEIGHT)
30 |
31 | base_layout = QtWidgets.QVBoxLayout()
32 |
33 | # layout
34 | layout = QtWidgets.QHBoxLayout()
35 | self.tb = QtWidgets.QLineEdit()
36 | self.tb.setBaseSize(300, 50)
37 | self.btn = QtWidgets.QPushButton("search")
38 | self.btn.clicked.connect(self.btn_clicked)
39 | layout.addWidget(self.tb)
40 | layout.addWidget(self.btn)
41 |
42 | base_layout.addLayout(layout)
43 |
44 | self.model = QtGui.QStandardItemModel(0, 4)
45 | self.model.setHeaderData(0, QtCore.Qt.Horizontal, "name")
46 | self.model.setHeaderData(1, QtCore.Qt.Horizontal, "type")
47 | self.model.setHeaderData(2, QtCore.Qt.Horizontal, "input_names")
48 | self.model.setHeaderData(3, QtCore.Qt.Horizontal, "output_names")
49 |
50 | self.view = QtWidgets.QTreeView()
51 | self.view.setSortingEnabled(True)
52 | self.view.doubleClicked.connect(self.viewClicked)
53 | self.view.setModel(self.model)
54 | base_layout.addWidget(self.view)
55 |
56 | self.setLayout(base_layout)
57 | self.update(self.graph)
58 | self.search("")
59 | self.view.sortByColumn(0, QtCore.Qt.SortOrder.AscendingOrder)
60 |
61 | def viewClicked(self, index:QtCore.QModelIndex):
62 | indexItem = self.model.index(index.row(), 0, index.parent())
63 | node_name = self.model.data(indexItem)
64 | inputs = self.graph.get_input_node_by_name(node_name)
65 | nodes = self.graph.get_node_by_name(node_name)
66 | outputs = self.graph.get_output_node_by_name(node_name)
67 | nodes = inputs + nodes + outputs
68 | self.graph.fit_to_selection_node(nodes[0])
69 | parent = self.parent()
70 | if hasattr(parent, "properties_bin"):
71 | parent.properties_bin.add_node(nodes[0])
72 |
73 | def btn_clicked(self, e):
74 | self.search(self.tb.text())
75 |
76 | def update(self, graph: ONNXNodeGraph=None):
77 | self.all_row_items = []
78 |
79 | for _ in range(self.model.rowCount()):
80 | self.model.takeRow(0)
81 |
82 | if graph is None:
83 | return
84 |
85 | for i, n in enumerate(self.graph.all_nodes()):
86 | if isinstance(n, ONNXNode):
87 | name = n.get_node_name()
88 | type_name = n.op
89 | input_names = [io.name for io in n.onnx_inputs]
90 | output_names = [io.name for io in n.onnx_outputs]
91 |
92 | name_item = QtGui.QStandardItem(name)
93 | type_item = QtGui.QStandardItem(type_name)
94 | input_names_item = QtGui.QStandardItem(", ".join(input_names))
95 | output_names_item = QtGui.QStandardItem(", ".join(output_names))
96 | name_item.setEditable(False)
97 | type_item.setEditable(False)
98 | input_names_item.setEditable(False)
99 | output_names_item.setEditable(False)
100 | self.model.setItem(i, 0, name_item)
101 | self.model.setItem(i, 1, type_item)
102 | self.model.setItem(i, 2, input_names_item)
103 | self.model.setItem(i, 3, output_names_item)
104 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item))
105 |
106 | elif isinstance(n, ONNXInput):
107 | name = n.get_node_name()
108 | type_name = "Input"
109 | output_names = [name for name in n.get_output_names()]
110 |
111 | name_item = QtGui.QStandardItem(name)
112 | type_item = QtGui.QStandardItem(type_name)
113 | input_names_item = QtGui.QStandardItem("")
114 | output_names_item = QtGui.QStandardItem(", ".join(output_names))
115 | name_item.setEditable(False)
116 | type_item.setEditable(False)
117 | input_names_item.setEditable(False)
118 | output_names_item.setEditable(False)
119 | self.model.setItem(i, 0, name_item)
120 | self.model.setItem(i, 1, type_item)
121 | self.model.setItem(i, 2, input_names_item)
122 | self.model.setItem(i, 3, output_names_item)
123 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item))
124 |
125 | elif isinstance(n, ONNXOutput):
126 | name = n.get_node_name()
127 | type_name = "Output"
128 | input_names = [name for name in n.get_input_names()]
129 |
130 | name_item = QtGui.QStandardItem(name)
131 | type_item = QtGui.QStandardItem(type_name)
132 | input_names_item = QtGui.QStandardItem(", ".join(input_names))
133 | output_names_item = QtGui.QStandardItem("")
134 | name_item.setEditable(False)
135 | type_item.setEditable(False)
136 | input_names_item.setEditable(False)
137 | output_names_item.setEditable(False)
138 | self.model.setItem(i, 0, name_item)
139 | self.model.setItem(i, 1, type_item)
140 | self.model.setItem(i, 2, input_names_item)
141 | self.model.setItem(i, 3, output_names_item)
142 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item))
143 |
144 |
145 | def search(self, word):
146 | for _ in range(self.model.rowCount()):
147 | self.model.takeRow(0)
148 |
149 | serach_words = [w.strip() for w in word.split(" ")]
150 | for row in self.all_row_items:
151 | all_matched = True
152 | for serach_word in serach_words:
153 | if serach_word == "":
154 | continue
155 | values = [r.text() for r in row]
156 | found = False
157 | for val in values:
158 | if val == "": continue
159 | if val.find(serach_word) >= 0:
160 | found = True
161 | break
162 | if found is False:
163 | all_matched = False
164 | break
165 |
166 | if all_matched:
167 | self.model.appendRow(row)
168 |
169 | if __name__ == "__main__":
170 | import signal
171 | import os
172 | # handle SIGINT to make the app terminate on CTRL+C
173 | signal.signal(signal.SIGINT, signal.SIG_DFL)
174 |
175 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
176 |
177 | app = QtWidgets.QApplication([])
178 | window = NodeSearchWidget(graph=ONNXNodeGraph("", 10, "", "", "", "", 0, 0))
179 | window.show()
180 |
181 | app.exec_()
--------------------------------------------------------------------------------
/onnxgraphqt/widgets/widgets_rename_op.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import signal
3 | from PySide2 import QtCore, QtWidgets, QtGui
4 |
5 | from onnxgraphqt.widgets.widgets_message_box import MessageBox
6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE
7 |
8 |
9 | RenameOpProperties = namedtuple("RenameOpProperties",
10 | [
11 | "old_new",
12 | ])
13 |
14 |
15 | class RenameOpWidget(QtWidgets.QDialog):
16 | _DEFAULT_WINDOW_WIDTH = 420
17 |
18 | def __init__(self, parent=None) -> None:
19 | super().__init__(parent)
20 | self.setModal(False)
21 | self.setWindowTitle("rename op")
22 | self.initUI()
23 |
24 | def initUI(self):
25 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH)
26 | set_font(self, font_size=BASE_FONT_SIZE)
27 |
28 | base_layout = QtWidgets.QVBoxLayout()
29 |
30 | # layout
31 | layout = QtWidgets.QVBoxLayout()
32 | lbl_name = QtWidgets.QLabel("Replace substring 'old' in opname with 'new'.")
33 | set_font(lbl_name, font_size=LARGE_FONT_SIZE, bold=True)
34 | layout.addWidget(lbl_name)
35 |
36 | layout_ledit = QtWidgets.QHBoxLayout()
37 | self.ledit_old = QtWidgets.QLineEdit()
38 | self.ledit_old.setPlaceholderText('old. e.g. "onnx::"')
39 | self.ledit_new = QtWidgets.QLineEdit()
40 | self.ledit_new.setPlaceholderText('new. e.g. "" ')
41 | layout_ledit.addWidget(self.ledit_old)
42 | layout_ledit.addWidget(self.ledit_new)
43 |
44 | # add layout
45 | base_layout.addLayout(layout)
46 | base_layout.addLayout(layout_ledit)
47 |
48 | # Dialog button
49 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok |
50 | QtWidgets.QDialogButtonBox.Cancel)
51 | btn.accepted.connect(self.accept)
52 | btn.rejected.connect(self.reject)
53 | # layout.addWidget(btn)
54 | base_layout.addWidget(btn)
55 |
56 | self.setLayout(base_layout)
57 |
58 | def get_properties(self)->RenameOpProperties:
59 | old = self.ledit_old.text().strip()
60 | new = self.ledit_new.text().strip()
61 | return RenameOpProperties(
62 | old_new=[old, new]
63 | )
64 |
65 | def accept(self) -> None:
66 | # value check
67 | invalid = False
68 | props = self.get_properties()
69 | print(props)
70 | err_msgs = []
71 | if props.old_new[0] == "":
72 | err_msgs.append("substring old must be set.")
73 | invalid = True
74 | if invalid:
75 | for m in err_msgs:
76 | print(m)
77 | MessageBox.error(err_msgs, "rename op", parent=self)
78 | return
79 | return super().accept()
80 |
81 |
82 | if __name__ == "__main__":
83 | import signal
84 | import os
85 | # handle SIGINT to make the app terminate on CTRL+C
86 | signal.signal(signal.SIGINT, signal.SIG_DFL)
87 |
88 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
89 |
90 | app = QtWidgets.QApplication([])
91 | window = RenameOpWidget()
92 | window.show()
93 |
94 | app.exec_()
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "onnxgraphqt"
7 | version = "0.0.1"
8 | authors = [
9 | { name="fateshelled", email="53618876+fateshelled@users.noreply.github.com" },
10 | ]
11 | description = "ONNX model visualizer"
12 | readme = "README.md"
13 | license = {text = "MIT LICENSE"}
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | ]
17 | dependencies = [
18 | "PySide2",
19 | "numpy",
20 | "pillow",
21 | "onnx",
22 | "onnx-simplifier",
23 | "protobuf==3.20.0",
24 | #"onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com",
25 | "onnx_graphsurgeon",
26 | #"git+https://github.com/jchanvfx/NodeGraphQt.git@v0.5.2#egg=NodeGraphQt",
27 | "simple-onnx-processing-tools",
28 | "grandalf",
29 | "networkx",
30 | ]
31 |
32 | [project.entry-points.console_scripts]
33 | onnxgraphqt = "onnxgraphqt:main"
34 |
35 | [tool.setuptools.packages.find]
36 | exclude = ["docker", "build", "tmp"]
37 |
38 | [project.urls]
39 | "Homepage" = "https://github.com/fateshelled/OnnxGraphQt"
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | PySide2
2 | numpy
3 | pillow
4 | onnx
5 | onnx-simplifier
6 | protobuf==3.20.0
7 | onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com
8 | git+https://github.com/jchanvfx/NodeGraphQt.git@v0.5.2#egg=NodeGraphQt
9 | simple-onnx-processing-tools
10 | grandalf
11 | networkx
12 |
--------------------------------------------------------------------------------