├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile ├── build.bash └── run.bash ├── onnxgraphqt ├── __init__.py ├── __main__.py ├── data │ ├── icon │ │ ├── icon.png │ │ ├── sam4onnx.png │ │ ├── sbi4onnx.png │ │ ├── scc4onnx.png │ │ ├── scs4onnx.png │ │ ├── sio4onnx.png │ │ ├── sna4onnx.png │ │ ├── snc4onnx.png │ │ ├── snd4onnx.png │ │ ├── sne4onnx.png │ │ ├── soc4onnx.png │ │ ├── sog4onnx.png │ │ └── sor4onnx.png │ ├── mobilenetv2-12-int8.onnx │ ├── mobilenetv2-7.onnx │ ├── onnx_opsets.json │ └── splash.png ├── graph │ ├── __init__.py │ ├── autolayout │ │ ├── __init__.py │ │ └── sugiyama_layout.py │ ├── onnx_node.py │ └── onnx_node_graph.py ├── main_window.py ├── utils │ ├── __init__.py │ ├── color.py │ ├── dtype.py │ ├── operators.py │ ├── opset.py │ ├── style.py │ └── widgets.py └── widgets │ ├── __init__.py │ ├── custom_node_item.py │ ├── custom_properties.py │ ├── custom_properties_bin.py │ ├── splash_screen.py │ ├── widgets_add_node.py │ ├── widgets_change_channel.py │ ├── widgets_change_input_ouput_shape.py │ ├── widgets_change_opset.py │ ├── widgets_combine_network.py │ ├── widgets_constant_shrink.py │ ├── widgets_delete_node.py │ ├── widgets_extract_network.py │ ├── widgets_generate_operator.py │ ├── widgets_inference_test.py │ ├── widgets_initialize_batchsize.py │ ├── widgets_menubar.py │ ├── widgets_message_box.py │ ├── widgets_modify_attrs.py │ ├── widgets_node_search.py │ └── widgets_rename_op.py ├── pyproject.toml └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 fateshelled 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OnnxGraphQt 2 | 3 | ONNX model visualizer. You can edit model structure with GUI! 4 | 5 | ![https://github.com/fateshelled/OnnxGraphQt/blob/main/LICENSE](https://img.shields.io/github/license/fateshelled/OnnxGraphQt) 6 | ![https://github.com/fateshelled/OnnxGraphQt/stargazers](https://img.shields.io/github/stars/fateshelled/OnnxGraphQt) 7 | 8 |

9 | 10 |

11 | 12 | 13 | ## Requirements 14 | - [NodeGraphQt](https://github.com/jchanvfx/NodeGraphQt) 15 | - PySide2 16 | - Qt.py 17 | - Numpy 18 | - Pillow 19 | - onnx 20 | - onnx-simplifier 21 | - onnx_graphsurgeon 22 | - [simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools) 23 | - networkx 24 | - grandalf 25 | 26 | ## Install 27 | ```bash 28 | sudo apt install python3-pyside2* 29 | 30 | git clone https://github.com/fateshelled/OnnxGraphQt 31 | cd OnnxGraphQt 32 | python3 -m pip install -U nvidia-pyindex 33 | python3 -m pip install -U Qt.py 34 | # If you want to use InferenceTest, install onnxruntime or onnxruntime-gpu 35 | # python3 -m pip install -U onnxruntime 36 | # python3 -m pip install -U onnxruntime-gpu 37 | python3 -m pip install -U -r requirements.txt 38 | 39 | # Install OnnxGraphQt 40 | python3 -m pip install . 41 | ``` 42 | 43 | ## Run with Docker 44 | ```bash 45 | git clone https://github.com/fateshelled/OnnxGraphQt 46 | cd OnnxGraphQt 47 | # build docker image 48 | ./docker/build.bash 49 | # run 50 | ./docker/run.bash 51 | ``` 52 | 53 | ## Usage 54 | ```bash 55 | # Open empty graph 56 | onnxgraphqt 57 | 58 | # Open with onnx model 59 | onnxgraphqt onnxgraphqt/data/mobilenetv2-7.onnx 60 | 61 | ``` 62 | 63 | ![mobilenetv7-7.onnx](https://user-images.githubusercontent.com/53618876/193456965-07b0ccbe-5cfe-4cd8-a233-8dc897dd2446.png) 64 | 65 | 66 | ### Open Onnx Model 67 | Open file dialog from menubar(File - Open) or drag and drop from file manager to main window. 68 | 69 | Sample model is available at `ONNXGraphQt/onnxgraphqt/data/mobilenetv2-7.onnx` 70 | 71 | ![file open](https://user-images.githubusercontent.com/53618876/193456986-919c08b1-1382-426e-8b80-5dbe0e6e146d.png) 72 | 73 | 74 | ### Export 75 | Export to ONNX file or Json file. 76 | 77 | ### Node detail 78 | Double click on Node for more information. 79 | 80 | ![node information](https://user-images.githubusercontent.com/53618876/193457001-1738f4e0-948a-47f5-acdc-63bd4e4f09c8.png) 81 | 82 | ### Node Search 83 | Node search window can be open from menubar(View - Search). 84 | You can search node by name, type, input or output name. 85 | 86 | ![serach](https://user-images.githubusercontent.com/53618876/173082166-0cb05288-8033-451d-8fd0-23a2836d301f.png) 87 | 88 | ### [simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools) 89 | 90 | Please refer to each tool's Github repository for detailed usage. 91 | 92 | - Generate Operator [[sog4onnx](https://github.com/PINTO0309/sog4onnx)] 93 | - Add Node [[sna4onnx](https://github.com/PINTO0309/sna4onnx)] 94 | - Combine Network [[snc4onnx](https://github.com/PINTO0309/snc4onnx)] 95 | - Extract Network [[sne4onnx](https://github.com/PINTO0309/sne4onnx)] 96 | - Rename Operator [[sor4onnx](https://github.com/PINTO0309/sor4onnx)] 97 | - Modify Attributes and Constant [[sam4onnx](https://github.com/PINTO0309/sam4onnx)] 98 | - Input Channel Conversion [[scc4onnx](https://github.com/PINTO0309/scc4onnx)] 99 | - Initialize Batchsize [[sbi4onnx](https://github.com/PINTO0309/sbi4onnx)] 100 | - Change Opset [[soc4onnx](https://github.com/PINTO0309/soc4onnx)] 101 | - Constant Value Shrink [[scs4onnx](https://github.com/PINTO0309/scs4onnx)] 102 | - Delete Node [[snd4onnx](https://github.com/PINTO0309/snd4onnx)])] 103 | - Inference Test [[sit4onnx](https://github.com/PINTO0309/sit4onnx)] 104 | - Change the INPUT and OUTPUT shape [[sio4onnx](https://github.com/PINTO0309/sio4onnx)] 105 | 106 | 107 | ## ToDo 108 | - [ ] Add Simple Structure Checker[[ssc4onnx](https://github.com/PINTO0309/ssc4onnx)] 109 | 110 | 111 | ## References 112 | - https://github.com/jchanvfx/NodeGraphQt 113 | - https://github.com/lutzroeder/netron 114 | - https://github.com/PINTO0309/simple-onnx-processing-tools 115 | - https://fdwr.github.io/LostOnnxDocs/OperatorFormulas.html 116 | - https://github.com/onnx/onnx/blob/main/docs/Operators.md 117 | 118 | 119 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | ARG USERNAME=onnxgraphqt 4 | ARG HOME=/home/${USERNAME} 5 | ARG GID=1000 6 | ARG UID=1000 7 | 8 | RUN groupadd -f -g ${GID} ${USERNAME} && \ 9 | useradd -m -s /bin/bash -u ${UID} -g ${GID} -G sudo ${USERNAME} && \ 10 | echo ${USERNAME}:${USERNAME} | chpasswd && \ 11 | echo "${USERNAME} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 12 | 13 | ENV DEBIAN_FRONTEND noninteractive 14 | RUN apt update && \ 15 | apt install -y \ 16 | git \ 17 | bash-completion \ 18 | python3-dev \ 19 | python3-pip \ 20 | python3-pyside2* && \ 21 | apt clean && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | USER ${USERNAME} 25 | WORKDIR ${HOME} 26 | RUN git clone https://github.com/fateshelled/OnnxGraphQt && \ 27 | cd OnnxGraphQt && \ 28 | python3 -m pip install -U pip && \ 29 | python3 -m pip install -U nvidia-pyindex && \ 30 | python3 -m pip install -U Qt.py && \ 31 | python3 -m pip install -U onnxruntime && \ 32 | python3 -m pip install -U -r requirements.txt && \ 33 | python3 -m pip install . && \ 34 | rm -rf ~/.cache/pip 35 | 36 | WORKDIR ${HOME}/OnnxGraphQt 37 | CMD [ "python3", "-m", "onnxgraphqt" ] 38 | -------------------------------------------------------------------------------- /docker/build.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR=$(cd $(dirname $0); pwd) 4 | docker build -t onnxgraphqt $SCRIPT_DIR -------------------------------------------------------------------------------- /docker/run.bash: -------------------------------------------------------------------------------- 1 | xhost + 2 | XAUTH=/tmp/.docker.xauth 3 | if [ ! -f $XAUTH ] 4 | then 5 | xauth_list=$(xauth nlist :0 | sed -e 's/^..../ffff/') 6 | if [ ! -z "$xauth_list" ] 7 | then 8 | echo $xauth_list | xauth -f $XAUTH nmerge - 9 | else 10 | touch $XAUTH 11 | fi 12 | chmod a+r $XAUTH 13 | fi 14 | 15 | docker run --net=host --rm -it \ 16 | -v=/tmp/.X11-unix:/tmp/.X11-unix:rw \ 17 | -v=${XAUTH}:${XAUTH}:rw \ 18 | -e="XAUTHORITY=${XAUTH}" \ 19 | -e="DISPLAY=${DISPLAY}" \ 20 | -e=TERM=xterm-256color \ 21 | -e=QT_X11_NO_MITSHM=1 \ 22 | onnxgraphqt:latest 23 | -------------------------------------------------------------------------------- /onnxgraphqt/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import PySide2 3 | 4 | plugin_path = os.path.join(os.path.dirname(PySide2.__file__), 5 | "Qt", "plugins", "platforms") 6 | os.environ["QT_PLUGIN_PATH"] = plugin_path 7 | os.environ["QT_QPA_PLATFORM_PLUGIN_PATH"] = plugin_path 8 | 9 | from .main_window import main 10 | -------------------------------------------------------------------------------- /onnxgraphqt/__main__.py: -------------------------------------------------------------------------------- 1 | from .main_window import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/icon.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sam4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sam4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sbi4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sbi4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/scc4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/scc4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/scs4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/scs4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sio4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sio4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sna4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sna4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/snc4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/snc4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/snd4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/snd4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sne4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sne4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/soc4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/soc4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sog4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sog4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/icon/sor4onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/icon/sor4onnx.png -------------------------------------------------------------------------------- /onnxgraphqt/data/mobilenetv2-12-int8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/mobilenetv2-12-int8.onnx -------------------------------------------------------------------------------- /onnxgraphqt/data/mobilenetv2-7.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/mobilenetv2-7.onnx -------------------------------------------------------------------------------- /onnxgraphqt/data/splash.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/data/splash.png -------------------------------------------------------------------------------- /onnxgraphqt/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/graph/__init__.py -------------------------------------------------------------------------------- /onnxgraphqt/graph/autolayout/__init__.py: -------------------------------------------------------------------------------- 1 | from .sugiyama_layout import sugiyama_layout -------------------------------------------------------------------------------- /onnxgraphqt/graph/autolayout/sugiyama_layout.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import grandalf as grand 3 | from grandalf.layouts import SugiyamaLayout 4 | 5 | __all__ = ["sugiyama_layout"] 6 | 7 | 8 | def sugiyama_layout(g: nx.Graph): 9 | gg = grand.utils.convert_nextworkx_graph_to_grandalf(g) # undocumented function 10 | 11 | class defaultview(object): 12 | w, h = 400, 100 13 | for v in gg.V(): v.view = defaultview() 14 | if len(gg.C) == 0: 15 | return {} 16 | sug = SugiyamaLayout(gg.C[0]) 17 | sug.init_all(optimize=False) 18 | sug.draw() 19 | 20 | pos = {v.data: (v.view.xy[0], v.view.xy[1]) for v in gg.C[0].sV} # Extracts the positions 21 | return pos 22 | -------------------------------------------------------------------------------- /onnxgraphqt/graph/onnx_node.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import ( 3 | Dict, List, Any, Optional, Union 4 | ) 5 | from collections import OrderedDict 6 | 7 | from NodeGraphQt import BaseNode 8 | from NodeGraphQt.constants import NodeEnum, LayoutDirectionEnum, NodePropWidgetEnum 9 | 10 | from onnxgraphqt.utils.color import ( 11 | COLOR_BG, 12 | COLOR_FONT, 13 | COLOR_GRID, 14 | INPUT_NODE_COLOR, 15 | OUTPUT_NODE_COLOR, 16 | NODE_BORDER_COLOR, 17 | get_node_color, 18 | ) 19 | from onnxgraphqt.utils.widgets import set_font, GRAPH_FONT_SIZE 20 | from onnxgraphqt.widgets.custom_node_item import CustomNodeItem 21 | 22 | 23 | @dataclass 24 | class OnnxNodeIO: 25 | name: str 26 | dtype: str 27 | shape: List[int] 28 | values: List 29 | 30 | 31 | class ONNXNode(BaseNode): 32 | # unique node identifier. 33 | __identifier__ = 'nodes.node' 34 | # initial default node name. 35 | NODE_NAME = 'onnxnode' 36 | 37 | def __init__(self): 38 | super(ONNXNode, self).__init__(qgraphics_item=CustomNodeItem) 39 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value) 40 | self.attrs = OrderedDict() 41 | self.node_name = "" 42 | self.op = "" 43 | self.onnx_inputs = [] 44 | self.onnx_outputs = [] 45 | 46 | # create node inputs. 47 | self.input_port = self.add_input('multi in', multi_input=True, display_name=False) 48 | # create node outputs. 49 | self.output_port = self.add_output('multi out', multi_output=True, display_name=False) 50 | self.set_font() 51 | 52 | def set_attrs(self, attrs:OrderedDict, push_undo=False): 53 | self.attrs = attrs 54 | for key, val in self.attrs.items(): 55 | if self.has_property(key + "_"): 56 | self.set_property(key + "_", val, push_undo=push_undo) 57 | else: 58 | if key == "dtype": 59 | self.create_property(key + "_", val, widget_type=NodePropWidgetEnum.QLABEL) 60 | else: 61 | self.create_property(key + "_", val, widget_type=NodePropWidgetEnum.QLINE_EDIT) 62 | 63 | def get_attrs(self)->OrderedDict: 64 | d = [(key, self.get_property(key + "_")) for key in self.attrs.keys()] 65 | return OrderedDict(d) 66 | 67 | def set_node_name(self, node_name:str, push_undo=False): 68 | self.node_name = node_name 69 | if not self.has_property("node_name"): 70 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT) 71 | else: 72 | self.set_property("node_name", self.node_name, push_undo=push_undo) 73 | 74 | def get_node_name(self): 75 | self.node_name = self.get_property("node_name") 76 | return self.node_name 77 | 78 | def set_op(self, op:str, push_undo=False): 79 | self.op = op 80 | self._view.name = op 81 | if not self.has_property("op"): 82 | self.create_property("op", self.op, widget_type=NodePropWidgetEnum.QLABEL) 83 | else: 84 | self.set_property("op", self.op, push_undo=push_undo) 85 | 86 | def set_onnx_inputs(self, onnx_inputs:List[OnnxNodeIO], push_undo=False): 87 | self.onnx_inputs = onnx_inputs 88 | value = [[inp.name, inp.dtype, inp.shape, inp.values] for inp in self.onnx_inputs] 89 | if not self.has_property("inputs_"): 90 | self.create_property("inputs_", value, widget_type=NodePropWidgetEnum.QLINE_EDIT) 91 | else: 92 | self.set_property("inputs_", value, push_undo=push_undo) 93 | 94 | def set_onnx_outputs(self, onnx_outputs:List[OnnxNodeIO], push_undo=False): 95 | self.onnx_outputs = onnx_outputs 96 | value = [[out.name, out.dtype, out.shape, out.values] for out in self.onnx_outputs] 97 | if not self.has_property("outputs_"): 98 | self.create_property("outputs_", value, widget_type=NodePropWidgetEnum.QLINE_EDIT) 99 | else: 100 | self.set_property("outputs_", value, push_undo=push_undo) 101 | 102 | def set_color(self, push_undo=False): 103 | self.view.text_color = COLOR_FONT + [255] 104 | color = get_node_color(self.op) 105 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo) 106 | self.set_property('color', color + [255], push_undo) 107 | 108 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False): 109 | set_font(self.view.text_item, font_size=font_size, bold=bold) 110 | 111 | 112 | class ONNXInput(BaseNode): 113 | # unique node identifier. 114 | __identifier__ = 'nodes.node' 115 | # initial default node name. 116 | NODE_NAME = 'input' 117 | def __init__(self): 118 | super(ONNXInput, self).__init__(qgraphics_item=CustomNodeItem) 119 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value) 120 | self.node_name = "" 121 | self.shape = [] 122 | self.dtype = "" 123 | self.output_names = [] 124 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT) 125 | self.create_property("shape", self.shape, widget_type=NodePropWidgetEnum.QLINE_EDIT) 126 | self.create_property("dtype", self.dtype, widget_type=NodePropWidgetEnum.QLINE_EDIT) 127 | self.create_property("output_names", self.output_names, widget_type=NodePropWidgetEnum.QTEXT_EDIT) 128 | # create node outputs. 129 | self.output_port = self.add_output('multi out', multi_output=True, display_name=False) 130 | self.set_color() 131 | self.set_font() 132 | self._view.name = "input" 133 | 134 | def get_node_name(self): 135 | self.node_name = self.get_property("node_name") 136 | return self.node_name 137 | 138 | def set_node_name(self, node_name:str, push_undo=False): 139 | self.node_name = node_name 140 | self.set_property("node_name", self.node_name, push_undo=push_undo) 141 | 142 | def get_shape(self): 143 | self.shape = self.get_property("shape") 144 | return self.shape 145 | 146 | def set_shape(self, shape, push_undo=False): 147 | self.shape = shape 148 | self.set_property("shape", self.shape, push_undo=push_undo) 149 | 150 | def get_dtype(self): 151 | self.dtype = self.get_property("dtype") 152 | return self.dtype 153 | 154 | def set_dtype(self, dtype, push_undo=False): 155 | self.dtype = str(dtype) 156 | self.set_property("dtype", self.dtype, push_undo=push_undo) 157 | 158 | def get_output_names(self)->List[str]: 159 | self.output_names = self.get_property("output_names") 160 | return self.output_names 161 | 162 | def set_output_names(self, output_names, push_undo=False): 163 | self.output_names = output_names 164 | self.set_property("output_names", self.output_names, push_undo=push_undo) 165 | 166 | def set_color(self, push_undo=False): 167 | self.view.text_color = COLOR_FONT + [255] 168 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo) 169 | self.set_property('color', INPUT_NODE_COLOR + [255], push_undo) 170 | 171 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False): 172 | set_font(self.view.text_item, font_size=font_size, bold=bold) 173 | 174 | 175 | class ONNXOutput(BaseNode): 176 | # unique node identifier. 177 | __identifier__ = 'nodes.node' 178 | # initial default node name. 179 | NODE_NAME = 'output' 180 | def __init__(self): 181 | super(ONNXOutput, self).__init__(qgraphics_item=CustomNodeItem) 182 | self.set_layout_direction(LayoutDirectionEnum.VERTICAL.value) 183 | self.node_name = "" 184 | self.shape = [] 185 | self.dtype:str = "" 186 | self.input_names = [] 187 | self.create_property("node_name", self.node_name, widget_type=NodePropWidgetEnum.QLINE_EDIT) 188 | self.create_property("shape", self.shape, widget_type=NodePropWidgetEnum.QLINE_EDIT) 189 | self.create_property("dtype", self.dtype, widget_type=NodePropWidgetEnum.QLINE_EDIT) 190 | self.create_property("input_names", self.input_names, widget_type=NodePropWidgetEnum.QTEXT_EDIT) 191 | # create node inputs. 192 | self.input_port = self.add_input('multi in', multi_input=True, display_name=False) 193 | self.set_color() 194 | self.set_font() 195 | self._view.name = "output" 196 | 197 | def get_node_name(self): 198 | self.node_name = self.get_property("node_name") 199 | return self.node_name 200 | 201 | def set_node_name(self, node_name:str, push_undo=False): 202 | self.node_name = node_name 203 | self.set_property("node_name", self.node_name, push_undo=push_undo) 204 | 205 | def get_shape(self): 206 | self.shape = self.get_property("shape") 207 | return self.shape 208 | 209 | def set_shape(self, shape, push_undo=False): 210 | self.shape = shape 211 | self.set_property("shape", self.shape, push_undo) 212 | 213 | def get_dtype(self): 214 | self.dtype = self.get_property("dtype") 215 | return self.dtype 216 | 217 | def set_dtype(self, dtype, push_undo=False): 218 | self.dtype = str(dtype) 219 | self.set_property("dtype", self.dtype, push_undo) 220 | 221 | def get_input_names(self): 222 | self.input_names = self.get_property("input_names") 223 | return self.input_names 224 | 225 | def set_input_names(self, input_names, push_undo=False): 226 | self.input_names = input_names 227 | self.set_property("input_names", self.input_names, push_undo) 228 | 229 | def set_color(self, push_undo=False): 230 | self.view.text_color = COLOR_FONT + [255] 231 | self.set_property('border_color', NODE_BORDER_COLOR + [255], push_undo) 232 | self.set_property('color', INPUT_NODE_COLOR + [255], push_undo) 233 | 234 | def set_font(self, font_size=GRAPH_FONT_SIZE, bold=False): 235 | set_font(self.view.text_item, font_size=font_size, bold=bold) 236 | -------------------------------------------------------------------------------- /onnxgraphqt/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/utils/__init__.py -------------------------------------------------------------------------------- /onnxgraphqt/utils/color.py: -------------------------------------------------------------------------------- 1 | COLOR_WHITE = [235, 235, 235] 2 | COLOR_PALEGRAY = [230, 230, 230] 3 | COLOR_LIGHTGRAY = [200, 200, 203] 4 | COLOR_GRAY = [127, 135, 143] 5 | COLOR_DARKGRAY = [50, 55, 60] 6 | COLOR_BLACK = [20, 20, 20] 7 | 8 | COLOR_RED = [112, 41, 33] 9 | COLOR_GREEN = [51, 85, 51] 10 | COLOR_BLUE = [51, 85, 136] 11 | COLOR_BROWN = [89, 66, 59] 12 | 13 | COLOR_BG = COLOR_PALEGRAY 14 | COLOR_FONT = COLOR_BLACK 15 | COLOR_GRID = COLOR_GRAY 16 | INPUT_NODE_COLOR = COLOR_LIGHTGRAY 17 | OUTPUT_NODE_COLOR = COLOR_LIGHTGRAY 18 | DEFAULT_COLOR = COLOR_GRAY 19 | 20 | NODE_BG_COLOR = COLOR_WHITE 21 | NODE_BORDER_COLOR = COLOR_DARKGRAY 22 | NODE_SELECTED_BORDER_COLOR = [240, 50, 0] 23 | 24 | NODE_COLORS = { 25 | # Generic 26 | 'Identity': COLOR_DARKGRAY, 27 | # Constant 28 | 'Constant': COLOR_DARKGRAY, 29 | 'ConstantOfShape': COLOR_DARKGRAY, 30 | # Math 31 | 'Add': COLOR_DARKGRAY, 32 | 'Sub': COLOR_DARKGRAY, 33 | 'Mul': COLOR_DARKGRAY, 34 | 'Div': COLOR_DARKGRAY, 35 | 'Sqrt': COLOR_DARKGRAY, 36 | 'Reciprocal': COLOR_DARKGRAY, 37 | 'Pow': COLOR_DARKGRAY, 38 | 'Exp': COLOR_DARKGRAY, 39 | 'Log': COLOR_DARKGRAY, 40 | 'Abs': COLOR_DARKGRAY, 41 | 'Neg': COLOR_DARKGRAY, 42 | 'Ceil': COLOR_DARKGRAY, 43 | 'Floor': COLOR_DARKGRAY, 44 | 'Clip': COLOR_RED, 45 | 'Erf': COLOR_DARKGRAY, 46 | 'IsNan': COLOR_DARKGRAY, 47 | 'IsInf': COLOR_DARKGRAY, 48 | 'Sign': COLOR_DARKGRAY, 49 | # Logical 50 | 'Greater': COLOR_DARKGRAY, 51 | 'Less': COLOR_DARKGRAY, 52 | 'Equal': COLOR_DARKGRAY, 53 | 'Not': COLOR_DARKGRAY, 54 | 'And': COLOR_DARKGRAY, 55 | 'Or': COLOR_DARKGRAY, 56 | 'Xor': COLOR_DARKGRAY, 57 | # Trigonometric 58 | 'Sin': COLOR_DARKGRAY, 59 | 'Cos': COLOR_DARKGRAY, 60 | 'Tan': COLOR_DARKGRAY, 61 | 'Asin': COLOR_DARKGRAY, 62 | 'Acos': COLOR_DARKGRAY, 63 | 'Atan': COLOR_DARKGRAY, 64 | 'Sinh': COLOR_DARKGRAY, 65 | 'Cosh': COLOR_DARKGRAY, 66 | 'Tanh': COLOR_DARKGRAY, 67 | 'Acosh': COLOR_DARKGRAY, 68 | 'Asinh': COLOR_DARKGRAY, 69 | 'Atanh': COLOR_DARKGRAY, 70 | 'Max': COLOR_DARKGRAY, 71 | # Reduction 72 | 'Sum': COLOR_DARKGRAY, 73 | 'Mean': COLOR_DARKGRAY, 74 | 'Max': COLOR_DARKGRAY, 75 | 'Min': COLOR_DARKGRAY, 76 | # Activation 77 | 'Sigmoid': COLOR_RED, 78 | 'HardSigmoid': COLOR_RED, 79 | 'Tanh': COLOR_RED, 80 | 'ScaledTanh': COLOR_RED, 81 | 'Relu': COLOR_RED, 82 | 'LeakyRelu': COLOR_RED, 83 | 'PRelu': COLOR_RED, 84 | 'ThresholdedRelu': COLOR_RED, 85 | 'Elu': COLOR_RED, 86 | 'Selu': COLOR_RED, 87 | 'Softmax': COLOR_RED, 88 | 'LogSoftmax': COLOR_RED, 89 | 'Hardmax': COLOR_RED, 90 | 'Softsign': COLOR_RED, 91 | 'Softplus': COLOR_RED, 92 | 'Affine': COLOR_RED, 93 | 'Shrink': COLOR_RED, 94 | # Random 95 | 'RandomNormal': COLOR_DARKGRAY, 96 | 'RandomNormalLike': COLOR_DARKGRAY, 97 | 'RandomUniform': COLOR_DARKGRAY, 98 | 'RandomUniformLike': COLOR_DARKGRAY, 99 | 'Multinomial': COLOR_DARKGRAY, 100 | # Multiplication 101 | 'EyeLike': COLOR_BLUE, 102 | 'Gemm': COLOR_BLUE, 103 | 'MatMul': COLOR_BLUE, 104 | 'Conv': COLOR_BLUE, 105 | 'ConvTranspose': COLOR_BLUE, 106 | # Conversion 107 | 'Cast': COLOR_DARKGRAY, 108 | # Reorganization 109 | 'Transpose': COLOR_GREEN, 110 | 'Expand': COLOR_DARKGRAY, 111 | 'Tile': COLOR_DARKGRAY, 112 | 'Split': COLOR_DARKGRAY, 113 | 'Slice': COLOR_BROWN, 114 | 'DynamicSlice': COLOR_BROWN, 115 | 'Concat': COLOR_BROWN, 116 | 'Gather': COLOR_GREEN, 117 | 'GatherElements': COLOR_GREEN, 118 | 'ScatterElements': COLOR_DARKGRAY, 119 | 'Pad': COLOR_DARKGRAY, 120 | 'SpaceToDepth': COLOR_DARKGRAY, 121 | 'DepthToSpace': COLOR_DARKGRAY, 122 | 'Shape': COLOR_DARKGRAY, 123 | 'Size': COLOR_DARKGRAY, 124 | 'Reshape': COLOR_BROWN, 125 | 'Flatten': COLOR_DARKGRAY, 126 | 'Squeeze': COLOR_DARKGRAY, 127 | 'Unsqueeze': COLOR_GREEN, 128 | 'OneHot': COLOR_DARKGRAY, 129 | 'TopK': COLOR_DARKGRAY, 130 | 'Where': COLOR_DARKGRAY, 131 | 'Compress': COLOR_DARKGRAY, 132 | 'Reverse': COLOR_DARKGRAY, 133 | # Pooling 134 | 'GlobalAveragePool': COLOR_GREEN, 135 | 'AveragePool': COLOR_GREEN, 136 | 'GlobalMaxPool': COLOR_GREEN, 137 | 'MaxPool': COLOR_GREEN, 138 | 'MaxUnpool': COLOR_GREEN, 139 | 'LpPool': COLOR_GREEN, 140 | 'GlobalLpPool': COLOR_GREEN, 141 | 'MaxRoiPool': COLOR_GREEN, 142 | # Reduce 143 | 'ReduceSum': COLOR_DARKGRAY, 144 | 'ReduceMean': COLOR_DARKGRAY, 145 | 'ReduceProd': COLOR_DARKGRAY, 146 | 'ReduceLogSum': COLOR_DARKGRAY, 147 | 'ReduceLogSumExp': COLOR_DARKGRAY, 148 | 'ReduceSumSquare': COLOR_DARKGRAY, 149 | 'ReduceL1': COLOR_DARKGRAY, 150 | 'ReduceL2': COLOR_DARKGRAY, 151 | 'ReduceMax': COLOR_DARKGRAY, 152 | 'ReduceMin': COLOR_DARKGRAY, 153 | 'ArgMax': COLOR_DARKGRAY, 154 | 'ArgMin': COLOR_DARKGRAY, 155 | # Imaging 156 | 'Upsample': COLOR_DARKGRAY, 157 | # Flow 158 | 'If': COLOR_DARKGRAY, 159 | 'Loop': COLOR_DARKGRAY, 160 | 'Scan': COLOR_DARKGRAY, 161 | # Normalization 162 | 'InstanceNormalization': COLOR_DARKGRAY, 163 | 'BatchNormalization': COLOR_DARKGRAY, 164 | 'LRN': COLOR_DARKGRAY, 165 | 'MeanVarianceNormalization': COLOR_DARKGRAY, 166 | 'LpNormalization': COLOR_DARKGRAY, 167 | # collation 168 | 'Nonzero': COLOR_DARKGRAY, 169 | # NGram 170 | 'TfldfVectorizer': COLOR_DARKGRAY, 171 | # Aggregate 172 | 'RNN': COLOR_DARKGRAY, 173 | 'GRU': COLOR_DARKGRAY, 174 | 'LSTM': COLOR_DARKGRAY, 175 | # Training 176 | 'Dropout': COLOR_DARKGRAY, 177 | # Quantize 178 | 'QuantizeLinear': COLOR_DARKGRAY, 179 | 'QLinearConv': COLOR_BLUE, 180 | 'DequantizeLinear': COLOR_DARKGRAY, 181 | 'QLinearGlobalAveragePool': COLOR_GREEN, 182 | 'QLinearAdd': COLOR_DARKGRAY, 183 | 'QLinearMatMul': COLOR_BLUE, 184 | } 185 | 186 | def get_node_color(op_name): 187 | return NODE_COLORS.get(op_name, DEFAULT_COLOR) 188 | 189 | class PrintColor: 190 | BLACK = ['\033[30m', ""] 191 | RED = ['\033[31m', ""] 192 | GREEN = ['\033[32m', ""] 193 | YELLOW = ['\033[33m', ""] 194 | BLUE = ['\033[34m', ""] 195 | MAGENTA = ['\033[35m', ""] 196 | CYAN = ['\033[36m', ""] 197 | WHITE = ['\033[37m', ""] 198 | COLOR_DEFAULT = ['\033[39m', ""] 199 | BOLD = ['\033[1m', ""] 200 | UNDERLINE = ['\033[4m', ""] 201 | INVISIBLE = ['\033[08m', ""] 202 | REVERCE = ['\033[07m', ""] 203 | BG_BLACK = ['\033[40m', ""] 204 | BG_RED = ['\033[41m', ""] 205 | BG_GREEN = ['\033[42m', ""] 206 | BG_YELLOW = ['\033[43m', ""] 207 | BG_BLUE = ['\033[44m', ""] 208 | BG_MAGENTA = ['\033[45m', ""] 209 | BG_CYAN = ['\033[46m', ""] 210 | BG_WHITE = ['\033[47m', ""] 211 | BG_DEFAULT = ['\033[49m', ""] 212 | RESET = ['\033[0m', ""] 213 | 214 | 215 | def remove_PrintColor(message:str)->str: 216 | ret = message 217 | for key, v in vars(PrintColor).items(): 218 | if key[:2] == "__": 219 | continue 220 | ret = ret.replace(v[0], '') 221 | for i in range(32): 222 | v = f'\033[38;5;{i}m' 223 | ret = ret.replace(v, '') 224 | return ret 225 | 226 | def replace_PrintColor(message: str)->str: 227 | ret = message 228 | for key, v in vars(PrintColor).items(): 229 | if key[:2] == "__": 230 | continue 231 | ret = ret.replace(v[0], v[1]) 232 | for i in range(32): 233 | v = f'\033[38;5;{i}m' 234 | ret = ret.replace(v, '') 235 | return ret 236 | 237 | if __name__ == "__main__": 238 | import inspect 239 | text = "\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\nAttributes: {'perm': [3, 2, 1, 0]}]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\nNote: Producer node(s) of first tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\nAttributes: {'perm': [3, 2, 1, 0]}]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[38;5;11m[W] Found distinct tensors that share the same name:\n[id: 139661839911328] Variable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n[id: 139661797372160] Variable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\nNote: Producer node(s) of first tensor:\n[]\nProducer node(s) of second tensor:\n[input_order_convert_transpose_0 (Transpose)\n\tInputs: [\n\t\tVariable (transpose_out_input): (shape=[224, 224, 3, 'batch_size'], dtype=float32)\n\t]\n\tOutputs: [\n\t\tVariable (transpose_out_input): (shape=['batch_size', 3, 224, 224], dtype=float32)\n\t]\nAttributes: OrderedDict([('perm', [3, 2, 1, 0])])]\x1b[0m\n\x1b[33mWARNING:\x1b[0m The input shape of the next OP does not match the output shape. Be sure to open the .onnx file to verify the certainty of the geometry.\n\x1b[33mWARNING:\x1b[0m onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Transpose, node name: input_order_convert_transpose_0): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 1: (224) vs (3)\n\x1b[32mINFO:\x1b[0m Finish!\n" 240 | print(remove_PrintColor(text)) 241 | print("------------------------------") 242 | for key, v in vars(PrintColor).items(): 243 | if key[:2] == "__": 244 | continue 245 | print(f"{v[0]}{key}{PrintColor.RESET[0]}") 246 | print("------------------------------") 247 | print() 248 | print(replace_PrintColor(text)) -------------------------------------------------------------------------------- /onnxgraphqt/utils/dtype.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import numpy as np 3 | 4 | AVAILABLE_DTYPES = [ 5 | 'float32', 6 | 'float64', 7 | 'int32', 8 | 'int64', 9 | 'str', 10 | ] 11 | 12 | DTYPES_TO_ONNX_DTYPES = { 13 | float: onnx.TensorProto.FLOAT, 14 | int: onnx.TensorProto.INT64, 15 | str: onnx.TensorProto.STRING, 16 | } 17 | 18 | DTYPES_TO_NUMPY_TYPES = { 19 | 'float32': np.float32, 20 | 'float64': np.float64, 21 | 'int32': np.int32, 22 | 'int64': np.int64, 23 | } 24 | 25 | NUMPY_TYPES_TO_ONNX_DTYPES = { 26 | np.dtype('float32'): onnx.TensorProto.FLOAT, 27 | np.dtype('float64'): onnx.TensorProto.DOUBLE, 28 | np.dtype('int32'): onnx.TensorProto.INT32, 29 | np.dtype('int64'): onnx.TensorProto.INT64, 30 | np.float32: onnx.TensorProto.FLOAT, 31 | np.float64: onnx.TensorProto.DOUBLE, 32 | np.int32: onnx.TensorProto.INT32, 33 | np.int64: onnx.TensorProto.INT64, 34 | } 35 | 36 | NUMPY_TYPES_TO_CLASSES = { 37 | np.dtype('float32'): np.float32, 38 | np.dtype('float64'): np.float64, 39 | np.dtype('int32'): np.int32, 40 | np.dtype('int64'): np.int64, 41 | } 42 | -------------------------------------------------------------------------------- /onnxgraphqt/utils/operators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from dataclasses import dataclass 4 | from typing import List, Dict 5 | from json import JSONDecoder 6 | 7 | @dataclass 8 | class OperatorAttribute: 9 | name: str 10 | value_type: type 11 | default_value: str 12 | 13 | @dataclass 14 | class OperatorVersion: 15 | since_opset: int 16 | inputs: int 17 | outputs: int 18 | attrs: List[OperatorAttribute] 19 | 20 | @dataclass 21 | class Operator: 22 | name: str 23 | versions: List[OperatorVersion] 24 | 25 | _DEFAULT_ONNX_OPSETS_JSON_PATH = os.path.join(os.path.dirname(__file__), 26 | '..', 27 | 'data', 'onnx_opsets.json') 28 | def _load_json(json_path=_DEFAULT_ONNX_OPSETS_JSON_PATH)->List[Operator]: 29 | json_str = '' 30 | with open(json_path, mode='r') as f: 31 | json_str = f.read() 32 | dec = JSONDecoder() 33 | json_dict:Dict = dec.decode(json_str) 34 | 35 | ret = [] 36 | for op_name, v1 in json_dict.items(): 37 | versions = [] 38 | for since_opset, v2 in v1.items(): 39 | since_opset = int(since_opset) 40 | inputs = v2["inputs"] 41 | outputs = v2["outputs"] 42 | attrs = [] 43 | if inputs == 'inf': 44 | inputs = math.inf 45 | else: 46 | inputs = int(inputs) 47 | if outputs == 'inf': 48 | outputs = math.inf 49 | else: 50 | outputs = int(outputs) 51 | 52 | for attr_name, (attr_value_type, defalut_value) in v2["attrs"].items(): 53 | attrs.append( 54 | OperatorAttribute(name=attr_name, 55 | value_type=attr_value_type, 56 | default_value=defalut_value) 57 | ) 58 | versions.append( 59 | OperatorVersion( 60 | since_opset=since_opset, 61 | inputs=inputs, outputs=outputs, attrs=attrs) 62 | ) 63 | 64 | ret.append( 65 | Operator( 66 | name=op_name, 67 | versions=versions) 68 | ) 69 | return ret 70 | 71 | def _get_latest_opset_version(opsets:List[Operator])->int: 72 | opset = 1 73 | for op in opsets: 74 | for v in op.versions: 75 | if opset < v.since_opset: 76 | opset = v.since_opset 77 | return opset 78 | 79 | onnx_opsets = _load_json() 80 | opnames = [op.name for op in onnx_opsets] 81 | latest_opset = _get_latest_opset_version(onnx_opsets) 82 | 83 | if __name__ == "__main__": 84 | print(_DEFAULT_ONNX_OPSETS_JSON_PATH) 85 | for op in onnx_opsets: 86 | print(op.name) 87 | for v in op.versions: 88 | print(v) 89 | print() -------------------------------------------------------------------------------- /onnxgraphqt/utils/opset.py: -------------------------------------------------------------------------------- 1 | DEFAULT_OPSET = 16 -------------------------------------------------------------------------------- /onnxgraphqt/utils/style.py: -------------------------------------------------------------------------------- 1 | from NodeGraphQt import NodeGraph 2 | from NodeGraphQt.base.menu import NodeGraphMenu 3 | 4 | def set_context_menu_style(graph:NodeGraph, text_color, bg_color, selected_color, disabled_text_color=None): 5 | context_menu: NodeGraphMenu = graph.get_context_menu("graph").qmenu 6 | style = get_context_menu_stylesheet(text_color, bg_color, selected_color, disabled_text_color) 7 | context_menu.setStyleSheet(style) 8 | 9 | def get_context_menu_stylesheet(text_color, bg_color, selected_color, disabled_text_color=None): 10 | if disabled_text_color is None: 11 | disabled_text_color =[int(0.5 * abs(text_color[i] + bg_color[i])) for i in range(3)] 12 | style_dict = { 13 | 'QMenu': { 14 | 'color': 'rgb({0},{1},{2})'.format(*text_color), 15 | 'background-color': 'rgb({0},{1},{2})'.format(*bg_color), 16 | 'border': '1px solid rgba({0},{1},{2},30)'.format(*text_color), 17 | 'border-radius': '3px', 18 | }, 19 | 'QMenu::item': { 20 | 'padding': '5px 18px 2px', 21 | 'background-color': 'transparent', 22 | }, 23 | 'QMenu::item:selected': { 24 | 'color': 'rgb({0},{1},{2})'.format(*text_color), 25 | 'background-color': 'rgba({0},{1},{2},200)' 26 | .format(*selected_color), 27 | }, 28 | 'QMenu::item:disabled': { 29 | 'color': 'rgb({0},{1},{2})'.format(*disabled_text_color), 30 | }, 31 | 'QMenu::separator': { 32 | 'height': '1px', 33 | 'background': 'rgba({0},{1},{2}, 50)'.format(*text_color), 34 | 'margin': '4px 8px', 35 | } 36 | } 37 | stylesheet = '' 38 | for css_class, css in style_dict.items(): 39 | style = '{} {{\n'.format(css_class) 40 | for elm_name, elm_val in css.items(): 41 | style += ' {}:{};\n'.format(elm_name, elm_val) 42 | style += '}\n' 43 | stylesheet += style 44 | return stylesheet -------------------------------------------------------------------------------- /onnxgraphqt/utils/widgets.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import math 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | from NodeGraphQt.constants import PipeEnum 5 | from NodeGraphQt.qgraphics.pipe import PIPE_STYLES 6 | 7 | 8 | BASE_FONT_SIZE = 16 9 | LARGE_FONT_SIZE = 18 10 | GRAPH_FONT_SIZE = 32 11 | PIPE_WIDTH = 3.0 12 | 13 | def set_font(widget: QtWidgets.QWidget, font_size:int=None, bold=False): 14 | f = widget.font() 15 | if font_size: 16 | f.setPixelSize(font_size) 17 | f.setBold(bold) 18 | widget.setFont(f) 19 | 20 | def iconButton_paintEvent(button: QtWidgets.QPushButton, pixmap: QtGui.QPixmap, event: QtGui.QPaintEvent): 21 | QtWidgets.QPushButton.paintEvent(button, event) 22 | pos_x = 5 + int((30 - pixmap.width())*0.5 + 0.5) 23 | pos_y = (button.height() - pixmap.height()) / 2 24 | painter = QtGui.QPainter(button) 25 | painter.setRenderHint(QtGui.QPainter.Antialiasing, True) 26 | painter.setRenderHint(QtGui.QPainter.SmoothPixmapTransform, True) 27 | painter.drawPixmap(pos_x, pos_y, pixmap) 28 | 29 | def createIconButton(text:str, icon_path: str, icon_size:List[int]=[25, 25], font_size:int=None) -> QtWidgets.QPushButton: 30 | button = QtWidgets.QPushButton() 31 | button.setContentsMargins(QtCore.QMargins(5, 5, 5, 5)) 32 | button.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 33 | pixmap = QtGui.QPixmap(icon_path).scaled(icon_size[0], icon_size[1], QtCore.Qt.AspectRatioMode.KeepAspectRatio, QtCore.Qt.SmoothTransformation) 34 | def paintEvent(button, pixmap): 35 | def func(event): 36 | return iconButton_paintEvent(button, pixmap, event) 37 | return func 38 | button.paintEvent = paintEvent(button, pixmap) 39 | button_layout = QtWidgets.QVBoxLayout() 40 | button_layout.setMargin(0) 41 | button.setLayout(button_layout) 42 | label = QtWidgets.QLabel(text=text) 43 | label.setMargin(0) 44 | label.setAlignment(QtCore.Qt.AlignRight) 45 | set_font(label, font_size=font_size) 46 | button_layout.addWidget(label) 47 | return button 48 | 49 | 50 | def pipe_paint(pipe, painter, option, widget, text=""): 51 | """ 52 | Draws the connection line between nodes. 53 | 54 | Args: 55 | painter (QtGui.QPainter): painter used for drawing the item. 56 | option (QtGui.QStyleOptionGraphicsItem): 57 | used to describe the parameters needed to draw. 58 | widget (QtWidgets.QWidget): not used. 59 | """ 60 | color = QtGui.QColor(*pipe._color) 61 | pen_style = PIPE_STYLES.get(pipe.style) 62 | pen_width = PIPE_WIDTH #PipeEnum.WIDTH.value 63 | if pipe._active: 64 | color = QtGui.QColor(*PipeEnum.ACTIVE_COLOR.value) 65 | if pen_style == QtCore.Qt.DashDotDotLine: 66 | pen_width += 1 67 | else: 68 | pen_width += 0.35 69 | elif pipe._highlight: 70 | color = QtGui.QColor(*PipeEnum.HIGHLIGHT_COLOR.value) 71 | pen_style = PIPE_STYLES.get(PipeEnum.DRAW_TYPE_DEFAULT.value) 72 | 73 | if pipe.disabled(): 74 | if not pipe._active: 75 | color = QtGui.QColor(*PipeEnum.DISABLED_COLOR.value) 76 | pen_width += 0.2 77 | pen_style = PIPE_STYLES.get(PipeEnum.DRAW_TYPE_DOTTED.value) 78 | 79 | pen = QtGui.QPen(color, pen_width, pen_style) 80 | pen.setCapStyle(QtCore.Qt.RoundCap) 81 | pen.setJoinStyle(QtCore.Qt.MiterJoin) 82 | 83 | painter.save() 84 | painter.setPen(pen) 85 | painter.setRenderHint(painter.Antialiasing, True) 86 | painter.drawPath(pipe.path()) 87 | 88 | # draw arrow 89 | if pipe.input_port and pipe.output_port: 90 | cen_x = pipe.path().pointAtPercent(0.5).x() 91 | cen_y = pipe.path().pointAtPercent(0.5).y() 92 | loc_pt = pipe.path().pointAtPercent(0.49) 93 | tgt_pt = pipe.path().pointAtPercent(0.51) 94 | dist = math.hypot(tgt_pt.x() - cen_x, tgt_pt.y() - cen_y) 95 | if dist < 0.5: 96 | painter.restore() 97 | return 98 | 99 | color.setAlpha(255) 100 | if pipe._highlight: 101 | painter.setBrush(QtGui.QBrush(color.lighter(150))) 102 | elif pipe._active or pipe.disabled(): 103 | painter.setBrush(QtGui.QBrush(color.darker(200))) 104 | else: 105 | painter.setBrush(QtGui.QBrush(color.darker(130))) 106 | 107 | pen_width = 0.6 108 | if dist < 1.0: 109 | pen_width *= (1.0 + dist) 110 | 111 | pen = QtGui.QPen(color, pen_width) 112 | pen.setCapStyle(QtCore.Qt.RoundCap) 113 | pen.setJoinStyle(QtCore.Qt.MiterJoin) 114 | painter.setPen(pen) 115 | 116 | transform = QtGui.QTransform() 117 | transform.translate(cen_x, cen_y) 118 | radians = math.atan2(tgt_pt.y() - loc_pt.y(), 119 | tgt_pt.x() - loc_pt.x()) 120 | degrees = math.degrees(radians) - 90 121 | transform.rotate(degrees) 122 | if dist < 1.0: 123 | transform.scale(dist, dist) 124 | painter.drawPolygon(transform.map(pipe._arrow)) 125 | if text: 126 | painter.setPen(QtCore.Qt.black) 127 | set_font(painter, font_size=20) 128 | painter.drawText(QtCore.QRectF(cen_x, cen_y, 200, 100), text) 129 | 130 | 131 | # QPaintDevice: Cannot destroy paint device that is being painted. 132 | painter.restore() 133 | -------------------------------------------------------------------------------- /onnxgraphqt/widgets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fateshelled/OnnxGraphQt/60d8957743b2cd3babaf97eef0b063fe1f637cf1/onnxgraphqt/widgets/__init__.py -------------------------------------------------------------------------------- /onnxgraphqt/widgets/custom_node_item.py: -------------------------------------------------------------------------------- 1 | 2 | from PySide2 import QtCore, QtWidgets, QtGui 3 | from NodeGraphQt.constants import NodeEnum 4 | from NodeGraphQt.qgraphics.node_base import NodeItem 5 | 6 | from onnxgraphqt.utils.color import NODE_BG_COLOR, NODE_SELECTED_BORDER_COLOR 7 | from onnxgraphqt.utils.widgets import set_font 8 | 9 | 10 | class CustomNodeItem(NodeItem): 11 | def __init__(self, name='node', parent=None): 12 | super().__init__(name, parent) 13 | self._display_name = "" 14 | 15 | def set_display_name(self, name: str): 16 | self._display_name = name 17 | 18 | def _paint_vertical(self, painter, option, widget): 19 | painter.save() 20 | painter.setPen(QtCore.Qt.NoPen) 21 | painter.setBrush(QtCore.Qt.NoBrush) 22 | 23 | # base background. 24 | margin = 1.0 25 | radius = 4.0 26 | header_height = 20.0 27 | rect = self.boundingRect() 28 | rect = QtCore.QRectF(rect.left() + margin, 29 | rect.top() + margin, 30 | rect.width() - (margin * 2), 31 | rect.height() - (margin * 2)) 32 | 33 | painter.setBrush(QtGui.QColor(*NODE_BG_COLOR + [255])) 34 | painter.drawRoundedRect(rect, radius, radius) 35 | 36 | # header 37 | header_rect = QtCore.QRectF(rect.x() + margin, rect.y() + margin, 38 | rect.width() - margin * 2, header_height) 39 | painter.setBrush(QtGui.QColor(*self.color)) 40 | painter.drawRoundedRect(header_rect, radius, radius) 41 | 42 | # header text 43 | r, g, b = self.color[:3] 44 | if 0.2126*r + 0.715*g + 0.0722*b > 128: 45 | painter.setPen(QtCore.Qt.black) 46 | else: 47 | painter.setPen(QtCore.Qt.white) 48 | if self.selected: 49 | set_font(painter, bold=True) 50 | painter.drawText(QtCore.QRectF(5, 5, rect.width(), rect.height()), self._display_name) 51 | set_font(painter, bold=False) 52 | 53 | # light overlay on background when selected. 54 | if self.selected: 55 | painter.setBrush( 56 | QtGui.QColor(*NodeEnum.SELECTED_COLOR.value) 57 | ) 58 | painter.drawRoundedRect(rect, radius, radius) 59 | 60 | # # top & bottom edge background. 61 | # padding = 2.0 62 | # height = 10 63 | # if self.selected: 64 | # painter.setBrush(QtGui.QColor(*NodeEnum.SELECTED_COLOR.value)) 65 | # else: 66 | # painter.setBrush(QtGui.QColor(0, 0, 0, 80)) 67 | # for y in [rect.y() + padding, rect.height() - height - 1]: 68 | # edge_rect = QtCore.QRectF(rect.x() + padding, y, 69 | # rect.width() - (padding * 2), height) 70 | # painter.drawRoundedRect(edge_rect, 3.0, 3.0) 71 | 72 | # node border 73 | border_width = 0.8 74 | border_color = QtGui.QColor(*self.border_color) 75 | if self.selected: 76 | border_width = 1.2 77 | border_color = QtGui.QColor( 78 | *list(NODE_SELECTED_BORDER_COLOR + [255]) 79 | ) 80 | border_rect = QtCore.QRectF(rect.left(), rect.top(), 81 | rect.width(), rect.height()) 82 | 83 | pen = QtGui.QPen(border_color, border_width) 84 | pen.setCosmetic(self.viewer().get_zoom() < 0.0) 85 | painter.setBrush(QtCore.Qt.NoBrush) 86 | painter.setPen(pen) 87 | painter.drawRoundedRect(border_rect, radius, radius) 88 | 89 | painter.restore() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/custom_properties.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from PySide2 import QtWidgets, QtCore, QtGui 3 | from NodeGraphQt.constants import NodePropWidgetEnum 4 | from NodeGraphQt.widgets.dialogs import FileDialog 5 | from NodeGraphQt.custom_widgets.properties_bin.prop_widgets_base import ( 6 | PropLineEdit, 7 | PropTextEdit 8 | ) 9 | from NodeGraphQt.custom_widgets.properties_bin.node_property_factory import NodePropertyWidgetFactory 10 | from NodeGraphQt.custom_widgets.properties_bin.node_property_widgets import _PropertiesContainer 11 | 12 | 13 | class CustomPropWindow(QtWidgets.QWidget): 14 | 15 | def __init__(self, parent=None): 16 | super(CustomPropWindow, self).__init__(parent) 17 | self.__layout = QtWidgets.QGridLayout() 18 | self.__layout.setColumnStretch(1, 1) 19 | # self.__layout.setSpacing(6) 20 | 21 | layout = QtWidgets.QVBoxLayout(self) 22 | layout.setAlignment(QtCore.Qt.AlignTop) 23 | layout.addLayout(self.__layout) 24 | 25 | def __repr__(self): 26 | return ''.format(hex(id(self))) 27 | 28 | def add_widget(self, name, widget, value, label=None): 29 | """ 30 | Add a property widget to the window. 31 | 32 | Args: 33 | name (str): property name to be displayed. 34 | widget (BaseProperty): property widget. 35 | value (object): property value. 36 | label (str): custom label to display. 37 | """ 38 | widget.setToolTip(name) 39 | if value is not None: 40 | widget.set_value(value) 41 | if label is None: 42 | label = name 43 | row = self.__layout.rowCount() 44 | if row > 0: 45 | row += 1 46 | 47 | label_flags = QtCore.Qt.AlignCenter | QtCore.Qt.AlignRight 48 | # label_flags = QtCore.Qt.AlignTop 49 | if widget.__class__.__name__ == 'PropTextEdit': 50 | label_flags = label_flags | QtCore.Qt.AlignTop 51 | elif widget.__class__.__name__ == 'PropList': 52 | label_flags = QtCore.Qt.AlignTop | QtCore.Qt.AlignRight 53 | 54 | self.__layout.addWidget(QtWidgets.QLabel(label), row, 0, label_flags) 55 | self.__layout.addWidget(widget, row, 1) 56 | 57 | def get_widget(self, name): 58 | """ 59 | Returns the property widget from the name. 60 | 61 | Args: 62 | name (str): property name. 63 | 64 | Returns: 65 | QtWidgets.QWidget: property widget. 66 | """ 67 | for row in range(self.__layout.rowCount()): 68 | item = self.__layout.itemAtPosition(row, 1) 69 | if item and name == item.widget().toolTip(): 70 | return item.widget() 71 | 72 | 73 | class PropList(QtWidgets.QWidget): 74 | 75 | def __init__(self, parent=None): 76 | super(PropList, self).__init__(parent) 77 | self.__layout = QtWidgets.QVBoxLayout() 78 | self.__layout.setAlignment(QtCore.Qt.AlignTop) 79 | 80 | layout = QtWidgets.QVBoxLayout(self) 81 | layout.setAlignment(QtCore.Qt.AlignTop) 82 | layout.addLayout(self.__layout) 83 | layout.setContentsMargins(0, 0 ,0 ,0) 84 | 85 | def add_item(self, name, dtype, shape, value): 86 | if dtype is None: 87 | label = QtWidgets.QLabel(f"{name}") 88 | self.__layout.addWidget(label) 89 | elif value is None or value == -1: 90 | label = QtWidgets.QLabel(f"{name}") 91 | type_shape = QtWidgets.QLabel(f"{dtype} {list(shape)}") 92 | self.__layout.addWidget(label) 93 | self.__layout.addWidget(type_shape) 94 | else: 95 | # has value 96 | content_layout = QtWidgets.QVBoxLayout() 97 | label = QtWidgets.QLabel(f"{name}") 98 | type_shape = QtWidgets.QLabel(f"{dtype} {list(shape)}") 99 | txt = QtWidgets.QLineEdit() 100 | txt.setText(str(value)) 101 | txt.setReadOnly(True) 102 | content_layout.addWidget(label) 103 | content_layout.addWidget(type_shape) 104 | content_layout.addWidget(txt) 105 | self.__layout.addLayout(content_layout) 106 | 107 | 108 | class CustomNodePropWidget(QtWidgets.QWidget): 109 | """ 110 | Node properties widget for display a Node object. 111 | 112 | Args: 113 | parent (QtWidgets.QWidget): parent object. 114 | node (NodeGraphQt.BaseNode): node. 115 | """ 116 | 117 | #: signal (node_id, prop_name, prop_value) 118 | property_changed = QtCore.Signal(str, str, object) 119 | property_closed = QtCore.Signal(str) 120 | 121 | def __init__(self, parent=None, node=None): 122 | super(CustomNodePropWidget, self).__init__(parent) 123 | self.__node_id = node.id 124 | 125 | self.prop_window = CustomPropWindow(self) 126 | 127 | self._layout = QtWidgets.QVBoxLayout(self) 128 | self._layout .addWidget(self.prop_window) 129 | # layout.setSpacing(4) 130 | # layout.addWidget(self.prop_window) 131 | 132 | self._read_node(node) 133 | 134 | def __repr__(self): 135 | return ''.format(hex(id(self))) 136 | 137 | def _on_close(self): 138 | """ 139 | called by the close button. 140 | """ 141 | self.property_closed.emit(self.__node_id) 142 | 143 | def _on_property_changed(self, name, value): 144 | """ 145 | slot function called when a property widget has changed. 146 | Args: 147 | name (str): property name. 148 | value (object): new value. 149 | """ 150 | self.property_changed.emit(self.__node_id, name, value) 151 | 152 | def _read_node(self, node): 153 | """ 154 | Populate widget from a node. 155 | 156 | Args: 157 | node (NodeGraphQt.BaseNode): node class. 158 | """ 159 | model = node.model 160 | graph_model = node.graph.model 161 | 162 | common_props = graph_model.get_node_common_properties(node.type_) 163 | 164 | properties = [] 165 | 166 | for prop_name, prop_val in model.custom_properties.items(): 167 | tab_name = model.get_tab_name(prop_name) 168 | if tab_name == 'Properties': 169 | properties.append((prop_name, prop_val)) 170 | 171 | # property widget factory. 172 | widget_factory = NodePropertyWidgetFactory() 173 | 174 | for prop_name, value in properties: 175 | wid_type = model.get_widget_type(prop_name).value 176 | if wid_type == 0: 177 | continue 178 | if prop_name in ["inputs_", "outputs_"]: 179 | io_len = len(value) 180 | for i, (name, dtype, shape, v) in enumerate(value): 181 | widget = PropList() 182 | widget.add_item(name, dtype, shape, v) 183 | if io_len == 1: 184 | self.prop_window.add_widget(prop_name, widget, None, 185 | prop_name.replace('_', ' ')) 186 | else: 187 | self.prop_window.add_widget(prop_name, widget, None, 188 | prop_name.replace('_', ' ') + f"[{i+1}]") 189 | else: 190 | widget = widget_factory.get_widget(wid_type) 191 | if isinstance(widget, PropLineEdit) or isinstance(widget, PropTextEdit): 192 | widget.setReadOnly(True) 193 | 194 | self.prop_window.add_widget(prop_name, widget, value, 195 | prop_name.replace('_', ' ')) 196 | 197 | 198 | def node_id(self): 199 | """ 200 | Returns the node id linked to the widget. 201 | 202 | Returns: 203 | str: node id 204 | """ 205 | return self.__node_id 206 | 207 | def add_widget(self, name, widget, tab='Properties'): 208 | """ 209 | add new node property widget. 210 | Args: 211 | name (str): property name. 212 | widget (BaseProperty): property widget. 213 | tab (str): tab name. 214 | """ 215 | if tab not in self._widgets.keys(): 216 | tab = 'Properties' 217 | window = self.__tab_windows[tab] 218 | window.add_widget(name, widget) 219 | widget.value_changed.connect(self._on_property_changed) 220 | 221 | def get_widget(self, name): 222 | """ 223 | get property widget. 224 | Args: 225 | name (str): property name. 226 | Returns: 227 | QtWidgets.QWidget: property widget. 228 | """ 229 | if name == 'name': 230 | return self.name_wgt 231 | for tab_name, prop_win in self.__tab_windows.items(): 232 | widget = prop_win.get_widget(name) 233 | if widget: 234 | return 235 | 236 | if __name__ == '__main__': 237 | import sys 238 | from NodeGraphQt import BaseNode, NodeGraph 239 | from NodeGraphQt.constants import NodePropWidgetEnum 240 | 241 | 242 | class TestNode(BaseNode): 243 | NODE_NAME = 'test node' 244 | 245 | def __init__(self): 246 | super(TestNode, self).__init__() 247 | self.create_property('label_test', 'foo bar', 248 | widget_type=NodePropWidgetEnum.QLABEL) 249 | self.create_property('line_edit', 'hello', 250 | widget_type=NodePropWidgetEnum.QLINE_EDIT) 251 | self.create_property('color_picker', (0, 0, 255), 252 | widget_type=NodePropWidgetEnum.COLOR_PICKER) 253 | self.create_property('integer', 10, 254 | widget_type=NodePropWidgetEnum.QSPIN_BOX) 255 | self.create_property('list', 'foo', 256 | items=['foo', 'bar'], 257 | widget_type=NodePropWidgetEnum.QCOMBO_BOX) 258 | self.create_property('range', 50, 259 | range=(45, 55), 260 | widget_type=NodePropWidgetEnum.SLIDER) 261 | self.create_property('text_edit', 'test text', 262 | widget_type=NodePropWidgetEnum.QTEXT_EDIT, 263 | tab='text') 264 | 265 | 266 | def prop_changed(node_id, prop_name, prop_value): 267 | print('-' * 100) 268 | print(node_id, prop_name, prop_value) 269 | 270 | 271 | def prop_close(node_id): 272 | print('=' * 100) 273 | print(node_id) 274 | 275 | 276 | app = QtWidgets.QApplication(sys.argv) 277 | 278 | graph = NodeGraph() 279 | graph.register_node(TestNode) 280 | 281 | test_node = graph.create_node('nodeGraphQt.nodes.TestNode') 282 | 283 | node_prop = CustomNodePropWidget(node=test_node) 284 | node_prop.property_changed.connect(prop_changed) 285 | node_prop.property_closed.connect(prop_close) 286 | node_prop.show() 287 | 288 | app.exec_() 289 | -------------------------------------------------------------------------------- /onnxgraphqt/widgets/custom_properties_bin.py: -------------------------------------------------------------------------------- 1 | from PySide2 import QtWidgets, QtCore, QtGui 2 | from NodeGraphQt.custom_widgets.properties_bin.node_property_widgets import _PropertiesList 3 | 4 | from onnxgraphqt.widgets.custom_properties import CustomNodePropWidget 5 | 6 | 7 | class CustomPropertiesBinWidget(QtWidgets.QWidget): 8 | """ 9 | The :class:`NodeGraphQt.PropertiesBinWidget` is a list widget for displaying 10 | and editing a nodes properties. 11 | 12 | .. image:: _images/prop_bin.png 13 | :width: 950px 14 | 15 | .. code-block:: python 16 | :linenos: 17 | 18 | from NodeGraphQt import NodeGraph, PropertiesBinWidget 19 | 20 | # create node graph. 21 | graph = NodeGraph() 22 | 23 | # create properties bin widget. 24 | properties_bin = PropertiesBinWidget(parent=None, node_graph=graph) 25 | properties_bin.show() 26 | 27 | Args: 28 | parent (QtWidgets.QWidget): parent of the new widget. 29 | node_graph (NodeGraphQt.NodeGraph): node graph. 30 | """ 31 | 32 | #: Signal emitted (node_id, prop_name, prop_value) 33 | property_changed = QtCore.Signal(str, str, object) 34 | 35 | def __init__(self, parent=None, node_graph=None): 36 | super(CustomPropertiesBinWidget, self).__init__(parent) 37 | self.setWindowTitle('Properties Bin') 38 | self._prop_list = _PropertiesList() 39 | self.resize(450, 400) 40 | 41 | self._block_signal = False 42 | self._lock = False 43 | 44 | layout = QtWidgets.QVBoxLayout(self) 45 | layout.addWidget(self._prop_list, 1) 46 | 47 | # wire up node graph. 48 | node_graph.add_properties_bin(self) 49 | node_graph.node_double_clicked.connect(self.add_node) 50 | 51 | def __repr__(self): 52 | return '<{} object at {}>'.format(self.__class__.__name__, hex(id(self))) 53 | 54 | def __on_prop_close(self, node_id): 55 | items = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly) 56 | [self._prop_list.removeRow(i.row()) for i in items] 57 | 58 | def add_node(self, node): 59 | """ 60 | Add node to the properties bin. 61 | 62 | Args: 63 | node (NodeGraphQt.NodeObject): node object. 64 | """ 65 | 66 | rows = self._prop_list.rowCount() 67 | if rows >= 1: 68 | self._prop_list.removeRow(rows - 1) 69 | 70 | itm_find = self._prop_list.findItems(node.id, QtCore.Qt.MatchExactly) 71 | if itm_find: 72 | self._prop_list.removeRow(itm_find[0].row()) 73 | 74 | self._prop_list.insertRow(0) 75 | prop_widget = CustomNodePropWidget(node=node) 76 | prop_widget.property_closed.connect(self.__on_prop_close) 77 | self._prop_list.setCellWidget(0, 0, prop_widget) 78 | 79 | item = QtWidgets.QTableWidgetItem(node.id) 80 | self._prop_list.setItem(0, 0, item) 81 | self._prop_list.selectRow(0) 82 | 83 | def remove_node(self, node): 84 | """ 85 | Remove node from the properties bin. 86 | 87 | Args: 88 | node (str or NodeGraphQt.BaseNode): node id or node object. 89 | """ 90 | node_id = node if isinstance(node, str) else node.id 91 | self.__on_prop_close(node_id) 92 | 93 | def prop_widget(self, node): 94 | """ 95 | Returns the node property widget. 96 | 97 | Args: 98 | node (str or NodeGraphQt.NodeObject): node id or node object. 99 | 100 | Returns: 101 | CustomNodePropWidget: node property widget. 102 | """ 103 | node_id = node if isinstance(node, str) else node.id 104 | itm_find = self._prop_list.findItems(node_id, QtCore.Qt.MatchExactly) 105 | if itm_find: 106 | item = itm_find[0] 107 | return self._prop_list.cellWidget(item.row(), 0) 108 | 109 | 110 | if __name__ == '__main__': 111 | import sys 112 | from NodeGraphQt import BaseNode, NodeGraph 113 | from NodeGraphQt.constants import NodePropWidgetEnum 114 | 115 | 116 | class TestNode(BaseNode): 117 | NODE_NAME = 'test node' 118 | 119 | def __init__(self): 120 | super(TestNode, self).__init__() 121 | self.create_property('label_test', 'foo bar', 122 | widget_type=NodePropWidgetEnum.QLABEL) 123 | self.create_property('text_edit', 'hello', 124 | widget_type=NodePropWidgetEnum.QLINE_EDIT) 125 | self.create_property('color_picker', (0, 0, 255), 126 | widget_type=NodePropWidgetEnum.COLOR_PICKER) 127 | self.create_property('integer', 10, 128 | widget_type=NodePropWidgetEnum.QSPIN_BOX) 129 | self.create_property('list', 'foo', 130 | items=['foo', 'bar'], 131 | widget_type=NodePropWidgetEnum.QCOMBO_BOX) 132 | self.create_property('range', 50, 133 | range=(45, 55), 134 | widget_type=NodePropWidgetEnum.SLIDER) 135 | 136 | def prop_changed(node_id, prop_name, prop_value): 137 | print('-'*100) 138 | print(node_id, prop_name, prop_value) 139 | 140 | 141 | app = QtWidgets.QApplication(sys.argv) 142 | 143 | graph = NodeGraph() 144 | graph.register_node(TestNode) 145 | 146 | prop_bin = CustomPropertiesBinWidget(node_graph=graph) 147 | prop_bin.property_changed.connect(prop_changed) 148 | 149 | node = graph.create_node('nodeGraphQt.nodes.TestNode') 150 | 151 | prop_bin.add_node(node) 152 | prop_bin.show() 153 | 154 | app.exec_() 155 | -------------------------------------------------------------------------------- /onnxgraphqt/widgets/splash_screen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | DEFAULT_SPLASH_IMAGE = os.path.join(os.path.dirname(__file__), "../data/splash.png") 6 | 7 | def create_screen(image_path=DEFAULT_SPLASH_IMAGE): 8 | pixmap = QtGui.QPixmap(image_path) 9 | splash = QtWidgets.QSplashScreen(pixmap, QtCore.Qt.WindowStaysOnTopHint) 10 | splash.setEnabled(False) 11 | return splash 12 | 13 | def create_screen_progressbar(image_path=DEFAULT_SPLASH_IMAGE): 14 | pixmap = QtGui.QPixmap(image_path) 15 | splash = QtWidgets.QSplashScreen(pixmap, QtCore.Qt.WindowStaysOnTopHint) 16 | splash.setEnabled(False) 17 | progressBar = QtWidgets.QProgressBar(splash) 18 | progressBar.setMaximum(10) 19 | progressBar.setGeometry(0, pixmap.height() - 50, pixmap.width(), 20) 20 | return splash, progressBar 21 | 22 | if __name__ == "__main__": 23 | import signal 24 | import os 25 | import sys 26 | import time 27 | # handle SIGINT to make the app terminate on CTRL+C 28 | signal.signal(signal.SIGINT, signal.SIG_DFL) 29 | 30 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 31 | 32 | app = QtWidgets.QApplication(sys.argv) 33 | splash, progressBar = create_screen_progressbar() 34 | # splash = create_screen() 35 | splash.show() 36 | time.sleep(0.01) 37 | splash.showMessage("loading...", alignment=QtCore.Qt.AlignBottom, color=QtGui.QColor.fromRgb(255, 255, 255)) 38 | for i in range(1, 11): 39 | progressBar.setValue(i) 40 | t = time.time() 41 | while time.time() < t + 0.1: 42 | app.processEvents() 43 | time.sleep(1) 44 | 45 | window = QtWidgets.QMainWindow() 46 | window.show() 47 | 48 | splash.finish(window) 49 | 50 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_add_node.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | from ast import literal_eval 5 | 6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 7 | from onnxgraphqt.utils.operators import onnx_opsets, opnames, OperatorVersion 8 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 10 | 11 | 12 | AVAILABLE_DTYPES = [ 13 | 'float32', 14 | 'float64', 15 | 'int32', 16 | 'int64', 17 | 'str', 18 | ] 19 | 20 | 21 | AddNodeProperties = namedtuple("AddNodeProperties", 22 | [ 23 | "connection_src_op_output_names", 24 | "connection_dest_op_input_names", 25 | "add_op_type", 26 | "add_op_name", 27 | "add_op_input_variables", 28 | "add_op_output_variables", 29 | "add_op_attributes", 30 | ]) 31 | 32 | 33 | class AddNodeWidgets(QtWidgets.QDialog): 34 | _DEFAULT_WINDOW_WIDTH = 500 35 | _MAX_VARIABLES_COUNT = 5 36 | _MAX_ATTRIBUTES_COUNT = 10 37 | # _LABEL_FONT_SIZE = 16 38 | 39 | def __init__(self, current_opset:int, graph: OnnxGraph=None, parent=None) -> None: 40 | super().__init__(parent) 41 | set_font(self, font_size=BASE_FONT_SIZE) 42 | self.setModal(False) 43 | self.setWindowTitle("add node") 44 | self.current_opset = current_opset 45 | self.graph = graph 46 | self.initUI() 47 | self.updateUI(self.graph) 48 | 49 | def initUI(self): 50 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 51 | 52 | base_layout = QtWidgets.QVBoxLayout() 53 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 54 | 55 | # src_output 56 | labels_src_output = ["src_op", "src_op output", "add_op", "add_op input"] 57 | layout_src_output = QtWidgets.QVBoxLayout() 58 | lbl_src_output = QtWidgets.QLabel("connection_src_op_output_names") 59 | set_font(lbl_src_output, font_size=LARGE_FONT_SIZE, bold=True) 60 | layout_src_output.addWidget(lbl_src_output) 61 | layout_src_output_form = QtWidgets.QFormLayout() 62 | self.src_output_names = {} 63 | for i in range(len(labels_src_output)): 64 | if i in [0, 1]: 65 | self.src_output_names[i] = QtWidgets.QComboBox() 66 | self.src_output_names[i].setEditable(True) 67 | else: 68 | self.src_output_names[i] = QtWidgets.QLineEdit() 69 | layout_src_output_form.addRow(labels_src_output[i], self.src_output_names[i]) 70 | layout_src_output.addLayout(layout_src_output_form) 71 | 72 | # dest_input 73 | labels_dest_input = ["add_op", "add_op output", "dst_op", "dst_op input"] 74 | layout_dest_input = QtWidgets.QVBoxLayout() 75 | lbl_dest_input = QtWidgets.QLabel("connection_dest_op_input_names") 76 | set_font(lbl_dest_input, font_size=LARGE_FONT_SIZE, bold=True) 77 | layout_dest_input.addWidget(lbl_dest_input) 78 | layout_dest_input_form = QtWidgets.QFormLayout() 79 | self.dest_input_names = {} 80 | for i in range(len(labels_dest_input)): 81 | if i in [0, 1]: 82 | self.dest_input_names[i] = QtWidgets.QLineEdit() 83 | else: 84 | self.dest_input_names[i] = QtWidgets.QComboBox() 85 | self.dest_input_names[i].setEditable(True) 86 | layout_dest_input_form.addRow(labels_dest_input[i], self.dest_input_names[i]) 87 | layout_dest_input.addLayout(layout_dest_input_form) 88 | 89 | # Form layout 90 | layout_op = QtWidgets.QFormLayout() 91 | layout_op.setLabelAlignment(QtCore.Qt.AlignRight) 92 | self.add_op_type = QtWidgets.QComboBox() 93 | for op in onnx_opsets: 94 | for version in op.versions: 95 | if version.since_opset <= self.current_opset: 96 | self.add_op_type.addItem(op.name, op.versions[0]) 97 | break 98 | self.add_op_type.setEditable(True) 99 | self.add_op_type.currentIndexChanged.connect(self.add_op_type_currentIndexChanged) 100 | self.add_op_name = QtWidgets.QLineEdit() 101 | self.add_op_name.setPlaceholderText("Name of op to be added") 102 | lbl_add_op_name = QtWidgets.QLabel("add_op_name") 103 | set_font(lbl_add_op_name, font_size=LARGE_FONT_SIZE, bold=True) 104 | lbl_add_op_type = QtWidgets.QLabel("add_op_type") 105 | set_font(lbl_add_op_type, font_size=LARGE_FONT_SIZE, bold=True) 106 | layout_op.addRow(lbl_add_op_name, self.add_op_name) 107 | layout_op.addRow(lbl_add_op_type, self.add_op_type) 108 | 109 | # variables 110 | layout_valiables = QtWidgets.QVBoxLayout() 111 | self.visible_input_valiables_count = 1 112 | self.visible_output_valiables_count = 1 113 | 114 | self.add_input_valiables = {} 115 | self.add_output_valiables = {} 116 | for i in range(self._MAX_VARIABLES_COUNT): 117 | self.create_variables_widget(i, is_input=True) 118 | for i in range(self._MAX_VARIABLES_COUNT): 119 | self.create_variables_widget(i, is_input=False) 120 | 121 | self.btn_add_input_valiables = QtWidgets.QPushButton("+") 122 | self.btn_del_input_valiables = QtWidgets.QPushButton("-") 123 | self.btn_add_input_valiables.clicked.connect(self.btn_add_input_valiables_clicked) 124 | self.btn_del_input_valiables.clicked.connect(self.btn_del_input_valiables_clicked) 125 | self.btn_add_output_valiables = QtWidgets.QPushButton("+") 126 | self.btn_del_output_valiables = QtWidgets.QPushButton("-") 127 | self.btn_add_output_valiables.clicked.connect(self.btn_add_output_valiables_clicked) 128 | self.btn_del_output_valiables.clicked.connect(self.btn_del_output_valiables_clicked) 129 | 130 | layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 131 | lbl_add_input_valiables = QtWidgets.QLabel("add op input valiables [optional]") 132 | set_font(lbl_add_input_valiables, font_size=LARGE_FONT_SIZE, bold=True) 133 | layout_valiables.addWidget(lbl_add_input_valiables) 134 | for key, widgets in self.add_input_valiables.items(): 135 | layout_valiables.addWidget(widgets["base"]) 136 | self.set_visible_input_valiables() 137 | layout_btn_input = QtWidgets.QHBoxLayout() 138 | layout_btn_input.addWidget(self.btn_add_input_valiables) 139 | layout_btn_input.addWidget(self.btn_del_input_valiables) 140 | layout_valiables.addLayout(layout_btn_input) 141 | 142 | layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 143 | lbl_add_output_valiables = QtWidgets.QLabel("add op output valiables [optional]") 144 | set_font(lbl_add_output_valiables, font_size=LARGE_FONT_SIZE, bold=True) 145 | layout_valiables.addWidget(lbl_add_output_valiables) 146 | for key, widgets in self.add_output_valiables.items(): 147 | layout_valiables.addWidget(widgets["base"]) 148 | self.set_visible_output_valiables() 149 | layout_btn_output = QtWidgets.QHBoxLayout() 150 | layout_btn_output.addWidget(self.btn_add_output_valiables) 151 | layout_btn_output.addWidget(self.btn_del_output_valiables) 152 | layout_valiables.addLayout(layout_btn_output) 153 | 154 | # add_attributes 155 | layout_attributes = QtWidgets.QVBoxLayout() 156 | layout_attributes.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 157 | lbl_add_atrributes = QtWidgets.QLabel("add op atrributes [optional]") 158 | set_font(lbl_add_atrributes, font_size=LARGE_FONT_SIZE, bold=True) 159 | layout_attributes.addWidget(lbl_add_atrributes) 160 | self.visible_add_attributes_count = 3 161 | self.add_op_attributes = {} 162 | for index in range(self._MAX_ATTRIBUTES_COUNT): 163 | self.add_op_attributes[index] = {} 164 | self.add_op_attributes[index]["base"] = QtWidgets.QWidget() 165 | self.add_op_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.add_op_attributes[index]["base"]) 166 | self.add_op_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0) 167 | self.add_op_attributes[index]["name"] = QtWidgets.QLineEdit() 168 | self.add_op_attributes[index]["name"].setPlaceholderText("name") 169 | self.add_op_attributes[index]["value"] = QtWidgets.QLineEdit() 170 | self.add_op_attributes[index]["value"].setPlaceholderText("value") 171 | self.add_op_attributes[index]["layout"].addWidget(self.add_op_attributes[index]["name"]) 172 | self.add_op_attributes[index]["layout"].addWidget(self.add_op_attributes[index]["value"]) 173 | layout_attributes.addWidget(self.add_op_attributes[index]["base"]) 174 | self.btn_add_op_attributes = QtWidgets.QPushButton("+") 175 | self.btn_del_op_attributes = QtWidgets.QPushButton("-") 176 | self.btn_add_op_attributes.clicked.connect(self.btn_add_op_attributes_clicked) 177 | self.btn_del_op_attributes.clicked.connect(self.btn_del_op_attributes_clicked) 178 | self.set_visible_add_op_attributes() 179 | layout_btn_attributes = QtWidgets.QHBoxLayout() 180 | layout_btn_attributes.addWidget(self.btn_add_op_attributes) 181 | layout_btn_attributes.addWidget(self.btn_del_op_attributes) 182 | layout_attributes.addLayout(layout_btn_attributes) 183 | 184 | # add layout 185 | base_layout.addLayout(layout_src_output) 186 | base_layout.addSpacerItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 10)) 187 | base_layout.addLayout(layout_dest_input) 188 | base_layout.addSpacerItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 10)) 189 | base_layout.addLayout(layout_op) 190 | base_layout.addLayout(layout_valiables) 191 | base_layout.addLayout(layout_attributes) 192 | 193 | # Dialog button 194 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 195 | QtWidgets.QDialogButtonBox.Cancel) 196 | btn.accepted.connect(self.accept) 197 | btn.rejected.connect(self.reject) 198 | # layout.addWidget(btn) 199 | base_layout.addWidget(btn) 200 | 201 | self.setLayout(base_layout) 202 | self.add_op_type_currentIndexChanged(self.add_op_type.currentIndex()) 203 | 204 | def updateUI(self, graph: OnnxGraph): 205 | if graph: 206 | for name in graph.node_inputs.keys(): 207 | self.src_output_names[0].addItem(name) 208 | self.src_output_names[1].addItem(name) 209 | for name in graph.nodes.keys(): 210 | self.src_output_names[0].addItem(name) 211 | self.src_output_names[1].addItem(name) 212 | 213 | for name in graph.node_inputs.keys(): 214 | self.dest_input_names[2].addItem(name) 215 | self.dest_input_names[3].addItem(name) 216 | for name in graph.nodes.keys(): 217 | self.dest_input_names[2].addItem(name) 218 | self.dest_input_names[3].addItem(name) 219 | 220 | def create_variables_widget(self, index:int, is_input=True)->QtWidgets.QBoxLayout: 221 | if is_input: 222 | self.add_input_valiables[index] = {} 223 | self.add_input_valiables[index]["base"] = QtWidgets.QWidget() 224 | self.add_input_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_input_valiables[index]["base"]) 225 | self.add_input_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0) 226 | self.add_input_valiables[index]["name"] = QtWidgets.QLineEdit() 227 | self.add_input_valiables[index]["name"].setPlaceholderText("name") 228 | self.add_input_valiables[index]["dtype"] = QtWidgets.QComboBox() 229 | for dtype in AVAILABLE_DTYPES: 230 | self.add_input_valiables[index]["dtype"].addItem(dtype) 231 | self.add_input_valiables[index]["dtype"].setEditable(True) 232 | self.add_input_valiables[index]["dtype"].setFixedSize(100, 20) 233 | self.add_input_valiables[index]["shape"] = QtWidgets.QLineEdit() 234 | self.add_input_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`") 235 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["name"]) 236 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["dtype"]) 237 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["shape"]) 238 | else: 239 | self.add_output_valiables[index] = {} 240 | self.add_output_valiables[index]["base"] = QtWidgets.QWidget() 241 | self.add_output_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_output_valiables[index]["base"]) 242 | self.add_output_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0) 243 | self.add_output_valiables[index]["name"] = QtWidgets.QLineEdit() 244 | self.add_output_valiables[index]["name"].setPlaceholderText("name") 245 | self.add_output_valiables[index]["dtype"] = QtWidgets.QComboBox() 246 | for dtype in AVAILABLE_DTYPES: 247 | self.add_output_valiables[index]["dtype"].addItem(dtype) 248 | self.add_output_valiables[index]["dtype"].setEditable(True) 249 | self.add_output_valiables[index]["dtype"].setFixedSize(100, 20) 250 | self.add_output_valiables[index]["dtype"].setPlaceholderText("dtype. e.g. `float32`") 251 | self.add_output_valiables[index]["shape"] = QtWidgets.QLineEdit() 252 | self.add_output_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`") 253 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["name"]) 254 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["dtype"]) 255 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["shape"]) 256 | 257 | def set_visible_input_valiables(self): 258 | for key, widgets in self.add_input_valiables.items(): 259 | widgets["base"].setVisible(key < self.visible_input_valiables_count) 260 | if self.visible_input_valiables_count == 0: 261 | self.btn_add_input_valiables.setEnabled(True) 262 | self.btn_del_input_valiables.setEnabled(False) 263 | elif self.visible_input_valiables_count >= self._MAX_VARIABLES_COUNT: 264 | self.btn_add_input_valiables.setEnabled(False) 265 | self.btn_del_input_valiables.setEnabled(True) 266 | else: 267 | self.btn_add_input_valiables.setEnabled(True) 268 | self.btn_del_input_valiables.setEnabled(True) 269 | 270 | def set_visible_output_valiables(self): 271 | for key, widgets in self.add_output_valiables.items(): 272 | widgets["base"].setVisible(key < self.visible_output_valiables_count) 273 | if self.visible_output_valiables_count == 0: 274 | self.btn_add_output_valiables.setEnabled(True) 275 | self.btn_del_output_valiables.setEnabled(False) 276 | elif self.visible_output_valiables_count >= self._MAX_VARIABLES_COUNT: 277 | self.btn_add_output_valiables.setEnabled(False) 278 | self.btn_del_output_valiables.setEnabled(True) 279 | else: 280 | self.btn_add_output_valiables.setEnabled(True) 281 | self.btn_del_output_valiables.setEnabled(True) 282 | 283 | def set_visible_add_op_attributes(self): 284 | for key, widgets in self.add_op_attributes.items(): 285 | widgets["base"].setVisible(key < self.visible_add_attributes_count) 286 | if self.visible_add_attributes_count == 0: 287 | self.btn_add_op_attributes.setEnabled(True) 288 | self.btn_del_op_attributes.setEnabled(False) 289 | elif self.visible_add_attributes_count >= self._MAX_ATTRIBUTES_COUNT: 290 | self.btn_add_op_attributes.setEnabled(False) 291 | self.btn_del_op_attributes.setEnabled(True) 292 | else: 293 | self.btn_add_op_attributes.setEnabled(True) 294 | self.btn_del_op_attributes.setEnabled(True) 295 | 296 | def btn_add_input_valiables_clicked(self, e): 297 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count + 1), self._MAX_VARIABLES_COUNT) 298 | self.set_visible_input_valiables() 299 | 300 | def btn_del_input_valiables_clicked(self, e): 301 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count - 1), self._MAX_VARIABLES_COUNT) 302 | self.set_visible_input_valiables() 303 | 304 | def btn_add_output_valiables_clicked(self, e): 305 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count + 1), self._MAX_VARIABLES_COUNT) 306 | self.set_visible_output_valiables() 307 | 308 | def btn_del_output_valiables_clicked(self, e): 309 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count - 1), self._MAX_VARIABLES_COUNT) 310 | self.set_visible_output_valiables() 311 | 312 | def btn_add_op_attributes_clicked(self, e): 313 | self.visible_add_attributes_count = min(max(0, self.visible_add_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT) 314 | self.set_visible_add_op_attributes() 315 | 316 | def btn_del_op_attributes_clicked(self, e): 317 | self.visible_add_attributes_count = min(max(0, self.visible_add_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT) 318 | self.set_visible_add_op_attributes() 319 | 320 | def add_op_type_currentIndexChanged(self, selected_index:int): 321 | selected_operator: OperatorVersion = self.add_op_type.currentData() 322 | self.visible_input_valiables_count = selected_operator.inputs 323 | self.visible_output_valiables_count = selected_operator.outputs 324 | self.visible_add_attributes_count = min(max(0, len(selected_operator.attrs)), self._MAX_ATTRIBUTES_COUNT) 325 | 326 | for i, att in enumerate(selected_operator.attrs): 327 | self.add_op_attributes[i]["name"].setText(att.name) 328 | self.add_op_attributes[i]["value"].setText(att.default_value) 329 | for j in range(len(selected_operator.attrs), self._MAX_ATTRIBUTES_COUNT): 330 | self.add_op_attributes[j]["name"].setText("") 331 | self.add_op_attributes[j]["value"].setText("") 332 | self.set_visible_input_valiables() 333 | self.set_visible_output_valiables() 334 | self.set_visible_add_op_attributes() 335 | 336 | def get_properties(self)->AddNodeProperties: 337 | add_op_input_variables = {} 338 | add_op_output_variables = {} 339 | for i in range(self.visible_input_valiables_count): 340 | name = self.add_input_valiables[i]["name"].text() 341 | dtype = self.add_input_valiables[i]["dtype"].currentText() 342 | shape = self.add_input_valiables[i]["shape"].text() 343 | if name and dtype and shape: 344 | add_op_input_variables[name] = [dtype, literal_eval(shape)] 345 | for i in range(self.visible_output_valiables_count): 346 | name = self.add_output_valiables[i]["name"].text() 347 | dtype = self.add_output_valiables[i]["dtype"].currentText() 348 | shape = self.add_output_valiables[i]["shape"].text() 349 | if name and dtype and shape: 350 | add_op_output_variables[name] = [dtype, literal_eval(shape)] 351 | 352 | if len(add_op_input_variables) == 0: 353 | add_op_input_variables = None 354 | if len(add_op_output_variables) == 0: 355 | add_op_output_variables = None 356 | 357 | add_op_attributes = {} 358 | for i in range(self.visible_add_attributes_count): 359 | name = self.add_op_attributes[i]["name"].text() 360 | value = self.add_op_attributes[i]["value"].text() 361 | if name and value: 362 | try: 363 | # For literal 364 | add_op_attributes[name] = literal_eval(value) 365 | except BaseException as e: 366 | # For str 367 | add_op_attributes[name] = value 368 | if len(add_op_attributes) == 0: 369 | add_op_attributes = None 370 | 371 | src_output_names = [] 372 | for i in [0, 1]: 373 | name: str = self.src_output_names[i].currentText().strip() 374 | src_output_names.append(name) 375 | for i in [2, 3]: 376 | name: str = self.src_output_names[i].text().strip() 377 | src_output_names.append(name) 378 | 379 | dest_input_names = [] 380 | for i in [0, 1]: 381 | name: str = self.dest_input_names[i].text().strip() 382 | dest_input_names.append(name) 383 | for i in [2, 3]: 384 | name: str = self.dest_input_names[i].currentText().strip() 385 | dest_input_names.append(name) 386 | 387 | return AddNodeProperties( 388 | connection_src_op_output_names=[src_output_names], 389 | connection_dest_op_input_names=[dest_input_names], 390 | add_op_type=self.add_op_type.currentText(), 391 | add_op_name=self.add_op_name.text(), 392 | add_op_input_variables=add_op_input_variables, 393 | add_op_output_variables=add_op_output_variables, 394 | add_op_attributes=add_op_attributes, 395 | ) 396 | 397 | def accept(self) -> None: 398 | # value check 399 | invalid = False 400 | props = self.get_properties() 401 | print(props) 402 | err_msgs = [] 403 | for src_op_output_name in props.connection_src_op_output_names: 404 | src_op, src_op_output, add_op, add_op_input = src_op_output_name 405 | if not src_op: 406 | err_msgs.append("- [connection_src_op_output_names] src_op is not set.") 407 | invalid = True 408 | if not src_op_output: 409 | err_msgs.append("- [connection_src_op_output_names] src_op output is not set.") 410 | invalid = True 411 | if not add_op: 412 | err_msgs.append("- [connection_src_op_output_names] add_op is not set.") 413 | invalid = True 414 | if not add_op_input: 415 | err_msgs.append("- [connection_src_op_output_names] add_op input is not set.") 416 | invalid = True 417 | for dest_op_input_name in props.connection_dest_op_input_names: 418 | add_op, add_op_output, dst_op, dst_op_input = dest_op_input_name 419 | if not add_op: 420 | err_msgs.append("- [connection_dest_op_input_names] add_op is not set") 421 | invalid = True 422 | if not add_op_output: 423 | err_msgs.append("- [connection_dest_op_input_names] add_op_output is not set") 424 | invalid = True 425 | if not dst_op: 426 | err_msgs.append("- [connection_dest_op_input_names] dst_op is not set") 427 | invalid = True 428 | if not dst_op_input: 429 | err_msgs.append("- [connection_dest_op_input_names] dst_op input is not set") 430 | invalid = True 431 | if not props.add_op_name: 432 | err_msgs.append("- [add_op_name] not set") 433 | invalid = True 434 | if not props.add_op_type in opnames: 435 | err_msgs.append("- [add_op_type] not support") 436 | invalid = True 437 | if invalid: 438 | for m in err_msgs: 439 | print(m) 440 | MessageBox.error(err_msgs, "add node", parent=self) 441 | return 442 | return super().accept() 443 | 444 | 445 | 446 | if __name__ == "__main__": 447 | import signal 448 | import os 449 | # handle SIGINT to make the app terminate on CTRL+C 450 | signal.signal(signal.SIGINT, signal.SIG_DFL) 451 | 452 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 453 | 454 | app = QtWidgets.QApplication([]) 455 | window = AddNodeWidgets(current_opset=16) 456 | window.show() 457 | 458 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_change_channel.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | from ast import literal_eval 5 | 6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 7 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 8 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 9 | 10 | 11 | ChangeChannelProperties = namedtuple("ChangeChannelProperties", 12 | [ 13 | "input_op_names_and_order_dims", 14 | "channel_change_inputs", 15 | ]) 16 | 17 | 18 | class ChangeChannelWidgets(QtWidgets.QDialog): 19 | _DEFAULT_WINDOW_WIDTH = 300 20 | _QLINE_EDIT_WIDTH = 170 21 | 22 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None: 23 | super().__init__(parent) 24 | self.setModal(False) 25 | self.setWindowTitle("change channel") 26 | self.graph = graph 27 | 28 | if graph: 29 | self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT = len(self.graph.inputs) 30 | self._MAX_CHANNEL_CHANGE_INPUTS_COUNT = len(self.graph.inputs) 31 | else: 32 | self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT = 4 33 | self._MAX_CHANNEL_CHANGE_INPUTS_COUNT = 4 34 | self.initUI() 35 | self.updateUI(self.graph) 36 | 37 | def initUI(self): 38 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 39 | set_font(self, font_size=BASE_FONT_SIZE) 40 | 41 | base_layout = QtWidgets.QVBoxLayout() 42 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 43 | 44 | self.layout_order_dims = QtWidgets.QVBoxLayout() 45 | self.layout_change_channel = QtWidgets.QVBoxLayout() 46 | self.visible_change_order_dim_inputs_count = self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT 47 | self.visible_channel_change_inputs_count = self._MAX_CHANNEL_CHANGE_INPUTS_COUNT 48 | 49 | self.input_op_names_and_order_dims = {} # transpose NCHW <-> NHWC 50 | self.channel_change_inputs = {} # RGB <-> BGR 51 | self.create_order_dims_widgets() 52 | self.create_channel_change_widgets() 53 | 54 | lbl_channel_change = QtWidgets.QLabel("channel_change_inputs") 55 | lbl_channel_change.setToolTip('Change the channel order of RGB and BGR.' + 56 | 'If the original model is RGB, it is transposed to BGR.' + 57 | 'If the original model is BGR, it is transposed to RGB.' + 58 | 'It can be selectively specified from among the OP names' + 59 | 'specified in input_op_names_and_order_dims.' + 60 | 'OP names not specified in input_op_names_and_order_dims are ignored.' + 61 | 'Multiple times can be specified as many times as the number' + 62 | 'of OP names specified in input_op_names_and_order_dims.' + 63 | 'channel_change_inputs = {"op_name": dimension_number_representing_the_channel}' + 64 | 'dimension_number_representing_the_channel must specify' + 65 | 'the dimension position after the change in input_op_names_and_order_dims.' + 66 | 'For example, dimension_number_representing_the_channel is 1 for NCHW and 3 for NHWC.') 67 | set_font(lbl_channel_change, font_size=LARGE_FONT_SIZE, bold=True) 68 | self.layout_change_channel.addWidget(lbl_channel_change) 69 | for i in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT): 70 | self.layout_change_channel.addWidget(self.channel_change_inputs[i]["base"]) 71 | 72 | lbl_order_dims = QtWidgets.QLabel("input_op_names_and_order_dims") 73 | lbl_order_dims.setToolTip("Specify the name of the input_op to be dimensionally changed and\n" + 74 | "the order of the dimensions after the change.\n" + 75 | "The name of the input_op to be dimensionally changed\n" + 76 | "can be specified multiple times.") 77 | set_font(lbl_order_dims, font_size=LARGE_FONT_SIZE, bold=True) 78 | self.layout_order_dims.addWidget(lbl_order_dims) 79 | for i in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT): 80 | self.layout_order_dims.addWidget(self.input_op_names_and_order_dims[i]["base"]) 81 | 82 | # add button 83 | self.btn_add_order_dims = QtWidgets.QPushButton("+") 84 | self.btn_del_order_dims = QtWidgets.QPushButton("-") 85 | self.btn_add_order_dims.clicked.connect(self.btn_add_order_dims_clicked) 86 | self.btn_del_order_dims.clicked.connect(self.btn_del_order_dims_clicked) 87 | self.layout_btn_order_dims = QtWidgets.QHBoxLayout() 88 | self.layout_btn_order_dims.addWidget(self.btn_add_order_dims) 89 | self.layout_btn_order_dims.addWidget(self.btn_del_order_dims) 90 | 91 | self.btn_add_channel_change = QtWidgets.QPushButton("+") 92 | self.btn_del_channel_change = QtWidgets.QPushButton("-") 93 | self.btn_add_channel_change.clicked.connect(self.btn_add_channel_change_clicked) 94 | self.btn_del_channel_change.clicked.connect(self.btn_del_channel_change_clicked) 95 | self.layout_btn_channel_change = QtWidgets.QHBoxLayout() 96 | self.layout_btn_channel_change.addWidget(self.btn_add_channel_change) 97 | self.layout_btn_channel_change.addWidget(self.btn_del_channel_change) 98 | 99 | self.layout_order_dims.addLayout(self.layout_btn_order_dims) 100 | self.layout_change_channel.addLayout(self.layout_btn_channel_change) 101 | 102 | # add layout 103 | base_layout.addLayout(self.layout_order_dims) 104 | base_layout.addSpacing(15) 105 | base_layout.addLayout(self.layout_change_channel) 106 | 107 | # Dialog button 108 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 109 | QtWidgets.QDialogButtonBox.Cancel) 110 | btn.accepted.connect(self.accept) 111 | btn.rejected.connect(self.reject) 112 | # layout.addWidget(btn) 113 | base_layout.addWidget(btn) 114 | 115 | self.setLayout(base_layout) 116 | self.set_visible_order_dims() 117 | self.set_visible_channel_change() 118 | 119 | def updateUI(self, graph: OnnxGraph): 120 | if graph: 121 | for index in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT): 122 | self.input_op_names_and_order_dims[index]["name"].clear() 123 | for name, input_node in graph.inputs.items(): 124 | self.input_op_names_and_order_dims[index]["name"].addItem(name) 125 | self.input_op_names_and_order_dims[index]["name"].setCurrentIndex(index) 126 | 127 | for index in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT): 128 | self.channel_change_inputs[index]["name"].clear() 129 | for name, input_node in graph.inputs.items(): 130 | self.channel_change_inputs[index]["name"].addItem(name) 131 | self.channel_change_inputs[index]["name"].setCurrentIndex(index) 132 | self.set_visible_order_dims() 133 | self.set_visible_channel_change() 134 | 135 | def btn_add_order_dims_clicked(self, e): 136 | self.visible_change_order_dim_inputs_count = min(max(0, self.visible_change_order_dim_inputs_count + 1), self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT) 137 | self.set_visible_order_dims() 138 | 139 | def btn_del_order_dims_clicked(self, e): 140 | self.visible_change_order_dim_inputs_count = min(max(0, self.visible_change_order_dim_inputs_count - 1), self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT) 141 | self.set_visible_order_dims() 142 | 143 | def btn_add_channel_change_clicked(self, e): 144 | self.visible_channel_change_inputs_count = min(max(0, self.visible_channel_change_inputs_count + 1), self._MAX_CHANNEL_CHANGE_INPUTS_COUNT) 145 | self.set_visible_channel_change() 146 | 147 | def btn_del_channel_change_clicked(self, e): 148 | self.visible_channel_change_inputs_count = min(max(0, self.visible_channel_change_inputs_count - 1), self._MAX_CHANNEL_CHANGE_INPUTS_COUNT) 149 | self.set_visible_channel_change() 150 | 151 | def create_order_dims_widgets(self): 152 | for index in range(self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT): 153 | self.input_op_names_and_order_dims[index] = {} 154 | self.input_op_names_and_order_dims[index]["base"] = QtWidgets.QWidget() 155 | self.input_op_names_and_order_dims[index]["layout"] = QtWidgets.QHBoxLayout(self.input_op_names_and_order_dims[index]["base"]) 156 | self.input_op_names_and_order_dims[index]["layout"].setContentsMargins(0, 0, 0, 0) 157 | self.input_op_names_and_order_dims[index]["name"] = QtWidgets.QComboBox() 158 | self.input_op_names_and_order_dims[index]["name"].setEditable(True) 159 | self.input_op_names_and_order_dims[index]["name"].setFixedWidth(self._QLINE_EDIT_WIDTH) 160 | self.input_op_names_and_order_dims[index]["value"] = QtWidgets.QLineEdit() 161 | self.input_op_names_and_order_dims[index]["value"].setPlaceholderText("List of dims.") 162 | self.input_op_names_and_order_dims[index]["layout"].addWidget(self.input_op_names_and_order_dims[index]["name"]) 163 | self.input_op_names_and_order_dims[index]["layout"].addWidget(self.input_op_names_and_order_dims[index]["value"]) 164 | 165 | def create_channel_change_widgets(self): 166 | for index in range(self._MAX_CHANNEL_CHANGE_INPUTS_COUNT): 167 | self.channel_change_inputs[index] = {} 168 | self.channel_change_inputs[index]["base"] = QtWidgets.QWidget() 169 | self.channel_change_inputs[index]["layout"] = QtWidgets.QHBoxLayout(self.channel_change_inputs[index]["base"]) 170 | self.channel_change_inputs[index]["layout"].setContentsMargins(0, 0, 0, 0) 171 | self.channel_change_inputs[index]["name"] = QtWidgets.QComboBox() 172 | self.channel_change_inputs[index]["name"].setEditable(True) 173 | self.channel_change_inputs[index]["name"].setFixedWidth(self._QLINE_EDIT_WIDTH) 174 | self.channel_change_inputs[index]["value"] = QtWidgets.QLineEdit() 175 | self.channel_change_inputs[index]["value"].setPlaceholderText("dim") 176 | self.channel_change_inputs[index]["layout"].addWidget(self.channel_change_inputs[index]["name"]) 177 | self.channel_change_inputs[index]["layout"].addWidget(self.channel_change_inputs[index]["value"]) 178 | 179 | def set_visible_order_dims(self): 180 | for key, widgets in self.input_op_names_and_order_dims.items(): 181 | widgets["base"].setVisible(key < self.visible_change_order_dim_inputs_count) 182 | if self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT == 1: 183 | self.btn_add_order_dims.setEnabled(False) 184 | self.btn_del_order_dims.setEnabled(False) 185 | elif self.visible_change_order_dim_inputs_count == 1: 186 | self.btn_add_order_dims.setEnabled(True) 187 | self.btn_del_order_dims.setEnabled(False) 188 | elif self.visible_change_order_dim_inputs_count >= self._MAX_CHANGE_ORDER_DIM_INPUTS_COUNT: 189 | self.btn_add_order_dims.setEnabled(False) 190 | self.btn_del_order_dims.setEnabled(True) 191 | else: 192 | self.btn_add_order_dims.setEnabled(True) 193 | self.btn_del_order_dims.setEnabled(True) 194 | 195 | def set_visible_channel_change(self): 196 | for key, widgets in self.channel_change_inputs.items(): 197 | widgets["base"].setVisible(key < self.visible_channel_change_inputs_count) 198 | if self._MAX_CHANNEL_CHANGE_INPUTS_COUNT == 1: 199 | self.btn_add_channel_change.setEnabled(False) 200 | self.btn_del_channel_change.setEnabled(False) 201 | elif self.visible_channel_change_inputs_count == 1: 202 | self.btn_add_channel_change.setEnabled(True) 203 | self.btn_del_channel_change.setEnabled(False) 204 | elif self.visible_channel_change_inputs_count >= self._MAX_CHANNEL_CHANGE_INPUTS_COUNT: 205 | self.btn_add_channel_change.setEnabled(False) 206 | self.btn_del_channel_change.setEnabled(True) 207 | else: 208 | self.btn_add_channel_change.setEnabled(True) 209 | self.btn_del_channel_change.setEnabled(True) 210 | 211 | 212 | def get_properties(self)->ChangeChannelProperties: 213 | input_op_names_and_order_dims = {} 214 | for i in range(self.visible_change_order_dim_inputs_count): 215 | name = self.input_op_names_and_order_dims[i]["name"].currentText() 216 | value = self.input_op_names_and_order_dims[i]["value"].text() 217 | if name and value: 218 | input_op_names_and_order_dims[name] = literal_eval(value) 219 | channel_change_inputs = {} 220 | for i in range(self.visible_channel_change_inputs_count): 221 | name = self.channel_change_inputs[i]["name"].currentText() 222 | value = self.channel_change_inputs[i]["value"].text() 223 | if name and value: 224 | channel_change_inputs[name] = literal_eval(value) 225 | return ChangeChannelProperties( 226 | input_op_names_and_order_dims=input_op_names_and_order_dims, 227 | channel_change_inputs=channel_change_inputs, 228 | ) 229 | 230 | def accept(self) -> None: 231 | # value check 232 | invalid = False 233 | props = self.get_properties() 234 | print(props) 235 | err_msgs = [] 236 | if len(props.channel_change_inputs) < 1 and len(props.input_op_names_and_order_dims) < 1: 237 | err_msgs.append("At least one of input_op_names_and_order_dims or channel_change_inputs must be specified.") 238 | invalid = True 239 | 240 | for key, val in props.channel_change_inputs.items(): 241 | if type(val) is not int: 242 | err_msgs.append(f'channel_change_inputs value must be integer. {key}: {val}') 243 | invalid = True 244 | 245 | if invalid: 246 | for m in err_msgs: 247 | print(m) 248 | MessageBox.error(err_msgs, "change channel", parent=self) 249 | return 250 | return super().accept() 251 | 252 | 253 | 254 | if __name__ == "__main__": 255 | import signal 256 | import os 257 | # handle SIGINT to make the app terminate on CTRL+C 258 | signal.signal(signal.SIGINT, signal.SIG_DFL) 259 | 260 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 261 | 262 | app = QtWidgets.QApplication([]) 263 | window = ChangeChannelWidgets() 264 | window.show() 265 | 266 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_change_input_ouput_shape.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | import os 5 | from ast import literal_eval 6 | 7 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 8 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 10 | 11 | 12 | ChangeInputOutputShapeProperties = namedtuple("ChangeInputOutputShapeProperties", 13 | [ 14 | "input_names", 15 | "input_shapes", 16 | "output_names", 17 | "output_shapes", 18 | ]) 19 | 20 | 21 | class ChangeInputOutputShapeWidget(QtWidgets.QDialog): 22 | _DEFAULT_WINDOW_WIDTH = 500 23 | 24 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None: 25 | super().__init__(parent) 26 | self.setModal(False) 27 | self.setWindowTitle("change input output shape") 28 | self.graph = graph 29 | self.initUI(graph) 30 | 31 | def initUI(self, graph: OnnxGraph=None): 32 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 33 | set_font(self, font_size=BASE_FONT_SIZE) 34 | 35 | base_layout = QtWidgets.QVBoxLayout() 36 | 37 | # layout 38 | layout = QtWidgets.QVBoxLayout() 39 | 40 | label_input_header = QtWidgets.QLabel("Inputs") 41 | set_font(label_input_header, font_size=LARGE_FONT_SIZE, bold=True) 42 | self.cb_change_input_shape = QtWidgets.QCheckBox("change input shape") 43 | self.cb_change_input_shape.setChecked(True) 44 | self.cb_change_input_shape.stateChanged.connect(self.cb_change_input_shape_changed) 45 | layout.addWidget(label_input_header) 46 | layout.addWidget(self.cb_change_input_shape) 47 | 48 | self.input_widgets = {} 49 | for name in graph.inputs.keys(): 50 | input = graph.inputs[name] 51 | shape = input.get_shape() 52 | 53 | self.input_widgets[name] = {} 54 | self.input_widgets[name]["label"] = QtWidgets.QLabel(name) 55 | self.input_widgets[name]["shape"] = QtWidgets.QLineEdit() 56 | self.input_widgets[name]["shape"].setText(str(shape)) 57 | child_layout = QtWidgets.QHBoxLayout() 58 | child_layout.addWidget(self.input_widgets[name]["label"]) 59 | child_layout.addWidget(QtWidgets.QLabel(f": ")) 60 | child_layout.addWidget(self.input_widgets[name]["shape"]) 61 | layout.addLayout(child_layout) 62 | 63 | layout.addSpacing(15) 64 | 65 | label_output_header = QtWidgets.QLabel("Outputs") 66 | set_font(label_output_header, font_size=LARGE_FONT_SIZE, bold=True) 67 | self.cb_change_output_shape = QtWidgets.QCheckBox("change output shape") 68 | self.cb_change_output_shape.setChecked(True) 69 | self.cb_change_output_shape.stateChanged.connect(self.cb_change_output_shape_changed) 70 | layout.addWidget(label_output_header) 71 | layout.addWidget(self.cb_change_output_shape) 72 | 73 | self.output_widgets = {} 74 | for name in graph.outputs.keys(): 75 | output = graph.outputs[name] 76 | shape = output.get_shape() 77 | 78 | self.output_widgets[name] = {} 79 | self.output_widgets[name]["label"] = QtWidgets.QLabel(name) 80 | self.output_widgets[name]["shape"] = QtWidgets.QLineEdit() 81 | self.output_widgets[name]["shape"].setText(str(shape)) 82 | child_layout = QtWidgets.QHBoxLayout() 83 | child_layout.addWidget(self.output_widgets[name]["label"]) 84 | child_layout.addWidget(QtWidgets.QLabel(f": ")) 85 | child_layout.addWidget(self.output_widgets[name]["shape"]) 86 | layout.addLayout(child_layout) 87 | 88 | # add layout 89 | base_layout.addLayout(layout) 90 | 91 | # Dialog button 92 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 93 | QtWidgets.QDialogButtonBox.Cancel) 94 | btn.accepted.connect(self.accept) 95 | btn.rejected.connect(self.reject) 96 | # layout.addWidget(btn) 97 | base_layout.addWidget(btn) 98 | 99 | self.setLayout(base_layout) 100 | 101 | def cb_change_input_shape_changed(self, e): 102 | change = self.cb_change_input_shape.isChecked() 103 | for widgets in self.input_widgets.values(): 104 | widgets["shape"].setEnabled(change) 105 | 106 | def cb_change_output_shape_changed(self, e): 107 | change = self.cb_change_output_shape.isChecked() 108 | for widgets in self.output_widgets.values(): 109 | widgets["shape"].setEnabled(change) 110 | 111 | def get_properties(self)->ChangeInputOutputShapeProperties: 112 | input_names = None 113 | input_shapes = None 114 | output_names = None 115 | output_shapes = None 116 | errors = [] 117 | 118 | if self.cb_change_input_shape.isChecked(): 119 | input_names = [] 120 | input_shapes = [] 121 | for widgets in self.input_widgets.values(): 122 | name = widgets["label"].text() 123 | str_shape = widgets["shape"].text().strip() 124 | if str_shape == "": 125 | errors.append(f"{name}: not entered") 126 | continue 127 | try: 128 | shape = literal_eval(str_shape) 129 | except BaseException as e: 130 | print(e) 131 | errors.append(f"{name}: {str(e)}") 132 | continue 133 | input_names.append(name) 134 | input_shapes.append(shape) 135 | 136 | if self.cb_change_output_shape.isChecked(): 137 | output_names = [] 138 | output_shapes = [] 139 | for widgets in self.output_widgets.values(): 140 | name = widgets["label"].text() 141 | str_shape = widgets["shape"].text().strip() 142 | if str_shape == "": 143 | errors.append(f"{name}: not entered") 144 | continue 145 | try: 146 | shape = literal_eval(str_shape) 147 | except BaseException as e: 148 | print(e) 149 | errors.append(f"{name}: {e}") 150 | continue 151 | output_names.append(name) 152 | output_shapes.append(shape) 153 | 154 | return ChangeInputOutputShapeProperties( 155 | input_names=input_names, 156 | input_shapes=input_shapes, 157 | output_names=output_names, 158 | output_shapes=output_shapes, 159 | ), errors 160 | 161 | def accept(self) -> None: 162 | # value check 163 | invalid = False 164 | props, errors = self.get_properties() 165 | print(props) 166 | if len(errors) > 0: 167 | err_msgs = ["shape convert error"] 168 | err_msgs += errors 169 | invalid = True 170 | else: 171 | err_msgs = [] 172 | if props.input_names is None and props.output_names is None: 173 | err_msgs.append("input shape or output shape must be change.") 174 | invalid = True 175 | if props.input_names is not None and len(props.input_names) != len(self.graph.inputs): 176 | err_msgs.append("input shape.") 177 | invalid = True 178 | if props.output_names is not None and len(props.output_names) != len(self.graph.outputs): 179 | err_msgs.append("input shape or output shape must be change.") 180 | invalid = True 181 | 182 | if invalid: 183 | for m in err_msgs: 184 | print(m) 185 | MessageBox.error(err_msgs, "change input output shape", parent=self) 186 | return 187 | return super().accept() 188 | 189 | 190 | if __name__ == "__main__": 191 | import signal 192 | import os 193 | import onnx 194 | import onnx_graphsurgeon as gs 195 | from onnxgraphqt.graph.onnx_node_graph import ONNXNodeGraph 196 | # handle SIGINT to make the app terminate on CTRL+C 197 | signal.signal(signal.SIGINT, signal.SIG_DFL) 198 | 199 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 200 | 201 | app = QtWidgets.QApplication([]) 202 | model_path = os.path.join(os.path.dirname(__file__), "../data/mobilenetv2-7.onnx") 203 | onnx_model = onnx.load(model_path) 204 | onnx_graph = gs.import_onnx(onnx_model) 205 | graph = ONNXNodeGraph(name=onnx_graph.name, 206 | opset=onnx_graph.opset, 207 | doc_string=onnx_graph.doc_string, 208 | import_domains=onnx_graph.import_domains, 209 | producer_name=onnx_model.producer_name, 210 | producer_version=onnx_model.producer_version, 211 | ir_version=onnx_model.ir_version, 212 | model_version=onnx_model.model_version) 213 | graph.load_onnx_graph(onnx_graph,) 214 | window = ChangeInputOutputShapeWidget(graph.to_data()) 215 | window.show() 216 | 217 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_change_opset.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | from onnxgraphqt.utils.opset import DEFAULT_OPSET 6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 8 | 9 | 10 | ChangeOpsetProperties = namedtuple("ChangeOpsetProperties", 11 | [ 12 | "opset", 13 | ]) 14 | 15 | 16 | class ChangeOpsetWidget(QtWidgets.QDialog): 17 | _DEFAULT_WINDOW_WIDTH = 300 18 | 19 | def __init__(self, current_opset, parent=None) -> None: 20 | super().__init__(parent) 21 | self.setModal(False) 22 | self.setWindowTitle("change opset") 23 | self.current_opset = current_opset 24 | self.initUI() 25 | 26 | def initUI(self): 27 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 28 | set_font(self, font_size=BASE_FONT_SIZE) 29 | 30 | base_layout = QtWidgets.QVBoxLayout() 31 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 32 | 33 | # Form layout 34 | layout = QtWidgets.QFormLayout() 35 | layout.setLabelAlignment(QtCore.Qt.AlignRight) 36 | self.ledit_opset = QtWidgets.QLineEdit() 37 | self.ledit_opset.setText(str(self.current_opset)) 38 | self.ledit_opset.setPlaceholderText("opset") 39 | 40 | label = QtWidgets.QLabel("opset number to be changed") 41 | set_font(label, font_size=LARGE_FONT_SIZE, bold=True) 42 | layout.addRow(label, self.ledit_opset) 43 | 44 | # add layout 45 | base_layout.addLayout(layout) 46 | 47 | # Dialog button 48 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 49 | QtWidgets.QDialogButtonBox.Cancel) 50 | btn.accepted.connect(self.accept) 51 | btn.rejected.connect(self.reject) 52 | # layout.addWidget(btn) 53 | base_layout.addWidget(btn) 54 | 55 | self.setLayout(base_layout) 56 | 57 | def get_properties(self)->ChangeOpsetProperties: 58 | return ChangeOpsetProperties( 59 | opset=self.ledit_opset.text() 60 | ) 61 | 62 | def accept(self) -> None: 63 | # value check 64 | invalid = False 65 | props = self.get_properties() 66 | print(props) 67 | err_msgs = [] 68 | if not str(props.opset).isdecimal(): 69 | err_msgs.append("opset must be unsigned integer") 70 | invalid = True 71 | if invalid: 72 | for m in err_msgs: 73 | print(m) 74 | MessageBox.error(err_msgs, "change opset", parent=self) 75 | return 76 | return super().accept() 77 | 78 | 79 | if __name__ == "__main__": 80 | import signal 81 | import os 82 | # handle SIGINT to make the app terminate on CTRL+C 83 | signal.signal(signal.SIGINT, signal.SIG_DFL) 84 | 85 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 86 | 87 | app = QtWidgets.QApplication([]) 88 | window = ChangeOpsetWidget(current_opset=DEFAULT_OPSET) 89 | window.show() 90 | 91 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_combine_network.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List, Dict 3 | import signal 4 | from PySide2 import QtCore, QtWidgets, QtGui 5 | from ast import literal_eval 6 | 7 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 8 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 9 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 10 | 11 | 12 | CombineNetworkProperties = namedtuple("CombineNetworkProperties", 13 | [ 14 | "combine_with_current_graph", 15 | "srcop_destop", 16 | "op_prefixes_after_merging", 17 | "input_onnx_file_paths", 18 | "output_of_onnx_file_in_the_process_of_fusion", 19 | ]) 20 | 21 | 22 | class CombineNetworkWidgets(QtWidgets.QDialog): 23 | _DEFAULT_WINDOW_WIDTH = 560 24 | _MAX_INPUTS_FILE = 3 25 | _MAX_OP_PREFIXES = _MAX_INPUTS_FILE + 1 26 | 27 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None: 28 | super().__init__(parent) 29 | self.setModal(False) 30 | self.setWindowTitle("combine graphs") 31 | self.graph = graph 32 | self.initUI() 33 | # self.updateUI(self.graph) 34 | 35 | def initUI(self): 36 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 37 | set_font(self, font_size=BASE_FONT_SIZE) 38 | 39 | base_layout = QtWidgets.QVBoxLayout() 40 | 41 | # src dst 42 | layout_src_dst = QtWidgets.QFormLayout() 43 | lbl_src_dst = QtWidgets.QLabel("srcop_destop") 44 | set_font(lbl_src_dst, font_size=LARGE_FONT_SIZE, bold=True) 45 | self.src_op_dst_op = QtWidgets.QTextEdit() 46 | self.src_op_dst_op.setMinimumHeight(20*7) 47 | self.src_op_dst_op.setPlaceholderText( 48 | 'e.g. '+ '\n' + \ 49 | ' [' + '\n' + \ 50 | ' [' + '\n' + \ 51 | ' "model1_out1_opname","model2_in1_opname",' + '\n' + \ 52 | ' "model1_out2_opname","model2_in2_opname",' + '\n' + \ 53 | ' ],' + '\n' + \ 54 | ' ]' 55 | ) 56 | layout_src_dst.addRow(lbl_src_dst, self.src_op_dst_op) 57 | 58 | # input_onnx_file_paths 59 | layout_inputs_file_paths = QtWidgets.QVBoxLayout() 60 | lbl_inputs = QtWidgets.QLabel("input_onnx_file_paths") 61 | set_font(lbl_inputs, font_size=LARGE_FONT_SIZE, bold=True) 62 | layout_inputs_file_paths.addWidget(lbl_inputs) 63 | 64 | self.inputs_file_paths = {} 65 | for index in range(self._MAX_INPUTS_FILE): 66 | self.inputs_file_paths[index] = {} 67 | self.inputs_file_paths[index]["base"] = QtWidgets.QWidget() 68 | self.inputs_file_paths[index]["layout"] = QtWidgets.QHBoxLayout(self.inputs_file_paths[index]["base"]) 69 | self.inputs_file_paths[index]["layout"].setContentsMargins(0, 0, 0, 0) 70 | self.inputs_file_paths[index]["label"] = QtWidgets.QLabel(f"file{index+1}") 71 | self.inputs_file_paths[index]["value"] = QtWidgets.QLineEdit() 72 | self.inputs_file_paths[index]["button"] = QtWidgets.QPushButton("select") 73 | self.inputs_file_paths[index]["button"].setObjectName(f"button{index}") 74 | 75 | def clicked(index): 76 | def func(e:bool): 77 | return self.btn_inputs_file_path_clicked(index, e) 78 | return func 79 | 80 | self.inputs_file_paths[index]["button"].clicked.connect(clicked(index)) 81 | 82 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["label"]) 83 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["value"]) 84 | self.inputs_file_paths[index]["layout"].addWidget(self.inputs_file_paths[index]["button"]) 85 | layout_inputs_file_paths.addWidget(self.inputs_file_paths[index]["base"]) 86 | 87 | self.combine_with_current_graph = QtWidgets.QCheckBox("combine_with_current_graph") 88 | if self.graph is not None and len(self.graph.nodes) > 0: 89 | self.combine_with_current_graph.setChecked(True) 90 | else: 91 | self.combine_with_current_graph.setChecked(False) 92 | self.combine_with_current_graph.setEnabled(False) 93 | self.combine_with_current_graph.clicked.connect(self.combine_with_current_graph_clicked) 94 | layout_inputs_file_paths.addWidget(self.combine_with_current_graph) 95 | layout_inputs_file_paths.setAlignment(self.combine_with_current_graph, QtCore.Qt.AlignRight) 96 | 97 | # 98 | layout_op_prefixes_base = QtWidgets.QVBoxLayout() 99 | lbl_op_prefixes = QtWidgets.QLabel("op_prefixes_after_merging") 100 | set_font(lbl_op_prefixes, font_size=LARGE_FONT_SIZE, bold=True) 101 | layout_op_prefixes_base.addWidget(lbl_op_prefixes) 102 | 103 | layout_op_prefixes = QtWidgets.QFormLayout() 104 | self.op_prefixes = {} 105 | for index in range(self._MAX_OP_PREFIXES): 106 | self.op_prefixes[index] = {} 107 | label = f"file{index}" if index > 0 else "current graph" 108 | self.op_prefixes[index]["label"] = QtWidgets.QLabel(label) 109 | self.op_prefixes[index]["value"] = QtWidgets.QLineEdit() 110 | layout_op_prefixes.addRow(self.op_prefixes[index]["label"], self.op_prefixes[index]["value"]) 111 | if self.graph is not None and len(self.graph.nodes) == 0: 112 | self.op_prefixes[0]["label"].setVisible(False) 113 | self.op_prefixes[0]["value"].setVisible(False) 114 | layout_op_prefixes_base.addLayout(layout_op_prefixes) 115 | 116 | 117 | # 118 | layout_output_in_process = QtWidgets.QVBoxLayout() 119 | self.output_in_process = QtWidgets.QCheckBox("output_of_onnx_file_in_the_process_of_fusion") 120 | self.output_in_process.setChecked(False) 121 | layout_output_in_process.addWidget(self.output_in_process) 122 | layout_output_in_process.setAlignment(self.output_in_process, QtCore.Qt.AlignRight) 123 | 124 | # add layout 125 | base_layout.addLayout(layout_src_dst) 126 | base_layout.addLayout(layout_inputs_file_paths) 127 | base_layout.addLayout(layout_op_prefixes_base) 128 | base_layout.addLayout(layout_output_in_process) 129 | 130 | # Dialog button 131 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 132 | QtWidgets.QDialogButtonBox.Cancel) 133 | btn.accepted.connect(self.accept) 134 | btn.rejected.connect(self.reject) 135 | # layout.addWidget(btn) 136 | base_layout.addWidget(btn) 137 | 138 | self.setLayout(base_layout) 139 | 140 | def btn_inputs_file_path_clicked(self, index:int, e:bool): 141 | 142 | file_name, filter = QtWidgets.QFileDialog.getOpenFileName( 143 | self, 144 | caption=f"Open ONNX Model File({index+1})", 145 | # directory=os.path.abspath(os.curdir), 146 | filter="*.onnx *.json") 147 | if not file_name: 148 | return 149 | self.inputs_file_paths[index]["value"].setText(file_name) 150 | print(file_name) 151 | 152 | def combine_with_current_graph_clicked(self, e:bool): 153 | checked = self.combine_with_current_graph.isChecked() 154 | self.op_prefixes[0]["label"].setVisible(checked) 155 | self.op_prefixes[0]["value"].setVisible(checked) 156 | 157 | def get_properties(self)->CombineNetworkProperties: 158 | combine_with_current_graph=self.combine_with_current_graph.isChecked() 159 | output_of_onnx_file_in_the_process_of_fusion=self.output_in_process.isChecked() 160 | 161 | srcop_destop = [] 162 | srt_op = self.src_op_dst_op.toPlainText() 163 | if srt_op: 164 | srcop_destop = literal_eval(srt_op) 165 | 166 | input_files = [] 167 | for index, widgets in self.inputs_file_paths.items(): 168 | file = widgets["value"].text().strip() 169 | if file: 170 | input_files.append(file) 171 | 172 | op_prefix = [] 173 | for index, widgets in self.op_prefixes.items(): 174 | if combine_with_current_graph is False and index == 0: 175 | continue 176 | prefix = widgets["value"].text().strip() 177 | if prefix: 178 | op_prefix.append(prefix) 179 | if len(op_prefix) == 0: 180 | op_prefix = None 181 | 182 | return CombineNetworkProperties( 183 | combine_with_current_graph=combine_with_current_graph, 184 | srcop_destop=srcop_destop, 185 | op_prefixes_after_merging=op_prefix, 186 | input_onnx_file_paths=input_files, 187 | output_of_onnx_file_in_the_process_of_fusion=output_of_onnx_file_in_the_process_of_fusion, 188 | ) 189 | 190 | def accept(self) -> None: 191 | # value check 192 | invalid = False 193 | props = self.get_properties() 194 | print(props) 195 | err_msgs = [] 196 | if len(props.input_onnx_file_paths) == 0: 197 | err_msgs.append("- input_onnx_file_paths must be specified") 198 | invalid = True 199 | if props.combine_with_current_graph: 200 | if len(props.srcop_destop) != len(props.input_onnx_file_paths): 201 | err_msgs.append("- The number of srcop_destops must be (number of input_onnx_file_paths + 1).") 202 | invalid = True 203 | if props.op_prefixes_after_merging: 204 | if len(props.input_onnx_file_paths) + 1 != len(props.op_prefixes_after_merging): 205 | err_msgs.append("- The number of op_prefixes_after_merging must match (number of input_onnx_file_paths + 1).") 206 | invalid = True 207 | else: 208 | if len(props.srcop_destop) != len(props.input_onnx_file_paths)-1: 209 | err_msgs.append("- The number of srcop_destop must match the number of input_onnx_file_paths.") 210 | invalid = True 211 | if props.op_prefixes_after_merging: 212 | if len(props.input_onnx_file_paths) != len(props.op_prefixes_after_merging): 213 | err_msgs.append("- The number of op_prefixes_after_merging must match number of input_onnx_file_paths.") 214 | invalid = True 215 | 216 | if invalid: 217 | for m in err_msgs: 218 | print(m) 219 | MessageBox.error(err_msgs, "combine network", parent=self) 220 | return 221 | return super().accept() 222 | 223 | 224 | 225 | if __name__ == "__main__": 226 | import signal 227 | import os 228 | # handle SIGINT to make the app terminate on CTRL+C 229 | signal.signal(signal.SIGINT, signal.SIG_DFL) 230 | 231 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 232 | 233 | app = QtWidgets.QApplication([]) 234 | window = CombineNetworkWidgets() 235 | window.show() 236 | 237 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_constant_shrink.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | from ast import literal_eval 5 | 6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 8 | 9 | 10 | ConstantShrinkProperties = namedtuple("ConstantShrinkProperties", 11 | [ 12 | "mode", 13 | "forced_extraction_op_names", 14 | "forced_extraction_constant_names", 15 | "disable_auto_downcast", 16 | ]) 17 | 18 | MODE = [ 19 | "shrink", 20 | "npy" 21 | ] 22 | 23 | class ConstantShrinkWidgets(QtWidgets.QDialog): 24 | _DEFAULT_WINDOW_WIDTH = 500 25 | 26 | def __init__(self, parent=None) -> None: 27 | super().__init__(parent) 28 | self.setModal(False) 29 | self.setWindowTitle("constant shrink") 30 | self.initUI() 31 | 32 | def initUI(self): 33 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 34 | set_font(self, font_size=BASE_FONT_SIZE) 35 | 36 | base_layout = QtWidgets.QVBoxLayout() 37 | 38 | # Form layout 39 | layout = QtWidgets.QFormLayout() 40 | layout.setLabelAlignment(QtCore.Qt.AlignRight) 41 | lbl_mode = QtWidgets.QLabel("mode") 42 | set_font(lbl_mode, font_size=LARGE_FONT_SIZE, bold=True) 43 | self.cmb_mode = QtWidgets.QComboBox() 44 | for m in MODE: 45 | self.cmb_mode.addItem(m) 46 | 47 | lbl_forced_extraction_op_names = QtWidgets.QLabel("forced_extraction_op_names") 48 | set_font(lbl_forced_extraction_op_names, font_size=LARGE_FONT_SIZE, bold=True) 49 | self.tb_forced_extraction_op_names = QtWidgets.QLineEdit() 50 | self.tb_forced_extraction_op_names.setPlaceholderText("e.g. ['aaa','bbb','ccc']") 51 | 52 | lbl_forced_extraction_constant_names = QtWidgets.QLabel("forced_extraction_constant_names") 53 | set_font(lbl_forced_extraction_constant_names, font_size=LARGE_FONT_SIZE, bold=True) 54 | self.tb_forced_extraction_constant_names = QtWidgets.QLineEdit() 55 | self.tb_forced_extraction_constant_names.setPlaceholderText("e.g. ['aaa','bbb','ccc']") 56 | 57 | layout.addRow(lbl_mode, self.cmb_mode) 58 | layout.addRow(lbl_forced_extraction_op_names, self.tb_forced_extraction_op_names) 59 | layout.addRow(lbl_forced_extraction_constant_names, self.tb_forced_extraction_constant_names) 60 | 61 | layout2 = QtWidgets.QVBoxLayout() 62 | self.check_auto_downcast = QtWidgets.QCheckBox("auto_downcast") 63 | self.check_auto_downcast.setChecked(True) 64 | layout2.addWidget(self.check_auto_downcast) 65 | layout2.setAlignment(self.check_auto_downcast, QtCore.Qt.AlignRight) 66 | 67 | # add layout 68 | base_layout.addLayout(layout) 69 | base_layout.addLayout(layout2) 70 | 71 | # Dialog button 72 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 73 | QtWidgets.QDialogButtonBox.Cancel) 74 | btn.accepted.connect(self.accept) 75 | btn.rejected.connect(self.reject) 76 | # layout.addWidget(btn) 77 | base_layout.addWidget(btn) 78 | 79 | self.setLayout(base_layout) 80 | 81 | def get_properties(self)->ConstantShrinkProperties: 82 | mode = self.cmb_mode.currentText() 83 | forced_extraction_op_names = [] 84 | forced_extraction_constant_names = [] 85 | disable_auto_downcast = not self.check_auto_downcast.isChecked() 86 | 87 | op_names = self.tb_forced_extraction_op_names.text() 88 | if op_names: 89 | try: 90 | forced_extraction_op_names = literal_eval(op_names) 91 | except Exception as e: 92 | raise e 93 | 94 | constant_names = self.tb_forced_extraction_constant_names.text() 95 | if constant_names: 96 | try: 97 | forced_extraction_constant_names = literal_eval(constant_names) 98 | except Exception as e: 99 | raise e 100 | 101 | return ConstantShrinkProperties( 102 | mode=mode, 103 | forced_extraction_op_names=forced_extraction_op_names, 104 | forced_extraction_constant_names=forced_extraction_constant_names, 105 | disable_auto_downcast=disable_auto_downcast 106 | ) 107 | 108 | def accept(self) -> None: 109 | # value check 110 | invalid = False 111 | try: 112 | props = self.get_properties() 113 | print(props) 114 | err_msgs = [] 115 | except Exception as e: 116 | print(e) 117 | return 118 | if not props.mode in MODE: 119 | err_msgs.append(f"- mode is select from {'or'.join(MODE)}") 120 | invalid = True 121 | if invalid: 122 | for m in err_msgs: 123 | print(m) 124 | MessageBox.error(err_msgs, "constant shrink", parent=self) 125 | return 126 | return super().accept() 127 | 128 | 129 | if __name__ == "__main__": 130 | import signal 131 | import os 132 | # handle SIGINT to make the app terminate on CTRL+C 133 | signal.signal(signal.SIGINT, signal.SIG_DFL) 134 | 135 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 136 | 137 | app = QtWidgets.QApplication([]) 138 | window = ConstantShrinkWidgets() 139 | window.show() 140 | 141 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_delete_node.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List 3 | import signal 4 | from PySide2 import QtCore, QtWidgets, QtGui 5 | 6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 7 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 8 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 9 | 10 | 11 | DeleteNodeProperties = namedtuple("DeleteNodeProperties", 12 | [ 13 | "remove_node_names", 14 | ]) 15 | 16 | class DeleteNodeWidgets(QtWidgets.QDialog): 17 | # _DEFAULT_WINDOW_WIDTH = 400 18 | _MAX_REMOVE_NODE_NAMES_COUNT = 5 19 | 20 | def __init__(self, graph: OnnxGraph=None, selected_nodes:List[str]=[], parent=None) -> None: 21 | super().__init__(parent) 22 | self.setModal(False) 23 | self.setWindowTitle("delete node") 24 | self.graph = graph 25 | self.selected_nodes = selected_nodes 26 | self.initUI() 27 | self.updateUI(self.graph, selected_nodes) 28 | 29 | def initUI(self): 30 | set_font(self, font_size=BASE_FONT_SIZE) 31 | 32 | base_layout = QtWidgets.QVBoxLayout() 33 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 34 | 35 | # attributes 36 | self.layout = QtWidgets.QVBoxLayout() 37 | lbl = QtWidgets.QLabel("remove_node_names ") 38 | set_font(lbl, font_size=LARGE_FONT_SIZE, bold=True) 39 | self.layout.addWidget(lbl) 40 | self.visible_remove_node_names_count = 1 41 | self.remove_node_names = {} 42 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT): 43 | self.remove_node_names[index] = {} 44 | self.remove_node_names[index]["base"] = QtWidgets.QWidget() 45 | self.remove_node_names[index]["layout"] = QtWidgets.QHBoxLayout(self.remove_node_names[index]["base"]) 46 | self.remove_node_names[index]["layout"].setContentsMargins(0, 0, 0, 0) 47 | self.remove_node_names[index]["name"] = QtWidgets.QComboBox() 48 | self.remove_node_names[index]["name"].setEditable(True) 49 | self.remove_node_names[index]["layout"].addWidget(self.remove_node_names[index]["name"]) 50 | self.layout.addWidget(self.remove_node_names[index]["base"]) 51 | self.btn_add = QtWidgets.QPushButton("+") 52 | self.btn_del = QtWidgets.QPushButton("-") 53 | self.btn_add.clicked.connect(self.btn_add_clicked) 54 | self.btn_del.clicked.connect(self.btn_del_clicked) 55 | self.set_visible() 56 | layout_btn = QtWidgets.QHBoxLayout() 57 | layout_btn.addWidget(self.btn_add) 58 | layout_btn.addWidget(self.btn_del) 59 | self.layout.addLayout(layout_btn) 60 | 61 | # add layout 62 | base_layout.addLayout(self.layout) 63 | 64 | # Dialog button 65 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 66 | QtWidgets.QDialogButtonBox.Cancel) 67 | btn.accepted.connect(self.accept) 68 | btn.rejected.connect(self.reject) 69 | # layout.addWidget(btn) 70 | base_layout.addWidget(btn) 71 | 72 | self.setLayout(base_layout) 73 | 74 | def updateUI(self, graph: OnnxGraph=None, selected_nodes:List[str]=[]): 75 | if graph: 76 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT): 77 | self.remove_node_names[index]["name"].clear() 78 | for name, node in graph.nodes.items(): 79 | self.remove_node_names[index]["name"].addItem(name) 80 | self.remove_node_names[index]["name"].setCurrentIndex(-1) 81 | 82 | index = 0 83 | visible_count = 0 84 | for index in range(self._MAX_REMOVE_NODE_NAMES_COUNT): 85 | if len(selected_nodes) < visible_count + 1: 86 | break 87 | node = selected_nodes[index] 88 | if node in graph.nodes.keys(): 89 | self.remove_node_names[visible_count]["name"].setCurrentText(node) 90 | visible_count += 1 91 | self.visible_remove_node_names_count = min(visible_count + 1, self._MAX_REMOVE_NODE_NAMES_COUNT) 92 | self.set_visible() 93 | 94 | def set_visible(self): 95 | for key, widgets in self.remove_node_names.items(): 96 | widgets["base"].setVisible(key < self.visible_remove_node_names_count) 97 | if self.visible_remove_node_names_count == 1: 98 | self.btn_add.setEnabled(True) 99 | self.btn_del.setEnabled(False) 100 | elif self.visible_remove_node_names_count >= self._MAX_REMOVE_NODE_NAMES_COUNT: 101 | self.btn_add.setEnabled(False) 102 | self.btn_del.setEnabled(True) 103 | else: 104 | self.btn_add.setEnabled(True) 105 | self.btn_del.setEnabled(True) 106 | 107 | def btn_add_clicked(self, e): 108 | self.visible_remove_node_names_count = min(max(0, self.visible_remove_node_names_count + 1), self._MAX_REMOVE_NODE_NAMES_COUNT) 109 | self.set_visible() 110 | 111 | def btn_del_clicked(self, e): 112 | self.visible_remove_node_names_count = min(max(0, self.visible_remove_node_names_count - 1), self._MAX_REMOVE_NODE_NAMES_COUNT) 113 | self.set_visible() 114 | 115 | 116 | def get_properties(self)->DeleteNodeProperties: 117 | 118 | remove_node_names = [] 119 | for i in range(self.visible_remove_node_names_count): 120 | name = self.remove_node_names[i]["name"].currentText() 121 | if str.strip(name): 122 | remove_node_names.append(name) 123 | 124 | return DeleteNodeProperties( 125 | remove_node_names=remove_node_names 126 | ) 127 | 128 | def accept(self) -> None: 129 | # value check 130 | invalid = False 131 | props = self.get_properties() 132 | print(props) 133 | err_msgs = [] 134 | if len(props.remove_node_names) == 0: 135 | err_msgs.append("- remove_node_names is not set.") 136 | invalid = True 137 | 138 | if invalid: 139 | for m in err_msgs: 140 | print(m) 141 | MessageBox.error(err_msgs, "delete node", parent=self) 142 | return 143 | return super().accept() 144 | 145 | 146 | 147 | if __name__ == "__main__": 148 | import signal 149 | import os 150 | # handle SIGINT to make the app terminate on CTRL+C 151 | signal.signal(signal.SIGINT, signal.SIG_DFL) 152 | 153 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 154 | 155 | app = QtWidgets.QApplication([]) 156 | window = DeleteNodeWidgets() 157 | window.show() 158 | 159 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_extract_network.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 6 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 8 | 9 | 10 | ExtractNetworkProperties = namedtuple("ExtractNetworkProperties", 11 | [ 12 | "input_op_names", 13 | "output_op_names", 14 | ]) 15 | 16 | class ExtractNetworkWidgets(QtWidgets.QDialog): 17 | _DEFAULT_WINDOW_WIDTH = 400 18 | _MAX_INPUT_OP_NAMES_COUNT = 5 19 | _MAX_OUTPUT_OP_NAMES_COUNT = 5 20 | 21 | def __init__(self, graph: OnnxGraph=None, parent=None) -> None: 22 | super().__init__(parent) 23 | self.setModal(False) 24 | self.setWindowTitle("extract network") 25 | self.graph = graph 26 | self.initUI() 27 | self.updateUI(self.graph) 28 | 29 | def initUI(self): 30 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 31 | set_font(self, font_size=BASE_FONT_SIZE) 32 | 33 | base_layout = QtWidgets.QVBoxLayout() 34 | 35 | # inputs 36 | self.layout_inputs = QtWidgets.QVBoxLayout() 37 | 38 | lbl_input_op_names = QtWidgets.QLabel("input_op_names") 39 | set_font(lbl_input_op_names, font_size=LARGE_FONT_SIZE, bold=True) 40 | 41 | self.layout_inputs.addWidget(lbl_input_op_names) 42 | 43 | self.visible_input_op_names_count = 1 44 | self.widgets_inputs = {} 45 | for index in range(self._MAX_INPUT_OP_NAMES_COUNT): 46 | self.widgets_inputs[index] = {} 47 | self.widgets_inputs[index]["base"] = QtWidgets.QWidget() 48 | self.widgets_inputs[index]["layout"] = QtWidgets.QHBoxLayout(self.widgets_inputs[index]["base"]) 49 | self.widgets_inputs[index]["layout"].setContentsMargins(0, 0, 0, 0) 50 | self.widgets_inputs[index]["name"] = QtWidgets.QComboBox() 51 | self.widgets_inputs[index]["name"].setEditable(True) 52 | self.widgets_inputs[index]["layout"].addWidget(self.widgets_inputs[index]["name"]) 53 | self.layout_inputs.addWidget(self.widgets_inputs[index]["base"]) 54 | self.btn_add_inputs = QtWidgets.QPushButton("+") 55 | self.btn_del_inputs = QtWidgets.QPushButton("-") 56 | self.btn_add_inputs.clicked.connect(self.btn_add_inputs_clicked) 57 | self.btn_del_inputs.clicked.connect(self.btn_del_inputs_clicked) 58 | 59 | layout_btn_inputs = QtWidgets.QHBoxLayout() 60 | layout_btn_inputs.addWidget(self.btn_add_inputs) 61 | layout_btn_inputs.addWidget(self.btn_del_inputs) 62 | self.layout_inputs.addLayout(layout_btn_inputs) 63 | 64 | # outputs 65 | self.layout_outputs = QtWidgets.QVBoxLayout() 66 | 67 | lbl_output_op_names = QtWidgets.QLabel("output_op_names") 68 | set_font(lbl_output_op_names, font_size=LARGE_FONT_SIZE, bold=True) 69 | self.layout_outputs.addWidget(lbl_output_op_names) 70 | 71 | self.visible_output_op_names_count = 1 72 | self.widgets_outputs = {} 73 | for index in range(self._MAX_OUTPUT_OP_NAMES_COUNT): 74 | self.widgets_outputs[index] = {} 75 | self.widgets_outputs[index]["base"] = QtWidgets.QWidget() 76 | self.widgets_outputs[index]["layout"] = QtWidgets.QHBoxLayout(self.widgets_outputs[index]["base"]) 77 | self.widgets_outputs[index]["layout"].setContentsMargins(0, 0, 0, 0) 78 | self.widgets_outputs[index]["name"] = QtWidgets.QComboBox() 79 | self.widgets_outputs[index]["name"].setEditable(True) 80 | self.widgets_outputs[index]["layout"].addWidget(self.widgets_outputs[index]["name"]) 81 | self.layout_outputs.addWidget(self.widgets_outputs[index]["base"]) 82 | self.btn_add_outputs = QtWidgets.QPushButton("+") 83 | self.btn_del_outputs = QtWidgets.QPushButton("-") 84 | self.btn_add_outputs.clicked.connect(self.btn_add_outputs_clicked) 85 | self.btn_del_outputs.clicked.connect(self.btn_del_outputs_clicked) 86 | 87 | layout_btn_outputs = QtWidgets.QHBoxLayout() 88 | layout_btn_outputs.addWidget(self.btn_add_outputs) 89 | layout_btn_outputs.addWidget(self.btn_del_outputs) 90 | self.layout_outputs.addLayout(layout_btn_outputs) 91 | 92 | # Dialog button 93 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 94 | QtWidgets.QDialogButtonBox.Cancel) 95 | btn.accepted.connect(self.accept) 96 | btn.rejected.connect(self.reject) 97 | 98 | # add layout 99 | base_layout.addLayout(self.layout_inputs) 100 | base_layout.addSpacing(10) 101 | base_layout.addLayout(self.layout_outputs) 102 | base_layout.addSpacing(10) 103 | base_layout.addWidget(btn) 104 | 105 | self.setLayout(base_layout) 106 | 107 | self.set_btn_visible() 108 | 109 | def updateUI(self, graph: OnnxGraph=None): 110 | if graph: 111 | for index in range(self._MAX_INPUT_OP_NAMES_COUNT): 112 | self.widgets_inputs[index]["name"].clear() 113 | for name in graph.node_inputs.keys(): 114 | self.widgets_inputs[index]["name"].addItem(name) 115 | self.widgets_inputs[index]["name"].setCurrentIndex(-1) 116 | 117 | for index in range(self._MAX_OUTPUT_OP_NAMES_COUNT): 118 | self.widgets_outputs[index]["name"].clear() 119 | for name in graph.node_inputs.keys(): 120 | self.widgets_outputs[index]["name"].addItem(name) 121 | self.widgets_outputs[index]["name"].setCurrentIndex(-1) 122 | 123 | def set_btn_visible(self): 124 | for key, widgets in self.widgets_inputs.items(): 125 | widgets["base"].setVisible(key < self.visible_input_op_names_count) 126 | for key, widgets in self.widgets_outputs.items(): 127 | widgets["base"].setVisible(key < self.visible_output_op_names_count) 128 | 129 | if self.visible_input_op_names_count == 1: 130 | self.btn_add_inputs.setEnabled(True) 131 | self.btn_del_inputs.setEnabled(False) 132 | elif self.visible_input_op_names_count >= self._MAX_INPUT_OP_NAMES_COUNT: 133 | self.btn_add_inputs.setEnabled(False) 134 | self.btn_del_inputs.setEnabled(True) 135 | else: 136 | self.btn_add_inputs.setEnabled(True) 137 | self.btn_del_inputs.setEnabled(True) 138 | 139 | if self.visible_output_op_names_count == 1: 140 | self.btn_add_outputs.setEnabled(True) 141 | self.btn_del_outputs.setEnabled(False) 142 | elif self.visible_output_op_names_count >= self._MAX_OUTPUT_OP_NAMES_COUNT: 143 | self.btn_add_outputs.setEnabled(False) 144 | self.btn_del_outputs.setEnabled(True) 145 | else: 146 | self.btn_add_outputs.setEnabled(True) 147 | self.btn_del_outputs.setEnabled(True) 148 | 149 | def btn_add_inputs_clicked(self, e): 150 | self.visible_input_op_names_count = min(max(0, self.visible_input_op_names_count + 1), self._MAX_INPUT_OP_NAMES_COUNT) 151 | self.set_btn_visible() 152 | 153 | def btn_del_inputs_clicked(self, e): 154 | self.visible_input_op_names_count = min(max(0, self.visible_input_op_names_count - 1), self._MAX_INPUT_OP_NAMES_COUNT) 155 | self.set_btn_visible() 156 | 157 | def btn_add_outputs_clicked(self, e): 158 | self.visible_output_op_names_count = min(max(0, self.visible_output_op_names_count + 1), self._MAX_OUTPUT_OP_NAMES_COUNT) 159 | self.set_btn_visible() 160 | 161 | def btn_del_outputs_clicked(self, e): 162 | self.visible_output_op_names_count = min(max(0, self.visible_output_op_names_count - 1), self._MAX_OUTPUT_OP_NAMES_COUNT) 163 | self.set_btn_visible() 164 | 165 | 166 | def get_properties(self)->ExtractNetworkProperties: 167 | 168 | inputs_op_names = [] 169 | for i in range(self.visible_input_op_names_count): 170 | name = self.widgets_inputs[i]["name"].currentText() 171 | if str.strip(name): 172 | inputs_op_names.append(name) 173 | 174 | outputs_op_names = [] 175 | for i in range(self.visible_output_op_names_count): 176 | name = self.widgets_outputs[i]["name"].currentText() 177 | if str.strip(name): 178 | outputs_op_names.append(name) 179 | 180 | return ExtractNetworkProperties( 181 | input_op_names=inputs_op_names, 182 | output_op_names=outputs_op_names 183 | ) 184 | 185 | def accept(self) -> None: 186 | # value check 187 | invalid = False 188 | props = self.get_properties() 189 | print(props) 190 | err_msgs = [] 191 | if len(props.input_op_names) == 0: 192 | err_msgs.append("- input_op_names is not set") 193 | invalid = True 194 | if len(props.output_op_names) == 0: 195 | err_msgs.append("- output_op_names is not set") 196 | invalid = True 197 | 198 | if invalid: 199 | for m in err_msgs: 200 | print(m) 201 | MessageBox.error(err_msgs, "extract network", parent=self) 202 | return 203 | return super().accept() 204 | 205 | 206 | 207 | if __name__ == "__main__": 208 | import signal 209 | import os 210 | # handle SIGINT to make the app terminate on CTRL+C 211 | signal.signal(signal.SIGINT, signal.SIG_DFL) 212 | 213 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 214 | 215 | app = QtWidgets.QApplication([]) 216 | window = ExtractNetworkWidgets() 217 | window.show() 218 | 219 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_generate_operator.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List, Dict 3 | import signal 4 | from PySide2 import QtCore, QtWidgets, QtGui 5 | from ast import literal_eval 6 | import numpy as np 7 | 8 | from onnxgraphqt.utils.opset import DEFAULT_OPSET 9 | from onnxgraphqt.utils.operators import onnx_opsets, opnames, OperatorVersion, latest_opset 10 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 11 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 12 | 13 | 14 | AVAILABLE_DTYPES = [ 15 | 'float32', 16 | 'float64', 17 | 'int32', 18 | 'int64', 19 | 'str', 20 | ] 21 | 22 | 23 | GenerateOperatorProperties = namedtuple("GenerateOperatorProperties", 24 | [ 25 | "op_type", 26 | "opset", 27 | "op_name", 28 | "input_variables", 29 | "output_variables", 30 | "attributes", 31 | ]) 32 | 33 | 34 | class GenerateOperatorWidgets(QtWidgets.QDialog): 35 | _DEFAULT_WINDOW_WIDTH = 500 36 | _MAX_INPUT_VARIABLES_COUNT = 5 37 | _MAX_OUTPUT_VARIABLES_COUNT = 5 38 | _MAX_ATTRIBUTES_COUNT = 10 39 | 40 | def __init__(self, opset=DEFAULT_OPSET, parent=None) -> None: 41 | super().__init__(parent) 42 | self.setModal(False) 43 | self.setWindowTitle("generate operator") 44 | self.initUI(opset) 45 | 46 | def initUI(self, opset:int): 47 | # self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 48 | set_font(self, font_size=BASE_FONT_SIZE) 49 | 50 | base_layout = QtWidgets.QVBoxLayout() 51 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 52 | 53 | # Form layout 54 | layout = QtWidgets.QFormLayout() 55 | layout.setLabelAlignment(QtCore.Qt.AlignRight) 56 | self.cmb_optype = QtWidgets.QComboBox() 57 | self.cmb_optype.setEditable(True) 58 | lbl_op_type = QtWidgets.QLabel("op_type") 59 | set_font(lbl_op_type, font_size=LARGE_FONT_SIZE, bold=True) 60 | layout.addRow(lbl_op_type, self.cmb_optype) 61 | 62 | self.cmb_opset = QtWidgets.QComboBox() 63 | self.cmb_opset.setEditable(True) 64 | for i in range(1,latest_opset + 1): 65 | self.cmb_opset.addItem(str(i), i) 66 | lbl_opset = QtWidgets.QLabel("opset") 67 | set_font(lbl_opset, font_size=LARGE_FONT_SIZE, bold=True) 68 | layout.addRow(lbl_opset, self.cmb_opset) 69 | 70 | self.tb_opname = QtWidgets.QLineEdit() 71 | self.tb_opname.setText("") 72 | lbl_op_name = QtWidgets.QLabel("op_name") 73 | set_font(lbl_op_name, font_size=LARGE_FONT_SIZE, bold=True) 74 | layout.addRow(lbl_op_name, self.tb_opname) 75 | 76 | # variables 77 | self.layout_valiables = QtWidgets.QVBoxLayout() 78 | self.visible_input_valiables_count = 1 79 | self.visible_output_valiables_count = 1 80 | 81 | self.add_input_valiables = {} 82 | self.add_output_valiables = {} 83 | for i in range(self._MAX_INPUT_VARIABLES_COUNT): 84 | self.create_variables_widget(i, is_input=True) 85 | for i in range(self._MAX_OUTPUT_VARIABLES_COUNT): 86 | self.create_variables_widget(i, is_input=False) 87 | 88 | self.btn_add_input_valiables = QtWidgets.QPushButton("+") 89 | self.btn_del_input_valiables = QtWidgets.QPushButton("-") 90 | self.btn_add_input_valiables.clicked.connect(self.btn_add_input_valiables_clicked) 91 | self.btn_del_input_valiables.clicked.connect(self.btn_del_input_valiables_clicked) 92 | self.btn_add_output_valiables = QtWidgets.QPushButton("+") 93 | self.btn_del_output_valiables = QtWidgets.QPushButton("-") 94 | self.btn_add_output_valiables.clicked.connect(self.btn_add_output_valiables_clicked) 95 | self.btn_del_output_valiables.clicked.connect(self.btn_del_output_valiables_clicked) 96 | 97 | self.layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 98 | lbl_inp_val = QtWidgets.QLabel("input valiables [optional]") 99 | set_font(lbl_inp_val, font_size=LARGE_FONT_SIZE, bold=True) 100 | self.layout_valiables.addWidget(lbl_inp_val) 101 | for key, widgets in self.add_input_valiables.items(): 102 | self.layout_valiables.addWidget(widgets["base"]) 103 | layout_btn_input = QtWidgets.QHBoxLayout() 104 | layout_btn_input.addWidget(self.btn_add_input_valiables) 105 | layout_btn_input.addWidget(self.btn_del_input_valiables) 106 | self.layout_valiables.addLayout(layout_btn_input) 107 | 108 | self.layout_valiables.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 109 | lbl_out_val = QtWidgets.QLabel("output valiables [optional]") 110 | set_font(lbl_out_val, font_size=LARGE_FONT_SIZE, bold=True) 111 | self.layout_valiables.addWidget(lbl_out_val) 112 | for key, widgets in self.add_output_valiables.items(): 113 | self.layout_valiables.addWidget(widgets["base"]) 114 | layout_btn_output = QtWidgets.QHBoxLayout() 115 | layout_btn_output.addWidget(self.btn_add_output_valiables) 116 | layout_btn_output.addWidget(self.btn_del_output_valiables) 117 | self.layout_valiables.addLayout(layout_btn_output) 118 | 119 | # add_attributes 120 | self.layout_attributes = QtWidgets.QVBoxLayout() 121 | self.layout_attributes.addItem(QtWidgets.QSpacerItem(self._DEFAULT_WINDOW_WIDTH, 20)) 122 | lbl_atrributes = QtWidgets.QLabel("atrributes [optional]") 123 | set_font(lbl_atrributes, font_size=LARGE_FONT_SIZE, bold=True) 124 | self.layout_attributes.addWidget(lbl_atrributes) 125 | self.visible_attributes_count = 3 126 | self.attributes = {} 127 | for index in range(self._MAX_ATTRIBUTES_COUNT): 128 | self.attributes[index] = {} 129 | self.attributes[index]["base"] = QtWidgets.QWidget() 130 | self.attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.attributes[index]["base"]) 131 | self.attributes[index]["layout"].setContentsMargins(0, 0, 0, 0) 132 | self.attributes[index]["name"] = QtWidgets.QLineEdit() 133 | self.attributes[index]["name"].setPlaceholderText("name") 134 | self.attributes[index]["value"] = QtWidgets.QLineEdit() 135 | self.attributes[index]["value"].setPlaceholderText("value") 136 | self.attributes[index]["layout"].addWidget(self.attributes[index]["name"]) 137 | self.attributes[index]["layout"].addWidget(self.attributes[index]["value"]) 138 | self.layout_attributes.addWidget(self.attributes[index]["base"]) 139 | self.btn_add_attributes = QtWidgets.QPushButton("+") 140 | self.btn_del_attributes = QtWidgets.QPushButton("-") 141 | self.btn_add_attributes.clicked.connect(self.btn_add_attributes_clicked) 142 | self.btn_del_attributes.clicked.connect(self.btn_del_attributes_clicked) 143 | layout_btn_attributes = QtWidgets.QHBoxLayout() 144 | layout_btn_attributes.addWidget(self.btn_add_attributes) 145 | layout_btn_attributes.addWidget(self.btn_del_attributes) 146 | self.layout_attributes.addLayout(layout_btn_attributes) 147 | 148 | # add layout 149 | base_layout.addLayout(layout) 150 | base_layout.addLayout(self.layout_valiables) 151 | base_layout.addLayout(self.layout_attributes) 152 | 153 | # Dialog button 154 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 155 | QtWidgets.QDialogButtonBox.Cancel) 156 | btn.accepted.connect(self.accept) 157 | btn.rejected.connect(self.reject) 158 | base_layout.addWidget(btn) 159 | 160 | self.setLayout(base_layout) 161 | 162 | self.cmb_optype.currentIndexChanged.connect(self.cmb_optype_currentIndexChanged) 163 | self.cmb_opset.currentIndexChanged.connect(self.cmb_opset_currentIndexChanged) 164 | self.cmb_opset.setCurrentIndex(opset-1) 165 | 166 | def create_variables_widget(self, index:int, is_input=True)->QtWidgets.QBoxLayout: 167 | if is_input: 168 | self.add_input_valiables[index] = {} 169 | self.add_input_valiables[index]["base"] = QtWidgets.QWidget() 170 | self.add_input_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_input_valiables[index]["base"]) 171 | self.add_input_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0) 172 | self.add_input_valiables[index]["name"] = QtWidgets.QLineEdit() 173 | self.add_input_valiables[index]["name"].setPlaceholderText("name") 174 | self.add_input_valiables[index]["dtype"] = QtWidgets.QComboBox() 175 | for dtype in AVAILABLE_DTYPES: 176 | self.add_input_valiables[index]["dtype"].addItem(dtype) 177 | self.add_input_valiables[index]["dtype"].setEditable(True) 178 | self.add_input_valiables[index]["dtype"].setFixedSize(100, 20) 179 | self.add_input_valiables[index]["shape"] = QtWidgets.QLineEdit() 180 | self.add_input_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`") 181 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["name"]) 182 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["dtype"]) 183 | self.add_input_valiables[index]["layout"].addWidget(self.add_input_valiables[index]["shape"]) 184 | else: 185 | self.add_output_valiables[index] = {} 186 | self.add_output_valiables[index]["base"] = QtWidgets.QWidget() 187 | self.add_output_valiables[index]["layout"] = QtWidgets.QHBoxLayout(self.add_output_valiables[index]["base"]) 188 | self.add_output_valiables[index]["layout"].setContentsMargins(0, 0, 0, 0) 189 | self.add_output_valiables[index]["name"] = QtWidgets.QLineEdit() 190 | self.add_output_valiables[index]["name"].setPlaceholderText("name") 191 | self.add_output_valiables[index]["dtype"] = QtWidgets.QComboBox() 192 | for dtype in AVAILABLE_DTYPES: 193 | self.add_output_valiables[index]["dtype"].addItem(dtype) 194 | self.add_output_valiables[index]["dtype"].setEditable(True) 195 | self.add_output_valiables[index]["dtype"].setFixedSize(100, 20) 196 | self.add_output_valiables[index]["dtype"].setPlaceholderText("dtype. e.g. `float32`") 197 | self.add_output_valiables[index]["shape"] = QtWidgets.QLineEdit() 198 | self.add_output_valiables[index]["shape"].setPlaceholderText("shape. e.g. `[1, 2, 3]`") 199 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["name"]) 200 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["dtype"]) 201 | self.add_output_valiables[index]["layout"].addWidget(self.add_output_valiables[index]["shape"]) 202 | 203 | def set_visible_input_valiables(self): 204 | for key, widgets in self.add_input_valiables.items(): 205 | widgets["base"].setVisible(key < self.visible_input_valiables_count) 206 | if self.visible_input_valiables_count == 0: 207 | self.btn_add_input_valiables.setEnabled(True) 208 | self.btn_del_input_valiables.setEnabled(False) 209 | elif self.visible_input_valiables_count >= self._MAX_INPUT_VARIABLES_COUNT: 210 | self.btn_add_input_valiables.setEnabled(False) 211 | self.btn_del_input_valiables.setEnabled(True) 212 | else: 213 | self.btn_add_input_valiables.setEnabled(True) 214 | self.btn_del_input_valiables.setEnabled(True) 215 | 216 | def set_visible_output_valiables(self): 217 | for key, widgets in self.add_output_valiables.items(): 218 | widgets["base"].setVisible(key < self.visible_output_valiables_count) 219 | if self.visible_output_valiables_count == 0: 220 | self.btn_add_output_valiables.setEnabled(True) 221 | self.btn_del_output_valiables.setEnabled(False) 222 | elif self.visible_output_valiables_count >= self._MAX_OUTPUT_VARIABLES_COUNT: 223 | self.btn_add_output_valiables.setEnabled(False) 224 | self.btn_del_output_valiables.setEnabled(True) 225 | else: 226 | self.btn_add_output_valiables.setEnabled(True) 227 | self.btn_del_output_valiables.setEnabled(True) 228 | 229 | def set_visible_add_op_attributes(self): 230 | for key, widgets in self.attributes.items(): 231 | widgets["base"].setVisible(key < self.visible_attributes_count) 232 | if self.visible_attributes_count == 0: 233 | self.btn_add_attributes.setEnabled(True) 234 | self.btn_del_attributes.setEnabled(False) 235 | elif self.visible_attributes_count >= self._MAX_ATTRIBUTES_COUNT: 236 | self.btn_add_attributes.setEnabled(False) 237 | self.btn_del_attributes.setEnabled(True) 238 | else: 239 | self.btn_add_attributes.setEnabled(True) 240 | self.btn_del_attributes.setEnabled(True) 241 | 242 | def btn_add_input_valiables_clicked(self, e): 243 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count + 1), self._MAX_INPUT_VARIABLES_COUNT) 244 | self.set_visible_input_valiables() 245 | 246 | def btn_del_input_valiables_clicked(self, e): 247 | self.visible_input_valiables_count = min(max(0, self.visible_input_valiables_count - 1), self._MAX_INPUT_VARIABLES_COUNT) 248 | self.set_visible_input_valiables() 249 | 250 | def btn_add_output_valiables_clicked(self, e): 251 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count + 1), self._MAX_OUTPUT_VARIABLES_COUNT) 252 | self.set_visible_output_valiables() 253 | 254 | def btn_del_output_valiables_clicked(self, e): 255 | self.visible_output_valiables_count = min(max(0, self.visible_output_valiables_count - 1), self._MAX_OUTPUT_VARIABLES_COUNT) 256 | self.set_visible_output_valiables() 257 | 258 | def btn_add_attributes_clicked(self, e): 259 | self.visible_attributes_count = min(max(0, self.visible_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT) 260 | self.set_visible_add_op_attributes() 261 | 262 | def btn_del_attributes_clicked(self, e): 263 | self.visible_attributes_count = min(max(0, self.visible_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT) 264 | self.set_visible_add_op_attributes() 265 | 266 | def cmb_optype_currentIndexChanged(self, selected_index:int): 267 | selected_operator: OperatorVersion = self.cmb_optype.currentData() 268 | if selected_operator: 269 | self.visible_input_valiables_count = selected_operator.inputs 270 | self.visible_output_valiables_count = selected_operator.outputs 271 | self.visible_attributes_count = min(max(0, len(selected_operator.attrs)), self._MAX_ATTRIBUTES_COUNT) 272 | 273 | for i, att in enumerate(selected_operator.attrs): 274 | self.attributes[i]["name"].setText(att.name) 275 | self.attributes[i]["value"].setText(att.default_value) 276 | for j in range(len(selected_operator.attrs), self._MAX_ATTRIBUTES_COUNT): 277 | self.attributes[j]["name"].setText("") 278 | self.attributes[j]["value"].setText("") 279 | self.set_visible_input_valiables() 280 | self.set_visible_output_valiables() 281 | self.set_visible_add_op_attributes() 282 | 283 | def cmb_opset_currentIndexChanged(self, selected_index:int): 284 | current_opset:int = self.cmb_opset.currentData() 285 | current_optype = self.cmb_optype.currentText() 286 | current_optype_index = 0 287 | self.cmb_optype.clear() 288 | for i, op in enumerate(onnx_opsets): 289 | for v in op.versions: 290 | if v.since_opset <= current_opset: 291 | if op.name == current_optype: 292 | current_optype_index = self.cmb_optype.count() 293 | self.cmb_optype.addItem(op.name, v) 294 | break 295 | self.cmb_optype.setCurrentIndex(current_optype_index) 296 | 297 | def get_properties(self)->GenerateOperatorProperties: 298 | 299 | op_type = self.cmb_optype.currentText() 300 | opset = self.cmb_opset.currentText() 301 | if opset: 302 | opset = literal_eval(opset) 303 | if not isinstance(opset, int): 304 | opset = "" 305 | op_name = self.tb_opname.text() 306 | 307 | input_variables = {} 308 | output_variables = {} 309 | for i in range(self.visible_input_valiables_count): 310 | name = self.add_input_valiables[i]["name"].text() 311 | dtype = self.add_input_valiables[i]["dtype"].currentText() 312 | shape = self.add_input_valiables[i]["shape"].text() 313 | if name and dtype and shape: 314 | input_variables[name] = [dtype, literal_eval(shape)] 315 | for i in range(self.visible_output_valiables_count): 316 | name = self.add_output_valiables[i]["name"].text() 317 | dtype = self.add_output_valiables[i]["dtype"].currentText() 318 | shape = self.add_output_valiables[i]["shape"].text() 319 | if name and dtype and shape: 320 | output_variables[name] = [dtype, literal_eval(shape)] 321 | 322 | if len(input_variables) == 0: 323 | input_variables = None 324 | if len(output_variables) == 0: 325 | output_variables = None 326 | 327 | attributes = {} 328 | for i in range(self.visible_attributes_count): 329 | name = self.attributes[i]["name"].text() 330 | value = self.attributes[i]["value"].text() 331 | if name and value: 332 | attributes[name] = literal_eval(value) 333 | if len(attributes) == 0: 334 | attributes = None 335 | 336 | return GenerateOperatorProperties( 337 | op_type=op_type, 338 | opset=opset, 339 | op_name=op_name, 340 | input_variables=input_variables, 341 | output_variables=output_variables, 342 | attributes=attributes, 343 | ) 344 | 345 | def accept(self) -> None: 346 | # value check 347 | invalid = False 348 | props = self.get_properties() 349 | print(props) 350 | err_msgs = [] 351 | if not props.op_type in opnames: 352 | err_msgs.append("- op_type is invalid.") 353 | invalid = True 354 | if not isinstance(props.opset, int): 355 | err_msgs.append("- opset must be unsigned integer.") 356 | invalid = True 357 | if not props.op_name: 358 | err_msgs.append("- op_name is not set.") 359 | invalid = True 360 | if invalid: 361 | for m in err_msgs: 362 | print(m) 363 | MessageBox.error(err_msgs, "generate operator", parent=self) 364 | return 365 | return super().accept() 366 | 367 | 368 | 369 | if __name__ == "__main__": 370 | import signal 371 | import os 372 | # handle SIGINT to make the app terminate on CTRL+C 373 | signal.signal(signal.SIGINT, signal.SIG_DFL) 374 | 375 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 376 | 377 | app = QtWidgets.QApplication([]) 378 | window = GenerateOperatorWidgets() 379 | window.show() 380 | 381 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_inference_test.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from typing import List, Union 3 | import tempfile 4 | import signal 5 | from PySide2 import QtCore, QtWidgets, QtGui 6 | from ast import literal_eval 7 | import onnx 8 | import onnxruntime as ort 9 | 10 | import os 11 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 12 | from onnxgraphqt.utils.color import replace_PrintColor 13 | 14 | ONNX_PROVIDER_TABLE = { 15 | "TensorrtExecutionProvider": ["tensorrt"], 16 | "CUDAExecutionProvider": ["cuda"], 17 | "OpenVINOExecutionProvider": ["openvino_cpu", "openvino_gpu"], 18 | "CPUExecutionProvider": ["cpu"] 19 | } 20 | 21 | class InferenceProcess(QtCore.QThread): 22 | signal = QtCore.Signal(str) 23 | btn_signal = QtCore.Signal(bool) 24 | def __init__(self, parent=None) -> None: 25 | super().__init__(parent) 26 | self.onnx_file_path: str = "" 27 | self.batch_size: int = 1 28 | self.fixes_shapes: List[int] = None 29 | self.test_loop_count: int = 1 30 | self.onnx_execution_provider: str = "" 31 | 32 | def set_properties(self, 33 | onnx_file_path: str, 34 | batch_size: int, fixes_shapes: Union[List[int], None], 35 | test_loop_count: int, onnx_execution_provider: str): 36 | self.onnx_file_path = onnx_file_path 37 | self.batch_size = batch_size 38 | self.fixes_shapes = fixes_shapes 39 | self.test_loop_count = test_loop_count 40 | self.onnx_execution_provider = onnx_execution_provider 41 | 42 | def run(self): 43 | self.btn_signal.emit(False) 44 | cmd = f"python3 -m sit4onnx " # fix for docker 45 | cmd += f" --input_onnx_file_path {self.onnx_file_path} " 46 | cmd += f" --batch_size {self.batch_size} " 47 | cmd += f" --test_loop_count {self.test_loop_count} " 48 | cmd += f" --onnx_execution_provider {self.onnx_execution_provider} " 49 | if self.fixes_shapes is not None: 50 | cmd += f" --fixes_shapes {self.fixes_shapes} " 51 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 52 | while True: 53 | line: bytes = proc.stdout.readline() 54 | if line: 55 | txt = line.decode("utf-8").replace("\n", "") 56 | print(txt) 57 | self.signal.emit(replace_PrintColor(txt)) 58 | if not line and proc.poll() is not None: 59 | txt = "\n" 60 | print(txt) 61 | self.signal.emit(txt) 62 | break 63 | self.btn_signal.emit(True) 64 | 65 | class InferenceTestWidgets(QtWidgets.QDialog): 66 | _DEFAULT_WINDOW_WIDTH = 600 67 | 68 | def __init__(self, onnx_model: onnx.ModelProto=None, parent=None) -> None: 69 | super().__init__(parent) 70 | self.setModal(False) 71 | self.setWindowTitle("inference test") 72 | self.fp = tempfile.NamedTemporaryFile() 73 | self.tmp_filename = self.fp.name + ".onnx" 74 | try: 75 | onnx.save(onnx_model, self.tmp_filename) 76 | except BaseException as e: 77 | raise e 78 | self.initUI() 79 | 80 | def __del__(self) -> None: 81 | print(f"delete {self.tmp_filename}") 82 | try: 83 | os.remove(self.tmp_filename) 84 | except BaseException as e: 85 | raise e 86 | 87 | def initUI(self): 88 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 89 | set_font(self, font_size=BASE_FONT_SIZE) 90 | 91 | base_layout = QtWidgets.QVBoxLayout() 92 | # base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 93 | 94 | # add layout 95 | footer_layout = QtWidgets.QVBoxLayout() 96 | main_layout = QtWidgets.QFormLayout() 97 | base_layout.addLayout(main_layout) 98 | base_layout.addLayout(footer_layout) 99 | 100 | # Form layout 101 | main_layout.setLabelAlignment(QtCore.Qt.AlignRight) 102 | 103 | self.tb_batch_size = QtWidgets.QLineEdit() 104 | self.tb_batch_size.setText("1") 105 | self.tb_batch_size.setAlignment(QtCore.Qt.AlignRight) 106 | lbl_batch_size = QtWidgets.QLabel("Batch Size: ") 107 | set_font(lbl_batch_size, font_size=LARGE_FONT_SIZE, bold=True) 108 | main_layout.addRow(lbl_batch_size, self.tb_batch_size) 109 | 110 | self.tb_fixed_shapes = QtWidgets.QLineEdit() 111 | self.tb_fixed_shapes.setAlignment(QtCore.Qt.AlignRight) 112 | lbl_fixed_shapes = QtWidgets.QLabel("Fixed Shapes: ") 113 | set_font(lbl_fixed_shapes, font_size=LARGE_FONT_SIZE, bold=True) 114 | main_layout.addRow(lbl_fixed_shapes, self.tb_fixed_shapes) 115 | 116 | self.tb_test_loop_count = QtWidgets.QLineEdit() 117 | self.tb_test_loop_count.setText("10") 118 | self.tb_test_loop_count.setAlignment(QtCore.Qt.AlignRight) 119 | lbl_test_loop_count = QtWidgets.QLabel("Test Loop Count: ") 120 | set_font(lbl_test_loop_count, font_size=LARGE_FONT_SIZE, bold=True) 121 | main_layout.addRow(lbl_test_loop_count, self.tb_test_loop_count) 122 | 123 | self.cmb_onnx_execution_provider = QtWidgets.QComboBox() 124 | self.cmb_onnx_execution_provider.setEditable(False) 125 | for provider in ort.get_available_providers(): 126 | if provider in ONNX_PROVIDER_TABLE.keys(): 127 | providers = ONNX_PROVIDER_TABLE[provider] 128 | for p in providers: 129 | self.cmb_onnx_execution_provider.addItem(f"{p} ", p) 130 | lbl_onnx_execution_provider = QtWidgets.QLabel("Execution Provider: ") 131 | set_font(lbl_onnx_execution_provider, font_size=LARGE_FONT_SIZE, bold=True) 132 | main_layout.addRow(lbl_onnx_execution_provider, self.cmb_onnx_execution_provider) 133 | 134 | # textbox 135 | self.tb_console = QtWidgets.QTextBrowser() 136 | self.tb_console.setReadOnly(True) 137 | self.tb_console.setStyleSheet(f"font-size: {BASE_FONT_SIZE}px; color: #FFFFFF; background-color: #505050;") 138 | footer_layout.addWidget(self.tb_console) 139 | 140 | # Dialog button 141 | self.btn_infer = QtWidgets.QPushButton("inference") 142 | self.btn_infer.clicked.connect(self.btn_infer_clicked) 143 | footer_layout.addWidget(self.btn_infer) 144 | 145 | # inferenceProcess 146 | self.inference_process = InferenceProcess() 147 | self.inference_process.signal.connect(self.update_text) 148 | self.inference_process.btn_signal.connect(self.btn_infer.setEnabled) 149 | 150 | self.setLayout(base_layout) 151 | 152 | def update_text(self, txt: str): 153 | # self.tb_console.appendPlainText(txt) 154 | self.tb_console.append(txt) 155 | 156 | def btn_infer_clicked(self) -> None: 157 | batch_size = literal_eval(self.tb_batch_size.text()) 158 | try: 159 | fixes_shapes = literal_eval(self.tb_fixed_shapes.text()) 160 | except: 161 | fixes_shapes = None 162 | test_loop_count = literal_eval(self.tb_test_loop_count.text()) 163 | onnx_execution_provider = self.cmb_onnx_execution_provider.currentData() 164 | self.inference_process.set_properties(self.tmp_filename, 165 | batch_size, 166 | fixes_shapes, 167 | test_loop_count, 168 | onnx_execution_provider) 169 | self.inference_process.start() 170 | 171 | 172 | if __name__ == "__main__": 173 | import signal 174 | import os 175 | from onnxgraphqt.utils.color import PrintColor 176 | # handle SIGINT to make the app terminate on CTRL+C 177 | signal.signal(signal.SIGINT, signal.SIG_DFL) 178 | 179 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 180 | 181 | app = QtWidgets.QApplication([]) 182 | 183 | model_path = os.path.join(os.path.dirname(__file__), "../data/mobilenetv2-7.onnx") 184 | model = onnx.load_model(model_path) 185 | window = InferenceTestWidgets(model) 186 | window.show() 187 | 188 | window.update_text(f"{PrintColor.BLACK[1]}BLACK{PrintColor.RESET[1]}") 189 | window.update_text(f"{PrintColor.RED[1]}RED{PrintColor.RESET[1]}") 190 | window.update_text(f"{PrintColor.GREEN[1]}GREEN{PrintColor.RESET[1]}") 191 | window.update_text(f"{PrintColor.YELLOW[1]}YELLOW{PrintColor.RESET[1]}") 192 | window.update_text(f"{PrintColor.BLUE[1]}BLUE{PrintColor.RESET[1]}") 193 | window.update_text(f"{PrintColor.MAGENTA[1]}MAGENTA{PrintColor.RESET[1]}") 194 | window.update_text(f"{PrintColor.CYAN[1]}CYAN{PrintColor.RESET[1]}") 195 | window.update_text(f"{PrintColor.WHITE[1]}WHITE{PrintColor.RESET[1]}") 196 | 197 | app.exec_() 198 | del window 199 | -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_initialize_batchsize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | from onnxgraphqt.utils.opset import DEFAULT_OPSET 6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 7 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 8 | 9 | 10 | InitializeBatchsizeProperties = namedtuple("InitializeBatchsizeProperties", 11 | [ 12 | "initialization_character_string", 13 | ]) 14 | 15 | 16 | class InitializeBatchsizeWidget(QtWidgets.QDialog): 17 | _DEFAULT_WINDOW_WIDTH = 350 18 | 19 | def __init__(self, current_batchsize="-1", parent=None) -> None: 20 | super().__init__(parent) 21 | self.setModal(False) 22 | self.setWindowTitle("initialize batchsize") 23 | self.initUI() 24 | self.updateUI(current_batchsize) 25 | 26 | def initUI(self): 27 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 28 | set_font(self, font_size=BASE_FONT_SIZE) 29 | 30 | base_layout = QtWidgets.QVBoxLayout() 31 | 32 | # layout 33 | layout = QtWidgets.QVBoxLayout() 34 | lbl_name = QtWidgets.QLabel("Input string to initialize batch size.") 35 | set_font(lbl_name, font_size=LARGE_FONT_SIZE, bold=True) 36 | self.ledit_character = QtWidgets.QLineEdit() 37 | self.ledit_character.setText("-1") 38 | self.ledit_character.setPlaceholderText("initialization_character_string") 39 | layout.addWidget(lbl_name) 40 | layout.addWidget(self.ledit_character) 41 | 42 | # add layout 43 | base_layout.addLayout(layout) 44 | 45 | # Dialog button 46 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 47 | QtWidgets.QDialogButtonBox.Cancel) 48 | btn.accepted.connect(self.accept) 49 | btn.rejected.connect(self.reject) 50 | # layout.addWidget(btn) 51 | base_layout.addWidget(btn) 52 | 53 | self.setLayout(base_layout) 54 | 55 | def updateUI(self, current_batchsize): 56 | self.ledit_character.setText(str(current_batchsize)) 57 | 58 | def get_properties(self)->InitializeBatchsizeProperties: 59 | character = self.ledit_character.text().strip() 60 | return InitializeBatchsizeProperties( 61 | initialization_character_string=character 62 | ) 63 | 64 | def accept(self) -> None: 65 | # value check 66 | invalid = False 67 | props = self.get_properties() 68 | print(props) 69 | err_msgs = [] 70 | if props.initialization_character_string == "": 71 | err_msgs.append("- initialization_character_string is not set.") 72 | invalid = True 73 | if invalid: 74 | for m in err_msgs: 75 | print(m) 76 | MessageBox.error(err_msgs, "initialize batchsize", parent=self) 77 | return 78 | return super().accept() 79 | 80 | 81 | if __name__ == "__main__": 82 | import signal 83 | import os 84 | # handle SIGINT to make the app terminate on CTRL+C 85 | signal.signal(signal.SIGINT, signal.SIG_DFL) 86 | 87 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 88 | 89 | app = QtWidgets.QApplication([]) 90 | window = InitializeBatchsizeWidget() 91 | window.show() 92 | 93 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_menubar.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, List 2 | from dataclasses import dataclass 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | 6 | @dataclass 7 | class SubMenu: 8 | name: str 9 | func: Callable 10 | icon: str 11 | 12 | class Separator: 13 | pass 14 | 15 | @dataclass 16 | class Menu: 17 | name: str 18 | contents: List[Union[SubMenu, Separator]] 19 | 20 | class MenuBarWidget(QtWidgets.QMenuBar): 21 | def __init__(self, menu_list: List[Menu], parent=None) -> None: 22 | super().__init__(parent) 23 | 24 | self.menu_actions = {} 25 | for menu in menu_list: 26 | m = self.addMenu(menu.name) 27 | for content in menu.contents: 28 | if isinstance(content, Separator): 29 | m.addSeparator() 30 | elif isinstance(content, SubMenu): 31 | self.menu_actions[content.name] = m.addAction(content.name, content.func) 32 | if content.icon: 33 | self.menu_actions[content.name].setIcon(QtGui.QIcon(content.icon)) 34 | 35 | 36 | if __name__ == "__main__": 37 | import sys 38 | app = QtWidgets.QApplication(sys.argv) 39 | 40 | m = MenuBarWidget() 41 | m.show() 42 | 43 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_message_box.py: -------------------------------------------------------------------------------- 1 | from PySide2 import QtCore, QtWidgets, QtGui 2 | from typing import Union, List 3 | 4 | 5 | class MessageBox(QtWidgets.QMessageBox): 6 | def __init__(self, 7 | text:Union[str, List[str]], 8 | title:str, 9 | default_button=QtWidgets.QMessageBox.Ok, 10 | icon=QtWidgets.QMessageBox.Icon.Information, 11 | parent=None) -> int: 12 | super().__init__(parent) 13 | if isinstance(text, list): 14 | self.setText('\n'.join(text)) 15 | else: 16 | self.setText(text) 17 | self.setWindowTitle(title) 18 | self.setStandardButtons(default_button) 19 | self.setIcon(icon) 20 | return 21 | 22 | @classmethod 23 | def info(cls, 24 | text:Union[str, List[str]], 25 | title:str, 26 | default_button=QtWidgets.QMessageBox.Ok, 27 | parent=None): 28 | return MessageBox(text, "[INFO] " + title, default_button, icon=MessageBox.Icon.Information, parent=parent).exec_() 29 | 30 | @classmethod 31 | def question(cls, 32 | text:Union[str, List[str]], 33 | title:str, 34 | default_button=QtWidgets.QMessageBox.Yes|QtWidgets.QMessageBox.No, 35 | parent=None): 36 | return MessageBox(text, "[QUESTION] " + title, default_button, icon=MessageBox.Icon.Question, parent=parent).exec_() 37 | 38 | @classmethod 39 | def warn(cls, 40 | text:Union[str, List[str]], 41 | title:str, 42 | default_button=QtWidgets.QMessageBox.Ok, 43 | parent=None): 44 | return MessageBox(text, "[WARN] " + title, default_button, icon=MessageBox.Icon.Warning, parent=parent).exec_() 45 | 46 | @classmethod 47 | def error(cls, 48 | text:Union[str, List[str]], 49 | title:str, 50 | default_button=QtWidgets.QMessageBox.Ok, 51 | parent=None): 52 | return MessageBox(text, "[ERROR] " + title, default_button, icon=MessageBox.Icon.Critical, parent=parent).exec_() 53 | -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_modify_attrs.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List, Dict 3 | import signal 4 | from PySide2 import QtCore, QtWidgets, QtGui 5 | from ast import literal_eval 6 | import numpy as np 7 | from sam4onnx.onnx_attr_const_modify import ( 8 | ATTRIBUTE_DTYPES_TO_NUMPY_TYPES, 9 | CONSTANT_DTYPES_TO_NUMPY_TYPES 10 | ) 11 | 12 | from onnxgraphqt.graph.onnx_node_graph import OnnxGraph 13 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 14 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 15 | 16 | 17 | ModifyAttrsProperties = namedtuple("ModifyAttrsProperties", 18 | [ 19 | "op_name", 20 | "attributes", 21 | "input_constants", 22 | "delete_attributes", 23 | ]) 24 | 25 | 26 | def get_dtype_str(list_or_scalar)->str: 27 | v0 = None 28 | if isinstance(list_or_scalar, list): 29 | v0 = np.ravel(list_or_scalar)[0].tolist() 30 | else: 31 | v0 = list_or_scalar 32 | 33 | if isinstance(v0, int): 34 | dtype = "int32" 35 | elif isinstance(v0, float): 36 | dtype = "float32" 37 | elif isinstance(v0, complex): 38 | dtype = "complex64" 39 | elif isinstance(v0, bool): 40 | dtype = "bool" 41 | elif v0 is None: 42 | dtype = "" 43 | else: 44 | dtype = "str" 45 | return dtype 46 | 47 | 48 | class ModifyAttrsWidgets(QtWidgets.QDialog): 49 | _DEFAULT_WINDOW_WIDTH = 500 50 | _MAX_ATTRIBUTES_COUNT = 5 51 | _MAX_DELETE_ATTRIBUTES_COUNT = 5 52 | _MAX_CONST_COUNT = 4 53 | 54 | def __init__(self, graph: OnnxGraph=None, selected_node: str="", parent=None) -> None: 55 | super().__init__(parent) 56 | self.setModal(False) 57 | self.setWindowTitle("modify attributes") 58 | self.graph = graph 59 | self.initUI() 60 | self.updateUI(self.graph, selected_node) 61 | 62 | def initUI(self): 63 | # self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 64 | set_font(self, font_size=BASE_FONT_SIZE) 65 | 66 | base_layout = QtWidgets.QVBoxLayout() 67 | base_layout.setSizeConstraint(base_layout.SizeConstraint.SetFixedSize) 68 | 69 | # Form layout 70 | layout = QtWidgets.QFormLayout() 71 | layout.setLabelAlignment(QtCore.Qt.AlignRight) 72 | self.cmb_opname = QtWidgets.QComboBox() 73 | self.cmb_opname.setEditable(True) 74 | lbl_opname = QtWidgets.QLabel("opname") 75 | set_font(lbl_opname, font_size=LARGE_FONT_SIZE, bold=True) 76 | layout.addRow(lbl_opname, self.cmb_opname) 77 | 78 | # attributes 79 | self.layout_attributes = QtWidgets.QVBoxLayout() 80 | lbl_attributes = QtWidgets.QLabel("attributes") 81 | set_font(lbl_attributes, font_size=LARGE_FONT_SIZE, bold=True) 82 | self.layout_attributes.addWidget(lbl_attributes) 83 | self.visible_attributes_count = 3 84 | self.edit_attributes = {} 85 | for index in range(self._MAX_ATTRIBUTES_COUNT): 86 | self.edit_attributes[index] = {} 87 | self.edit_attributes[index]["base"] = QtWidgets.QWidget() 88 | self.edit_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.edit_attributes[index]["base"]) 89 | self.edit_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0) 90 | self.edit_attributes[index]["name"] = QtWidgets.QComboBox() 91 | self.edit_attributes[index]["name"].setEditable(True) 92 | self.edit_attributes[index]["name"].setFixedWidth(200) 93 | self.edit_attributes[index]["value"] = QtWidgets.QLineEdit() 94 | self.edit_attributes[index]["value"].setPlaceholderText("value") 95 | self.edit_attributes[index]["dtype"] = QtWidgets.QComboBox() 96 | for key, dtype in ATTRIBUTE_DTYPES_TO_NUMPY_TYPES.items(): 97 | self.edit_attributes[index]["dtype"].addItem(key, dtype) 98 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["name"]) 99 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["value"]) 100 | self.edit_attributes[index]["layout"].addWidget(self.edit_attributes[index]["dtype"]) 101 | self.layout_attributes.addWidget(self.edit_attributes[index]["base"]) 102 | self.btn_add_attributes = QtWidgets.QPushButton("+") 103 | self.btn_del_attributes = QtWidgets.QPushButton("-") 104 | self.btn_add_attributes.clicked.connect(self.btn_add_attributes_clicked) 105 | self.btn_del_attributes.clicked.connect(self.btn_del_attributes_clicked) 106 | layout_btn_attributes = QtWidgets.QHBoxLayout() 107 | layout_btn_attributes.addWidget(self.btn_add_attributes) 108 | layout_btn_attributes.addWidget(self.btn_del_attributes) 109 | self.layout_attributes.addLayout(layout_btn_attributes) 110 | 111 | # input_const 112 | self.layout_const = QtWidgets.QVBoxLayout() 113 | lbl_input_constants = QtWidgets.QLabel("input_constants") 114 | set_font(lbl_input_constants, font_size=LARGE_FONT_SIZE, bold=True) 115 | self.layout_const.addWidget(lbl_input_constants) 116 | self.visible_const_count = 3 117 | self.edit_const = {} 118 | for index in range(self._MAX_CONST_COUNT): 119 | self.edit_const[index] = {} 120 | self.edit_const[index]["base"] = QtWidgets.QWidget() 121 | self.edit_const[index]["layout"] = QtWidgets.QHBoxLayout(self.edit_const[index]["base"]) 122 | self.edit_const[index]["layout"].setContentsMargins(0, 0, 0, 0) 123 | self.edit_const[index]["name"] = QtWidgets.QComboBox() 124 | self.edit_const[index]["name"].setEditable(True) 125 | self.edit_const[index]["name"].setFixedWidth(200) 126 | self.edit_const[index]["value"] = QtWidgets.QLineEdit() 127 | self.edit_const[index]["value"].setPlaceholderText("value") 128 | self.edit_const[index]["dtype"] = QtWidgets.QComboBox() 129 | for key, dtype in CONSTANT_DTYPES_TO_NUMPY_TYPES.items(): 130 | self.edit_const[index]["dtype"].addItem(key, dtype) 131 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["name"]) 132 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["value"]) 133 | self.edit_const[index]["layout"].addWidget(self.edit_const[index]["dtype"]) 134 | self.layout_const.addWidget(self.edit_const[index]["base"]) 135 | self.btn_add_const = QtWidgets.QPushButton("+") 136 | self.btn_del_const = QtWidgets.QPushButton("-") 137 | self.btn_add_const.clicked.connect(self.btn_add_const_clicked) 138 | self.btn_del_const.clicked.connect(self.btn_del_const_clicked) 139 | layout_btn_const = QtWidgets.QHBoxLayout() 140 | layout_btn_const.addWidget(self.btn_add_const) 141 | layout_btn_const.addWidget(self.btn_del_const) 142 | self.layout_const.addLayout(layout_btn_const) 143 | 144 | # delete_attributes 145 | self.layout_delete_attributes = QtWidgets.QVBoxLayout() 146 | lbl_delete_attributes = QtWidgets.QLabel("delete_attributes") 147 | set_font(lbl_delete_attributes, font_size=LARGE_FONT_SIZE, bold=True) 148 | self.layout_delete_attributes.addWidget(lbl_delete_attributes) 149 | self.visible_delete_attributes_count = 3 150 | self.delete_attributes = {} 151 | for index in range(self._MAX_DELETE_ATTRIBUTES_COUNT): 152 | self.delete_attributes[index] = {} 153 | self.delete_attributes[index]["base"] = QtWidgets.QWidget() 154 | self.delete_attributes[index]["layout"] = QtWidgets.QHBoxLayout(self.delete_attributes[index]["base"]) 155 | self.delete_attributes[index]["layout"].setContentsMargins(0, 0, 0, 0) 156 | self.delete_attributes[index]["name"] = QtWidgets.QComboBox() 157 | self.delete_attributes[index]["name"].setPlaceholderText("name") 158 | self.delete_attributes[index]["name"].setEditable(True) 159 | self.delete_attributes[index]["layout"].addWidget(self.delete_attributes[index]["name"]) 160 | self.layout_delete_attributes.addWidget(self.delete_attributes[index]["base"]) 161 | self.btn_add_delete_attributes = QtWidgets.QPushButton("+") 162 | self.btn_del_delete_attributes = QtWidgets.QPushButton("-") 163 | self.btn_add_delete_attributes.clicked.connect(self.btn_add_delete_attributes_clicked) 164 | self.btn_del_delete_attributes.clicked.connect(self.btn_del_delete_attributes_clicked) 165 | layout_btn_delete_attributes = QtWidgets.QHBoxLayout() 166 | layout_btn_delete_attributes.addWidget(self.btn_add_delete_attributes) 167 | layout_btn_delete_attributes.addWidget(self.btn_del_delete_attributes) 168 | self.layout_delete_attributes.addLayout(layout_btn_delete_attributes) 169 | 170 | # Dialog button 171 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 172 | QtWidgets.QDialogButtonBox.Cancel) 173 | btn.accepted.connect(self.accept) 174 | btn.rejected.connect(self.reject) 175 | 176 | # add layout 177 | base_layout.addLayout(layout) 178 | base_layout.addSpacing(10) 179 | base_layout.addLayout(self.layout_attributes) 180 | base_layout.addSpacing(10) 181 | base_layout.addLayout(self.layout_const) 182 | base_layout.addSpacing(10) 183 | base_layout.addLayout(self.layout_delete_attributes) 184 | base_layout.addSpacing(10) 185 | base_layout.addWidget(btn) 186 | 187 | self.set_visible_attributes() 188 | self.set_visible_const() 189 | self.set_visible_delete_attributes() 190 | self.setLayout(base_layout) 191 | 192 | def updateUI(self, graph: OnnxGraph, selected_node: str=""): 193 | 194 | self.cmb_opname.clear() 195 | if self.graph: 196 | for op_name in self.graph.nodes.keys(): 197 | self.cmb_opname.addItem(op_name) 198 | self.cmb_opname.setEditable(True) 199 | self.cmb_opname.setCurrentIndex(-1) 200 | 201 | if self.graph: 202 | def edit_attributes_name_currentIndexChanged(attr_index, current_index): 203 | op_name = self.cmb_opname.currentText() 204 | attr_name = self.edit_attributes[attr_index]["name"].currentText() 205 | attrs = self.graph.nodes[op_name].attrs 206 | if attr_name: 207 | value = attrs[attr_name] 208 | self.edit_attributes[attr_index]["value"].setText(str(value)) 209 | dtype = get_dtype_str(value) 210 | self.edit_attributes[attr_index]["dtype"].setCurrentText(dtype) 211 | 212 | def edit_const_name_currentIndexChanged(attr_index, current_index): 213 | # op_name = self.cmb_opname.currentText() 214 | input_name = self.edit_const[attr_index]["name"].currentText() 215 | node_input = self.graph.node_inputs.get(input_name) 216 | if node_input: 217 | self.edit_const[attr_index]["value"].setText(str(node_input.values)) 218 | dtype = get_dtype_str(node_input.values) 219 | self.edit_const[attr_index]["dtype"].setCurrentText(dtype) 220 | 221 | def cmb_opname_currentIndexChanged(current_index): 222 | op_name = self.cmb_opname.currentText() 223 | 224 | for index in range(self._MAX_ATTRIBUTES_COUNT): 225 | self.edit_attributes[index]["name"].clear() 226 | for attr_name in self.graph.nodes[op_name].attrs.keys(): 227 | self.edit_attributes[index]["name"].addItem(attr_name) 228 | self.edit_attributes[index]["name"].setCurrentIndex(-1) 229 | def on_change(edit_attr_index): 230 | def func(selected_index): 231 | return edit_attributes_name_currentIndexChanged(edit_attr_index, selected_index) 232 | return func 233 | self.edit_attributes[index]["name"].currentIndexChanged.connect(on_change(index)) 234 | self.edit_attributes[index]["value"].setText("") 235 | 236 | for index in range(self._MAX_DELETE_ATTRIBUTES_COUNT): 237 | self.delete_attributes[index]["name"].clear() 238 | for attr_name in self.graph.nodes[op_name].attrs.keys(): 239 | self.delete_attributes[index]["name"].addItem(attr_name) 240 | self.delete_attributes[index]["name"].setCurrentIndex(-1) 241 | 242 | for index in range(self._MAX_CONST_COUNT): 243 | self.edit_const[index]["name"].clear() 244 | for name, val in self.graph.node_inputs.items(): 245 | self.edit_const[index]["name"].addItem(name) 246 | self.edit_const[index]["name"].setCurrentIndex(-1) 247 | def on_change(edit_const_index): 248 | def func(selected_index): 249 | return edit_const_name_currentIndexChanged(edit_const_index, selected_index) 250 | return func 251 | self.edit_const[index]["name"].currentIndexChanged.connect(on_change(index)) 252 | self.edit_const[index]["value"].setText("") 253 | self.cmb_opname.currentIndexChanged.connect(cmb_opname_currentIndexChanged) 254 | 255 | if selected_node: 256 | if selected_node in self.graph.nodes.keys(): 257 | self.cmb_opname.setCurrentText(selected_node) 258 | 259 | def set_visible_attributes(self): 260 | for key, widgets in self.edit_attributes.items(): 261 | widgets["base"].setVisible(key < self.visible_attributes_count) 262 | if self.visible_attributes_count == 1: 263 | self.btn_add_attributes.setEnabled(True) 264 | self.btn_del_attributes.setEnabled(False) 265 | elif self.visible_attributes_count >= self._MAX_ATTRIBUTES_COUNT: 266 | self.btn_add_attributes.setEnabled(False) 267 | self.btn_del_attributes.setEnabled(True) 268 | else: 269 | self.btn_add_attributes.setEnabled(True) 270 | self.btn_del_attributes.setEnabled(True) 271 | 272 | def set_visible_delete_attributes(self): 273 | for key, widgets in self.delete_attributes.items(): 274 | widgets["base"].setVisible(key < self.visible_delete_attributes_count) 275 | if self.visible_delete_attributes_count == 1: 276 | self.btn_add_delete_attributes.setEnabled(True) 277 | self.btn_del_delete_attributes.setEnabled(False) 278 | elif self.visible_delete_attributes_count >= self._MAX_DELETE_ATTRIBUTES_COUNT: 279 | self.btn_add_delete_attributes.setEnabled(False) 280 | self.btn_del_delete_attributes.setEnabled(True) 281 | else: 282 | self.btn_add_delete_attributes.setEnabled(True) 283 | self.btn_del_delete_attributes.setEnabled(True) 284 | 285 | def set_visible_const(self): 286 | for key, widgets in self.edit_const.items(): 287 | widgets["base"].setVisible(key < self.visible_const_count) 288 | if self.visible_const_count == 1: 289 | self.btn_add_const.setEnabled(True) 290 | self.btn_del_const.setEnabled(False) 291 | elif self.visible_const_count >= self._MAX_CONST_COUNT: 292 | self.btn_add_const.setEnabled(False) 293 | self.btn_del_const.setEnabled(True) 294 | else: 295 | self.btn_add_const.setEnabled(True) 296 | self.btn_del_const.setEnabled(True) 297 | 298 | def btn_add_attributes_clicked(self, e): 299 | self.visible_attributes_count = min(max(0, self.visible_attributes_count + 1), self._MAX_ATTRIBUTES_COUNT) 300 | self.set_visible_attributes() 301 | 302 | def btn_del_attributes_clicked(self, e): 303 | self.visible_attributes_count = min(max(0, self.visible_attributes_count - 1), self._MAX_ATTRIBUTES_COUNT) 304 | self.set_visible_attributes() 305 | 306 | def btn_add_const_clicked(self, e): 307 | self.visible_const_count = min(max(0, self.visible_const_count + 1), self._MAX_CONST_COUNT) 308 | self.set_visible_const() 309 | 310 | def btn_del_const_clicked(self, e): 311 | self.visible_const_count = min(max(0, self.visible_const_count - 1), self._MAX_CONST_COUNT) 312 | self.set_visible_const() 313 | 314 | def btn_add_delete_attributes_clicked(self, e): 315 | self.visible_delete_attributes_count = min(max(0, self.visible_delete_attributes_count + 1), self._MAX_DELETE_ATTRIBUTES_COUNT) 316 | self.set_visible_delete_attributes() 317 | 318 | def btn_del_delete_attributes_clicked(self, e): 319 | self.visible_delete_attributes_count = min(max(0, self.visible_delete_attributes_count - 1), self._MAX_DELETE_ATTRIBUTES_COUNT) 320 | self.set_visible_delete_attributes() 321 | 322 | def get_properties(self)->ModifyAttrsProperties: 323 | opname = self.cmb_opname.currentText() 324 | 325 | attributes = {} 326 | for i in range(self.visible_attributes_count): 327 | name = self.edit_attributes[i]["name"].currentText() 328 | value = self.edit_attributes[i]["value"].text() 329 | dtype = self.edit_attributes[i]["dtype"].currentData() 330 | if name and value: 331 | if dtype == np.str_: 332 | attributes[name] = value 333 | else: 334 | value = literal_eval(value) 335 | if isinstance(value, list): 336 | attributes[name] = np.asarray(value, dtype=dtype) 337 | else: 338 | attributes[name] = value 339 | 340 | delete_attributes = [] 341 | for i in range(self.visible_delete_attributes_count): 342 | name = self.delete_attributes[i]["name"].currentText() 343 | if name: 344 | delete_attributes.append(name) 345 | 346 | input_constants = {} 347 | for i in range(self.visible_const_count): 348 | name = self.edit_const[i]["name"].currentText() 349 | value = self.edit_const[i]["value"].text() 350 | dtype = self.edit_const[i]["dtype"].currentData() 351 | if name and value: 352 | value = literal_eval(value) 353 | input_constants[name] = np.asarray(value, dtype=dtype) 354 | 355 | return ModifyAttrsProperties( 356 | op_name=opname, 357 | attributes=attributes, 358 | input_constants=input_constants, 359 | delete_attributes=delete_attributes, 360 | ) 361 | 362 | def accept(self) -> None: 363 | # value check 364 | invalid = False 365 | props = self.get_properties() 366 | print(props) 367 | err_msgs = [] 368 | edit_attr = len(props.attributes) > 0 369 | edit_const = len(props.input_constants) > 0 370 | delete_attr = len(props.delete_attributes) > 0 371 | if (not props.op_name and edit_attr) or (not props.op_name and delete_attr): 372 | err_msgs.append("- op_name and attributes must always be specified at the same time.") 373 | invalid = True 374 | if invalid: 375 | for m in err_msgs: 376 | print(m) 377 | MessageBox.error(err_msgs, "modify attrs", parent=self) 378 | return 379 | return super().accept() 380 | 381 | 382 | 383 | if __name__ == "__main__": 384 | import signal 385 | import os 386 | # handle SIGINT to make the app terminate on CTRL+C 387 | signal.signal(signal.SIGINT, signal.SIG_DFL) 388 | 389 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 390 | 391 | app = QtWidgets.QApplication([]) 392 | window = ModifyAttrsWidgets() 393 | window.show() 394 | 395 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_node_search.py: -------------------------------------------------------------------------------- 1 | import signal 2 | from PySide2 import QtCore, QtWidgets, QtGui 3 | 4 | from onnxgraphqt.graph.onnx_node_graph import ONNXNodeGraph 5 | from onnxgraphqt.graph.onnx_node import ONNXInput, ONNXOutput, ONNXNode, OnnxNodeIO 6 | 7 | 8 | class NodeSearchWidget(QtWidgets.QDialog): 9 | _DEFAULT_WINDOW_WIDTH = 500 10 | _DEFAULT_WINDOW_HEIGHT = 600 11 | 12 | def __init__(self, graph:ONNXNodeGraph=None, parent=None) -> None: 13 | super().__init__(parent) 14 | self.setModal(False) 15 | self.setWindowTitle("node search") 16 | # self.nodes = nodes 17 | self.graph = graph 18 | self.initUI() 19 | 20 | def initUI(self): 21 | if self.parentWidget(): 22 | x = self.parentWidget().x() 23 | y = self.parentWidget().y() 24 | parent_w = self.parentWidget().width() 25 | parent_h = self.parentWidget().height() 26 | self.setGeometry(x + parent_w, y, self._DEFAULT_WINDOW_WIDTH, self._DEFAULT_WINDOW_HEIGHT) 27 | else: 28 | 29 | self.setGeometry(0, 0, self._DEFAULT_WINDOW_WIDTH, self._DEFAULT_WINDOW_HEIGHT) 30 | 31 | base_layout = QtWidgets.QVBoxLayout() 32 | 33 | # layout 34 | layout = QtWidgets.QHBoxLayout() 35 | self.tb = QtWidgets.QLineEdit() 36 | self.tb.setBaseSize(300, 50) 37 | self.btn = QtWidgets.QPushButton("search") 38 | self.btn.clicked.connect(self.btn_clicked) 39 | layout.addWidget(self.tb) 40 | layout.addWidget(self.btn) 41 | 42 | base_layout.addLayout(layout) 43 | 44 | self.model = QtGui.QStandardItemModel(0, 4) 45 | self.model.setHeaderData(0, QtCore.Qt.Horizontal, "name") 46 | self.model.setHeaderData(1, QtCore.Qt.Horizontal, "type") 47 | self.model.setHeaderData(2, QtCore.Qt.Horizontal, "input_names") 48 | self.model.setHeaderData(3, QtCore.Qt.Horizontal, "output_names") 49 | 50 | self.view = QtWidgets.QTreeView() 51 | self.view.setSortingEnabled(True) 52 | self.view.doubleClicked.connect(self.viewClicked) 53 | self.view.setModel(self.model) 54 | base_layout.addWidget(self.view) 55 | 56 | self.setLayout(base_layout) 57 | self.update(self.graph) 58 | self.search("") 59 | self.view.sortByColumn(0, QtCore.Qt.SortOrder.AscendingOrder) 60 | 61 | def viewClicked(self, index:QtCore.QModelIndex): 62 | indexItem = self.model.index(index.row(), 0, index.parent()) 63 | node_name = self.model.data(indexItem) 64 | inputs = self.graph.get_input_node_by_name(node_name) 65 | nodes = self.graph.get_node_by_name(node_name) 66 | outputs = self.graph.get_output_node_by_name(node_name) 67 | nodes = inputs + nodes + outputs 68 | self.graph.fit_to_selection_node(nodes[0]) 69 | parent = self.parent() 70 | if hasattr(parent, "properties_bin"): 71 | parent.properties_bin.add_node(nodes[0]) 72 | 73 | def btn_clicked(self, e): 74 | self.search(self.tb.text()) 75 | 76 | def update(self, graph: ONNXNodeGraph=None): 77 | self.all_row_items = [] 78 | 79 | for _ in range(self.model.rowCount()): 80 | self.model.takeRow(0) 81 | 82 | if graph is None: 83 | return 84 | 85 | for i, n in enumerate(self.graph.all_nodes()): 86 | if isinstance(n, ONNXNode): 87 | name = n.get_node_name() 88 | type_name = n.op 89 | input_names = [io.name for io in n.onnx_inputs] 90 | output_names = [io.name for io in n.onnx_outputs] 91 | 92 | name_item = QtGui.QStandardItem(name) 93 | type_item = QtGui.QStandardItem(type_name) 94 | input_names_item = QtGui.QStandardItem(", ".join(input_names)) 95 | output_names_item = QtGui.QStandardItem(", ".join(output_names)) 96 | name_item.setEditable(False) 97 | type_item.setEditable(False) 98 | input_names_item.setEditable(False) 99 | output_names_item.setEditable(False) 100 | self.model.setItem(i, 0, name_item) 101 | self.model.setItem(i, 1, type_item) 102 | self.model.setItem(i, 2, input_names_item) 103 | self.model.setItem(i, 3, output_names_item) 104 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item)) 105 | 106 | elif isinstance(n, ONNXInput): 107 | name = n.get_node_name() 108 | type_name = "Input" 109 | output_names = [name for name in n.get_output_names()] 110 | 111 | name_item = QtGui.QStandardItem(name) 112 | type_item = QtGui.QStandardItem(type_name) 113 | input_names_item = QtGui.QStandardItem("") 114 | output_names_item = QtGui.QStandardItem(", ".join(output_names)) 115 | name_item.setEditable(False) 116 | type_item.setEditable(False) 117 | input_names_item.setEditable(False) 118 | output_names_item.setEditable(False) 119 | self.model.setItem(i, 0, name_item) 120 | self.model.setItem(i, 1, type_item) 121 | self.model.setItem(i, 2, input_names_item) 122 | self.model.setItem(i, 3, output_names_item) 123 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item)) 124 | 125 | elif isinstance(n, ONNXOutput): 126 | name = n.get_node_name() 127 | type_name = "Output" 128 | input_names = [name for name in n.get_input_names()] 129 | 130 | name_item = QtGui.QStandardItem(name) 131 | type_item = QtGui.QStandardItem(type_name) 132 | input_names_item = QtGui.QStandardItem(", ".join(input_names)) 133 | output_names_item = QtGui.QStandardItem("") 134 | name_item.setEditable(False) 135 | type_item.setEditable(False) 136 | input_names_item.setEditable(False) 137 | output_names_item.setEditable(False) 138 | self.model.setItem(i, 0, name_item) 139 | self.model.setItem(i, 1, type_item) 140 | self.model.setItem(i, 2, input_names_item) 141 | self.model.setItem(i, 3, output_names_item) 142 | self.all_row_items.append((name_item, type_item, input_names_item, output_names_item)) 143 | 144 | 145 | def search(self, word): 146 | for _ in range(self.model.rowCount()): 147 | self.model.takeRow(0) 148 | 149 | serach_words = [w.strip() for w in word.split(" ")] 150 | for row in self.all_row_items: 151 | all_matched = True 152 | for serach_word in serach_words: 153 | if serach_word == "": 154 | continue 155 | values = [r.text() for r in row] 156 | found = False 157 | for val in values: 158 | if val == "": continue 159 | if val.find(serach_word) >= 0: 160 | found = True 161 | break 162 | if found is False: 163 | all_matched = False 164 | break 165 | 166 | if all_matched: 167 | self.model.appendRow(row) 168 | 169 | if __name__ == "__main__": 170 | import signal 171 | import os 172 | # handle SIGINT to make the app terminate on CTRL+C 173 | signal.signal(signal.SIGINT, signal.SIG_DFL) 174 | 175 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 176 | 177 | app = QtWidgets.QApplication([]) 178 | window = NodeSearchWidget(graph=ONNXNodeGraph("", 10, "", "", "", "", 0, 0)) 179 | window.show() 180 | 181 | app.exec_() -------------------------------------------------------------------------------- /onnxgraphqt/widgets/widgets_rename_op.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import signal 3 | from PySide2 import QtCore, QtWidgets, QtGui 4 | 5 | from onnxgraphqt.widgets.widgets_message_box import MessageBox 6 | from onnxgraphqt.utils.widgets import set_font, BASE_FONT_SIZE, LARGE_FONT_SIZE 7 | 8 | 9 | RenameOpProperties = namedtuple("RenameOpProperties", 10 | [ 11 | "old_new", 12 | ]) 13 | 14 | 15 | class RenameOpWidget(QtWidgets.QDialog): 16 | _DEFAULT_WINDOW_WIDTH = 420 17 | 18 | def __init__(self, parent=None) -> None: 19 | super().__init__(parent) 20 | self.setModal(False) 21 | self.setWindowTitle("rename op") 22 | self.initUI() 23 | 24 | def initUI(self): 25 | self.setFixedWidth(self._DEFAULT_WINDOW_WIDTH) 26 | set_font(self, font_size=BASE_FONT_SIZE) 27 | 28 | base_layout = QtWidgets.QVBoxLayout() 29 | 30 | # layout 31 | layout = QtWidgets.QVBoxLayout() 32 | lbl_name = QtWidgets.QLabel("Replace substring 'old' in opname with 'new'.") 33 | set_font(lbl_name, font_size=LARGE_FONT_SIZE, bold=True) 34 | layout.addWidget(lbl_name) 35 | 36 | layout_ledit = QtWidgets.QHBoxLayout() 37 | self.ledit_old = QtWidgets.QLineEdit() 38 | self.ledit_old.setPlaceholderText('old. e.g. "onnx::"') 39 | self.ledit_new = QtWidgets.QLineEdit() 40 | self.ledit_new.setPlaceholderText('new. e.g. "" ') 41 | layout_ledit.addWidget(self.ledit_old) 42 | layout_ledit.addWidget(self.ledit_new) 43 | 44 | # add layout 45 | base_layout.addLayout(layout) 46 | base_layout.addLayout(layout_ledit) 47 | 48 | # Dialog button 49 | btn = QtWidgets.QDialogButtonBox(QtWidgets.QDialogButtonBox.Ok | 50 | QtWidgets.QDialogButtonBox.Cancel) 51 | btn.accepted.connect(self.accept) 52 | btn.rejected.connect(self.reject) 53 | # layout.addWidget(btn) 54 | base_layout.addWidget(btn) 55 | 56 | self.setLayout(base_layout) 57 | 58 | def get_properties(self)->RenameOpProperties: 59 | old = self.ledit_old.text().strip() 60 | new = self.ledit_new.text().strip() 61 | return RenameOpProperties( 62 | old_new=[old, new] 63 | ) 64 | 65 | def accept(self) -> None: 66 | # value check 67 | invalid = False 68 | props = self.get_properties() 69 | print(props) 70 | err_msgs = [] 71 | if props.old_new[0] == "": 72 | err_msgs.append("substring old must be set.") 73 | invalid = True 74 | if invalid: 75 | for m in err_msgs: 76 | print(m) 77 | MessageBox.error(err_msgs, "rename op", parent=self) 78 | return 79 | return super().accept() 80 | 81 | 82 | if __name__ == "__main__": 83 | import signal 84 | import os 85 | # handle SIGINT to make the app terminate on CTRL+C 86 | signal.signal(signal.SIGINT, signal.SIG_DFL) 87 | 88 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 89 | 90 | app = QtWidgets.QApplication([]) 91 | window = RenameOpWidget() 92 | window.show() 93 | 94 | app.exec_() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "onnxgraphqt" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="fateshelled", email="53618876+fateshelled@users.noreply.github.com" }, 10 | ] 11 | description = "ONNX model visualizer" 12 | readme = "README.md" 13 | license = {text = "MIT LICENSE"} 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | ] 17 | dependencies = [ 18 | "PySide2", 19 | "numpy", 20 | "pillow", 21 | "onnx", 22 | "onnx-simplifier", 23 | "protobuf==3.20.0", 24 | #"onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com", 25 | "onnx_graphsurgeon", 26 | #"git+https://github.com/jchanvfx/NodeGraphQt.git@v0.5.2#egg=NodeGraphQt", 27 | "simple-onnx-processing-tools", 28 | "grandalf", 29 | "networkx", 30 | ] 31 | 32 | [project.entry-points.console_scripts] 33 | onnxgraphqt = "onnxgraphqt:main" 34 | 35 | [tool.setuptools.packages.find] 36 | exclude = ["docker", "build", "tmp"] 37 | 38 | [project.urls] 39 | "Homepage" = "https://github.com/fateshelled/OnnxGraphQt" 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PySide2 2 | numpy 3 | pillow 4 | onnx 5 | onnx-simplifier 6 | protobuf==3.20.0 7 | onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com 8 | git+https://github.com/jchanvfx/NodeGraphQt.git@v0.5.2#egg=NodeGraphQt 9 | simple-onnx-processing-tools 10 | grandalf 11 | networkx 12 | --------------------------------------------------------------------------------