├── nnViewer ├── back │ ├── __init__.py │ ├── models.py │ ├── wrapped_function.py │ ├── node_getter.py │ ├── graph_initializer.py │ ├── utils.py │ ├── hook_functions.py │ ├── wrapper.py │ ├── nodes.py │ └── graph.py ├── front │ ├── __init__.py │ ├── maps.py │ ├── gui.py │ ├── utils.py │ └── node_item.py └── __init__.py ├── requirements.txt ├── .gitignore └── README.md /nnViewer/back/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nnViewer/front/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt : -------------------------------------------------------------------------------- 1 | torch==2.2.2 2 | pygraphviz==1.14 3 | PyQt5==5.15.11 -------------------------------------------------------------------------------- /nnViewer/__init__.py: -------------------------------------------------------------------------------- 1 | from nnViewer.back.graph_initializer import wrap_model 2 | from nnViewer.front.gui import run_gui 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | graphs/ 2 | .idea/ 3 | env/ 4 | __pycache__/ 5 | .DS_store 6 | __pycache__/ 7 | *.pyc 8 | *.pyo 9 | ignore/ 10 | setup.py 11 | pyproject.toml 12 | nnViewer.egg-info/ 13 | dist/ -------------------------------------------------------------------------------- /nnViewer/back/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class PosData: 5 | x:float = 0 6 | y:float = 0 7 | height: float = 0 8 | width: float = 0 9 | 10 | @dataclass 11 | class PosDataUpperModules: 12 | x:float = 0 13 | y:float = 0 14 | height: float = 0 15 | width: float = 0 16 | name: str = "" 17 | class_name: str = "" 18 | margin: float = 0 19 | level: float = 0 20 | -------------------------------------------------------------------------------- /nnViewer/front/maps.py: -------------------------------------------------------------------------------- 1 | def map_strings_to_colors(strings): 2 | colors = [ 3 | "#186B66", 4 | "#B97A6A", 5 | "#66014A", 6 | "#085D8F", 7 | "#065D6F", 8 | "#7C3F40", 9 | "#3C5B5C", 10 | "#D26D7E", 11 | "#4A2B73", 12 | "#4a536b", 13 | ] 14 | 15 | 16 | i = 0 17 | while len(colors) < len(strings): 18 | colors.append(colors[i]) 19 | i += 1 20 | 21 | return {string: colors[i] for i, string in enumerate(strings)} -------------------------------------------------------------------------------- /nnViewer/back/wrapped_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | function_to_wrap = [ 6 | # torch.Tensor 7 | (torch.Tensor.__add__, torch.Tensor), 8 | (torch.Tensor.__div__, torch.Tensor), 9 | (torch.Tensor.__getitem__, torch.Tensor), 10 | (torch.Tensor.__iadd__, torch.Tensor), 11 | (torch.Tensor.__isub__, torch.Tensor), 12 | (torch.Tensor.__matmul__, torch.Tensor), 13 | (torch.Tensor.__mul__, torch.Tensor), 14 | (torch.Tensor.__rmul__, torch.Tensor), 15 | (torch.Tensor.__neg__, torch.Tensor), 16 | (torch.Tensor.__sub__, torch.Tensor), 17 | (torch.Tensor.__truediv__, torch.Tensor), 18 | (torch.Tensor.__pow__, torch.Tensor), 19 | (torch.Tensor.exp, torch.Tensor), 20 | (torch.Tensor.expand, torch.Tensor), 21 | (torch.Tensor.flatten, torch.Tensor), 22 | (torch.Tensor.mean, torch.Tensor), 23 | (torch.Tensor.neg, torch.Tensor), 24 | (torch.Tensor.neg_, torch.Tensor), 25 | (torch.Tensor.negative, torch.Tensor), 26 | (torch.Tensor.permute, torch.Tensor), 27 | (torch.Tensor.reshape, torch.Tensor), 28 | (torch.Tensor.t, torch.Tensor), 29 | (torch.Tensor.transpose, torch.Tensor), 30 | (torch.Tensor.view, torch.Tensor), 31 | 32 | # torch 33 | (torch.cat, torch), 34 | (torch.div, torch), 35 | (torch.exp, torch), 36 | (torch.exp_, torch), 37 | (torch.matmul, torch), 38 | (torch.pow, torch), 39 | (torch.stack, torch), 40 | (torch.sum, torch), 41 | (torch.neg, torch), 42 | (torch.t, torch), 43 | (torch.rsqrt, torch), 44 | (torch.sqrt, torch), 45 | (torch.sigmoid, torch), 46 | 47 | # torch.nn 48 | (nn.Conv2d, nn), 49 | (nn.Embedding, nn), 50 | 51 | # torch.nn.functional 52 | (F.conv1d, F), 53 | (F.conv2d, F), 54 | (F.embedding, F), 55 | (F.scaled_dot_product_attention, F), 56 | (F.interpolate, F), 57 | ] 58 | -------------------------------------------------------------------------------- /nnViewer/back/node_getter.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple, Set 2 | 3 | from torch import Tensor 4 | 5 | from nnViewer.back.nodes import VarNode, FunctionNode, Node 6 | from nnViewer.back.utils import split_module_name 7 | 8 | def get_node_from_base_tensor(var: Union[Tensor, Tuple[Tensor, ...]], 9 | params: Dict, 10 | stop_grad_fns: Set = None, 11 | with_base_tensor: bool = True) -> Tuple[Set[Node], Set[Tuple[str, str]], bool]: 12 | stoped = False 13 | nodes = set() 14 | edges = set() 15 | 16 | if stop_grad_fns is None: 17 | stop_grad_fns = set() 18 | 19 | param_map = {id(v): k for k, v in params.items()} if params else {} 20 | seen = {None} 21 | 22 | def search_nodes(fn): 23 | nonlocal stoped 24 | if fn in seen or fn in stop_grad_fns: 25 | stoped = stoped or (fn in stop_grad_fns) 26 | return 27 | 28 | seen.add(fn) 29 | 30 | if hasattr(fn, 'variable'): 31 | var = fn.variable 32 | var_id = str(id(var)) 33 | seen.add(var) 34 | name = split_module_name(param_map.get(id(var), "var"))[-1] 35 | nodes.add(VarNode(id=var_id, name=name, variable=var)) 36 | edges.add((var_id, str(id(fn)))) 37 | 38 | nodes.add(FunctionNode(id=str(id(fn)), name=str(type(fn).__name__), function=fn)) 39 | 40 | if hasattr(fn, 'next_functions'): 41 | for u in fn.next_functions: 42 | if u[0] is not None: 43 | edges.add((str(id(u[0])), str(id(fn)))) 44 | search_nodes(u[0]) 45 | 46 | def add_base_tensor(var: Tensor): 47 | var_id = str(id(var)) 48 | if var in seen: 49 | return 50 | 51 | seen.add(var) 52 | if with_base_tensor: 53 | nodes.add(VarNode(id=var_id, variable=var, name="output")) 54 | if var.grad_fn: 55 | search_nodes(var.grad_fn) 56 | if with_base_tensor: 57 | edges.add((str(id(var.grad_fn)), var_id)) 58 | 59 | if var._is_view(): 60 | add_base_tensor(var._base) 61 | if with_base_tensor: 62 | edges.add((str(id(var._base)), var_id)) 63 | 64 | if isinstance(var, tuple): 65 | for v in var: 66 | add_base_tensor(v) 67 | else: 68 | add_base_tensor(var) 69 | 70 | return nodes, edges, stoped 71 | -------------------------------------------------------------------------------- /nnViewer/back/graph_initializer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Tuple, Callable, Dict, Optional 3 | import torch 4 | 5 | from torch import Tensor, nn 6 | 7 | from nnViewer.back.graph import Graph 8 | from nnViewer.back.node_getter import get_node_from_base_tensor 9 | from nnViewer.back.utils import get_var_as_tuple_tensor 10 | from nnViewer.back.hook_functions import hooks, root_node_belong_to_module, set_hooks, remove_hooks 11 | from nnViewer.back.wrapper import unwrap_functions, wrapped_output, enable_function_wrapping 12 | 13 | class GraphInitializer: 14 | def __init__(self, 15 | model: nn.Module, 16 | fn_to_wrap: Callable,): 17 | self.model = model 18 | if not fn_to_wrap: 19 | fn_to_wrap = model.forward 20 | self.forward_fn = fn_to_wrap 21 | self.graph = None 22 | setattr(self.model, fn_to_wrap.__name__, self.wrap_forward) 23 | 24 | def wrap_forward(self, *args, **kwargs): 25 | if self.graph: 26 | return self.forward_fn(*args, **kwargs) 27 | 28 | self.graph, output = build_nn_graph(self.model, self.forward_fn, args, kwargs) 29 | return output 30 | 31 | def wrap_model(model, fn_to_wrap: Optional[Callable] = None) -> GraphInitializer: 32 | if isinstance(model, nn.Module): 33 | return GraphInitializer(model, fn_to_wrap) 34 | 35 | sub_models = (getattr(model, attr) for attr in dir(model)) 36 | for sub_model in sub_models: 37 | if isinstance(sub_model, nn.Module): 38 | return GraphInitializer(sub_model, fn_to_wrap) 39 | 40 | raise ValueError("No nn.Module found") 41 | 42 | def build_nn_graph(model: nn.Module, 43 | forward_fn: Callable, 44 | args: Tuple, 45 | kwargs: Dict) -> Tuple[Graph, Tuple[Tensor]]: 46 | gc.disable() 47 | hook_handles = set_hooks(model) 48 | enable_function_wrapping() 49 | 50 | with torch.set_grad_enabled(True): 51 | output = forward_fn(*args, **kwargs) 52 | 53 | unwrap_functions() 54 | 55 | remove_hooks(hook_handles) 56 | 57 | output_tuple = get_var_as_tuple_tensor(output) 58 | output_tuple = tuple(output for output in output_tuple if hasattr(output, "grad_fn")) 59 | 60 | nodes, edges, _ = get_node_from_base_tensor(output_tuple, dict(model.named_parameters())) 61 | 62 | gc.enable() 63 | 64 | graph = Graph(nodes, edges) 65 | 66 | graph.set_up(hooks, root_node_belong_to_module, wrapped_output) 67 | 68 | return graph, output 69 | -------------------------------------------------------------------------------- /nnViewer/back/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict 2 | import re 3 | 4 | from torch import Tensor 5 | 6 | from nnViewer.back.models import PosData, PosDataUpperModules 7 | 8 | 9 | def parse_pos(pos_str): 10 | points = pos_str.replace('e', '').replace(',', ' ').split() 11 | points = [(float(points[i]), -float(points[i+1])) for i in range(0, len(points), 2)] 12 | last_point = points.pop(0) 13 | points.append(last_point) 14 | return points 15 | 16 | def split_module_name(name: str) -> List[str]: 17 | if name == ".": 18 | return [""] 19 | return re.split(r"\.+", name) 20 | 21 | def create_bounding_rectangle(rectangles: List[PosData], 22 | class_name: str, 23 | margin_height: float = 15, 24 | margin_width: float = 15, 25 | level: float = 0): 26 | min_x = rectangles[0].x - rectangles[0].width / 2 27 | min_y = rectangles[0].y - rectangles[0].height / 2 28 | max_x = rectangles[0].x + rectangles[0].width / 2 29 | max_y = rectangles[0].y + rectangles[0].height / 2 30 | 31 | for rect in rectangles[1:]: 32 | min_x = min(min_x, rect.x - rect.width / 2) 33 | min_y = min(min_y, rect.y - rect.height / 2) 34 | max_x = max(max_x, rect.x + rect.width / 2) 35 | max_y = max(max_y, rect.y + rect.height / 2) 36 | 37 | bounding_width = max_x - min_x 38 | bounding_height = max_y - min_y 39 | 40 | return PosDataUpperModules( 41 | height=bounding_height + margin_height*2, 42 | width=bounding_width + margin_width*2, 43 | margin = margin_width, 44 | x=(min_x + max_x)/2, 45 | y=(min_y + max_y)/2, 46 | class_name=class_name, 47 | level=level 48 | ) 49 | 50 | 51 | def get_var_as_tuple_tensor(var): 52 | var_output = [] 53 | if isinstance(var, tuple): 54 | for l in var: 55 | if isinstance(l, Tensor): 56 | var_output.append(l) 57 | return tuple(var_output) 58 | 59 | elif isinstance(var, Tensor): 60 | return (var,) 61 | 62 | elif isinstance(var, Dict): 63 | for _, value in var.items(): 64 | if isinstance(value, Tensor): 65 | var_output.append(value) 66 | return tuple(var_output) 67 | 68 | elif (not isinstance(var, Tensor)) and (not isinstance(var, Tuple)): 69 | for _, value in var.__dict__.items(): 70 | if isinstance(value, Tensor): 71 | var_output.append(value) 72 | return tuple(var_output) 73 | 74 | else: 75 | return var 76 | -------------------------------------------------------------------------------- /nnViewer/back/hook_functions.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Dict 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | from nnViewer.back.node_getter import get_node_from_base_tensor 7 | from nnViewer.back.nodes import (LinearNode, LayerNormNode, ModuleNode, Conv2dNode, 8 | EmbeddingNode, Conv1dNode) 9 | from nnViewer.back.utils import get_var_as_tuple_tensor 10 | 11 | hooks = set() 12 | known_nodes = set() 13 | root_node_belong_to_module = {} 14 | 15 | 16 | NODE_CLASSES = { 17 | "Linear": LinearNode, 18 | "LayerNorm": LayerNormNode, 19 | "Conv2d": Conv2dNode, 20 | "Conv1d": Conv1dNode, 21 | "Embedding": EmbeddingNode, 22 | } 23 | 24 | SKIP_MODULE = ["Dropout", "Identity"] 25 | 26 | def hook_fn_factory(name: str) -> Callable: 27 | def hook_fn(module: nn.Module, 28 | args: Tuple, 29 | kwargs: Dict, 30 | output: Tuple) -> None: 31 | 32 | if not name or module.__class__.__name__ in SKIP_MODULE: 33 | return None 34 | 35 | input_tuple = tuple( 36 | arg for arg in list(args) + list(kwargs.values()) 37 | if isinstance(arg, torch.Tensor) 38 | ) 39 | 40 | module_id = create_unique_id() 41 | 42 | stop_grad_fns = {arg.grad_fn for arg in input_tuple 43 | if hasattr(arg, "grad_fn") and arg.grad_fn is not None} 44 | 45 | output_tuple = get_var_as_tuple_tensor(output) 46 | nodes, _, _ = get_node_from_base_tensor( 47 | var = output_tuple, 48 | params = dict(module.named_parameters()), 49 | stop_grad_fns=stop_grad_fns, 50 | with_base_tensor=False 51 | ) 52 | 53 | nodes_ids = {node.id for node in nodes} 54 | 55 | for id_root_node in nodes_ids: 56 | if id_root_node not in known_nodes: 57 | root_node_belong_to_module[id_root_node] = module_id 58 | known_nodes.add(id_root_node) 59 | 60 | if nodes_ids: 61 | node = create_module_node( 62 | module_id, 63 | name, 64 | input_tuple, 65 | module, 66 | output_tuple, 67 | ) 68 | node.all_root_sub_ids = nodes_ids 69 | hooks.add(node) 70 | return hook_fn 71 | 72 | def create_module_node( 73 | module_id: str, 74 | name:str, 75 | input_tuple: Tuple[Tensor], 76 | module: nn.Module, 77 | output_tuple:Tuple[Tensor]): 78 | node_class = NODE_CLASSES.get(module.__class__.__name__, ModuleNode) 79 | 80 | return node_class( 81 | module_id, 82 | name, 83 | input_tuple, 84 | module, 85 | output_tuple, 86 | ) 87 | 88 | def set_hooks(model): 89 | hook_handles = [] 90 | if isinstance(model, nn.Module): 91 | for name, layer in model.named_modules(): 92 | hook_handles.append( 93 | layer.register_forward_hook(hook_fn_factory(name), with_kwargs=True) 94 | ) 95 | return hook_handles 96 | 97 | def remove_hooks(hook_handles): 98 | for handle in hook_handles: 99 | handle.remove() 100 | 101 | current_id = 0 102 | def create_unique_id(): 103 | global current_id 104 | current_id += 1 105 | return str(current_id) 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # nnViewer 3 | 4 | **nnViewer** is a Python library designed to provide an intuitive GUI for visualizing the structure and flow of a `torch.nn.Module`. Whether you're debugging or exploring complex neural networks, nnViewer makes it easier to understand your models. 5 | 6 | **[Watch the demo video](https://drive.google.com/file/d/1YUzYWcpsEofURfNiDq7SVexJiN6lWCuu/view?usp=sharing)** 7 | 8 | ## Installation 9 | 10 | Before installing `nnViewer`, you need to install `graphviz` and its development libraries. You can install them based on your operating system. 11 | 12 | ### Linux (Ubuntu/Debian) 13 | 14 | 1. **Install system dependencies**: 15 | ```bash 16 | sudo apt install graphviz libgraphviz-dev 17 | ``` 18 | 19 | 2. **Install the Python package**: 20 | ```bash 21 | pip install nnViewer 22 | ``` 23 | 24 | ### macOS 25 | 26 | 1. **Install system dependencies using Homebrew**: 27 | If you don't have Homebrew installed, you can install it from [here](https://brew.sh/). Then, install `graphviz`: 28 | ```bash 29 | brew install graphviz 30 | ``` 31 | 32 | 2. **Install the Python package**: 33 | ```bash 34 | pip install nnViewer 35 | ``` 36 | 37 | 38 | ## Quick Start 39 | 40 | Here's an example of how to use nnViewer with a Hugging Face model: 41 | 42 | ```python 43 | from transformers import AutoImageProcessor, AutoModel 44 | from PIL import Image 45 | import requests 46 | from nnViewer import wrap_model, run_gui 47 | 48 | # Load an image 49 | url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 50 | image = Image.open(requests.get(url, stream=True).raw) 51 | 52 | # Load the model and processor 53 | processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large') 54 | model = AutoModel.from_pretrained('facebook/dinov2-large') 55 | 56 | # Prepare the inputs 57 | inputs = processor(images=image, return_tensors="pt") 58 | 59 | # Wrap the model to initialize the graph 60 | graph_init = wrap_model(model) 61 | 62 | # Perform a forward pass to populate the graph 63 | model(**inputs) 64 | 65 | # Launch the GUI to visualize the computational graph 66 | run_gui(graph_init.graph) 67 | ``` 68 | 69 | ## Overview 70 | 71 | ### `wrap_model(model: nn.Module, fn_to_wrap: Optional[Callable] = None) -> GraphInitializer` 72 | 73 | The `wrap_model` function wraps a `torch.nn.Module` and initializes the computational graph for visualization. 74 | 75 | - **Arguments**: 76 | - `model` (`nn.Module`): The model to wrap. 77 | - `fn_to_wrap` (`Callable`, optional): A custom forward function to wrap. If not provided, the default `forward` function of the model is used. The output of the `fn_to_wrap` should be a tensor with a `grad_fn`, meaning the output must be a tensor that is part of the computation graph (i.e., a tensor that requires gradients). 78 | 79 | - **Returns**: A `GraphInitializer` object that can be used to visualize the model's computational graph. 80 | 81 | ### How to Use 82 | 83 | 1. **Wrap the model**: First, you need to call `wrap_model(model)` to initialize the computational graph. 84 | 2. **Perform a forward pass**: Once the model is wrapped, run a forward pass with some sample inputs. This will populate the computational graph. 85 | 3. **Visualize the graph**: After the forward pass, you can visualize the model’s graph using the `run_gui()` function. 86 | 87 | ## Navigating the Graph 88 | 89 | Once the GUI is running, you can interact with the computational graph in the following ways: 90 | 91 | - **Click on a node**: Expands the node to show more information about the operation it represents. 92 | - **Double-click on a node**: Contracts the node and hides the details, simplifying the view. 93 | - **Right-click on a node**: Opens a context menu with options: 94 | - **Get More Information**: Shows additional details about the node, including its attributes. You can click on the attributes to view them. 95 | - **Show Computation**: Displays the matrices before and after the module for that specific node. 96 | 97 | ## Contributing 98 | 99 | Contributions are welcome! If you find any issues or have feature requests, feel free to open a GitHub issue or submit a pull request. 100 | 101 | ## License 102 | 103 | This project is licensed under the MIT License. See the `LICENSE` file for more details. 104 | -------------------------------------------------------------------------------- /nnViewer/back/wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | 6 | from nnViewer.back.nodes import Node, AddNode, CatNode, ViewNode, GetItemNode, FunctionNode, MulNode, SubNode, PowNode, \ 7 | MeanNode, ExpNode, DivNode, SumNode, MatMulNode, AttentionProductNode 8 | from nnViewer.back.wrapped_function import function_to_wrap 9 | 10 | wrapped_output = [] 11 | wrapped_nodes = [] 12 | 13 | def wrapper(function_name, function): 14 | def custom_function(*args, **kwargs): 15 | output = function(*args, **kwargs) 16 | if hasattr(output, "grad_fn"): 17 | if output.grad_fn is not None: 18 | visited = set() 19 | grad_fns_found = [] 20 | 21 | def traverse_grad_fn(grad_fn): 22 | if grad_fn in visited: 23 | return 24 | visited.add(grad_fn) 25 | 26 | for arg in args: 27 | if isinstance(arg, torch.Tensor) and arg.grad_fn == grad_fn: 28 | return 29 | for kwarg in kwargs.values(): 30 | if isinstance(kwarg, torch.Tensor) and kwarg.grad_fn == grad_fn: 31 | return 32 | 33 | grad_fns_found.append(str(id(grad_fn))) 34 | 35 | for next_fn in grad_fn.next_functions: 36 | if next_fn[0] is not None: 37 | traverse_grad_fn(next_fn[0]) 38 | 39 | traverse_grad_fn(output.grad_fn) 40 | 41 | wrapped_output.append({ 42 | "args": args, 43 | "kwargs": kwargs, 44 | "grad_fn_created":grad_fns_found[1:], 45 | "node": create_function_node(output, args, kwargs, function_name, function) 46 | }) 47 | 48 | return output 49 | return custom_function 50 | 51 | 52 | def enable_function_wrapping(): 53 | for original_function, module in function_to_wrap: 54 | function_name = original_function.__name__ 55 | 56 | if not hasattr(module, f"_original_{function_name}"): 57 | setattr(module, f"_original_{function_name}", original_function) 58 | 59 | wrapped_function = wrapper(function_name, original_function) 60 | setattr(module, function_name, wrapped_function) 61 | 62 | def unwrap_functions(): 63 | for original_function, module in function_to_wrap: 64 | function_name = original_function.__name__ 65 | 66 | if hasattr(module, f"_original_{function_name}"): 67 | original_function = getattr(module, f"_original_{function_name}") 68 | setattr(module, function_name, original_function) 69 | 70 | def create_function_node(output, args, kwargs, function_name, function) -> Node: 71 | grad_fn_id = str(id(output.grad_fn)) 72 | 73 | common_params = { 74 | "id": grad_fn_id, 75 | "name": function_name, 76 | "function": function, 77 | } 78 | 79 | try: 80 | if function_name in ["__add__", "__iadd__"]: 81 | return AddNode(mat1=args[0], mat2=args[1], output=output, **common_params) 82 | 83 | elif function_name in ["cat", "stack"]: 84 | return CatNode(input=args[0], output=output, **common_params) 85 | 86 | elif function_name in ["view", "transpose", "reshape", "expand", "flatten", "t", "permute"]: 87 | return ViewNode(input=args[0], output=output, **common_params) 88 | 89 | elif function_name in ["__mul__", "__rmul__"]: 90 | return MulNode(mat1=args[0], mat2=args[1], output=output, **common_params) 91 | 92 | elif function_name == "__sub__": 93 | return SubNode(mat1=args[0], mat2=args[1], output=output, **common_params) 94 | 95 | elif function_name == "__getitem__": 96 | if args[1]: 97 | if not any(slice is None for slice in args[1]): 98 | slice_str = format_slice(args[1], args[0].shape) 99 | return GetItemNode(input=args[0], slice=slice_str, output=output, **common_params) 100 | 101 | slice_str = f"{args[0].shape}->{output.shape}" 102 | return GetItemNode(input=args[0], slice=slice_str, output=output, **common_params) 103 | 104 | elif function_name == "pow": 105 | return PowNode(input=args[0], pow_value=args[1], output=output, **common_params) 106 | 107 | elif function_name == "__truediv__": 108 | return DivNode(mat1=args[0], mat2=args[1], output=output, **common_params) 109 | 110 | elif function_name == "exp": 111 | return ExpNode(exp_value=args[0], output=output, **common_params) 112 | 113 | elif function_name == "mean": 114 | dim = kwargs.get("dim", args[1] if len(args) > 1 else None) 115 | return MeanNode(input=args[0], output=output, dim=dim, **common_params) 116 | 117 | elif function_name == "sum": 118 | dim = kwargs.get("dim", args[1] if len(args) > 1 else None) 119 | return SumNode(input=args[0], output=output, dim=dim, **common_params) 120 | 121 | elif function_name in ["matmul", "__matmul__"]: 122 | return MatMulNode(mat1=args[0], mat2=args[1], output=output, **common_params) 123 | 124 | elif function_name == "scaled_dot_product_attention": 125 | mask = kwargs.get("mask", args[3] if len(args) > 3 else None) 126 | return AttentionProductNode(key=args[0], 127 | query=args[1], 128 | value=args[2], 129 | output=output, 130 | mask=mask, 131 | **common_params) 132 | 133 | else: 134 | warnings.warn(f"{function_name} not implemented") 135 | 136 | except: 137 | warnings.warn(f"Wrapped function {function_name} hasn't found its node") 138 | return FunctionNode(**common_params) 139 | 140 | return FunctionNode(**common_params) 141 | 142 | def format_slice(slice_tuple, tensor_size): 143 | """ 144 | Format a slice tuple into a readable string representation for PyTorch tensors. 145 | 146 | Args: 147 | slice_tuple: Tuple containing slice objects, integers, tensors, and/or Ellipsis 148 | tensor_size: torch.Size object representing tensor dimensions 149 | 150 | Returns: 151 | str: Formatted string representation of the slice 152 | """ 153 | dimensions = len(tensor_size) 154 | slice_str = [] 155 | 156 | for i, slice_item in enumerate(slice_tuple): 157 | # Handle tensors 158 | if isinstance(slice_item, torch.Tensor): 159 | tensor_list = slice_item.tolist() 160 | if isinstance(tensor_list, list): 161 | # Si tous les éléments sont consécutifs, utiliser le format start:stop 162 | if len(tensor_list) == 2 and tensor_list[1] - tensor_list[0] == 1: 163 | slice_str.append(f"{tensor_list[0]}:{tensor_list[1]}") 164 | # Si les éléments sont identiques, utiliser une seule valeur 165 | elif all(x == tensor_list[0] for x in tensor_list): 166 | slice_str.append(str(tensor_list[0])) 167 | else: 168 | slice_str.append(str(tensor_list)) 169 | else: 170 | slice_str.append(str(tensor_list)) 171 | 172 | # Handle Ellipsis 173 | elif slice_item is Ellipsis: 174 | remaining_dims = dimensions - len(slice_tuple) + 1 175 | slice_str.extend([":"] * remaining_dims) 176 | 177 | # Handle slice objects 178 | elif isinstance(slice_item, slice): 179 | start, stop, step = slice_item.start, slice_item.stop, slice_item.step 180 | if start is None and stop is None and step is None: 181 | slice_str.append(":") 182 | else: 183 | slice_text = "" 184 | if start is not None: 185 | slice_text += str(start) 186 | slice_text += ":" 187 | if stop is not None: 188 | slice_text += str(stop) 189 | if step is not None: 190 | slice_text += ":" + str(step) 191 | slice_str.append(slice_text) 192 | 193 | # Handle integer indices 194 | elif isinstance(slice_item, int): 195 | slice_str.append(str(slice_item)) 196 | 197 | else: 198 | raise TypeError(f"Unsupported slice type: {type(slice_item)}") 199 | 200 | # Add missing dimensions as ":" 201 | while len(slice_str) < dimensions: 202 | slice_str.append(":") 203 | 204 | return f"[{', '.join(slice_str)}]" -------------------------------------------------------------------------------- /nnViewer/front/gui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from PyQt5.QtWidgets import (QApplication, QGraphicsView, QGraphicsScene, QGraphicsTextItem 4 | , QMainWindow, QGraphicsItem, QGraphicsPathItem) 5 | from PyQt5.QtCore import Qt, QRectF 6 | from PyQt5.QtGui import QBrush, QPen, QPainter, QPainterPath, QFont, QColor 7 | from PyQt5.QtCore import QTimer 8 | 9 | from nnViewer.back.graph import Graph 10 | from nnViewer.back.models import PosDataUpperModules 11 | from nnViewer.back.nodes import (ModuleNode, FunctionNode, VarNode, LinearNode, Node, 12 | LayerNormNode, ViewNode, CatNode, AddNode, GetItemNode, 13 | Conv2dNode, MulNode, SubNode, Conv1dNode, DivNode, MatMulNode, 14 | AttentionProductNode) 15 | from nnViewer.front.node_item import ClickableRectItem 16 | from nnViewer.front.utils import create_centered_text_item, get_tensor_shape_as_string 17 | from nnViewer.front.maps import map_strings_to_colors 18 | 19 | FONT = "Courier New" 20 | 21 | 22 | class GraphViewer(QMainWindow): 23 | def __init__(self, graph: Graph): 24 | super().__init__() 25 | 26 | self.setWindowTitle("Graph Viewer") 27 | screen_geometry = QApplication.desktop().screenGeometry() 28 | self.setGeometry(screen_geometry.x(), screen_geometry.y(), screen_geometry.width(), 800) 29 | 30 | self.view = QGraphicsView(self) 31 | self.scene = QGraphicsScene(self) 32 | 33 | self.view.setScene(self.scene) 34 | self.setCentralWidget(self.view) 35 | 36 | self.view.setRenderHint(QPainter.Antialiasing) 37 | self.view.setDragMode(QGraphicsView.ScrollHandDrag) 38 | self.view.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) 39 | self.view.setInteractive(True) 40 | 41 | self.zoom_factor = 1.15 42 | 43 | self.graph = graph 44 | self.modules_colors = map_strings_to_colors(self.graph.get_module_class_name()) 45 | self.rectangles = [] 46 | 47 | self.set_items() 48 | self.render_graph() 49 | 50 | def render_graph(self): 51 | for edge in self.graph.pos_edges: 52 | self.draw_edge(edge) 53 | 54 | for node in self.graph.flying_nodes: 55 | self.draw_node(node) 56 | 57 | for upper_module in list(self.graph.flying_upper_modules): 58 | self.draw_upper_module(upper_module) 59 | 60 | self.scene.setSceneRect(self.scene.itemsBoundingRect()) 61 | 62 | def draw_upper_module(self, upper_module: PosDataUpperModules): 63 | 64 | path = QPainterPath() 65 | rect = QRectF( 66 | upper_module.x - upper_module.width / 2, 67 | upper_module.y - upper_module.height / 2, 68 | upper_module.width, 69 | upper_module.height 70 | ) 71 | corner_radius = 10 72 | path.addRoundedRect(rect, corner_radius, corner_radius) 73 | 74 | rect = QGraphicsPathItem(path) 75 | 76 | color = QColor(f"{self.modules_colors[upper_module.class_name]}") 77 | color.setAlpha(128) 78 | 79 | rect.setPen(QPen(Qt.black, 2)) 80 | rect.setBrush(QBrush(color)) 81 | rect.setZValue(upper_module.level) 82 | 83 | self.scene.addItem(rect) 84 | 85 | label = QGraphicsTextItem(upper_module.class_name) 86 | label.setFont(QFont(FONT, 12)) 87 | 88 | nb_label = int(upper_module.height/(label.boundingRect().width()+1)) 89 | label = QGraphicsTextItem(" ".join([upper_module.class_name]*nb_label)) 90 | font = QFont(FONT, 12) 91 | font.setBold(True) 92 | label.setFont(font) 93 | 94 | label.setRotation(-90) 95 | 96 | label.setPos( 97 | upper_module.x - upper_module.width / 2, 98 | upper_module.y + label.boundingRect().width()/2 99 | ) 100 | 101 | label.setZValue(upper_module.level) 102 | self.scene.addItem(label) 103 | 104 | def draw_node(self, 105 | node: Node): 106 | 107 | rect = ClickableRectItem(node, 108 | node.pos.x - node.pos.width/2, 109 | node.pos.y - node.pos.height/2, 110 | node.pos.width, 111 | node.pos.height) 112 | rect.setFlag(QGraphicsItem.ItemIsSelectable) 113 | rect.setFlag(QGraphicsItem.ItemIsFocusable) 114 | rect.setPen(QPen(Qt.black)) 115 | rect.setBrush(QBrush(node.color)) 116 | rect.signal_proxy.clicked.connect(lambda: self.single_click_envent(rect)) 117 | rect.signal_proxy.doubleClicked.connect(lambda: self.contract_graph(node)) 118 | rect.setZValue(100) 119 | 120 | self.scene.addItem(rect) 121 | self.rectangles.append(rect) 122 | 123 | node.item.setPos(node.pos.x - node.item.boundingRect().width()/2, 124 | node.pos.y - node.item.boundingRect().height()/2) 125 | node.item.setZValue(100) 126 | self.scene.addItem(node.item) 127 | 128 | def draw_edge(self, coordinates): 129 | path = QPainterPath() 130 | 131 | path.moveTo(coordinates[0][0], coordinates[0][1]) 132 | for coordinate in coordinates[1:]: 133 | path.lineTo(coordinate[0], coordinate[1]) 134 | 135 | pen = QPen(Qt.black, 2) 136 | path_item = self.scene.addPath(path, pen) 137 | path_item.setZValue(99) 138 | 139 | def set_items(self, 140 | margin: int = 10): 141 | 142 | for node in self.graph.flying_nodes: 143 | node.color, node.item = self.get_illustration_item(node) 144 | 145 | # if isinstance(node, VarNode) or isinstance(node, FunctionNode): 146 | # if node.name in function_mapping_plot.keys(): 147 | # margin = 0 148 | 149 | rect_width = node.item.boundingRect().width() + margin 150 | rect_height = node.item.boundingRect().height() + margin 151 | 152 | if isinstance(node, ModuleNode): 153 | rect_width = max(rect_width, node.pos.width) 154 | rect_height = max(rect_height, node.pos.height) 155 | 156 | node.pos.width = rect_width 157 | node.pos.height = rect_height 158 | 159 | def get_illustration_item(self, node): 160 | if isinstance(node, LinearNode): 161 | label_text = "Fully Connected Layer" 162 | color = QColor(250, 128, 114) 163 | item = QGraphicsTextItem(label_text) 164 | item.setFont(QFont(FONT, 12)) 165 | return color, item 166 | 167 | elif isinstance(node, LayerNormNode): 168 | label_text = "Layer Normalization" 169 | color = QColor("#5b0e2d") 170 | item = QGraphicsTextItem(label_text) 171 | item.setFont(QFont(FONT, 12)) 172 | return color, item 173 | 174 | elif isinstance(node, Conv2dNode): 175 | label_text = "Conv 2D" 176 | color = QColor("#7AAB9F") 177 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 178 | return color, item 179 | 180 | elif isinstance(node, Conv1dNode): 181 | label_text = "Conv 1D" 182 | color = QColor("#7AAB9F") 183 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 184 | return color, item 185 | 186 | elif isinstance(node, ModuleNode): 187 | label_text = node.module.__class__.__name__ 188 | color = QColor(self.modules_colors[label_text]) 189 | item = QGraphicsTextItem(label_text) 190 | level = len(node.up_modules) + 1 191 | item.setFont(QFont(FONT, int(12 + 12/level))) 192 | return color, item 193 | 194 | elif isinstance(node, ViewNode): 195 | label_text = (f"{get_tensor_shape_as_string(node.input)}\n" 196 | f"->\n" 197 | f"{get_tensor_shape_as_string(node.output)}") 198 | color = QColor("#26495c") 199 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 200 | 201 | return color, item 202 | 203 | elif isinstance(node, CatNode): 204 | str_shape = (get_tensor_shape_as_string(node.output)) 205 | label_text = (f"concat \n" 206 | f"{str_shape}") 207 | color = QColor("#26495c") 208 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 209 | return color, item 210 | 211 | elif isinstance(node, AddNode): 212 | label_text = "+" 213 | color = Qt.darkGray 214 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 215 | return color, item 216 | 217 | elif isinstance(node, DivNode): 218 | label_text = "/" 219 | color = Qt.darkGray 220 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 221 | return color, item 222 | 223 | elif isinstance(node, AttentionProductNode): 224 | label_text = "Attention Product" 225 | color = Qt.darkGray 226 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 227 | return color, item 228 | 229 | elif isinstance(node, MatMulNode): 230 | label_text = "Matrix Multiplication" 231 | color = Qt.darkGray 232 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 233 | return color, item 234 | 235 | elif isinstance(node, MulNode): 236 | label_text = "*" 237 | color = Qt.darkGray 238 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 239 | return color, item 240 | 241 | elif isinstance(node, SubNode): 242 | label_text = "-" 243 | color = Qt.darkGray 244 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 245 | return color, item 246 | 247 | elif isinstance(node, GetItemNode): 248 | label_text = node.slice 249 | color = Qt.darkGray 250 | item = create_centered_text_item(label_text, QFont(FONT, 12)) 251 | return color, item 252 | 253 | elif isinstance(node, FunctionNode): 254 | label_text = node.name 255 | # if label_text in function_mapping_plot.keys(): 256 | # item = function_mapping_plot[label_text]() 257 | # color = QColor(173, 216, 230) 258 | # return color, item 259 | color = Qt.darkGray 260 | item = QGraphicsTextItem(label_text) 261 | item.setFont(QFont(FONT, 12)) 262 | return color, item 263 | 264 | elif isinstance(node, VarNode): 265 | label_text = node.name 266 | color = QColor(93, 76, 61) 267 | item = QGraphicsTextItem(label_text) 268 | item.setFont(QFont(FONT, 12)) 269 | return color, item 270 | 271 | def contract_graph(self, node): 272 | self.graph.contract_flying_node(node.id) 273 | self.set_items() 274 | self.graph.compute_pos_and_edges() 275 | self.graph.compute_flying_upper_modules_pos() 276 | QTimer.singleShot(0, self.clear_and_render_graph) 277 | 278 | def single_click_envent(self, rect): 279 | self.graph.expend_flying_node(rect.node.id) 280 | self.set_items() 281 | self.graph.compute_pos_and_edges() 282 | self.graph.compute_flying_upper_modules_pos() 283 | QTimer.singleShot(0, self.clear_and_render_graph) 284 | 285 | def clear_and_render_graph(self): 286 | self.clear_scene() 287 | self.render_graph() 288 | 289 | def clear_scene(self): 290 | for item in self.scene.items(): 291 | if isinstance(item, ClickableRectItem): 292 | item.signal_proxy.clicked.disconnect() 293 | self.scene.clear() 294 | 295 | def keyPressEvent(self, event): 296 | if event.key() in (Qt.Key_Plus, Qt.Key_Equal): 297 | self.view.scale(self.zoom_factor, self.zoom_factor) 298 | elif event.key() == Qt.Key_Minus: 299 | self.view.scale(1 / self.zoom_factor, 1 / self.zoom_factor) 300 | 301 | def run_gui(graph): 302 | app = QApplication(sys.argv) 303 | 304 | viewer = GraphViewer(graph) 305 | viewer.show() 306 | 307 | sys.exit(app.exec_()) 308 | -------------------------------------------------------------------------------- /nnViewer/back/nodes.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from PyQt5.QtGui import QColor 4 | from PyQt5.QtWidgets import QGraphicsRectItem 5 | from torch import Tensor, nn, empty 6 | 7 | from nnViewer.back.models import PosData 8 | 9 | LEVEL_1_WIDTH = 250 10 | MAX_HEIGHT = 500 11 | 12 | class Node(): 13 | def __init__(self, 14 | id: str, 15 | name: str): 16 | self.id = id 17 | self.name = name 18 | 19 | self.parents: List[Node] = [] 20 | self.childrens: List[Node] = [] 21 | self.flying_childrens: List[Node] = [] 22 | self.flying_parents: List[Node] = [] 23 | self.next_ids: List[str] = [] 24 | self.previous_ids: List[str] = [] 25 | self.pos: PosData = PosData() 26 | self.upper_module: ModuleNode = None 27 | self.up_modules = [] 28 | self.item: QGraphicsRectItem = None 29 | self.color: QColor = QColor(0, 0, 0) 30 | 31 | def add_parent(self, 32 | parent) -> None: 33 | self.parents.append(parent) 34 | 35 | def add_children(self, 36 | children) -> None: 37 | self.childrens.append(children) 38 | 39 | def __hash__(self): 40 | return hash(self.id) 41 | 42 | 43 | class VarNode(Node): 44 | def __init__(self, 45 | id: str, 46 | name: str, 47 | variable: Tensor): 48 | 49 | super().__init__(id, name) 50 | self.variable = variable 51 | 52 | 53 | class FunctionNode(Node): 54 | def __init__(self, 55 | id: str, 56 | name: str, 57 | function: str, 58 | ): 59 | 60 | super().__init__(id, name) 61 | self.function = function 62 | 63 | self.output = None 64 | self.input = None 65 | self.parents_id = [] 66 | 67 | class ModuleNode(Node): 68 | def __init__(self, 69 | id: str, 70 | name: str, 71 | input: Tuple[Tensor], 72 | module: nn.Module, 73 | output: Tuple[Tensor]): 74 | 75 | super().__init__(id, name) 76 | self.module = module 77 | self.input = input 78 | self.output = output 79 | 80 | self.sub_nodes = set() 81 | self.all_root_sub_ids = [] 82 | self.all_sub_childrens = [] 83 | self.all_sub_parents = [] 84 | self.nb_parameters:int = 0 85 | 86 | def set_height_and_width(self, max_number_parameters): 87 | level = len(self.up_modules) - 2 88 | width = LEVEL_1_WIDTH * 0.8**level 89 | height = (self.nb_parameters/max_number_parameters) * MAX_HEIGHT 90 | self.pos.height = max(height, 25) 91 | self.pos.width = max(width, 25) 92 | 93 | class LinearNode(ModuleNode): 94 | def __init__(self, 95 | id: str, 96 | name: str, 97 | input: Tuple[Tensor], 98 | module: nn.Module, 99 | output: Tensor, 100 | ): 101 | 102 | super().__init__(id, name, input, module, output) 103 | 104 | self.weight: Tensor = None 105 | self.bias: Tensor = None 106 | 107 | # def get_weights(self): 108 | # for sub_node in self.sub_nodes: 109 | # if sub_node.name == "weight": 110 | # self.weight = sub_node.variable 111 | # if sub_node.name== "bias": 112 | # self.bias = sub_node.variable 113 | 114 | def set_height_and_width(self, max_number_parameters): 115 | self.pos.height = 70 * (1+(self.nb_parameters/max_number_parameters)) 116 | self.pos.width = 50 * (1+(self.nb_parameters/max_number_parameters)) 117 | 118 | class LayerNormNode(ModuleNode): 119 | def __init__(self, 120 | id: str, 121 | name: str, 122 | input: Tuple[Tensor], 123 | module: nn.Module, 124 | output: Tensor, 125 | ): 126 | 127 | super().__init__(id, name, input, module, output) 128 | 129 | self.weight: Tensor = empty(()) 130 | self.bias: Tensor = empty(()) 131 | 132 | # def get_weights(self): 133 | # for sub_node in self.sub_nodes: 134 | # if sub_node.name == "weight": 135 | # self.weight = sub_node.variable 136 | # if sub_node.name == "bias": 137 | # self.bias = sub_node.variable 138 | 139 | class Conv2dNode(ModuleNode): 140 | def __init__(self, 141 | id: str, 142 | name: str, 143 | input: Tuple[Tensor], 144 | module: nn.Module, 145 | output: Tuple[Tensor], 146 | ): 147 | 148 | super().__init__(id, name, input, module, output) 149 | 150 | class Conv1dNode(ModuleNode): 151 | def __init__(self, 152 | id: str, 153 | name: str, 154 | input: Tuple[Tensor], 155 | module: nn.Module, 156 | output: Tuple[Tensor], 157 | ): 158 | 159 | super().__init__(id, name, input, module, output) 160 | 161 | class EmbeddingNode(ModuleNode): 162 | def __init__(self, 163 | id: str, 164 | name: str, 165 | input: Tuple[Tensor], 166 | module: nn.Module, 167 | output: Tuple[Tensor], 168 | ): 169 | 170 | super().__init__(id, name, input, module, output) 171 | 172 | class BMMNode(FunctionNode): 173 | def __init__(self, 174 | id: str, 175 | name: str, 176 | function: str, 177 | mat1: Tensor, 178 | mat2: Tensor, 179 | output): 180 | 181 | super().__init__(id, name, function) 182 | self.mat2 = mat2 183 | self.mat1 = mat1 184 | self.output = output 185 | 186 | class MulNode(FunctionNode): 187 | def __init__(self, 188 | id: str, 189 | name: str, 190 | function: str, 191 | mat1: Tensor, 192 | mat2: Tensor, 193 | output: Tensor): 194 | 195 | super().__init__(id, name, function) 196 | self.mat1 = mat1 197 | self.mat2 = mat2 198 | self.output = output 199 | 200 | class AttentionProductNode(FunctionNode): 201 | def __init__(self, 202 | id: str, 203 | name: str, 204 | function: str, 205 | key: Tensor, 206 | query: Tensor, 207 | value: Tensor, 208 | output: Tensor, 209 | mask: Union[Tensor, None]): 210 | 211 | super().__init__(id, name, function) 212 | self.key = key 213 | self.query = query 214 | self.value = value 215 | 216 | self.output = output 217 | self.mask = mask 218 | self.attention_matrix = query @ key.transpose(-2, -1) 219 | 220 | 221 | class AddNode(FunctionNode): 222 | def __init__(self, 223 | id: str, 224 | name: str, 225 | function: str, 226 | mat1: Tensor, 227 | mat2: Tensor, 228 | output: Tensor): 229 | 230 | super().__init__(id, name, function) 231 | self.mat1 = mat1 232 | self.mat2 = mat2 233 | self.output = output 234 | 235 | class MatMulNode(FunctionNode): 236 | def __init__(self, 237 | id: str, 238 | name: str, 239 | function: str, 240 | mat1: Tensor, 241 | mat2: Tensor, 242 | output: Tensor): 243 | 244 | super().__init__(id, name, function) 245 | self.mat1 = mat1 246 | self.mat2 = mat2 247 | self.output = output 248 | 249 | class MeanNode(FunctionNode): 250 | def __init__(self, 251 | id: str, 252 | name: str, 253 | dim: int, 254 | function: str, 255 | input: Tensor, 256 | output: Tensor): 257 | 258 | super().__init__(id, name, function) 259 | self.input = input 260 | self.dim = dim 261 | self.output = output 262 | 263 | class SumNode(FunctionNode): 264 | def __init__(self, 265 | id: str, 266 | name: str, 267 | dim: int, 268 | function: str, 269 | input: Tensor, 270 | output: Tensor): 271 | 272 | super().__init__(id, name, function) 273 | self.input = input 274 | self.dim = dim 275 | self.output = output 276 | 277 | class PowNode(FunctionNode): 278 | def __init__(self, 279 | id: str, 280 | name: str, 281 | function: str, 282 | input: Tensor, 283 | pow_value: float, 284 | output: Tensor): 285 | 286 | super().__init__(id, name, function) 287 | self.input = input 288 | self.pow_value = pow_value 289 | self.output = output 290 | 291 | class SubNode(FunctionNode): 292 | def __init__(self, 293 | id: str, 294 | name: str, 295 | function: str, 296 | mat1: Tensor, 297 | mat2: Tensor, 298 | output: Tensor): 299 | 300 | super().__init__(id, name, function) 301 | self.mat1 = mat1 302 | self.mat2 = mat2 303 | self.output = output 304 | 305 | class DivNode(FunctionNode): 306 | def __init__(self, 307 | id: str, 308 | name: str, 309 | function: str, 310 | mat1: Tensor, 311 | mat2: Tensor, 312 | output: Tensor): 313 | 314 | super().__init__(id, name, function) 315 | self.mat1 = mat1 316 | self.mat2 = mat2 317 | self.output = output 318 | 319 | class StackNode(FunctionNode): 320 | def __init__(self, 321 | id: str, 322 | name: str, 323 | function: str, 324 | input: Tensor, 325 | output: Tensor): 326 | 327 | super().__init__(id, name, function) 328 | self.input = input 329 | self.output = output 330 | 331 | class ViewNode(FunctionNode): 332 | def __init__(self, 333 | id: str, 334 | name: str, 335 | function: str, 336 | input: Tensor, 337 | output: Tensor): 338 | 339 | super().__init__(id, name, function) 340 | self.input = input 341 | self.output = output 342 | 343 | class TransposeNode(FunctionNode): 344 | def __init__(self, 345 | id: str, 346 | name: str, 347 | function: str, 348 | input: Tensor, 349 | output: Tensor): 350 | 351 | super().__init__(id, name, function) 352 | self.input = input 353 | self.output = output 354 | 355 | class ExpandNode(FunctionNode): 356 | def __init__(self, 357 | id: str, 358 | name: str, 359 | function: str, 360 | input: Tensor, 361 | output: Tensor): 362 | 363 | super().__init__(id, name, function) 364 | self.input = input 365 | self.output = output 366 | 367 | class ExpNode(FunctionNode): 368 | def __init__(self, 369 | id: str, 370 | name: str, 371 | function: str, 372 | exp_value: Tensor, 373 | output: Tensor): 374 | 375 | super().__init__(id, name, function) 376 | self.exp_value = exp_value 377 | self.output = output 378 | 379 | class CatNode(FunctionNode): 380 | def __init__(self, 381 | id: str, 382 | name: str, 383 | function: str, 384 | input: Tensor, 385 | output: Tensor): 386 | 387 | super().__init__(id, name, function) 388 | self.input = input 389 | self.output = output 390 | 391 | class GetItemNode(FunctionNode): 392 | def __init__(self, 393 | id: str, 394 | name: str, 395 | function: str, 396 | input: Tensor, 397 | output: Tensor, 398 | slice: str): 399 | 400 | super().__init__(id, name, function) 401 | self.input = input 402 | self.output = output 403 | self.slice = slice 404 | -------------------------------------------------------------------------------- /nnViewer/front/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Tuple 3 | import re 4 | 5 | import numpy as np 6 | from PyQt5.QtGui import QImage, QColor, QPixmap, qRgb 7 | from PyQt5.QtWidgets import QGraphicsTextItem, QHBoxLayout, QLineEdit 8 | from PyQt5.QtCore import Qt 9 | from torch import Tensor 10 | import torch 11 | 12 | from nnViewer.back.nodes import ModuleNode, FunctionNode, VarNode, BMMNode, MulNode, AddNode, ViewNode, \ 13 | ExpandNode, GetItemNode, Conv2dNode, EmbeddingNode, PowNode, MeanNode, StackNode, CatNode, Conv1dNode, ExpNode, \ 14 | DivNode, SumNode, MatMulNode, AttentionProductNode, LinearNode 15 | 16 | 17 | def get_tuple_of_tensors_shapes_as_string(tensor_tuple): 18 | return ", ".join(["None" if tensor is None else get_tensor_shape_as_string(tensor) for tensor in tensor_tuple]) 19 | 20 | def get_tensor_shape_as_string(tensor): 21 | return f"({', '.join(map(str, tensor.shape))})" 22 | 23 | def split_module_name(name: str) -> List[str]: 24 | if name == ".": 25 | return [""] 26 | return re.split(r"\.+", name) 27 | 28 | def get_node_info(node): 29 | def create_tensor_info(label, tensor, visualization_type="tensor"): 30 | return { 31 | "value": label, 32 | "tensor": tensor, 33 | "visualization_type": visualization_type 34 | } 35 | 36 | if isinstance(node, Conv2dNode): 37 | return { 38 | "Input Shape": create_tensor_info( 39 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 40 | node.input[0] 41 | ), 42 | "Output Shape": create_tensor_info( 43 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 44 | node.output[0] 45 | ), 46 | "Number of Parameters": "{:,}".format(node.nb_parameters), 47 | "Input Channels": str(node.module.in_channels), 48 | "Output Channels": str(node.module.out_channels), 49 | "Kernel Size": create_tensor_info( 50 | f"{str(node.module.kernel_size[0])} x {str(node.module.kernel_size[1])}", 51 | node.module.weight if hasattr(node.module, 'weight') else None, 52 | "kernel" 53 | ), 54 | "Stride": f"{str(node.module.stride[0])} x {str(node.module.stride[1])}", 55 | "Padding": f"{str(node.module.padding[0])} x {str(node.module.padding[1])}", 56 | } 57 | 58 | if isinstance(node, Conv1dNode): 59 | return { 60 | "Input Shape": create_tensor_info( 61 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 62 | node.input[0] 63 | ), 64 | "Output Shape": create_tensor_info( 65 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 66 | node.output[0] 67 | ), 68 | "Number of Parameters": "{:,}".format(node.nb_parameters), 69 | "Input Channels": str(node.module.in_channels), 70 | "Output Channels": str(node.module.out_channels), 71 | "Kernel Size": create_tensor_info( 72 | f"{str(node.module.kernel_size[0])}", 73 | node.module.weight, 74 | "kernel" 75 | ), 76 | "Stride": f"{str(node.module.stride[0])}", 77 | "Padding": f"{str(node.module.padding[0])}", 78 | } 79 | 80 | if isinstance(node, LinearNode): 81 | return { 82 | "Input Shape": create_tensor_info( 83 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 84 | node.input[0] 85 | ), 86 | "Output Shape": create_tensor_info( 87 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 88 | node.output[0] 89 | ), 90 | "Number of Parameters": "{:,}".format(node.nb_parameters), 91 | "Weights":create_tensor_info( 92 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.module.weight)), 93 | node.module.weight 94 | ), 95 | "Bias": create_tensor_info( 96 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.module.bias)), 97 | node.module.bias 98 | ) 99 | } 100 | 101 | if isinstance(node, EmbeddingNode): 102 | return { 103 | "Input Shape": create_tensor_info( 104 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 105 | node.input[0] 106 | ), 107 | "Output Shape": create_tensor_info( 108 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 109 | node.output[0] 110 | ), 111 | "Number of Parameters": "{:,}".format(node.nb_parameters), 112 | "Size of the Embedding Matrix": create_tensor_info( 113 | f"{str(node.module.num_embeddings)} x {str(node.module.embedding_dim)}", 114 | node.module.weight 115 | ), 116 | "Size of Embedding Vector": str(node.module.embedding_dim), 117 | } 118 | 119 | elif isinstance(node, ModuleNode): 120 | return { 121 | "Input Shape": create_tensor_info( 122 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 123 | node.input[0] 124 | ), 125 | "Output Shape": create_tensor_info( 126 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 127 | node.output[0] 128 | ), 129 | "Number of Parameters": "{:,}".format(node.nb_parameters), 130 | } 131 | 132 | elif isinstance(node, MulNode): 133 | return { 134 | "First Element": create_tensor_info(format_matrix_data(node.mat1), node.mat1), 135 | "Second Element": create_tensor_info(format_matrix_data(node.mat2), node.mat2), 136 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 137 | } 138 | 139 | elif isinstance(node, ExpNode): 140 | return { 141 | "Exponential value": str(node.exp_value), 142 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 143 | } 144 | 145 | elif isinstance(node, DivNode): 146 | return { 147 | "First Element": create_tensor_info(format_matrix_data(node.mat1), node.mat1), 148 | "Second Element": create_tensor_info(format_matrix_data(node.mat2), node.mat2), 149 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 150 | 151 | } 152 | 153 | elif isinstance(node, AddNode): 154 | return { 155 | "First Element": create_tensor_info(format_matrix_data(node.mat1), node.mat1), 156 | "Second Element": create_tensor_info(format_matrix_data(node.mat2), node.mat2), 157 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 158 | } 159 | 160 | elif isinstance(node, MatMulNode): 161 | return { 162 | "First Element": create_tensor_info(format_matrix_data(node.mat1), node.mat1), 163 | "Second Element": create_tensor_info(format_matrix_data(node.mat2), node.mat2), 164 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 165 | } 166 | 167 | elif isinstance(node, AttentionProductNode): 168 | return { 169 | "Key": create_tensor_info(format_matrix_data(node.key), node.key), 170 | "Query": create_tensor_info(format_matrix_data(node.query), node.query), 171 | "Value": create_tensor_info(format_matrix_data(node.value), node.value), 172 | "Attention Matrix": create_tensor_info(format_matrix_data(node.attention_matrix), node.attention_matrix), 173 | "mask": "Not masked attention" if not node.mask else (format_matrix_data(node.mask), node.mask), 174 | "Output": create_tensor_info(format_matrix_data(node.output), node.output), 175 | } 176 | 177 | elif isinstance(node, PowNode): 178 | return { 179 | "Input Tensor Shape": create_tensor_info(format_matrix_data(node.input), node.input), 180 | "Pow Value": create_tensor_info(format_matrix_data(node.pow_value), node.pow_value), 181 | } 182 | 183 | elif isinstance(node, MeanNode): 184 | return { 185 | "Input Tensor Shape": create_tensor_info(format_matrix_data(node.input), node.input), 186 | "Output Tensor Shape": create_tensor_info(format_matrix_data(node.output), node.output), 187 | "Mean on Dim": str(node.dim) 188 | } 189 | 190 | elif isinstance(node, SumNode): 191 | return { 192 | "Input Tensor Shape": create_tensor_info(format_matrix_data(node.input), node.input), 193 | "Output Tensor Shape": create_tensor_info(format_matrix_data(node.output), node.output), 194 | "Sum on Dim": str(node.dim) 195 | } 196 | 197 | elif isinstance(node, CatNode): 198 | return { 199 | "Shape of Tensor to Stack": create_tensor_info( 200 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 201 | node.input[0] 202 | ), 203 | "Output Tensor Shape": create_tensor_info( 204 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 205 | node.output[0] 206 | ), 207 | } 208 | 209 | elif isinstance(node, ViewNode): 210 | return { 211 | "Input Shape": create_tensor_info( 212 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 213 | node.input[0] 214 | ), 215 | "Output Shape": create_tensor_info( 216 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 217 | node.output[0] 218 | ), 219 | } 220 | 221 | elif isinstance(node, GetItemNode): 222 | return { 223 | "Input Shape": create_tensor_info( 224 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.input)), 225 | node.input[0] 226 | ), 227 | "Output Shape": create_tensor_info( 228 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.output)), 229 | node.output[0] 230 | ), 231 | "Slices": str(node.slice) 232 | } 233 | 234 | elif isinstance(node, ExpandNode): 235 | return { 236 | "Node Name": str(node.name), 237 | } 238 | 239 | elif isinstance(node, FunctionNode): 240 | return { 241 | "Node Name": str(node.name), 242 | } 243 | 244 | elif isinstance(node, VarNode): 245 | return { 246 | "Node Name": str(node.name), 247 | "Variable Shape": create_tensor_info( 248 | get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(node.variable)), 249 | node.variable 250 | ), 251 | } 252 | 253 | def format_matrix_data(matrix): 254 | if isinstance(matrix, Tensor): 255 | return f"Tensor of shape: {get_tuple_of_tensors_shapes_as_string(get_var_as_tuple_tensor(matrix))}" 256 | return str(matrix) 257 | 258 | 259 | 260 | def make_color_paler(color, factor=0.2): 261 | r, g, b = color 262 | 263 | new_r = int(r + (255 - r) * factor) 264 | new_g = int(g + (255 - g) * factor) 265 | new_b = int(b + (255 - b) * factor) 266 | 267 | return ( 268 | min(255, max(0, new_r)), 269 | min(255, max(0, new_g)), 270 | min(255, max(0, new_b)) 271 | ) 272 | 273 | def get_string(var): 274 | if isinstance(var, Tensor): 275 | return get_tensor_shape_as_string(var) 276 | elif isinstance(var, float): 277 | return str(var) 278 | elif isinstance(var, int): 279 | return str(var) 280 | elif isinstance(var, str): 281 | return var 282 | elif var is None: 283 | return "None" 284 | else: 285 | return str(var) 286 | 287 | def get_var_as_tuple_tensor(var): 288 | if var is None: 289 | return (None) 290 | elif (not isinstance(var, Tensor)) and (not isinstance(var, Tuple)): 291 | var_output = [] 292 | for _, value in var.__dict__.items(): 293 | if isinstance(value, Tensor): 294 | var_output.append(value) 295 | return tuple(var_output) 296 | elif not isinstance(var, tuple): 297 | return (var,) 298 | else: 299 | return var 300 | 301 | def create_centered_text_item(label_text, font): 302 | lines = label_text.splitlines() 303 | 304 | max_length = max(len(line) for line in lines) 305 | 306 | centered_lines = [line.center(max_length) for line in lines] 307 | 308 | centered_text = "\n".join(centered_lines) 309 | item = QGraphicsTextItem(centered_text) 310 | item.setFont(font) 311 | 312 | item.setTextWidth(item.boundingRect().width()) 313 | 314 | item.setPos(-item.boundingRect().width() / 2, -item.boundingRect().height() / 2) 315 | 316 | return item 317 | 318 | 319 | def create_image_from_matrix(matrix, color="gray"): 320 | height, width = matrix.shape 321 | 322 | # Calculate new dimensions while maintaining aspect ratio 323 | if height > 256 or width > 256: 324 | ratio = min(256 / height, 256 / width) 325 | new_height = int(height * ratio) 326 | new_width = int(width * ratio) 327 | 328 | # Resize the matrix using numpy 329 | from scipy.ndimage import zoom 330 | scale_factors = (new_height / height, new_width / width) 331 | matrix = zoom(matrix, scale_factors, order=0) 332 | height, width = new_height, new_width 333 | 334 | # Ensure matrix values are within [0, 255] 335 | matrix = np.clip(matrix, 0, 255).astype(np.uint8) 336 | 337 | # Create image with the new dimensions 338 | image = QImage(width, height, QImage.Format_RGB32) 339 | 340 | # Create the color array 341 | if color == "gray": 342 | for y in range(height): 343 | for x in range(width): 344 | val = int(matrix[y, x]) 345 | image.setPixel(x, y, qRgb(val, val, val)) 346 | else: # red 347 | for y in range(height): 348 | for x in range(width): 349 | val = int(matrix[y, x]) 350 | image.setPixel(x, y, qRgb(val, 0, 0)) 351 | 352 | # Si l'image est plus petite que 256x256, on l'agrandit en mode pixelisé 353 | if width < 256 or height < 256: 354 | image = image.scaled(256, 256, Qt.KeepAspectRatio, Qt.FastTransformation) 355 | 356 | return QPixmap.fromImage(image) 357 | 358 | 359 | def normalize_to_255(matrix): 360 | min_val = np.min(matrix) 361 | max_val = np.max(matrix) 362 | 363 | normalized_matrix = (matrix - min_val) / (max_val - min_val) * 255 364 | normalized_matrix = np.clip(normalized_matrix, 0, 255) # Ensure values are in range [0, 255] 365 | 366 | return normalized_matrix.astype(np.uint8) 367 | 368 | def get_image_from_slice_layout_and_tensor(slice_layout, tensor, color="gray"): 369 | slice = get_slice_from_layout(slice_layout) 370 | tensor = tensor.detach() if tensor.requires_grad else tensor 371 | slice_tensor = tensor[slice] 372 | 373 | slice_numpy = slice_tensor.numpy() 374 | 375 | slice_numpy = normalize_to_255(slice_numpy) 376 | 377 | return create_image_from_matrix(slice_numpy, color) 378 | 379 | def get_slice_from_layout(slice_layout): 380 | slice_list = [] 381 | for i in range(slice_layout.count()): 382 | s = slice_layout.itemAt(i).widget().text() 383 | if s == ':': 384 | slice_list.append(slice(None)) 385 | else: 386 | slice_list.append(int(s)) 387 | return tuple(slice_list) 388 | 389 | def set_up_slice_layout_from_tensor(tensor): 390 | input_slice_layout = QHBoxLayout() 391 | slice_fields = [] 392 | for i, _ in enumerate(tensor.shape): 393 | slice_field = QLineEdit() 394 | if i < len(tensor.shape) - 2: 395 | slice_field.setText("0") 396 | else: 397 | slice_field.setText(":") 398 | slice_field.setAlignment(Qt.AlignCenter) 399 | input_slice_layout.addWidget(slice_field) 400 | slice_field.setFixedWidth(18) 401 | slice_fields.append(slice_field) 402 | return input_slice_layout, slice_fields 403 | -------------------------------------------------------------------------------- /nnViewer/back/graph.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from copy import copy 3 | from typing import Tuple, Dict, List, Union, Set 4 | 5 | from pygraphviz import AGraph 6 | from torch import Tensor 7 | import torch 8 | 9 | from nnViewer.back.nodes import Node, VarNode, ModuleNode, FunctionNode 10 | from nnViewer.back.utils import create_bounding_rectangle, parse_pos 11 | 12 | SAVED_PREFIX = "_saved_" 13 | NODE_FUNCTION_TO_DELETE = ["AccumulateGrad", "ToCopyBackward0", "CloneBackward0", "SliceBackward0"] 14 | ACCUMULATE_GRAD_FN = ["AccumulateGrad"] 15 | 16 | class Graph(): 17 | def __init__(self, 18 | nodes: Set[Node], 19 | edges: Set[Tuple[str, str]]): 20 | 21 | self.nodes = set() 22 | self.edges = set() 23 | self.ids = set() 24 | 25 | self.add_nodes(nodes) 26 | self.add_edges(edges) 27 | 28 | self.module_nodes = set() 29 | self.flying_nodes = set() 30 | self.flying_edges = set() 31 | self.flying_upper_modules = [] 32 | self.max_number_parameters = 0 33 | self.level_max = 0 34 | self.pos_edges = [] 35 | 36 | def add_nodes(self, nodes: Union[Node, Set[Node]]) -> None: 37 | if isinstance(nodes, Node): 38 | nodes = {nodes} 39 | self.nodes.update(nodes) 40 | self.ids.update({node.id for node in nodes}) 41 | 42 | def set_level_max(self) -> None: 43 | for node in self.nodes: 44 | node_level = len(node.up_modules) - 1 45 | self.level_max = max(self.level_max, node_level) 46 | 47 | def get_node(self, node_id: str, graph_name:str="all") -> Union[Node, bool]: 48 | graph_map = { 49 | "all": self.nodes, 50 | "flying": self.flying_nodes, 51 | "module": self.module_nodes 52 | } 53 | 54 | nodes = graph_map.get(graph_name) 55 | if nodes is None: 56 | raise Exception(f"Unknown graph name: {graph_name}") 57 | 58 | return next((node for node in nodes if node.id == node_id), False) 59 | 60 | def get_nodes(self, node_ids: List[str]) -> List[Node]: 61 | return [self.get_node(idx) for idx in node_ids] 62 | 63 | def get_module_class_name(self) -> List[str]: 64 | return list(set(sorted([node.module.__class__.__name__ for node in self.module_nodes]))) 65 | 66 | def delete_nodes_type(self, funcs_to_delete: List[str]) -> None: 67 | for node in copy(self.nodes): 68 | if isinstance(node, FunctionNode): 69 | if node.function.__class__.__name__ in funcs_to_delete: 70 | self.remove_node_and_reset_relatives(node) 71 | self.reset_edges() 72 | 73 | def safe_delete(self, 74 | node:Node) -> None: 75 | for child in node.childrens: 76 | if node in child.parents: 77 | child.parents.remove(node) 78 | 79 | for parent_node in node.parents: 80 | if node in parent_node.childrens: 81 | parent_node.childrens.remove(node) 82 | 83 | if node in self.nodes: 84 | self.nodes.remove(node) 85 | 86 | if node.id in self.ids: 87 | self.ids.remove(node.id) 88 | 89 | for module_node in self.module_nodes: 90 | if node.id in module_node.all_root_sub_ids: 91 | module_node.all_root_sub_ids.remove(node.id) 92 | if node in module_node.sub_nodes: 93 | module_node.sub_nodes.remove(node) 94 | 95 | def remove_node_and_reset_relatives(self, node: Node) -> None: 96 | childrens = node.childrens 97 | parents = node.parents 98 | for child in childrens: 99 | if node in child.parents: 100 | child.parents.remove(node) 101 | child.parents.extend(parents) 102 | for parent in parents: 103 | if node in parent.childrens: 104 | parent.childrens.remove(node) 105 | parent.childrens.extend(childrens) 106 | self.nodes.remove(node) 107 | self.ids.remove(node.id) 108 | 109 | for module_node in self.module_nodes: 110 | if node.id in module_node.all_root_sub_ids: 111 | module_node.all_root_sub_ids.remove(node.id) 112 | if node in module_node.sub_nodes: 113 | module_node.sub_nodes.remove(node) 114 | 115 | def add_edge(self, 116 | edge: Tuple[str, str]) -> None: 117 | self.edges.add(edge) 118 | 119 | def add_edges(self, edges: Union[Tuple[str, str], Set[Tuple[str, str]]]) -> None: 120 | if isinstance(edges, tuple): 121 | edges = {edges} 122 | self.edges.update(edges) 123 | 124 | def set_relatives(self)-> None: 125 | for edge in self.edges: 126 | head_node = self.get_node(edge[1]) 127 | tail_node = self.get_node(edge[0]) 128 | head_node.add_parent(tail_node) 129 | tail_node.add_children(head_node) 130 | 131 | def set_sub_parents_and_childrens(self)-> None: 132 | for node in self.module_nodes: 133 | for sub_node_id in node.all_root_sub_ids: 134 | sub_node = self.get_node(sub_node_id) 135 | 136 | node.all_sub_childrens.extend(sub_node.next_ids) 137 | node.all_sub_parents.extend(sub_node.previous_ids) 138 | 139 | def set_next_data(self) -> None: 140 | for node in self.nodes: 141 | node.next_ids = [child.id for child in node.childrens] 142 | 143 | def set_previous_data(self) -> None: 144 | for node in self.nodes: 145 | node.previous_ids = [parent.id for parent in node.parents] 146 | 147 | def init_flying_graph(self) -> None : 148 | for flying_node in self.nodes: 149 | if not flying_node.up_modules: 150 | self.set_flying_relatives(flying_node) 151 | self.flying_nodes.add(flying_node) 152 | 153 | self.set_flying_edges() 154 | 155 | def set_flying_edges(self) -> None: 156 | self.flying_edges = set() 157 | for node in self.flying_nodes: 158 | for child in node.flying_childrens: 159 | self.flying_edges.add((node.id, child.id)) 160 | 161 | def reset_edges(self) -> None: 162 | self.edges = set() 163 | for node in self.nodes: 164 | for child in node.childrens: 165 | self.edges.add((node.id, child.id)) 166 | for parent in node.parents: 167 | self.edges.add((parent.id, node.id)) 168 | 169 | def get_flying_ids(self) -> List[str]: 170 | return [node.id for node in self.flying_nodes] 171 | 172 | def get_flying_modules(self) -> List[str]: 173 | return [node.name for node in self.flying_nodes] 174 | 175 | def delete_flying_node(self, node: Node) -> None: 176 | self.flying_nodes.remove(node) 177 | for flying_node in self.flying_nodes: 178 | if node in flying_node.flying_childrens: 179 | flying_node.flying_childrens.remove(node) 180 | if node in flying_node.flying_parents: 181 | flying_node.flying_parents.remove(node) 182 | 183 | def set_flying_relatives(self, node: Node) -> None: 184 | node.flying_childrens = [] 185 | node.flying_parents = [] 186 | if isinstance(node, ModuleNode): 187 | all_sub_childrens = node.all_sub_childrens 188 | all_sub_parents = node.all_sub_parents 189 | else: 190 | all_sub_childrens = node.next_ids 191 | all_sub_parents = node.previous_ids 192 | 193 | for flying_node in self.flying_nodes: 194 | if isinstance(flying_node, ModuleNode): 195 | all_root_sub_ids = flying_node.all_root_sub_ids 196 | else: 197 | all_root_sub_ids = [flying_node.id] 198 | 199 | if len(set(all_sub_childrens) & set(all_root_sub_ids)) > 0: 200 | node.flying_childrens.append(flying_node) 201 | if node not in flying_node.flying_parents: 202 | flying_node.flying_parents.append(node) 203 | if len(set(all_sub_parents) & set(all_root_sub_ids)) > 0: 204 | node.flying_parents.append(flying_node) 205 | if node not in flying_node.flying_childrens: 206 | flying_node.flying_childrens.append(node) 207 | 208 | def expend_flying_node(self, node_id:str) -> None: 209 | if node_id not in self.get_flying_ids(): 210 | raise Exception("this node is not flying") 211 | node = self.get_node(node_id) 212 | if type(node) is ModuleNode: 213 | self.delete_flying_node(node) 214 | 215 | for new_node in node.sub_nodes: 216 | self.set_flying_relatives(new_node) 217 | self.flying_nodes.add(new_node) 218 | self.set_flying_edges() 219 | 220 | else: 221 | warnings.warn("this node is not can't be expended") 222 | 223 | def contract_flying_node(self, node_id:str) -> None: 224 | upper_node = self.get_node(node_id).upper_module 225 | if upper_node: 226 | sub_nodes = { 227 | flying_node 228 | for flying_node in self.flying_nodes 229 | if upper_node.id in flying_node.up_modules 230 | } 231 | 232 | for sub_node in sub_nodes: 233 | self.delete_flying_node(sub_node) 234 | 235 | self.set_flying_relatives(upper_node) 236 | self.flying_nodes.add(upper_node) 237 | self.set_flying_edges() 238 | 239 | else: 240 | warnings.warn("this node cannot be contract") 241 | 242 | def compute_flying_upper_modules_pos(self) -> None: 243 | self.flying_upper_modules = [] 244 | 245 | upper_modules = {node.upper_module 246 | for node in self.flying_nodes if node.upper_module is not None} 247 | 248 | for upper_module in upper_modules: 249 | upper_modules_pos = [] 250 | for flying_node in self.flying_nodes: 251 | if upper_module.id in flying_node.up_modules: 252 | upper_modules_pos.append(flying_node.pos) 253 | 254 | level = len(upper_module.up_modules) 255 | self.flying_upper_modules.append( 256 | create_bounding_rectangle( 257 | rectangles=upper_modules_pos, 258 | class_name=upper_module.module.__class__.__name__, 259 | margin_width=20 * (self.level_max - level), 260 | margin_height=20, 261 | level=level 262 | ) 263 | ) 264 | 265 | def compute_pos_and_edges(self) -> None: 266 | G = AGraph(directed=True) 267 | 268 | for node in self.flying_nodes: 269 | if node.pos.width: 270 | G.add_node(node.id, shape="box", width=node.pos.width/72, height=node.pos.height/72) 271 | else: 272 | G.add_node(node.id, shape="box", width=1/72, height=1/72) 273 | 274 | for edge in self.flying_edges: 275 | G.add_edge(edge[0], edge[1], sep="0.001") 276 | 277 | G.layout(prog="dot") 278 | 279 | for node in G.nodes(): 280 | graph_node = self.get_node(str(node), graph_name="flying") 281 | graph_node.pos.x = float(node.attr["pos"].split(",")[0]) 282 | graph_node.pos.y = -float(node.attr["pos"].split(",")[1]) 283 | 284 | self.pos_edges = [] 285 | for edge in G.edges(): 286 | edge_data = G.get_edge(edge[0], edge[1]) 287 | pos = edge_data.attr["pos"] 288 | curve_points = parse_pos(pos) 289 | pose1 = self.get_node(edge[1]) 290 | pose0 = self.get_node(edge[0]) 291 | curve_points.insert(0, (pose0.pos.x, pose0.pos.y)) 292 | curve_points.append((pose1.pos.x, pose1.pos.y)) 293 | self.pos_edges.append(curve_points) 294 | 295 | def set_modules(self) -> None: 296 | for module_node in self.module_nodes: 297 | module_node.all_root_sub_ids = module_node.all_root_sub_ids & self.ids 298 | module_node.upper_module = None 299 | num_sub_nodes = float('inf') 300 | 301 | for node_include in self.module_nodes: 302 | if module_node is not node_include: 303 | if module_node.all_root_sub_ids < node_include.all_root_sub_ids: 304 | if (module_node.upper_module is None 305 | or len(node_include.all_root_sub_ids) < num_sub_nodes): 306 | module_node.upper_module = node_include 307 | num_sub_nodes = len(node_include.all_root_sub_ids) 308 | 309 | if module_node.upper_module: 310 | module_node.upper_module.sub_nodes.add(module_node) 311 | 312 | for module_node in self.module_nodes: 313 | max_upper_module = module_node.upper_module 314 | while max_upper_module is not None: 315 | module_node.up_modules.append(max_upper_module.id) 316 | max_upper_module = max_upper_module.upper_module 317 | module_node.up_modules = module_node.up_modules[::-1] 318 | 319 | def set_nodes_in_module(self, 320 | root_node_belong_to_module: Dict[str, str]) -> None: 321 | for node_id, module_id in root_node_belong_to_module.items(): 322 | node = self.get_node(node_id) 323 | module = self.get_node(module_id, graph_name="module") 324 | if node: 325 | module.sub_nodes.add(node) 326 | node.up_modules = module.up_modules + [module_id] 327 | node.upper_module = module 328 | 329 | module_nodes_to_delete = {module_node for module_node in self.module_nodes 330 | if not module_node.sub_nodes} 331 | 332 | for module_node in module_nodes_to_delete: 333 | self.safe_delete(module_node) 334 | 335 | def set_number_parameters(self) -> None: 336 | for node in self.module_nodes: 337 | for idx in node.all_root_sub_ids: 338 | sub_node = self.get_node(idx) 339 | if isinstance(sub_node, VarNode): 340 | node.nb_parameters += sub_node.variable.numel() 341 | 342 | self.max_number_parameters = max(self.max_number_parameters, node.nb_parameters) 343 | 344 | def set_height_and_width(self) -> None: 345 | for node in self.module_nodes: 346 | node.set_height_and_width(self.max_number_parameters) 347 | 348 | def set_wrapped_data(self, 349 | wrapped_output:Dict) -> None: 350 | intermediate_nodes = [] 351 | for wrapped_data in wrapped_output: 352 | node = self.get_node(wrapped_data["node"].id) 353 | if node: 354 | new_node = wrapped_data["node"] 355 | self.set_input(node, wrapped_data) 356 | 357 | new_node.up_modules = node.up_modules 358 | new_node.upper_module = node.upper_module 359 | self.safe_replace(node, new_node) 360 | 361 | parents_id = [] 362 | for arg in wrapped_data["args"]+tuple(wrapped_data["kwargs"].values()): 363 | if hasattr(arg, "grad_fn"): 364 | if arg.grad_fn is not None: 365 | parents_id.append(str(id(arg.grad_fn))) 366 | 367 | for parent_id in parents_id: 368 | parent_node = self.get_node(parent_id) 369 | if parent_node: 370 | new_node.parents.append(parent_node) 371 | parent_node.childrens.append(new_node) 372 | 373 | for grad_fn_to_del in wrapped_data["grad_fn_created"]: 374 | intermediate_nodes.append(self.get_node(grad_fn_to_del)) 375 | if grad_fn_to_del in parent_node.childrens: 376 | parent_node.childrens.remove(grad_fn_to_del) 377 | 378 | for del_node in set(intermediate_nodes): 379 | if del_node in self.nodes: 380 | self.safe_delete(del_node) 381 | 382 | def safe_replace(self, 383 | node:Node, 384 | new_node:Node) -> None: 385 | 386 | new_node.childrens = node.childrens 387 | new_node.parents = node.parents 388 | 389 | for child in node.childrens: 390 | child.parents.append(new_node) 391 | 392 | for parent in node.parents: 393 | parent.childrens.append(new_node) 394 | 395 | if node.upper_module: 396 | node.upper_module.sub_nodes.remove(node) 397 | node.upper_module.sub_nodes.add(new_node) 398 | 399 | self.nodes.remove(node) 400 | self.nodes.add(new_node) 401 | 402 | def set_input(self, node: Node, wrap_data: Dict) -> None: 403 | args = wrap_data["args"] 404 | 405 | if all(isinstance(parent, VarNode) for parent in node.parents): 406 | params = [parent.variable for parent in node.parents] 407 | 408 | for arg in args: 409 | if isinstance(arg, Tensor) and not any(torch.equal(arg, t) for t in params): 410 | input_node = VarNode( 411 | id = str(id(arg)), 412 | name = "input", 413 | variable = arg, 414 | ) 415 | input_node.add_children(node) 416 | node.add_parent(input_node) 417 | self.add_nodes( 418 | input_node 419 | ) 420 | 421 | def set_up(self, 422 | hooks: Set[ModuleNode], 423 | root_node_belong_to_module: Dict[str, str], 424 | wrapped_output: Dict) -> None: 425 | 426 | self.add_nodes(hooks) 427 | self.module_nodes = hooks 428 | 429 | self.set_modules() 430 | self.set_nodes_in_module(root_node_belong_to_module) 431 | 432 | self.set_relatives() 433 | 434 | self.delete_nodes_type(ACCUMULATE_GRAD_FN) 435 | self.set_wrapped_data(wrapped_output) 436 | self.delete_nodes_type(NODE_FUNCTION_TO_DELETE) 437 | 438 | self.set_level_max() 439 | 440 | self.set_next_data() 441 | self.set_previous_data() 442 | 443 | self.set_sub_parents_and_childrens() 444 | 445 | self.set_number_parameters() 446 | self.set_height_and_width() 447 | 448 | self.init_flying_graph() 449 | self.set_flying_edges() 450 | 451 | self.compute_pos_and_edges() 452 | self.compute_flying_upper_modules_pos() 453 | -------------------------------------------------------------------------------- /nnViewer/front/node_item.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import (QGraphicsRectItem, QMenu, QDialog, QLabel, QVBoxLayout, QFormLayout, 2 | QPushButton, 3 | QHBoxLayout, QTableWidget, QTableWidgetItem, QGraphicsDropShadowEffect, 4 | QWidget, QFrame) 5 | from PyQt5.QtCore import Qt, QTimer, QObject, pyqtSignal, QPropertyAnimation, QRect, QEasingCurve 6 | from PyQt5.QtGui import QFont, QColor, QPixmap, QPainter, QPainterPath 7 | from torch import float16 8 | 9 | from nnViewer.back.nodes import Conv2dNode, VarNode, ModuleNode, FunctionNode 10 | from nnViewer.front.utils import (get_node_info, get_tensor_shape_as_string, 11 | get_image_from_slice_layout_and_tensor, 12 | set_up_slice_layout_from_tensor) 13 | 14 | 15 | class ComputationCard(QFrame): 16 | def __init__(self, title, shape_text=None, parent=None): 17 | super().__init__(parent) 18 | self.setObjectName("computationCard") 19 | self.setStyleSheet(""" 20 | QFrame#computationCard { 21 | background-color: #1a1a1a; 22 | border-radius: 15px; 23 | padding: 15px; 24 | min-width: 250px; 25 | } 26 | QLabel { 27 | color: #ffffff; 28 | } 29 | QLineEdit { 30 | background-color: #2d2d2d; 31 | border: none; 32 | border-radius: 5px; 33 | color: #ffffff; 34 | padding: 5px; 35 | min-width: 30px; 36 | max-width: 50px; 37 | font-size: 12px; 38 | } 39 | QLineEdit:focus { 40 | border: 1px solid #2196F3; 41 | } 42 | """) 43 | 44 | layout = QVBoxLayout(self) 45 | layout.setSpacing(10) 46 | 47 | # Title 48 | title_label = QLabel(title) 49 | title_label.setFont(QFont("Segoe UI", 12, QFont.Bold)) 50 | title_label.setStyleSheet("color: #2196F3;") 51 | layout.addWidget(title_label) 52 | 53 | # Shape text if provided 54 | if shape_text: 55 | shape_label = QLabel(shape_text) 56 | shape_label.setFont(QFont("Segoe UI", 10)) 57 | shape_label.setStyleSheet("color: #808080;") 58 | layout.addWidget(shape_label) 59 | 60 | self.slice_container = QWidget() 61 | self.slice_layout = QFormLayout(self.slice_container) 62 | self.slice_layout.setSpacing(8) 63 | self.slice_layout.setContentsMargins(0, 5, 0, 5) 64 | layout.addWidget(self.slice_container) 65 | 66 | # Image container avec une taille plus grande 67 | self.image_scroll = QWidget() 68 | self.image_scroll.setFixedSize(250, 250) 69 | 70 | # Image container 71 | self.image_label = QLabel(self.image_scroll) 72 | self.image_label.setAlignment(Qt.AlignCenter) 73 | self.image_label.setStyleSheet("background-color: transparent;") 74 | 75 | # Layout pour le widget d'image 76 | image_layout = QVBoxLayout(self.image_scroll) 77 | image_layout.addWidget(self.image_label) 78 | image_layout.setContentsMargins(0, 0, 0, 0) 79 | 80 | layout.addWidget(self.image_scroll) 81 | 82 | # Effet d'ombre 83 | shadow = QGraphicsDropShadowEffect() 84 | shadow.setBlurRadius(20) 85 | shadow.setColor(QColor(0, 0, 0, 150)) 86 | shadow.setOffset(0, 0) 87 | self.setGraphicsEffect(shadow) 88 | 89 | class SignalProxy(QObject): 90 | clicked = pyqtSignal(QGraphicsRectItem) 91 | doubleClicked = pyqtSignal(QGraphicsRectItem) 92 | 93 | 94 | class ClickableRectItem(QGraphicsRectItem): 95 | def __init__(self, node, *args, **kwargs): 96 | super().__init__(*args, **kwargs) 97 | self.setAcceptHoverEvents(True) 98 | self.signal_proxy = SignalProxy() 99 | self.click_timer = QTimer() 100 | self.click_timer.setSingleShot(True) 101 | self.click_timer.timeout.connect(self.on_single_click) 102 | self.double_click_detected = False 103 | self.initial_pos = None 104 | self.as_moved = False 105 | self.node = node 106 | 107 | def mousePressEvent(self, event): 108 | self.initial_pos = self.pos() 109 | self.as_moved = False 110 | 111 | def mouseReleaseEvent(self, event): 112 | if self.click_timer.isActive(): 113 | self.double_click_detected = True 114 | self.click_timer.stop() 115 | else: 116 | if event.button() == Qt.RightButton: 117 | self.show_context_menu(event.screenPos()) 118 | else: 119 | self.double_click_detected = False 120 | self.click_timer.start(200) 121 | super().mousePressEvent(event) 122 | 123 | if self.pos() != self.initial_pos: 124 | self.as_moved = True 125 | 126 | def on_single_click(self): 127 | if not self.double_click_detected and not self.as_moved: 128 | self.signal_proxy.clicked.emit(self) 129 | 130 | def mouseDoubleClickEvent(self, event): 131 | self.signal_proxy.doubleClicked.emit(self) 132 | super().mouseDoubleClickEvent(event) 133 | 134 | def paint(self, painter, option, widget=None): 135 | painter.setRenderHint(QPainter.Antialiasing) 136 | pen = self.pen() 137 | brush = self.brush() 138 | painter.setPen(pen) 139 | painter.setBrush(brush) 140 | path = QPainterPath() 141 | rect = self.rect() 142 | path.addRoundedRect(rect, 15, 15) 143 | painter.drawPath(path) 144 | 145 | def show_context_menu(self, global_pos): 146 | menu = QMenu() 147 | menu.setStyleSheet(""" 148 | QMenu { 149 | background-color: #1a1a1a; 150 | border: 1px solid #2d2d2d; 151 | border-radius: 5px; 152 | color: #ffffff; 153 | } 154 | QMenu::item { 155 | padding: 8px 20px; 156 | } 157 | QMenu::item:selected { 158 | background-color: #2196F3; 159 | } 160 | """) 161 | 162 | expand_action = menu.addAction("Expand") 163 | contract_action = menu.addAction("Contract") 164 | info_action = menu.addAction("Get More Information") 165 | 166 | show_computation_action = None 167 | if isinstance(self.node, ModuleNode): 168 | show_computation_action = menu.addAction("Show Computation") 169 | 170 | if isinstance(self.node, VarNode): 171 | show_computation_action = menu.addAction("Show Variable") 172 | 173 | action = menu.exec(global_pos) 174 | 175 | if action == expand_action: 176 | self.signal_proxy.clicked.emit(self) 177 | elif action == contract_action: 178 | self.signal_proxy.doubleClicked.emit(self) 179 | elif action == info_action: 180 | self.get_more_information() 181 | elif action == show_computation_action: 182 | if not isinstance(self.node, FunctionNode): 183 | self.show_computation() 184 | 185 | def show_computation(self): 186 | if isinstance(self.node, Conv2dNode): 187 | self.show_conv2d_computation() 188 | elif isinstance(self.node, VarNode): 189 | self.show_var() 190 | else: 191 | self.show_default_computation() 192 | 193 | def create_computation_dialog(self, title): 194 | dialog = QDialog() 195 | dialog.setWindowFlag(Qt.FramelessWindowHint) 196 | 197 | # Style 198 | dialog.setStyleSheet(""" 199 | QDialog { 200 | background-color: #121212; 201 | border-radius: 15px; 202 | } 203 | QPushButton { 204 | background-color: #2196F3; 205 | color: white; 206 | border: none; 207 | border-radius: 5px; 208 | padding: 10px 20px; 209 | font-weight: bold; 210 | } 211 | QPushButton:hover { 212 | background-color: #1976D2; 213 | } 214 | """) 215 | 216 | # Header 217 | header = QWidget() 218 | header.setStyleSheet("background-color: #1a1a1a; border-radius: 15px 15px 0 0;") 219 | header_layout = QHBoxLayout(header) 220 | 221 | title_label = QLabel(title) 222 | title_label.setStyleSheet("color: #2196F3; font-size: 18px; font-weight: bold;") 223 | 224 | close_button = QPushButton("×") 225 | close_button.setStyleSheet(""" 226 | QPushButton { 227 | background-color: transparent; 228 | color: #808080; 229 | font-size: 20px; 230 | font-weight: bold; 231 | padding: 5px 10px; 232 | } 233 | QPushButton:hover { 234 | color: #ff4444; 235 | } 236 | """) 237 | close_button.clicked.connect(dialog.close) 238 | 239 | header_layout.addWidget(title_label) 240 | header_layout.addWidget(close_button, alignment=Qt.AlignRight) 241 | 242 | # Main layout 243 | main_layout = QVBoxLayout(dialog) 244 | main_layout.addWidget(header) 245 | 246 | # Content container 247 | content = QWidget() 248 | main_layout.addWidget(content) 249 | 250 | # Shadow effect 251 | shadow = QGraphicsDropShadowEffect() 252 | shadow.setBlurRadius(20) 253 | shadow.setColor(QColor(0, 0, 0, 150)) 254 | shadow.setOffset(0, 0) 255 | dialog.setGraphicsEffect(shadow) 256 | 257 | return dialog, content 258 | 259 | def show_conv2d_computation(self): 260 | dialog, content = self.create_computation_dialog("Convolution Layer Visualization") 261 | # dialog.setMinimumSize(1200, 900) 262 | 263 | layout = QHBoxLayout(content) 264 | layout.setSpacing(15) 265 | layout.setContentsMargins(15, 15, 15, 15) 266 | 267 | # Input card 268 | input_card = ComputationCard("Input Tensor", 269 | f"Shape: {get_tensor_shape_as_string(self.node.input[0])}") 270 | input_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.input[0]) 271 | input_card.slice_layout.addRow(input_slice_layout) 272 | layout.addWidget(input_card) 273 | 274 | # Conv weights card 275 | conv_card = ComputationCard("Convolution Weights", 276 | f"Shape: {get_tensor_shape_as_string(self.node.module.weight)}") 277 | conv_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.module.weight) 278 | conv_card.slice_layout.addRow(conv_slice_layout) 279 | layout.addWidget(conv_card) 280 | 281 | # Output card 282 | output_card = ComputationCard("Output Tensor", 283 | f"Shape: {get_tensor_shape_as_string(self.node.output[0])}") 284 | output_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.output[0]) 285 | output_card.slice_layout.addRow(output_slice_layout) 286 | layout.addWidget(output_card) 287 | 288 | # Update button with container 289 | button_container = QWidget() 290 | button_layout = QHBoxLayout(button_container) 291 | update_button = QPushButton("Update Visualization") 292 | button_layout.addWidget(update_button, alignment=Qt.AlignCenter) 293 | content.layout().addWidget(button_container) 294 | 295 | def update_displays(): 296 | self.display_slices([ 297 | (input_slice_layout, input_card.image_label, self.node.input[0], "gray"), 298 | (conv_slice_layout, conv_card.image_label, self.node.module.weight, "red"), 299 | (output_slice_layout, output_card.image_label, self.node.output[0], "gray") 300 | ]) 301 | 302 | update_button.clicked.connect(update_displays) 303 | update_displays() # Initial display 304 | 305 | dialog.exec_() 306 | 307 | def show_var(self): 308 | dialog, content = self.create_computation_dialog("Variable Visualization") 309 | 310 | layout = QVBoxLayout(content) 311 | 312 | # Variable card 313 | var_card = ComputationCard("Variable Tensor", 314 | f"Shape: {get_tensor_shape_as_string(self.node.variable)}") 315 | var_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.variable) 316 | var_card.slice_layout.addRow(var_slice_layout) 317 | layout.addWidget(var_card) 318 | 319 | # Update button 320 | button_container = QWidget() 321 | button_layout = QHBoxLayout(button_container) 322 | update_button = QPushButton("Update Visualization") 323 | button_layout.addWidget(update_button) 324 | button_layout.setAlignment(Qt.AlignCenter) 325 | 326 | layout.addWidget(button_container) 327 | 328 | def update_displays(): 329 | self.display_slices([ 330 | (var_slice_layout, var_card.image_label, self.node.variable, "gray") 331 | ]) 332 | 333 | update_button.clicked.connect(update_displays) 334 | update_displays() # Initial display 335 | 336 | dialog.exec_() 337 | 338 | def show_default_computation(self): 339 | dialog, content = self.create_computation_dialog("Default Computation Visualization") 340 | 341 | layout = QHBoxLayout(content) 342 | layout.setSpacing(20) 343 | 344 | # Input card 345 | input_card = ComputationCard("Input Tensor", 346 | f"Shape: {get_tensor_shape_as_string(self.node.input[0])}") 347 | input_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.input[0]) 348 | input_card.slice_layout.addRow(input_slice_layout) 349 | layout.addWidget(input_card) 350 | 351 | # Output card 352 | output_card = ComputationCard("Output Tensor", 353 | f"Shape: {get_tensor_shape_as_string(self.node.output[0])}") 354 | output_slice_layout, _ = set_up_slice_layout_from_tensor(self.node.output[0]) 355 | output_card.slice_layout.addRow(output_slice_layout) 356 | layout.addWidget(output_card) 357 | 358 | # Update button 359 | button_container = QWidget() 360 | button_layout = QHBoxLayout(button_container) 361 | update_button = QPushButton("Update Visualization") 362 | button_layout.addWidget(update_button) 363 | button_layout.setAlignment(Qt.AlignCenter) 364 | 365 | content.layout().addWidget(button_container) 366 | 367 | def update_displays(): 368 | self.display_slices([ 369 | (input_slice_layout, input_card.image_label, self.node.input[0], "gray"), 370 | (output_slice_layout, output_card.image_label, self.node.output[0], "gray") 371 | ]) 372 | 373 | update_button.clicked.connect(update_displays) 374 | update_displays() # Initial display 375 | 376 | dialog.exec_() 377 | 378 | def display_slices(self, matrix_to_display): 379 | try: 380 | for slice_layout, image_label, tensor, color in matrix_to_display: 381 | image = get_image_from_slice_layout_and_tensor(slice_layout, tensor.to(float16), color) 382 | 383 | parent_widget = image_label.parent() 384 | 385 | if parent_widget: 386 | available_width = parent_widget.width() - 20 387 | available_height = parent_widget.height() - 20 388 | 389 | qimage = image.toImage() 390 | 391 | scaled_image = qimage.scaled(available_width, available_height, 392 | Qt.KeepAspectRatio, 393 | Qt.SmoothTransformation) 394 | 395 | image_label.setPixmap(QPixmap.fromImage(scaled_image)) 396 | image_label.setAlignment(Qt.AlignCenter) 397 | else: 398 | image_label.setPixmap(image) 399 | image_label.setAlignment(Qt.AlignCenter) 400 | 401 | except Exception as e: 402 | print(f"Error in displaying slices: {e}") 403 | 404 | def get_more_information(self): 405 | info_dialog = QDialog() 406 | info_dialog.setWindowTitle("Node Information") 407 | info_dialog.setWindowFlag(Qt.FramelessWindowHint) 408 | 409 | screen_geometry = info_dialog.screen().geometry() 410 | info_dialog.setGeometry(screen_geometry.width(), 50, 600, 400) 411 | 412 | table = QTableWidget(info_dialog) 413 | node_info = get_node_info(self.node) 414 | table.setRowCount(len(node_info)) 415 | table.setColumnCount(2) 416 | 417 | table.setStyleSheet(""" 418 | QTableWidget { 419 | background-color: #1a1a1a; 420 | border: none; 421 | border-radius: 15px; 422 | color: #ffffff; 423 | gridline-color: #2d2d2d; 424 | selection-background-color: #2196F3; 425 | } 426 | QTableWidget::item { 427 | padding: 12px; 428 | border-bottom: 1px solid #2d2d2d; 429 | } 430 | QTableWidget::item:hover { 431 | background-color: #2d2d2d; 432 | } 433 | QScrollBar:vertical { 434 | border: none; 435 | background: #1a1a1a; 436 | width: 10px; 437 | border-radius: 5px; 438 | } 439 | QScrollBar::handle:vertical { 440 | background: #404040; 441 | border-radius: 5px; 442 | } 443 | QScrollBar::handle:vertical:hover { 444 | background: #4a4a4a; 445 | } 446 | """) 447 | 448 | if isinstance(self.node, ModuleNode): 449 | title_label = QLabel(f"Node Details - {self.node.module.__class__.__name__}") 450 | else: 451 | title_label = QLabel(f"Node Details - {self.node.name}") 452 | 453 | title_label.setStyleSheet(""" 454 | QLabel { 455 | color: #2196F3; 456 | font-size: 18px; 457 | font-weight: bold; 458 | padding: 15px; 459 | background-color: #1a1a1a; 460 | border-radius: 15px 15px 0 0; 461 | } 462 | """) 463 | 464 | close_button = QPushButton("×") 465 | close_button.setStyleSheet(""" 466 | QPushButton { 467 | background-color: transparent; 468 | color: #808080; 469 | font-size: 20px; 470 | font-weight: bold; 471 | border: none; 472 | padding: 5px 10px; 473 | } 474 | QPushButton:hover { 475 | color: #ff4444; 476 | } 477 | """) 478 | close_button.clicked.connect(info_dialog.close) 479 | 480 | header_layout = QHBoxLayout() 481 | header_layout.addWidget(title_label) 482 | header_layout.addWidget(close_button, alignment=Qt.AlignRight) 483 | header_layout.setContentsMargins(0, 0, 0, 0) 484 | 485 | for row, (key, value) in enumerate(node_info.items()): 486 | key_item = QTableWidgetItem(f" {key}") 487 | key_item.setBackground(QColor("#212121")) 488 | key_item.setForeground(QColor("#2196F3")) 489 | key_item.setFont(QFont("Segoe UI", 11, QFont.Bold)) 490 | key_item.setFlags(key_item.flags() & ~Qt.ItemIsEditable) 491 | table.setItem(row, 0, key_item) 492 | 493 | if isinstance(value, dict) and "tensor" in value: 494 | value_item = CustomTableItem(value) 495 | else: 496 | value_item = QTableWidgetItem(str(value)) 497 | 498 | value_item.setBackground(QColor("#1a1a1a")) 499 | value_item.setFont(QFont("Segoe UI", 11)) 500 | value_item.setFlags(value_item.flags() & ~Qt.ItemIsEditable) 501 | table.setItem(row, 1, value_item) 502 | 503 | def on_cell_clicked(item): 504 | if isinstance(item, CustomTableItem) and item.tensor is not None: 505 | show_tensor_visualization(item.tensor, item.visualization_type, info_dialog) 506 | 507 | table.itemClicked.connect(on_cell_clicked) 508 | table.verticalHeader().setVisible(False) 509 | table.horizontalHeader().setVisible(False) 510 | table.setShowGrid(False) 511 | table.setAlternatingRowColors(True) 512 | 513 | animation = QPropertyAnimation(info_dialog, b"geometry") 514 | animation.setDuration(300) 515 | animation.setStartValue(QRect(screen_geometry.width(), 50, 600, 400)) 516 | animation.setEndValue(QRect(screen_geometry.width() - 620, 50, 600, 400)) 517 | animation.setEasingCurve(QEasingCurve.OutCubic) 518 | 519 | main_layout = QVBoxLayout() 520 | main_layout.addLayout(header_layout) 521 | main_layout.addWidget(table) 522 | main_layout.setContentsMargins(10, 10, 10, 10) 523 | main_layout.setSpacing(0) 524 | 525 | info_dialog.setLayout(main_layout) 526 | 527 | shadow = QGraphicsDropShadowEffect() 528 | shadow.setBlurRadius(20) 529 | shadow.setColor(QColor(0, 0, 0, 150)) 530 | shadow.setOffset(0, 0) 531 | info_dialog.setGraphicsEffect(shadow) 532 | 533 | table.resizeColumnsToContents() 534 | table.resizeRowsToContents() 535 | total_width = table.horizontalHeader().length() + 40 536 | total_height = table.verticalHeader().length() + 100 537 | info_dialog.setFixedSize(max(total_width, 400), max(total_height, 200)) 538 | 539 | animation.start() 540 | info_dialog.exec_() 541 | 542 | 543 | def show_tensor_visualization(tensor, visualization_type="tensor", parent=None): 544 | if tensor is None or not hasattr(tensor, 'shape'): 545 | return 546 | 547 | dialog = QDialog(parent) 548 | dialog.setWindowFlag(Qt.FramelessWindowHint) 549 | 550 | main_layout = QVBoxLayout(dialog) 551 | main_layout.setContentsMargins(10, 10, 10, 10) 552 | main_layout.setSpacing(0) 553 | 554 | header = QWidget() 555 | header.setStyleSheet("background-color: #1a1a1a; border-radius: 15px 15px 0 0;") 556 | header_layout = QHBoxLayout(header) 557 | 558 | title = "Kernel Visualization" if visualization_type == "kernel" else "Tensor Visualization" 559 | title_label = QLabel(title) 560 | title_label.setStyleSheet("color: #2196F3; font-size: 18px; font-weight: bold;") 561 | 562 | close_button = QPushButton("×") 563 | close_button.setStyleSheet(""" 564 | QPushButton { 565 | background-color: transparent; 566 | color: #808080; 567 | font-size: 20px; 568 | font-weight: bold; 569 | padding: 5px 10px; 570 | } 571 | QPushButton:hover { 572 | color: #ff4444; 573 | } 574 | """) 575 | close_button.clicked.connect(dialog.close) 576 | 577 | header_layout.addWidget(title_label) 578 | header_layout.addWidget(close_button, alignment=Qt.AlignRight) 579 | 580 | main_layout.addWidget(header) 581 | 582 | # Content 583 | content = QWidget() 584 | content_layout = QVBoxLayout(content) 585 | 586 | # Shape information 587 | tensor_card = ComputationCard(title, f"Shape: ({', '.join(map(str, tensor.shape))})") 588 | tensor_slice_layout, _ = set_up_slice_layout_from_tensor(tensor) 589 | tensor_card.slice_layout.addRow(tensor_slice_layout) 590 | content_layout.addWidget(tensor_card) 591 | 592 | # Update button 593 | update_button = QPushButton("Update Visualization") 594 | content_layout.addWidget(update_button, alignment=Qt.AlignCenter) 595 | 596 | main_layout.addWidget(content) 597 | 598 | def update_display(): 599 | color = "red" if visualization_type == "kernel" else "gray" 600 | try: 601 | image = get_image_from_slice_layout_and_tensor(tensor_slice_layout, tensor.to(float16), color) 602 | 603 | parent_widget = tensor_card.image_label.parent() 604 | if parent_widget: 605 | available_width = parent_widget.width() - 20 606 | available_height = parent_widget.height() - 20 607 | qimage = image.toImage() 608 | scaled_image = qimage.scaled(available_width, available_height, 609 | Qt.KeepAspectRatio, 610 | Qt.SmoothTransformation) 611 | tensor_card.image_label.setPixmap(QPixmap.fromImage(scaled_image)) 612 | else: 613 | tensor_card.image_label.setPixmap(image) 614 | 615 | tensor_card.image_label.setAlignment(Qt.AlignCenter) 616 | 617 | except Exception as e: 618 | print(f"Error in displaying tensor: {e}") 619 | 620 | update_button.clicked.connect(update_display) 621 | update_display() 622 | 623 | shadow = QGraphicsDropShadowEffect() 624 | shadow.setBlurRadius(20) 625 | shadow.setColor(QColor(0, 0, 0, 150)) 626 | shadow.setOffset(0, 0) 627 | dialog.setGraphicsEffect(shadow) 628 | 629 | dialog.exec_() 630 | 631 | class CustomTableItem(QTableWidgetItem): 632 | def __init__(self, info_dict): 633 | super().__init__(info_dict.get("value", "")) 634 | self.tensor = info_dict.get("tensor") 635 | self.visualization_type = info_dict.get("visualization_type", "tensor") 636 | 637 | if self.tensor is not None: 638 | self.setToolTip("Click to visualize tensor") 639 | self.setForeground(QColor("#4FC3F7")) --------------------------------------------------------------------------------