├── AUTHORS ├── src ├── bird.jpg └── lib.typ ├── gallery ├── networks │ ├── FCN-8.png │ ├── U-Net.png │ ├── VGG16.png │ ├── VGG19.png │ ├── AlexNet.png │ ├── LeNet-5.png │ └── ResNet18.png └── features │ ├── customize.png │ ├── ALL-features.png │ ├── FCN-8(cold).png │ ├── basic-layout.png │ ├── connections.png │ ├── dimensions-labels.png │ └── predefined-layers.png ├── examples ├── networks │ ├── mnist-img-sample.jpg │ ├── LeNet-5.typ │ ├── U-Net.typ │ ├── AlexNet.typ │ ├── VGG16.typ │ ├── ResNet18.typ │ ├── VGG19.typ │ └── FCN-8.typ └── features │ ├── basic-layout.typ │ ├── connections.typ │ ├── customize.typ │ ├── dimensions-labels.typ │ ├── predefined-layers.typ │ ├── FCN-8(cold).typ │ └── ALL-features.typ ├── typst.toml ├── LICENSE ├── export-to-gallery.sh ├── codemeta.json └── README.md /AUTHORS: -------------------------------------------------------------------------------- 1 | Edgar Remy -------------------------------------------------------------------------------- /src/bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/src/bird.jpg -------------------------------------------------------------------------------- /gallery/networks/FCN-8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/FCN-8.png -------------------------------------------------------------------------------- /gallery/networks/U-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/U-Net.png -------------------------------------------------------------------------------- /gallery/networks/VGG16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/VGG16.png -------------------------------------------------------------------------------- /gallery/networks/VGG19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/VGG19.png -------------------------------------------------------------------------------- /gallery/networks/AlexNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/AlexNet.png -------------------------------------------------------------------------------- /gallery/networks/LeNet-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/LeNet-5.png -------------------------------------------------------------------------------- /gallery/features/customize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/customize.png -------------------------------------------------------------------------------- /gallery/networks/ResNet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/networks/ResNet18.png -------------------------------------------------------------------------------- /gallery/features/ALL-features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/ALL-features.png -------------------------------------------------------------------------------- /gallery/features/FCN-8(cold).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/FCN-8(cold).png -------------------------------------------------------------------------------- /gallery/features/basic-layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/basic-layout.png -------------------------------------------------------------------------------- /gallery/features/connections.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/connections.png -------------------------------------------------------------------------------- /examples/networks/mnist-img-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/examples/networks/mnist-img-sample.jpg -------------------------------------------------------------------------------- /gallery/features/dimensions-labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/dimensions-labels.png -------------------------------------------------------------------------------- /gallery/features/predefined-layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgaremy/neural-netz/HEAD/gallery/features/predefined-layers.png -------------------------------------------------------------------------------- /typst.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "neural-netz" 3 | version = "0.3.0" 4 | compiler = "0.14.0" 5 | entrypoint = "src/lib.typ" 6 | repository = "https://github.com/edgaremy/neural-netz" 7 | authors = ["Edgar Remy <@edgaremy>"] 8 | license = "MIT-0" 9 | description = "Visualize Neural Network Architectures with high-quality diagrams." 10 | keywords = ["neural", "network", "deep learning", "computer vision", "machine learning"] 11 | categories = ["visualization", "components"] 12 | disciplines = ["computer-science", "engineering"] 13 | exclude = ["/examples/*", "/gallery/*"] -------------------------------------------------------------------------------- /examples/features/basic-layout.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", image: "default"), 7 | (type: "conv", offset: 2), // Next layers are automatically connected with arrows 8 | (type: "conv", offset: 2), 9 | (type: "pool"), // Pool layers are sticked to previous convolution block (by default)) 10 | (type: "conv", widths: (1, 1), offset: 3) // you can offset layers 11 | )) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 4 | software and associated documentation files (the "Software"), to deal in the Software 5 | without restriction, including without limitation the rights to use, copy, modify, 6 | merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 7 | permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 10 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 11 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 12 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 13 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 14 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /examples/features/connections.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", label: "A", name: "a", show-connection: true), 7 | (type: "conv", label: "B", name: "b", offset: 2), 8 | (type: "conv", label: "C", name: "c", offset: 2), 9 | (type: "conv", label: "D", name: "d", offset: 2, show-connection: false), 10 | (type: "conv", label: "E", name: "e", offset: 2), 11 | ), connections: ( 12 | (from: "a", to: "c", type: "skip", mode: "depth", label: "depth mode", pos: 6), 13 | (from: "b", to: "d", type: "skip", mode: "flat", label: "flat mode", pos: 5), 14 | (from: "c", to: "e", type: "skip", mode: "air", label: "air mode (+touch layer instead of arrow)", pos: 5, touch-layer: true), 15 | ), 16 | palette: "cold", // There is a "warm" and a "cold" color palette. 17 | show-relu: true // visualize relu using darker color on convolution layers 18 | ) -------------------------------------------------------------------------------- /examples/features/customize.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | ( 7 | type: "custom", 8 | width: 0.3, height: 5, depth: 5, 9 | label: "custom..", 10 | fill: rgb("#FF6B6B"), 11 | opacity: 0.9, 12 | legend: "Custom Color", 13 | ),( 14 | type: "custom", 15 | width: 0.3, height: 5, depth: 5, 16 | label: "..colors !", 17 | fill: rgb("#FF6B6B"), 18 | opacity: 0.9, 19 | offset: 1.7, 20 | image: [hi] // Add any content (image, text etc.) 21 | ),( 22 | type: "custom", 23 | widths: (0.3, 0.4, 0.3), height: 5, depth: 5, 24 | label: "custom color+bandfill", 25 | fill: rgb("#4ECDC4"), 26 | bandfill: rgb("#FFE66D"), 27 | show-relu: true, 28 | offset: 2, 29 | legend: "Custom Color+Bandfill", 30 | ), 31 | ), 32 | show-legend: true, 33 | legend-title: "My new layers" // You can also change the legend title 34 | ) -------------------------------------------------------------------------------- /examples/features/dimensions-labels.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | ( 7 | type: "convres", // Each layer type has its own color 8 | widths: (1, 2), 9 | channels: (32, 64, 128), // An extra channel will be used as diagonal axis label 10 | height: 6, 11 | depth: 8, 12 | label: "residual convolution", 13 | ),( 14 | type: "pool", 15 | channels: ("", "text also works"), 16 | height: 4, 17 | depth: 6, 18 | connection-label: "connection label", // label of the connection to the NEXT layer 19 | ),( 20 | type: "conv", 21 | widths: (1.5, 1.5), 22 | height: 2, 23 | depth: 3, 24 | label: "whole block label", 25 | legend: "CUSTOM NAME", // you can overwrite the default legend of predefined layers 26 | offset: 4, 27 | ),( 28 | type: "fc", 29 | channels: (10,), 30 | height: 5, 31 | depth: 0, // With no depth, the layer is drawn as a 2D rectangle 32 | label: "2D layer", 33 | offset: 2, 34 | ), 35 | ), 36 | show-legend: true, 37 | ) -------------------------------------------------------------------------------- /export-to-gallery.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to export Typst example files to PNG images 4 | # Usage: ./export-to-gallery.sh 5 | # (you may need to make it executable with chmod +x export-to-gallery.sh) 6 | # (ensure you have Typst installed and accessible in your PATH) 7 | 8 | set -e # Exit on error 9 | 10 | # Configuration 11 | DPI=300 # Adjust DPI as needed (e.g., 150, 300, 600) 12 | OUTPUT_DIR="gallery" 13 | 14 | # Colors for output 15 | GREEN='\033[0;32m' 16 | BLUE='\033[0;34m' 17 | NC='\033[0m' # No Color 18 | 19 | echo -e "${BLUE}Starting Typst export process...${NC}" 20 | echo "DPI: $DPI" 21 | echo "Output directory: $OUTPUT_DIR" 22 | echo "" 23 | 24 | # Create output directories if they don't exist 25 | mkdir -p "$OUTPUT_DIR/features" 26 | mkdir -p "$OUTPUT_DIR/networks" 27 | 28 | # Export features examples 29 | echo -e "${GREEN}Exporting features examples...${NC}" 30 | for file in examples/features/*.typ; do 31 | if [ -f "$file" ]; then 32 | filename=$(basename "$file" .typ) 33 | typst compile "$file" "$OUTPUT_DIR/features/$filename.png" --ppi "$DPI" --root "." 34 | echo "✓ Exported $filename.typ" 35 | fi 36 | done 37 | 38 | # Export network examples 39 | echo -e "${GREEN}Exporting network examples...${NC}" 40 | for file in examples/networks/*.typ; do 41 | if [ -f "$file" ]; then 42 | filename=$(basename "$file" .typ) 43 | typst compile "$file" "$OUTPUT_DIR/networks/$filename.png" --ppi "$DPI" --root "." 44 | echo "✓ Exported $filename.typ" 45 | fi 46 | done 47 | 48 | echo "" 49 | echo -e "${BLUE}Export complete! Images saved to $OUTPUT_DIR/${NC}" 50 | -------------------------------------------------------------------------------- /examples/features/predefined-layers.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #let layers = ( 6 | ( 7 | type: "input", 8 | label: "input" 9 | ),( 10 | type: "conv", 11 | widths: (0.3, 0.3), 12 | label: "conv" 13 | ),( 14 | type: "pool", 15 | label: "pool", 16 | offset: 1 17 | ),( 18 | type: "convres", 19 | widths: (0.3, 0.3), 20 | label: "convres", 21 | offset: 1 22 | ),( 23 | type: "unpool", 24 | label: "unpool", 25 | offset: 1 26 | ),( 27 | type: "deconv", 28 | label: "deconv", 29 | offset: 1 30 | ),( 31 | type: "concat", 32 | label: "concat", 33 | offset: 1.4 34 | ),( 35 | type: "gap", 36 | label: "gap" 37 | ),( 38 | type: "fc", 39 | label: "fc", 40 | offset: 0.7 41 | ),( 42 | type: "convsoftmax", 43 | label: "convsoftmax", 44 | offset: 0.6 45 | ),( 46 | type: "sum", 47 | symbol: "+", 48 | channels: (""), 49 | offset: 0.7 50 | ),( 51 | type: "softmax", 52 | label: "softmax", 53 | offset: 0.6 54 | ),( 55 | type: "output", 56 | label: "output", 57 | offset: 1 58 | ),( 59 | type: "custom", 60 | widths: (0.3, 0.3), 61 | height: 3, 62 | depth: 3, 63 | label: "custom", 64 | legend: "Custom Layer", 65 | offset: 0.6 66 | ), 67 | ) 68 | 69 | #draw-network(layers, 70 | show-relu: true, 71 | show-legend: true, 72 | palette: "warm" 73 | ) 74 | 75 | #draw-network(layers, 76 | show-relu: true, 77 | show-legend: true, 78 | palette: "cold" 79 | ) -------------------------------------------------------------------------------- /codemeta.json: -------------------------------------------------------------------------------- 1 | { 2 | "@context": "https://w3id.org/codemeta/3.0", 3 | "type": "SoftwareSourceCode", 4 | "applicationCategory": "Computer Vision", 5 | "author": [ 6 | { 7 | "id": "https://orcid.org/0009-0006-4838-5183", 8 | "type": "Person", 9 | "affiliation": { 10 | "type": "Organization", 11 | "name": "CNRS, University of Toulouse, France" 12 | }, 13 | "email": "edgar.remy@cnrs.fr", 14 | "familyName": "Remy", 15 | "givenName": "Edgar" 16 | } 17 | ], 18 | "codeRepository": "https://github.com/edgaremy/neural-netz", 19 | "dateCreated": "2025-12-02", 20 | "dateModified": "2025-12-04", 21 | "datePublished": "2025-12-02", 22 | "description": "A Typst package for visualizing Neural Network Architectures with high-quality diagrams.", 23 | "downloadUrl": "https://github.com/edgaremy/neural-netz/releases/download/v0.3.0/neural-netz.0.3.0.zip", 24 | "keywords": [ 25 | "deep learning", 26 | "visualization", 27 | "neural networks", 28 | "typst" 29 | ], 30 | "license": "https://spdx.org/licenses/MIT-0", 31 | "name": "neural-netz, a Typst Package", 32 | "programmingLanguage": "Typst", 33 | "relatedLink": "https://typst.app/universe/package/neural-netz", 34 | "releaseNotes": "- Added new generic custom layer type.\n- Improved smart legend generation\n- Fixed use of custom images in input layer and made it more robust to various image widths.\n- More detailed documentation in the README.\n- Minor fixes.\n", 35 | "version": "0.3.0", 36 | "issueTracker": "https://github.com/edgaremy/neural-netz/issues" 37 | } -------------------------------------------------------------------------------- /examples/networks/LeNet-5.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | 6 | #let layers = ( 7 | ( 8 | type: "input", 9 | name: "I", 10 | image: image("mnist-img-sample.jpg", scaling: "pixelated"), 11 | height: 5, 12 | depth: 5, 13 | label: "input", 14 | channels: (1, 32), 15 | show-connection: true, 16 | ), 17 | ( 18 | type: "conv", 19 | name: "C1", 20 | widths: (0.5,), 21 | height: 4.8, 22 | depth: 4.8, 23 | label: "f = 5", 24 | channels: (6, 28), 25 | offset: 1.9, 26 | legend: "Convolution+ReLU", 27 | ), 28 | ( 29 | type: "pool", 30 | height: 2.4, 31 | depth: 2.4, 32 | ), 33 | ( 34 | type: "conv", 35 | widths: (1.2,), 36 | height: 2, 37 | depth: 2, 38 | label: "f = 5", 39 | channels: (16, 10), 40 | ), 41 | ( 42 | type: "pool", 43 | height: 1, 44 | depth: 1, 45 | ), 46 | ( 47 | type: "fc", 48 | label: "", 49 | channels: (120,), 50 | height: 4, 51 | depth: 0.3, 52 | offset: 0.8, 53 | legend: "Fully Connected+ReLU", 54 | ), 55 | ( 56 | type: "fc", 57 | label: "", 58 | channels: (84,), 59 | height: 3, 60 | depth: 0.3, 61 | offset: 0.5, 62 | ), 63 | ( 64 | type: "fc", 65 | label: "", 66 | channels: (10,), 67 | height: 2, 68 | depth: 0.3, 69 | offset: 0.5, 70 | ), 71 | ( 72 | type: "softmax", 73 | label: "softmax", 74 | height: 2, 75 | depth: 0.3, 76 | offset: 0.9, 77 | ), 78 | ) 79 | 80 | #draw-network( 81 | layers, 82 | show-legend: true, 83 | show-relu: true, 84 | ) -------------------------------------------------------------------------------- /examples/networks/U-Net.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", image: "default", channels: ("1", "128"), widths: (0.2,), height: 8, depth: 8, name: "input"), 7 | 8 | (type: "conv", channels: ("16", "128"), widths: (0.4,), height: 8, depth: 8, name: "down1", offset: 1.9), 9 | (type: "pool", height: 6.5, depth: 6.5, name: "pool1"), 10 | 11 | (type: "conv", channels: ("32", "64"), widths: (0.5,), height: 6.5, depth: 6.5, name: "down2"), 12 | (type: "pool", height: 5, depth: 5, name: "pool2"), 13 | 14 | (type: "conv", channels: ("64", "32"), widths: (0.8,), height: 5, depth: 5, name: "down3"), 15 | (type: "pool", height: 3.5, depth: 3.5, name: "pool3"), 16 | 17 | (type: "conv", channels: ("128", "16"), widths: (1.6,), height: 3.5, depth: 3.5, name: "down4"), 18 | (type: "pool", height: 2.5, depth: 2.5, name: "pool4"), 19 | 20 | (type: "conv", channels: ("256", "8"), widths: (3.2,), height: 2.5, depth: 2.5, name: "middle"), 21 | 22 | (type: "conv", channels: ("128", "64"), widths: (1.6,), height: 3.5, depth: 3.5, name: "up1", offset: 1.5), 23 | 24 | (type: "conv", channels: ("64", "32"), widths: (0.8,), height: 5, depth: 5, name: "up2", offset: 1.5), 25 | 26 | (type: "conv", channels: ("32", "64"), widths: (0.5,), height: 6.5, depth: 6.5, name: "up3", offset: 1.5), 27 | 28 | (type: "conv", channels: ("16", "128"), widths: (0.4,), height: 8, depth: 8, name: "up4", offset: 1.5), 29 | 30 | (type: "conv", channels: ("3", "128"), widths: (0.2,), height: 8, depth: 8, name: "output"), 31 | ), connections: ( 32 | // Decoder skip connections (matching test_simple.py architecture) 33 | (from: "down4", to: "up1", type: "skip", mode: "air", pos: 2.5, touch-layer: true), 34 | (from: "down3", to: "up2", type: "skip", mode: "air", pos: 3.4, touch-layer: true), 35 | (from: "down2", to: "up3", type: "skip", mode: "air", pos: 4.1, touch-layer: true), 36 | (from: "down1", to: "up4", type: "skip", mode: "air", pos: 4.8, touch-layer: true), 37 | )) -------------------------------------------------------------------------------- /examples/networks/AlexNet.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #let layers = ( 6 | ( 7 | type: "input", 8 | image: "default", 9 | height: 8, 10 | depth: 8, 11 | label: "input", 12 | channels: (3, 224), 13 | show-connections: true, 14 | ), 15 | ( 16 | type: "conv", 17 | widths: (0.5,), 18 | height: 6, 19 | depth: 6, 20 | label: "\n f = 11\nstride = 4", 21 | channels: (96, 55), 22 | offset: 2, 23 | legend: "Convolution+ReLU" 24 | ), 25 | ( 26 | type: "pool", 27 | height: 4, 28 | depth: 4, 29 | ), 30 | ( 31 | type: "conv", 32 | widths: (1.2,), 33 | height: 4, 34 | depth: 4, 35 | label: "f = 5", 36 | channels: (256, 27), 37 | ), 38 | ( 39 | type: "pool", 40 | height: 2, 41 | depth: 2, 42 | ), 43 | ( 44 | type: "conv", 45 | widths: (1.6,), 46 | height: 2, 47 | depth: 2, 48 | label: "f = 3", 49 | channels: (384, 13), 50 | ), 51 | ( 52 | type: "conv", 53 | widths: (1.6,), 54 | height: 2, 55 | depth: 2, 56 | label: "f = 3", 57 | channels: (384, 13), 58 | ), 59 | ( 60 | type: "conv", 61 | widths: (1.2,), 62 | height: 2, 63 | depth: 2, 64 | label: "f = 3", 65 | channels: (256, 13), 66 | ), 67 | ( 68 | type: "fc", 69 | label: "", 70 | channels: (4096,), 71 | height: 5, 72 | depth: 0.3, 73 | offset: 0.8, 74 | legend: "Fully Connected+ReLU" 75 | ), 76 | ( 77 | type: "fc", 78 | label: "", 79 | channels: (4096,), 80 | height: 5, 81 | depth: 0.3, 82 | offset: 0.5, 83 | ), 84 | ( 85 | type: "fc", 86 | label: "", 87 | channels: (1000,), 88 | height: 4, 89 | depth: 0.3, 90 | offset: 0.5, 91 | ), 92 | ( 93 | type: "softmax", 94 | label: "softmax", 95 | height: 4, 96 | depth: 0.3, 97 | offset: 0.9, 98 | ), 99 | ) 100 | 101 | #draw-network( 102 | layers, 103 | show-legend: true, 104 | show-relu: true, 105 | ) -------------------------------------------------------------------------------- /examples/networks/VGG16.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #let layers = ( 6 | ( 7 | type: "input", 8 | image: "default", 9 | height: 8, 10 | depth: 8, 11 | label: "input", 12 | channels: (3, 224), 13 | ), 14 | ( 15 | type: "conv", 16 | widths: (0.3, 0.3), 17 | height: 8, 18 | depth: 8, 19 | label: "conv1", 20 | channels: (64, 64, 224), 21 | offset: 1.9, 22 | ), 23 | ( 24 | type: "pool", 25 | height: 6, 26 | depth: 6, 27 | ), 28 | ( 29 | type: "conv", 30 | widths: (0.4,0.4,), 31 | height: 6, 32 | depth: 6, 33 | label: "conv2", 34 | channels: (128, 128, 112), 35 | ), 36 | ( 37 | type: "pool", 38 | height: 4, 39 | depth: 4, 40 | ), 41 | ( 42 | type: "conv", 43 | widths: (0.5, 0.5), 44 | height: 4, 45 | depth: 4, 46 | label: "conv3", 47 | channels: (256, 256, 56), 48 | ), 49 | ( 50 | type: "pool", 51 | height: 2, 52 | depth: 2, 53 | ), 54 | ( 55 | type: "conv", 56 | widths: (0.6, 0.6, 0.6), 57 | height: 2, 58 | depth: 2, 59 | label: "conv4", 60 | channels: (512, 512, 512, 28), 61 | offset: 1, 62 | ), 63 | ( 64 | type: "pool", 65 | height: 1, 66 | depth: 1, 67 | ), 68 | ( 69 | type: "conv", 70 | widths: (0.6, 0.8, 0.8), 71 | height: 1, 72 | depth: 1, 73 | label: "conv5", 74 | channels: (512, 512, 512, 14), 75 | offset: 0.8, 76 | ), 77 | ( 78 | type: "pool", 79 | height: 0.5, 80 | depth: 0.5, 81 | ), 82 | ( 83 | type: "fc", 84 | label: "fc", 85 | channels: (4096,), 86 | height: 5, 87 | depth: 0.3, 88 | offset: 0.8, 89 | ), 90 | ( 91 | type: "fc", 92 | label: "fc", 93 | channels: (4096,), 94 | height: 5, 95 | depth: 0.3, 96 | offset: 0.5, 97 | ), 98 | ( 99 | type: "fc", 100 | label: "fc", 101 | channels: (1000,), 102 | height: 4, 103 | depth: 0.3, 104 | offset: 0.5, 105 | ), 106 | ( 107 | type: "softmax", 108 | label: "softmax", 109 | height: 4, 110 | depth: 0.3, 111 | offset: 0.9, 112 | ), 113 | ) 114 | 115 | #draw-network( 116 | layers, 117 | show-relu: true, 118 | ) -------------------------------------------------------------------------------- /examples/networks/ResNet18.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #let layers = ( 6 | ( 7 | type: "input", 8 | image: "default", 9 | height: 8, 10 | depth: 8, 11 | label: "input", 12 | channels: (3, 224), 13 | ), 14 | ( 15 | type: "conv", 16 | widths: (0.3, 0.3), 17 | height: 8, 18 | depth: 8, 19 | label: "conv1", 20 | channels: (64, 64, 224), 21 | offset: 1.9, 22 | ), 23 | ( 24 | type: "pool", 25 | height: 6, 26 | depth: 6, 27 | ), 28 | ( 29 | type: "convres", 30 | widths: (0.4,0.4,), 31 | height: 6, 32 | depth: 6, 33 | label: "res2", 34 | channels: (128, 128, 112), 35 | ), 36 | ( 37 | type: "pool", 38 | height: 4, 39 | depth: 4, 40 | ), 41 | ( 42 | type: "convres", 43 | widths: (0.5, 0.5, 0.5), 44 | height: 4, 45 | depth: 4, 46 | label: "res3", 47 | channels: (256, 256, 256, 56), 48 | ), 49 | ( 50 | type: "pool", 51 | height: 2, 52 | depth: 2, 53 | ), 54 | ( 55 | type: "convres", 56 | widths: (0.6, 0.6, 0.6), 57 | height: 2, 58 | depth: 2, 59 | label: "res4", 60 | channels: (512, 512, 512, 28), 61 | offset: 1, 62 | ), 63 | ( 64 | type: "pool", 65 | height: 1, 66 | depth: 1, 67 | ), 68 | ( 69 | type: "convres", 70 | widths: (0.6, 0.8, 0.8), 71 | height: 1, 72 | depth: 1, 73 | label: "res5", 74 | channels: (512, 512, 512, 14), 75 | offset: 0.8, 76 | ), 77 | ( 78 | type: "pool", 79 | height: 0.5, 80 | depth: 0.5, 81 | ), 82 | ( 83 | type: "fc", 84 | label: "fc", 85 | channels: (4096,), 86 | height: 5, 87 | depth: 0.3, 88 | offset: 0.8, 89 | ), 90 | ( 91 | type: "fc", 92 | label: "fc", 93 | channels: (4096,), 94 | height: 5, 95 | depth: 0.3, 96 | offset: 0.5, 97 | ), 98 | ( 99 | type: "fc", 100 | label: "fc", 101 | channels: (1000,), 102 | height: 4, 103 | depth: 0.3, 104 | offset: 0.5, 105 | ), 106 | ( 107 | type: "softmax", 108 | label: "softmax", 109 | height: 4, 110 | depth: 0.3, 111 | offset: 0.9, 112 | ), 113 | ) 114 | 115 | #draw-network( 116 | layers, 117 | show-relu: true, 118 | ) -------------------------------------------------------------------------------- /examples/networks/VGG19.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #let layers = ( 6 | ( 7 | type: "input", 8 | image: "default", 9 | height: 8, 10 | depth: 8, 11 | label: "input", 12 | channels: (3, 224), 13 | ), 14 | ( 15 | type: "conv", 16 | widths: (0.3, 0.3), 17 | height: 8, 18 | depth: 8, 19 | label: "conv1", 20 | channels: (64, 64, 224), 21 | offset: 1.9, 22 | ), 23 | ( 24 | type: "pool", 25 | height: 6, 26 | depth: 6, 27 | ), 28 | ( 29 | type: "conv", 30 | widths: (0.4,0.4,), 31 | height: 6, 32 | depth: 6, 33 | label: "conv2", 34 | channels: (128, 128, 112), 35 | ), 36 | ( 37 | type: "pool", 38 | height: 4, 39 | depth: 4, 40 | ), 41 | ( 42 | type: "conv", 43 | widths: (0.5, 0.5), 44 | height: 4, 45 | depth: 4, 46 | label: "conv3", 47 | channels: (256, 256, 56), 48 | ), 49 | ( 50 | type: "pool", 51 | height: 2, 52 | depth: 2, 53 | ), 54 | ( 55 | type: "conv", 56 | widths: (0.6, 0.6, 0.6, 0.6), 57 | height: 2, 58 | depth: 2, 59 | label: "conv4", 60 | channels: (512, 512, 512, 512, 28), 61 | offset: 1, 62 | ), 63 | ( 64 | type: "pool", 65 | height: 1, 66 | depth: 1, 67 | ), 68 | ( 69 | type: "conv", 70 | widths: (0.6, 0.8, 0.8, 0.8), 71 | height: 1, 72 | depth: 1, 73 | label: "conv5", 74 | channels: (512, 512, 512, 512, 14), 75 | offset: 0.8, 76 | ), 77 | ( 78 | type: "pool", 79 | height: 0.5, 80 | depth: 0.5, 81 | ), 82 | ( 83 | type: "fc", 84 | label: "fc", 85 | channels: (4096,), 86 | height: 5, 87 | depth: 0.3, 88 | offset: 0.8, 89 | ), 90 | ( 91 | type: "fc", 92 | label: "fc", 93 | channels: (4096,), 94 | height: 5, 95 | depth: 0.3, 96 | offset: 0.5, 97 | ), 98 | ( 99 | type: "fc", 100 | label: "fc", 101 | channels: (1000,), 102 | height: 4, 103 | depth: 0.3, 104 | offset: 0.5, 105 | ), 106 | ( 107 | type: "softmax", 108 | label: "softmax", 109 | height: 4, 110 | depth: 0.3, 111 | offset: 0.9, 112 | ), 113 | ) 114 | 115 | #draw-network( 116 | layers, 117 | show-relu: true, 118 | ) -------------------------------------------------------------------------------- /examples/networks/FCN-8.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", image: "default", height: 8, depth: 8, label: "Input", name: "img"), 7 | (type: "conv", channels: ("64", "64", "I"), widths: (0.5, 0.5), height: 8, depth: 8, label: "Conv1", name: "c1", offset: 1.9), 8 | (type: "pool", height: 6.5, depth: 6.5, name: "p1"), 9 | (type: "conv", channels: ("128", "128", "I/2"), widths: (0.6, 0.6), height: 6.5, depth: 6.5, label: "Conv2", name: "c2"), 10 | (type: "pool", height: 5, depth: 5, name: "p2"), 11 | (type: "conv", channels: ("256", "256", "256", "I/4"), widths: (0.7, 0.7, 0.7), height: 5, depth: 5, label: "Conv3", name: "c3"), 12 | (type: "pool", height: 3.5, depth: 3.5, name: "p3"), 13 | (type: "conv", channels: ("512", "512", "512", "I/8"), widths: (0.8, 0.8, 0.8), height: 3.5, depth: 3.5, label: "Conv4", name: "c4"), 14 | (type: "pool", height: 2.5, depth: 2.5, name: "p4"), 15 | (type: "conv", channels: ("512", "512", "512", "I/16"), widths: (0.8, 0.8, 0.8), height: 2.5, depth: 2.5, label: "Conv5", name: "c5"), 16 | (type: "pool", height: 1.5, depth: 1.5, name: "p5"), 17 | (type: "conv", channels: ("4096", "4096"), widths: (1.5, 1.5), height: 1.5, depth: 1.5, label: "fc to conv", name: "fc"), 18 | 19 | // Upsampling path 20 | (type: "conv", channels: ("K", "I/32"), widths: (0.3,), height: 2.5, depth: 2.5, label: "fc8 to conv", name: "s32", offset: 0.8, show-relu: false), 21 | (type: "deconv", channels: ("K", "I/16"), height: 3.5, depth: 3.5, name: "up1", offset: 1), 22 | (type: "sum", radius: 0.5, symbol: "+", name: "add1", offset: 1), 23 | (type: "deconv", height: 5, depth: 5, channels: ("K", "I/8"), name: "up2", offset: 0.5), 24 | (type: "sum", radius: 0.5, symbol: "+", name: "add2", offset: 1), 25 | (type: "deconv", height: 8, depth: 8, channels: ("K",), label: "Deconv", name: "up3", offset: 0.5), 26 | (type: "convsoftmax", height: 8, depth: 8, channels: ("K", "I"), label: "softmax", offset: 1), 27 | ), connections: ( 28 | (from: "p4", to: "add1", type: "skip", mode: "flat", pos: 3, 29 | layers: ( 30 | (type: "conv", channels: ("K", "I/16"), widths: (0.3,), height: 2, depth: 3.5, name: "s16", show-relu: false), 31 | ) 32 | ), 33 | (from: "p3", to: "add2", type: "skip", mode: "flat", pos: 6, 34 | layers: ( 35 | (type: "conv", channels: ("K", "I/8"), widths: (0.3,), height: 2, depth: 3.5, name: "s16", show-relu: false), 36 | )) 37 | ), 38 | show-relu: true, 39 | ) -------------------------------------------------------------------------------- /examples/features/FCN-8(cold).typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", image: "default", height: 8, depth: 8, label: "Input", name: "img"), 7 | (type: "conv", channels: ("64", "64", "I"), widths: (0.5, 0.5), height: 8, depth: 8, label: "Conv1", name: "c1", offset: 1.9), 8 | (type: "pool", height: 6.5, depth: 6.5, name: "p1"), 9 | (type: "conv", channels: ("128", "128", "I/2"), widths: (0.6, 0.6), height: 6.5, depth: 6.5, label: "Conv2", name: "c2"), 10 | (type: "pool", height: 5, depth: 5, name: "p2"), 11 | (type: "conv", channels: ("256", "256", "256", "I/4"), widths: (0.7, 0.7, 0.7), height: 5, depth: 5, label: "Conv3", name: "c3"), 12 | (type: "pool", height: 3.5, depth: 3.5, name: "p3"), 13 | (type: "conv", channels: ("512", "512", "512", "I/8"), widths: (0.8, 0.8, 0.8), height: 3.5, depth: 3.5, label: "Conv4", name: "c4"), 14 | (type: "pool", height: 2.5, depth: 2.5, name: "p4"), 15 | (type: "conv", channels: ("512", "512", "512", "I/16"), widths: (0.8, 0.8, 0.8), height: 2.5, depth: 2.5, label: "Conv5", name: "c5"), 16 | (type: "pool", height: 1.5, depth: 1.5, name: "p5"), 17 | (type: "conv", channels: ("4096", "4096"), widths: (1.5, 1.5), height: 1.5, depth: 1.5, label: "fc to conv", name: "fc"), 18 | 19 | // Upsampling path 20 | (type: "conv", channels: ("K", "I/32"), widths: (0.3,), height: 2.5, depth: 2.5, label: "fc8 to conv", name: "s32", offset: 0.8, show-relu: false), 21 | (type: "deconv", channels: ("K", "I/16"), height: 3.5, depth: 3.5, name: "up1", offset: 1), 22 | (type: "sum", radius: 0.5, symbol: "+", name: "add1", offset: 1), 23 | (type: "deconv", height: 5, depth: 5, channels: ("K", "I/8"), name: "up2", offset: 0.5), 24 | (type: "sum", radius: 0.5, symbol: "+", name: "add2", offset: 1), 25 | (type: "deconv", height: 8, depth: 8, channels: ("K",), label: "Deconv", name: "up3", offset: 0.5), 26 | (type: "convsoftmax", height: 8, depth: 8, channels: ("K", "I"), label: "softmax", offset: 1), 27 | ), connections: ( 28 | (from: "p4", to: "add1", type: "skip", mode: "flat", pos: 3, 29 | layers: ( 30 | (type: "conv", channels: ("K", "I/16"), widths: (0.3,), height: 2, depth: 3.5, name: "s16", show-relu: false), 31 | ) 32 | ), 33 | (from: "p3", to: "add2", type: "skip", mode: "flat", pos: 6, 34 | layers: ( 35 | (type: "conv", channels: ("K", "I/8"), widths: (0.3,), height: 2, depth: 3.5, name: "s16", show-relu: false), 36 | )) 37 | ), 38 | show-relu: true, 39 | palette: "cold", 40 | ) -------------------------------------------------------------------------------- /examples/features/ALL-features.typ: -------------------------------------------------------------------------------- 1 | #import "../../src/lib.typ": draw-network // FOR YOUR OWN FILES, IMPORT THE FUNCTION FROM THE NEURAL-NETZ PACKAGE INSTEAD 2 | 3 | #set page(width: auto, height: auto, margin: 5mm) 4 | 5 | #draw-network(( 6 | (type: "input", image: "default"), 7 | (type: "conv", offset: 2), // Next layers are automatically connected with arrows 8 | (type: "conv", offset: 2), 9 | (type: "pool"), // Pool layers are sticked to previous convolution block (by default)) 10 | (type: "conv", widths: (1, 1), offset: 3) // you can offset layers 11 | )) 12 | 13 | #draw-network(( 14 | ( 15 | type: "convres", // Each layer type has its own color 16 | widths: (1, 2), 17 | channels: (32, 64, 128), // An extra channel will be used as diagonal axis label 18 | height: 6, 19 | depth: 8, 20 | label: "residual convolution", 21 | ),( 22 | type: "pool", 23 | channels: ("", "text also works"), 24 | height: 4, 25 | depth: 6, 26 | connection-label: "connection label", // label of the connection to the NEXT layer 27 | ),( 28 | type: "conv", 29 | widths: (1.5, 1.5), 30 | height: 2, 31 | depth: 3, 32 | label: "whole block label", 33 | legend: "CUSTOM NAME", // you can overwrite the default legend of predefined layers 34 | offset: 4, 35 | ),( 36 | type: "fc", 37 | channels: (10,), 38 | height: 5, 39 | depth: 0, // With no depth, the layer is drawn as a 2D rectangle 40 | label: "2D layer", 41 | offset: 2, 42 | ), 43 | ), 44 | show-legend: true, 45 | ) 46 | 47 | #draw-network(( 48 | (type: "input", label: "A", name: "a", show-connection: true), 49 | (type: "conv", label: "B", name: "b", offset: 2), 50 | (type: "conv", label: "C", name: "c", offset: 2), 51 | (type: "conv", label: "D", name: "d", offset: 2, show-connection: false), 52 | (type: "conv", label: "E", name: "e", offset: 2), 53 | ), connections: ( 54 | (from: "a", to: "c", type: "skip", mode: "depth", label: "depth mode", pos: 6), 55 | (from: "b", to: "d", type: "skip", mode: "flat", label: "flat mode", pos: 5), 56 | (from: "c", to: "e", type: "skip", mode: "air", label: "air mode (+touch layer instead of arrow)", pos: 5, touch-layer: true), 57 | ), 58 | palette: "cold", // There is a "warm" and a "cold" color palette. 59 | show-relu: true // visualize relu using darker color on convolution layers 60 | ) 61 | 62 | #draw-network(( 63 | ( 64 | type: "custom", 65 | width: 0.3, height: 5, depth: 5, 66 | label: "custom..", 67 | fill: rgb("#FF6B6B"), 68 | opacity: 0.9, 69 | legend: "Custom Color", 70 | ),( 71 | type: "custom", 72 | width: 0.3, height: 5, depth: 5, 73 | label: "..colors !", 74 | fill: rgb("#FF6B6B"), 75 | opacity: 0.9, 76 | offset: 1.7, 77 | image: [hi] // Add any content (image, text etc.) 78 | ),( 79 | type: "custom", 80 | widths: (0.3, 0.4, 0.3), height: 5, depth: 5, 81 | label: "custom color+bandfill", 82 | fill: rgb("#4ECDC4"), 83 | bandfill: rgb("#FFE66D"), 84 | show-relu: true, 85 | offset: 2, 86 | legend: "Custom Color+Bandfill", 87 | ), 88 | ), 89 | show-legend: true, 90 | legend-title: "My new layers" // You can also change the legend title 91 | ) 92 | 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # neural-netz 2 | 3 | Visualize Neural Network Architectures in high-quality diagrams using [Typst](https://typst.app), with style and API inspired by [PlotNeuralNet](https://github.com/HarisIqbal88/PlotNeuralNet). 4 | 5 | 6 | [![Static Badge](https://img.shields.io/badge/HAL-05401124-%23fcac8f?style=flat-square&logo=HAL&logoColor=%23fc6d3a&labelColor=%23171768)](https://hal.science/hal-05401124)  7 | [![GitHub Release](https://img.shields.io/github/v/release/edgaremy/neural-netz?style=flat-square&labelColor=%23aa2589&color=%23e2a6ed)](https://github.com/edgaremy/neural-netz/releases)  8 | ![GitHub License](https://img.shields.io/github/license/edgaremy/neural-netz?style=flat-square&labelColor=%2326ad84&color=%2396e7c8)  9 | ![GitHub Repo stars](https://img.shields.io/github/stars/edgaremy/neural-netz?style=flat-square&labelColor=%23e8963a&color=%23ffe0a1) 10 | 11 | 12 |

13 | Example of Neural Net visualizaiton with cold color palette 14 | Example of Neural Net visualizaiton with warm color palette 15 |

16 | 17 | Under the hood, this package only uses the native Typst package [CeTZ](https://typst.app/universe/package/cetz/) for building the diagrams. 18 | 19 | ## Usage 20 | 21 | Simply import the all-in-one drawing function from the neural-netz package: 22 | ```typ 23 | #import "@preview/neural-netz:0.3.0": draw-network 24 | ``` 25 | You can then call `draw-network` which has the following arguments: 26 | ```typ 27 | #draw-network( 28 | layers, 29 | connections: (), 30 | palette: "warm", 31 | show-legend: false, 32 | legend-title: "Layers", 33 | scale: 100%, 34 | stroke-thickness: 1, 35 | depth-multiplier: 0.3, 36 | show-relu: false, 37 | ) 38 | ``` 39 | See the examples in the following section to understand how to use it. Alternatively, you can also start from already written architecture examples (see the Examples section, near the end). 40 | 41 | ## Getting started 42 | 43 | Here are a few simple features for getting started. 44 | 45 | ### Basic layout 46 | 47 | ```typ 48 | #draw-network(( 49 | (type: "input", image: "default"), 50 | (type: "conv", offset: 2), // Next layers are automatically connected with arrows 51 | (type: "conv", offset: 2), 52 | (type: "pool"), // Pool layers are sticked to previous convolution block (by default)) 53 | (type: "conv", widths: (1, 1), offset: 3) // you can offset layers 54 | )) 55 | ``` 56 |

57 | Basic layout example 58 |

59 | 60 | For the input type layer, you can also specify a custom image by giving `image: image("path/to/your/image.jpg")`. Additionally not giving any image is equivalent to giving `image: none`. 61 | 62 | ### Dimensions and labels 63 | 64 | 65 | ```typ 66 | #draw-network(( 67 | ( 68 | type: "convres", // Each layer type has its own color 69 | widths: (1, 2), 70 | channels: (32, 64, 128), // An extra channel will be used as diagonal axis label 71 | height: 6, 72 | depth: 8, 73 | label: "residual convolution", 74 | ),( 75 | type: "pool", 76 | channels: ("", "text also works"), 77 | height: 4, 78 | depth: 6, 79 | connection-label: "connection label", // label of the connection to the NEXT layer 80 | ),( 81 | type: "conv", 82 | widths: (1.5, 1.5), 83 | height: 2, 84 | depth: 3, 85 | label: "whole block label", 86 | legend: "CUSTOM NAME", // you can overwrite the default legend of predefined layers 87 | offset: 4, 88 | ),( 89 | type: "fc", 90 | channels: (10,), 91 | height: 5, 92 | depth: 0, // With no depth, the layer is drawn as a 2D rectangle 93 | label: "2D layer", 94 | offset: 2, 95 | ), 96 | ), 97 | show-legend: true, 98 | ) 99 | ``` 100 |

101 | Dimensions and labels example 102 |

103 | 104 | Using `show-legend: true` you can add a smart legend to your visual ! 105 | 106 | And if you network does not fit the page width of your Typst document, **you can reduce the scale by giving `scale: 50%` as argument of `draw-network`** (adjust the scale value to your need). 107 | 108 | 109 | ### Adding other connections 110 | 111 | The main axis connections are drawn automatically, except for the input layer. You can overwrite that by using the boolean `show-connection` to tell if the connection **after** a layer should be drawn or not. You can also draw extra connections using the `connections` argument of `draw-network`. In order to make reference to a layer, it will need a `name`: 112 | 113 | ```typ 114 | #draw-network(( 115 | (type: "input", label: "A", name: "a", show-connection: true), 116 | (type: "conv", label: "B", name: "b", offset: 2), 117 | (type: "conv", label: "C", name: "c", offset: 2), 118 | (type: "conv", label: "D", name: "d", offset: 2, show-connection: false), 119 | (type: "conv", label: "E", name: "e", offset: 2), 120 | ), connections: ( 121 | (from: "a", to: "c", type: "skip", mode: "depth", label: "depth mode", pos: 6), 122 | (from: "b", to: "d", type: "skip", mode: "flat", label: "flat mode", pos: 5), 123 | (from: "c", to: "e", type: "skip", mode: "air", label: "air mode (+touch layer instead of arrow)", pos: 5, touch-layer: true), 124 | ), 125 | palette: "cold", // There is a "warm" and a "cold" color palette. 126 | show-relu: true // visualize relu using darker color on convolution layers 127 | ) 128 | ``` 129 |

130 | Adding connections example 131 |

132 | 133 | ### Predefined layer types 134 | 135 | Here is a visualization of all the predefined layer types, in both color palettes available (`"warm"` (default) and `"cold"`). You can find their associated name underneath each layer. Of course, this is just a starting point, you can modify most of their default attributes. 136 |

137 | Predefined layers example 138 |

139 |

code for this image

140 | 141 | ### Custom layers 142 | 143 | If you prefer to create you own type of layers, use `type: "custom"` as a starting point. It is a generic layer, that is easily customizable. It can have one or multiple channels, with an optional "bandfill" color for symbolizing activation functions (e.g. ReLU). Note that the visiblity of activations can be set with the boolean `show-relu` at the `draw-network` scale, and can be overwritten on a per-layer basis. 144 | 145 | A custom layer can also be added to the smart legend, when specifying a `legend` label (no need to specify the legend everytime for the same-colored custom layers). 146 | 147 | ```typ 148 | #draw-network(( 149 | ( 150 | type: "custom", 151 | width: 0.3, height: 5, depth: 5, 152 | label: "custom..", 153 | fill: rgb("#FF6B6B"), 154 | opacity: 0.9, 155 | legend: "Custom Color", 156 | ),( 157 | type: "custom", 158 | width: 0.3, height: 5, depth: 5, 159 | label: "..colors !", 160 | fill: rgb("#FF6B6B"), 161 | opacity: 0.9, 162 | offset: 1.7, 163 | image: [hi] // Add any content (image, text etc.) 164 | ),( 165 | type: "custom", 166 | widths: (0.3, 0.4, 0.3), height: 5, depth: 5, 167 | label: "custom color+bandfill", 168 | fill: rgb("#4ECDC4"), 169 | bandfill: rgb("#FFE66D"), 170 | show-relu: true, 171 | offset: 2, 172 | legend: "Custom Color+Bandfill", 173 | ), 174 | ), 175 | show-legend: true, 176 | legend-title: "My new layers" // You can also change the legend title 177 | ) 178 | ``` 179 |

180 | Custom layer example 181 |

182 | 183 | 184 | ## Examples 185 | Here are a few network architectures implemented with neural-netz (more examples can be found [in the repo](https://github.com/edgaremy/neural-netz/tree/db550ba2eda99ffbcbb01c1e0374ea6519e16a74/examples/networks)). 186 | 187 |

ResNet18

188 |

189 | ResNet18 visualization 190 |

191 |

code for this image

192 | 193 |

U-Net

194 |

195 | U-Net visualization 196 |

197 |

code for this image

198 | 199 |

FCN-8

200 |

201 | FCN-8 visualization 202 |

203 |

code for this image

204 | 205 | ## Cite this work 206 | If you use the neural-netz package for a scientific publication, you can [cite its initial publication on HAL](https://hal.science/hal-05401124), indicating current version as follows: 207 | #### APA 208 | ``` 209 | Remy, E. (2025). neural-netz, a Typst Package (Version 0.3.0) [Computer software]. https://hal.science/hal-05401124 210 | ``` 211 | #### BibTeX 212 | 213 | ```bib 214 | @softwareversion{remy:hal-05401124v1, 215 | TITLE = {{neural-netz, a Typst Package}}, 216 | AUTHOR = {Remy, Edgar}, 217 | URL = {https://hal.science/hal-05401124}, 218 | NOTE = {}, 219 | YEAR = {2025}, 220 | MONTH = Dec, 221 | SWHID = {swh:1:dir:c0d8294e6b01cdb5bc8703eeb51e546275244c0a;origin=https://github.com/edgaremy/neural-netz;visit=swh:1:snp:052f4efa29f793bf84901593b27254f2e0e15ffb;anchor=swh:1:rev:3669d7922f3581afc2da1f12b9e62c54e4242048}, 222 | VERSION = {0.3.0}, 223 | REPOSITORY = {https://github.com/edgaremy/neural-netz}, 224 | LICENSE = {https://spdx.org/licenses/MIT-0}, 225 | KEYWORDS = {visualization ; typst ; neural networks ; deep learning}, 226 | HAL_ID = {hal-05401124}, 227 | HAL_VERSION = {v1}, 228 | } 229 | ``` 230 | 231 | ## Acknowledgements 232 | 233 | This package could not have existed without the great Python+LaTeX visualization package [PlotNeuralNet](https://github.com/HarisIqbal88/PlotNeuralNet) made by Haris Iqbal. It proposes an elegant way for viewing neural networks, and its visual style was obviously a strong inspiration for the implementation of neural-netz. 234 | 235 | Default input image was [taken from iNaturalist](https://www.inaturalist.org/observations/205901632) (colors are slightly edited). 236 | 237 | If you feel like contributing to this package (bug fixes, features or even code refactoring), or want your model added to the model gallery, feel free to [make a PR to the neural-netz repo](https://github.com/edgaremy/neural-netz/pulls) :) -------------------------------------------------------------------------------- /src/lib.typ: -------------------------------------------------------------------------------- 1 | #import "@preview/cetz:0.4.2": canvas, draw 2 | 3 | // Draw a neural network from layer specifications 4 | #let draw-network( 5 | layers, 6 | connections: (), 7 | palette: "warm", 8 | show-legend: false, 9 | legend-title: "Layers", 10 | scale: 100%, 11 | stroke-thickness: 1, 12 | depth-multiplier: 0.3, 13 | show-relu: false, 14 | ) = { 15 | 16 | 17 | let colors-warm = ( 18 | conv: rgb("#ffe0a1"), 19 | conv-relu: rgb("#ffa947"), 20 | pool: rgb("#e04227"), 21 | unpool: rgb("#2E7D7D"), 22 | deconv: rgb("#88C1D0"), 23 | concat: rgb("#B39DDB"), 24 | softmax: rgb("#6A0066"), 25 | gap: rgb("#FF69B4"), 26 | fc: rgb("#B39DDB"), 27 | fc-relu: rgb("#9575CD"), 28 | sum: rgb("#70cf9b"), 29 | convres: rgb("#e681a8"), 30 | convres-relu: rgb("#ad507e"), 31 | convsoftmax: rgb("#6A0066"), 32 | input: rgb("#f7f1ed"), 33 | output: rgb("#6A0066"), 34 | custom: rgb("#dad9d7"), 35 | custom-relu: rgb("#a8a7a4"), 36 | arrow: rgb("#0f4d52"), 37 | connection: rgb("#0f4d52"), 38 | ) 39 | 40 | // Cold palette 41 | let colors-cold = ( 42 | conv: rgb("#CDEDFE"), 43 | conv-relu: rgb("#89C7E8"), 44 | pool: rgb("#af78e6"), 45 | unpool: rgb("#B8A3E8"), 46 | deconv: rgb("#96e7c8"), 47 | concat: rgb("#7EC8E3"), 48 | softmax: rgb("#4A148C"), 49 | gap: rgb("#E91E63"), 50 | fc: rgb("#9FA8DA"), 51 | fc-relu: rgb("#7986CB"), 52 | sum: rgb("#70cf9b"), 53 | convres: rgb("#8edbd5"), 54 | convres-relu: rgb("#54adac"), 55 | convsoftmax: rgb("#4A148C"), 56 | input: rgb("#ecebf5"), 57 | output: rgb("#4A148C"), 58 | custom: rgb("#d7d9da"), 59 | custom-relu: rgb("#a1a4ad"), 60 | arrow: rgb("#0f4d52"), 61 | connection: rgb("#0f4d52"), 62 | ) 63 | 64 | let strokes = ( 65 | solid: (paint: black.lighten(20%), thickness: 0.65pt * stroke-thickness), 66 | hidden: (paint: gray.darken(50%).transparentize(50%), thickness: 0.45pt * stroke-thickness, dash: (1pt, 0.8pt)), 67 | arrow: (thickness: 0pt), 68 | connection: (thickness: 1pt * stroke-thickness), 69 | ) 70 | 71 | let dynamic-color-strokes(fill) = { 72 | ( 73 | solid: (paint: fill.darken(50%).saturate(80%), thickness: strokes.solid.thickness), 74 | hidden: (paint: fill.darken(60%).saturate(80%).transparentize(60%), thickness: strokes.hidden.thickness, dash: strokes.hidden.dash), 75 | ) 76 | } 77 | 78 | let font-sizes = ( 79 | label: 8.5pt, 80 | channel-number: 7pt, 81 | layer-label: 8.5pt, 82 | output-number: 8pt, 83 | legend-title: 10pt, 84 | legend-item: 8pt, 85 | ) 86 | 87 | let opacity-values = ( 88 | front-face: 30%, 89 | top-face: 30%, 90 | right-face: 30%, 91 | band: 60%, 92 | ball: 10%, 93 | edge: 70%, 94 | ) 95 | 96 | let darken-amounts = ( 97 | top: 0%, 98 | right: 0%, 99 | ) 100 | 101 | let arrow-config = ( 102 | triangle-size: 0.2, 103 | axis-y: 2.5 104 | ) 105 | 106 | let depth-angle-deg = 45deg //calc.atan(depth-multiplier) * 180 / calc.pi 107 | 108 | let get-depth-offsets(d) = { 109 | (d * depth-multiplier, d * depth-multiplier) 110 | } 111 | 112 | let get-y-offset-for-center-on-axis(h, d, axis-y) = { 113 | let (_, oy) = get-depth-offsets(d) 114 | axis-y - h / 2 - oy / 2 115 | } 116 | 117 | let get-perspective-center-y(y-offset, h, oy) = { 118 | y-offset + h / 2 + oy / 2 119 | } 120 | 121 | let get-layer-anchors(x, y, w, h, ox, oy) = { 122 | let center-x = x + w/2 + ox/2 123 | let center-y = y + h/2 + oy/2 124 | ( 125 | west: (x, center-y), 126 | east: (x + w + ox, center-y), 127 | // True west/east are the geometric centers of the 3D west/east faces 128 | // West face center: halfway through depth, centered vertically 129 | true_west: (x + ox/2, center-y), 130 | // East face center: at right edge minus half depth, centered vertically 131 | true_east: (x + w + ox/2, center-y), 132 | north: (center-x, y + h + oy), 133 | south: (center-x, y), 134 | anchor: (center-x, center-y), 135 | near: (center-x, center-y), 136 | northeast: (x + w + ox, y + h + oy), 137 | southeast: (x + w + ox, y), 138 | northwest: (x, y + h + oy), 139 | southwest: (x, y), 140 | ) 141 | } 142 | 143 | let coord-along-path(start, end, pos: 1.0) = { 144 | (start.at(0) + (end.at(0) - start.at(0)) * pos, 145 | start.at(1) + (end.at(1) - start.at(1)) * pos) 146 | } 147 | 148 | let get-circle-boundary-point(from-pt, center-pt, radius) = { 149 | let dx = center-pt.at(0) - from-pt.at(0) 150 | let dy = center-pt.at(1) - from-pt.at(1) 151 | let dist = calc.sqrt(dx * dx + dy * dy) 152 | if dist > 0 { 153 | let ux = dx / dist 154 | let uy = dy / dist 155 | (center-pt.at(0) - ux * radius, center-pt.at(1) - uy * radius) 156 | } else { 157 | (center-pt.at(0) + radius, center-pt.at(1)) 158 | } 159 | } 160 | 161 | let colors = if palette == "cold" { colors-cold } else { colors-warm } 162 | let scale-factor = scale / 100% 163 | 164 | 165 | canvas(length: 1cm * scale-factor, { 166 | import draw: * 167 | 168 | let scaled-font = (size) => size * scale-factor 169 | 170 | // Helper function: Draw isometric image on right face 171 | let draw-isometric-image(x, y, w, h, ox, oy, image) = { 172 | let img-height = (h) * 28.25pt * scale-factor 173 | let img-width = (oy / depth-multiplier) * 28.25pt * scale-factor 174 | 175 | let actual-img-width() = measure(image).width 176 | let actual-img-height() = measure(image).height 177 | 178 | content((x+w+ox/2,y+h/2+oy/2), 179 | context { 180 | pad( 181 | x: -((1+depth-multiplier) * img-height - img-width)/2, 182 | y: +(img-height/2 - img-width)/2 183 | )[ 184 | #std.rotate(90deg)[ 185 | #std.skew(ax: 45deg)[ 186 | #std.rotate(-90deg)[ 187 | #pad( 188 | x: -(actual-img-width() - img-width * depth-multiplier)/2, 189 | y: -(actual-img-height() - img-height)/2 190 | )[ 191 | #std.scale(x: img-width * depth-multiplier, y: img-height)[ 192 | #image] 193 | ]]]]] 194 | } 195 | ) 196 | } 197 | 198 | let box-3d(x, y, w, h, d, fill, opacity: 1, show-left: true, show-right: true, ylabel: none, zlabel: none, is-input: false, image: none) = { 199 | let (ox, oy) = get-depth-offsets(d) 200 | let alpha = 100% - opacity * 100% 201 | 202 | let dyn-strokes = dynamic-color-strokes(fill) 203 | 204 | line((x, y), (x + ox, y + oy), stroke: dyn-strokes.hidden) 205 | line((x + ox, y + oy), (x + w + ox, y + oy), stroke: dyn-strokes.hidden) 206 | line((x + ox, y + oy), (x + ox, y + h + oy), stroke: dyn-strokes.hidden) 207 | 208 | rect((x, y), (x + w, y + h), fill: fill.transparentize(alpha), stroke: none) 209 | 210 | if show-left { 211 | line((x, y), (x, y + h), stroke: dyn-strokes.solid) 212 | } 213 | if show-right { 214 | line((x + w, y), (x + w, y + h), stroke: dyn-strokes.solid) 215 | } 216 | line((x, y + h), (x + w, y + h), stroke: dyn-strokes.solid) 217 | line((x, y), (x + w, y), stroke: dyn-strokes.solid) 218 | 219 | line((x, y + h), (x + ox, y + h + oy), (x + w + ox, y + h + oy), (x + w, y + h), 220 | close: true, fill: fill.darken(darken-amounts.top).transparentize(alpha), stroke: dyn-strokes.solid) 221 | 222 | // Draw right face normally 223 | line((x + w, y), (x + w + ox, y + oy), (x + w + ox, y + h + oy), (x + w, y + h), 224 | close: true, fill: fill.darken(darken-amounts.right).transparentize(alpha), stroke: dyn-strokes.solid) 225 | 226 | // DRAW IMAGE ON TOP OF RIGHT FACE WITH ISOMETRIC PERSPECTIVE 227 | if image != none { 228 | draw-isometric-image(x, y, w, h, ox, oy, image) 229 | } 230 | 231 | if is-input { 232 | if ylabel != none { 233 | content((x - 0.2, y + h/2), anchor: "east", 234 | [#text(size: scaled-font(font-sizes.layer-label), weight: "bold", str(ylabel))]) 235 | } 236 | if zlabel != none { 237 | content((x + w/2 + ox/2, y + h + oy - 0.9), angle: depth-angle-deg, 238 | [#text(size: scaled-font(font-sizes.layer-label), weight: "bold", str(zlabel))]) 239 | } 240 | } else { 241 | if ylabel != none { 242 | content((x - 0.3, y + h/2), anchor: "east", 243 | [#text(size: scaled-font(font-sizes.layer-label), str(ylabel))]) 244 | } 245 | if zlabel != none { 246 | content((x + w/2 + ox/2, y - 0.4), angle: depth-angle-deg, 247 | [#text(size: scaled-font(font-sizes.layer-label), str(zlabel))]) 248 | } 249 | } 250 | } 251 | 252 | // Helper function: Draw front face of a single band with optional relu split 253 | let draw-band-front-face(band-x, y, band-width, h, fill-color, bandfill-color, alpha, show-relu) = { 254 | if show-relu { 255 | let conv-width = band-width * 2 / 3 256 | rect((band-x, y), (band-x + conv-width, y + h), 257 | fill: fill-color.transparentize(calc.max(opacity-values.front-face, alpha)), stroke: none) 258 | rect((band-x + conv-width, y), (band-x + band-width, y + h), 259 | fill: bandfill-color.transparentize(calc.max(opacity-values.front-face, alpha)), stroke: none) 260 | } else { 261 | rect((band-x, y), (band-x + band-width, y + h), 262 | fill: fill-color.transparentize(calc.max(opacity-values.front-face, alpha)), stroke: none) 263 | } 264 | } 265 | 266 | // Helper function: Draw top face of a single band with optional relu split 267 | let draw-band-top-face(band-x, y, band-width, h, ox, oy, fill-color, bandfill-color, show-relu) = { 268 | if show-relu { 269 | let conv-width = band-width * 2 / 3 270 | line((band-x, y + h), (band-x + ox, y + h + oy), 271 | (band-x + conv-width + ox, y + h + oy), (band-x + conv-width, y + h), 272 | close: true, 273 | fill: fill-color.darken(darken-amounts.top).transparentize(opacity-values.top-face), 274 | stroke: none) 275 | line((band-x + conv-width, y + h), (band-x + conv-width + ox, y + h + oy), 276 | (band-x + band-width + ox, y + h + oy), (band-x + band-width, y + h), 277 | close: true, 278 | fill: bandfill-color.darken(darken-amounts.top).transparentize(opacity-values.top-face), 279 | stroke: none) 280 | } else { 281 | line((band-x, y + h), (band-x + ox, y + h + oy), 282 | (band-x + band-width + ox, y + h + oy), (band-x + band-width, y + h), 283 | close: true, 284 | fill: fill-color.darken(darken-amounts.top).transparentize(opacity-values.top-face), 285 | stroke: none) 286 | } 287 | } 288 | 289 | // Helper function: Draw band separator edges 290 | let draw-band-separator-edges(band-x, y, h, ox, oy, band-width, is-first, fill-color) = { 291 | 292 | let dyn-strokes = dynamic-color-strokes(fill-color) 293 | 294 | if is-first { 295 | // First band: draw the three hidden back edges 296 | line((band-x, y), (band-x + ox, y + oy), stroke: dyn-strokes.hidden) 297 | line((band-x + ox, y + oy), (band-x + ox, y + h + oy), stroke: dyn-strokes.hidden) 298 | line((band-x + ox, y + oy), (band-x + band-width + ox, y + oy), stroke: dyn-strokes.hidden) 299 | } else { 300 | // Front vertical separator (solid) 301 | line((band-x, y), (band-x, y + h), stroke: dyn-strokes.solid) 302 | // Diagonal connector from front top to back top (solid) 303 | line((band-x, y + h), (band-x + ox, y + h + oy), stroke: dyn-strokes.solid) 304 | // Diagonal connector from front bottom to back bottom (dashed) 305 | line((band-x, y), (band-x + ox, y + oy), stroke: dyn-strokes.hidden) 306 | // Back vertical edge (dashed) 307 | line((band-x + ox, y + oy), (band-x + ox, y + h + oy), stroke: dyn-strokes.hidden) 308 | // Back horizontal edge (dashed) 309 | line((band-x + ox, y + oy), (band-x + band-width + ox, y + oy), stroke: dyn-strokes.hidden) 310 | } 311 | } 312 | 313 | // Helper function: Display channels labels (single label below, second label on diagonal if provided) 314 | let draw-channels-labels(channels, center-x, right-x, y, ox, oy) = { 315 | if channels != none and channels.len() > 0 { 316 | // First element: display below the layer 317 | content((center-x, y - 0.15), 318 | [#text(size: scaled-font(font-sizes.channel-number), str(channels.at(0)))]) 319 | 320 | // Second element (if exists): display along depth diagonal 321 | if channels.len() > 1 { 322 | let diag-mid-x = right-x + ox / 2.5 323 | let diag-mid-y = y + oy / 2.5 324 | content((diag-mid-x, diag-mid-y - 0.23), angle: depth-angle-deg, 325 | [#text(size: scaled-font(font-sizes.channel-number), str(channels.at(1)))]) 326 | } 327 | } 328 | } 329 | 330 | let draw-arrow-icon(x1, y1, x2, y2, opacity: 0.7) = { 331 | let dx = x2 - x1 332 | let dy = y2 - y1 333 | let len = calc.sqrt(dx * dx + dy * dy) 334 | 335 | if len > 0 { 336 | let mid-x = (x1 + x2) / 2 337 | let mid-y = (y1 + y2) / 2 338 | let ux = dx / len 339 | let uy = dy / len 340 | let px = -uy 341 | let py = ux 342 | 343 | let size = arrow-config.triangle-size 344 | let tip = size * 0.9 345 | let back = size * 0.9 346 | let wing = size * 0.45 347 | 348 | let tip-pt = (mid-x + ux * tip, mid-y + uy * tip) 349 | let back-mid = (mid-x - ux * back, mid-y - uy * back) 350 | let right-pt = (back-mid.at(0) + px * wing, back-mid.at(1) + py * wing) 351 | let left-pt = (back-mid.at(0) - px * wing, back-mid.at(1) - py * wing) 352 | let back-tip = (back-mid.at(0) + ux * back * 0.5, back-mid.at(1) + uy * back * 0.5) 353 | 354 | let arrow-color = if opacity < 1.0 { 355 | colors.arrow.transparentize(100% - opacity * 100%) 356 | } else { 357 | colors.arrow 358 | } 359 | 360 | line(tip-pt, right-pt, back-tip, left-pt, close: true, 361 | fill: arrow-color, stroke: (paint: arrow-color, thickness: strokes.arrow.thickness)) 362 | } 363 | } 364 | 365 | let draw-segment-with-arrow(x1, y1, x2, y2, opacity: 0.7) = { 366 | let paint = if opacity < 1.0 { 367 | colors.connection.transparentize(100% - opacity * 100%) 368 | } else { 369 | colors.connection 370 | } 371 | line((x1, y1), (x2, y2), stroke: (paint: paint, thickness: strokes.connection.thickness)) 372 | draw-arrow-icon(x1, y1, x2, y2, opacity: opacity) 373 | } 374 | 375 | let draw-connection-path(segments, opacity: 0.7, layers: none, layer-positions-ref: (:), show-relu: false) = { 376 | // If there are layers to draw on segment idx==1, we need to split that segment 377 | if layers != none and layers.len() > 0 { 378 | // Draw first segment (idx==0) normally 379 | if segments.len() > 0 { 380 | let seg = segments.at(0) 381 | draw-segment-with-arrow(seg.at(0).at(0), seg.at(0).at(1), seg.at(1).at(0), seg.at(1).at(1), opacity: opacity) 382 | } 383 | 384 | // Process segment idx==1 with layers 385 | if segments.len() > 1 { 386 | let seg = segments.at(1) 387 | let seg-start = seg.at(0) 388 | let seg-end = seg.at(1) 389 | 390 | // Calculate positions for all layers along the segment 391 | let layer-infos = () 392 | for layer-spec in layers { 393 | let layer-type = layer-spec.at("type") 394 | 395 | if layer-type == "conv" { 396 | let widths = layer-spec.at("widths", default: (0.5,)) 397 | let total-width = widths.fold(0, (acc, w) => acc + w) 398 | let layer-h = layer-spec.at("height", default: 2) 399 | let layer-d = layer-spec.at("depth", default: 2) 400 | let (lox, loy) = get-depth-offsets(layer-d) 401 | 402 | layer-infos.push(( 403 | spec: layer-spec, 404 | width: total-width, 405 | height: layer-h, 406 | depth: layer-d, 407 | ox: lox, 408 | oy: loy, 409 | )) 410 | } 411 | } 412 | 413 | // Calculate positions along the segment for each layer 414 | let num-layers = layer-infos.len() 415 | let positions = () 416 | for (i, info) in layer-infos.enumerate() { 417 | let t = (i + 1) / (num-layers + 1) 418 | let center-x = seg-start.at(0) + (seg-end.at(0) - seg-start.at(0)) * t 419 | let center-y = seg-start.at(1) + (seg-end.at(1) - seg-start.at(1)) * t 420 | let layer-x = center-x - info.width / 2 421 | let layer-y = center-y - info.height / 2 - info.oy / 2 422 | 423 | // Use true_west (depth-adjusted) for connections 424 | let west-x = layer-x + info.ox / 2 425 | let east-x = layer-x + info.width + info.ox / 2 426 | 427 | positions.push(( 428 | x: layer-x, 429 | y: layer-y, 430 | center-x: center-x, 431 | center-y: center-y, 432 | west: (west-x, center-y), 433 | east: (east-x, center-y), 434 | )) 435 | } 436 | 437 | // Draw connection segments and layers in proper order (interleaved) 438 | // First arrow: from seg-start to first layer 439 | if positions.len() > 0 { 440 | draw-segment-with-arrow(seg-start.at(0), seg-start.at(1), positions.at(0).west.at(0), positions.at(0).west.at(1), opacity: opacity) 441 | } 442 | 443 | // Interleave layers and arrows in propagation order 444 | for (i, info) in layer-infos.enumerate() { 445 | let pos = positions.at(i) 446 | let layer-spec = info.spec 447 | let layer-name = layer-spec.at("name", default: none) 448 | 449 | let mid-x = pos.x 450 | let mid-y = pos.y 451 | let total-width = info.width 452 | let layer-h = info.height 453 | let lox = info.ox 454 | let loy = info.oy 455 | 456 | let fill-color = layer-spec.at("fill", default: colors.conv) 457 | let bandfill-color = layer-spec.at("bandfill", default: colors.at("conv-relu")) 458 | let layer-opacity = layer-spec.at("opacity", default: 1.0) 459 | let alpha-front = 100% - layer-opacity * 100% 460 | let widths = layer-spec.at("widths", default: (0.5,)) 461 | let channels = layer-spec.at("channels", default: none) 462 | let layer-show-relu = layer-spec.at("show-relu", default: show-relu) 463 | 464 | // Use dynamic color strokes for fill-color and bandfill-color 465 | let dyn-strokes = dynamic-color-strokes(fill-color) 466 | let dyn-band-strokes = dynamic-color-strokes(bandfill-color) 467 | 468 | // Determine if we have a diagonal label 469 | let has-diagonal-label = channels != none and channels.len() == widths.len() + 1 470 | let diagonal-label = if has-diagonal-label { channels.at(widths.len()) } else { none } 471 | 472 | let cumulative-x = mid-x 473 | for (j, w) in widths.enumerate() { 474 | let band-width = w 475 | let band-x = cumulative-x 476 | 477 | draw-band-front-face(band-x, mid-y, band-width, layer-h, fill-color, bandfill-color, alpha-front, layer-show-relu) 478 | 479 | if channels != none and j < channels.len() { 480 | content((band-x + band-width / 2, mid-y - 0.15), 481 | [#text(size: scaled-font(font-sizes.channel-number), str(channels.at(j)))]) 482 | } 483 | 484 | cumulative-x += band-width 485 | } 486 | 487 | line((mid-x, mid-y), (mid-x, mid-y + layer-h), stroke: dyn-strokes.solid) 488 | line((mid-x + total-width, mid-y), (mid-x + total-width, mid-y + layer-h), stroke: dyn-strokes.solid) 489 | line((mid-x, mid-y + layer-h), (mid-x + total-width, mid-y + layer-h), stroke: dyn-strokes.solid) 490 | line((mid-x, mid-y), (mid-x + total-width, mid-y), stroke: dyn-strokes.solid) 491 | 492 | cumulative-x = mid-x 493 | for (j, w) in widths.enumerate() { 494 | let band-width = w 495 | let band-x = cumulative-x 496 | 497 | draw-band-top-face(band-x, mid-y, band-width, layer-h, lox, loy, fill-color, bandfill-color, layer-show-relu) 498 | 499 | cumulative-x += band-width 500 | } 501 | 502 | let right-face-color = if layer-show-relu { bandfill-color } else { fill-color } 503 | let right-face-strokes = if layer-show-relu { dyn-band-strokes } else { dyn-strokes } 504 | line((mid-x + total-width, mid-y), (mid-x + total-width + lox, mid-y + loy), 505 | (mid-x + total-width + lox, mid-y + layer-h + loy), (mid-x + total-width, mid-y + layer-h), 506 | close: true, fill: right-face-color.darken(darken-amounts.right).transparentize(opacity-values.right-face), 507 | stroke: right-face-strokes.solid) 508 | 509 | cumulative-x = mid-x 510 | for (j, w) in widths.enumerate() { 511 | let band-width = w 512 | let band-x = cumulative-x 513 | // Use bandfill-color for band separator edges if relu, else fill-color 514 | let edge-strokes = if layer-show-relu { dyn-band-strokes } else { dyn-strokes } 515 | draw-band-separator-edges(band-x, mid-y, layer-h, lox, loy, band-width, j == 0, fill-color) 516 | cumulative-x += band-width 517 | } 518 | 519 | line((mid-x, mid-y + layer-h), (mid-x + lox, mid-y + layer-h + loy), stroke: dyn-strokes.solid) 520 | line((mid-x + lox, mid-y + layer-h + loy), (mid-x + total-width + lox, mid-y + layer-h + loy), stroke: dyn-strokes.solid) 521 | line((mid-x + total-width, mid-y + layer-h), (mid-x + total-width + lox, mid-y + layer-h + loy), stroke: dyn-strokes.solid) 522 | line((mid-x + total-width + lox, mid-y + loy), (mid-x + total-width + lox, mid-y + layer-h + loy), stroke: dyn-strokes.solid) 523 | line((mid-x + total-width, mid-y), (mid-x + total-width + lox, mid-y + loy), stroke: dyn-strokes.solid) 524 | 525 | let label = layer-spec.at("label", default: none) 526 | if label != none { 527 | content((mid-x + total-width / 2, mid-y - 0.5), 528 | [#text(size: scaled-font(font-sizes.layer-label), weight: "bold", label)]) 529 | } 530 | 531 | // Display diagonal label if provided 532 | if diagonal-label != none { 533 | let diag-start-x = mid-x + total-width 534 | let diag-start-y = mid-y 535 | let diag-mid-x = diag-start-x + lox / 2.5 536 | let diag-mid-y = diag-start-y + loy / 2.5 537 | content((diag-mid-x, diag-mid-y - 0.23), angle: depth-angle-deg, 538 | [#text(size: scaled-font(font-sizes.channel-number), str(diagonal-label))]) 539 | } 540 | 541 | if layer-name != none { 542 | layer-positions-ref.insert(layer-name, ( 543 | x: mid-x, y: mid-y, w: total-width, h: layer-h, ox: lox, oy: loy, 544 | anchors: get-layer-anchors(mid-x, mid-y, total-width, layer-h, lox, loy) 545 | )) 546 | } 547 | 548 | // Draw arrow to next layer (or to seg-end if this is the last layer) 549 | if i < layer-infos.len() - 1 { 550 | // Arrow to next layer 551 | let from-east = positions.at(i).east 552 | let to-west = positions.at(i + 1).west 553 | draw-segment-with-arrow(from-east.at(0), from-east.at(1), to-west.at(0), to-west.at(1), opacity: opacity) 554 | } else { 555 | // Last layer: arrow to seg-end 556 | draw-segment-with-arrow(positions.at(-1).east.at(0), positions.at(-1).east.at(1), seg-end.at(0), seg-end.at(1), opacity: opacity) 557 | } 558 | } 559 | } 560 | 561 | // Draw remaining segments (idx >= 2) normally 562 | for idx in range(2, segments.len()) { 563 | let seg = segments.at(idx) 564 | draw-segment-with-arrow(seg.at(0).at(0), seg.at(0).at(1), seg.at(1).at(0), seg.at(1).at(1), opacity: opacity) 565 | } 566 | } else { 567 | // No layers, draw all segments normally 568 | for seg in segments { 569 | draw-segment-with-arrow(seg.at(0).at(0), seg.at(0).at(1), seg.at(1).at(0), seg.at(1).at(1), opacity: opacity) 570 | } 571 | } 572 | } 573 | 574 | let x = 0 575 | let arrow-axis-y = arrow-config.axis-y 576 | let prev-center-y = arrow-axis-y 577 | let prev-x = 0 578 | let prev-depth-offset = 0 579 | let prev-pool-width = 0 580 | let used-layer-types = (:) 581 | let layer-positions = (:) 582 | let arrow-segments = (:) 583 | let legend-entries = () // Collect legend entries in order of appearance: array of (key, label, color, ...) 584 | 585 | // Default legend labels for each layer type 586 | let default-legend-labels = ( 587 | input: "Input", 588 | conv: "Convolution", 589 | convres: "Conv Residual", 590 | pool: "Pooling", 591 | unpool: "Unpooling", 592 | deconv: "Deconvolution", 593 | concat: "Concatenation", 594 | sum: "Element-wise Sum", 595 | gap: "Global Avg Pool", 596 | fc: "Fully Connected", 597 | convsoftmax: "Conv Softmax", 598 | softmax: "Softmax", 599 | output: "Output", 600 | ) 601 | 602 | for (i, l) in layers.enumerate() { 603 | used-layer-types.insert(l.type, true) 604 | 605 | // Ensure height and depth are set for arrow calculation (using type-specific defaults) 606 | if not l.keys().contains("height") { 607 | let default-h = if l.type == "pool" or l.type == "unpool" { 4 } else if l.type == "concat" or l.type == "fc" or l.type == "softmax" or l.type == "output" { 3 } else if l.type == "gap" { 1.5 } else if l.type == "convsoftmax" { 4 } else { 5 } 608 | l.insert("height", default-h) 609 | } 610 | if not l.keys().contains("depth") { 611 | let default-d = if l.type == "pool" or l.type == "unpool" { 4 } else if l.type == "concat" { 3 } else if l.type == "gap" { 1.5 } else if l.type == "fc" or l.type == "softmax" or l.type == "output" { 0.4 } else if l.type == "convsoftmax" { 4 } else { 5 } 612 | l.insert("depth", default-d) 613 | } 614 | 615 | let gap = if i == 0 { 616 | 0 617 | } else if l.type == "pool" or l.type == "unpool" { 618 | 0 619 | } else { 620 | l.at("offset", default: 1.2) 621 | } 622 | 623 | x += gap 624 | 625 | // Calculate and store arrow segment positions for ALL layers (for skip connections) 626 | // Only draw arrows if previous layer has show-connection enabled (controls outgoing arrows) 627 | if i > 0 { 628 | let prev-layer = layers.at(i - 1) 629 | let prev-show-connection = prev-layer.at("show-connection", default: if prev-layer.type == "input" { false } else { true }) 630 | if prev-show-connection { 631 | // Arrow starts from true_east of previous layer (depth-adjusted) 632 | let start-x = prev-x + prev-pool-width + prev-depth-offset / 2 633 | let start-y = prev-center-y 634 | 635 | // Read height and depth directly from layer (already set by each layer type) 636 | let curr-h = l.at("height") 637 | let curr-d = l.at("depth") 638 | let (curr-ox, curr-oy) = get-depth-offsets(curr-d) 639 | let curr-depth-offset = curr-ox 640 | let curr-y-offset = get-y-offset-for-center-on-axis(curr-h, curr-d, arrow-axis-y) 641 | let end-y = get-perspective-center-y(curr-y-offset, curr-h, curr-oy) 642 | 643 | // Arrow ends at true_west of current layer (depth-adjusted) 644 | // Special handling for sum node (use radius instead of depth) 645 | let end-x = if l.type == "sum" { 646 | let radius = l.at("radius", default: 0.4) 647 | x + prev-depth-offset / 2 648 | } else { 649 | // For pool/unpool with offset, calculate actual layer position first 650 | let is-curr-pool-or-unpool = l.type == "pool" or l.type == "unpool" 651 | let curr-offset = if is-curr-pool-or-unpool { l.at("offset", default: none) } else { none } 652 | let curr-layer-x = if curr-offset != none { x + curr-offset } else if is-curr-pool-or-unpool { x + prev-depth-offset / 2 - curr-ox / 2 } else { x } 653 | curr-layer-x + curr-depth-offset / 2 654 | } 655 | 656 | let prev-name = prev-layer.at("name", default: none) 657 | let curr-name = l.at("name", default: none) 658 | 659 | // Store true arrow endpoints (with depth) and midpoint 660 | let mid-arrow-x = (start-x + end-x) / 2 661 | let mid-arrow-y = (start-y + end-y) / 2 662 | 663 | // Store as outgoing arrow for previous layer (includes start point and midpoint) 664 | if prev-name != none { 665 | arrow-segments.insert(prev-name + "-out", ( 666 | start: (start-x, start-y), 667 | mid: (mid-arrow-x, mid-arrow-y), 668 | x: mid-arrow-x, 669 | y: mid-arrow-y 670 | )) 671 | } 672 | // Store as incoming arrow for current layer (includes end point and midpoint) 673 | if curr-name != none { 674 | arrow-segments.insert(curr-name + "-in", ( 675 | end: (end-x, end-y), 676 | mid: (mid-arrow-x, mid-arrow-y), 677 | x: mid-arrow-x, 678 | y: mid-arrow-y 679 | )) 680 | } 681 | 682 | // Draw arrow for non-pool/unpool layers, or pool/unpool with offset 683 | let is-pool-or-unpool = l.type == "pool" or l.type == "unpool" 684 | let has-offset = l.at("offset", default: none) != none 685 | if not is-pool-or-unpool or has-offset { 686 | draw-segment-with-arrow(start-x, start-y, end-x, end-y, opacity: 0.7) 687 | 688 | // Draw connection label if specified (read from previous layer, like show-connection) 689 | let conn-label = prev-layer.at("connection-label", default: none) 690 | if conn-label != none { 691 | content((mid-arrow-x, mid-arrow-y + 0.28), 692 | [#text(size: scaled-font(font-sizes.layer-label), conn-label)]) 693 | } 694 | } 695 | } 696 | } 697 | 698 | // CUSTOM LAYER (Universal layer type with full flexibility) 699 | if l.type == "custom" { 700 | let h = l.at("height", default: 5) 701 | let d = l.at("depth", default: 5) 702 | l.insert("height", h) 703 | l.insert("depth", d) 704 | let w = l.at("width", default: none) 705 | let widths = l.at("widths", default: none) 706 | let label = l.at("label", default: none) 707 | let xlabel = l.at("xlabel", default: none) 708 | let name = l.at("name", default: none) 709 | let fill-color = l.at("fill", default: colors.custom) 710 | let bandfill-color = l.at("bandfill", default: colors.at("custom-relu")) 711 | let layer-opacity = l.at("opacity", default: 0.7) 712 | let channels = l.at("channels", default: none) 713 | let ylabel-val = l.at("ylabel", default: none) 714 | let zlabel-val = l.at("zlabel", default: none) 715 | let layer-show-relu = l.at("show-relu", default: show-relu) 716 | let layer-show-connection = l.at("show-connection", default: true) 717 | let connection-label = l.at("connection-label", default: none) 718 | let img = l.at("image", default: none) 719 | let is-input-style = l.at("input-style", default: false) 720 | 721 | let (ox, oy) = get-depth-offsets(d) 722 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 723 | 724 | // Determine rendering mode: simple box or multi-band 725 | let use-simple-box = widths == none 726 | 727 | if use-simple-box { 728 | // Simple box rendering (like input, pool, fc, etc.) 729 | let actual-w = if w == none { 0.2 } else { w } 730 | 731 | if img == "default" { 732 | img = image("bird.jpg") 733 | } 734 | 735 | box-3d(x, y-offset, actual-w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 736 | 737 | // Display channels labels 738 | draw-channels-labels(channels, x + actual-w/2, x + actual-w, y-offset, ox, oy) 739 | 740 | // Track position if named 741 | if name != none { 742 | layer-positions.insert(name, ( 743 | x: x, y: y-offset, w: actual-w, h: h, ox: ox, oy: oy, type: "custom", 744 | anchors: get-layer-anchors(x, y-offset, actual-w, h, ox, oy), 745 | pool-offset: 0 746 | )) 747 | } 748 | 749 | if label != none { 750 | content((x + actual-w/2, y-offset - 0.5), 751 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 752 | } 753 | 754 | prev-x = x + actual-w 755 | prev-depth-offset = ox 756 | x += actual-w 757 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 758 | prev-pool-width = 0 759 | } else { 760 | // Multi-band rendering (like conv, convres) 761 | let dyn-strokes = dynamic-color-strokes(fill-color) 762 | let dyn-band-strokes = dynamic-color-strokes(bandfill-color) 763 | 764 | let has-diagonal-label = channels != none and channels.len() == widths.len() + 1 765 | let diagonal-label = if has-diagonal-label { channels.at(widths.len()) } else { none } 766 | let channel-labels = if channels != none { 767 | if has-diagonal-label { channels.slice(0, widths.len()) } else { channels } 768 | } else { 769 | (widths.map(w => "")) 770 | } 771 | 772 | let start-x = x 773 | let total-width = widths.fold(0, (acc, w) => acc + w) 774 | 775 | // Draw front face as colored bands 776 | let cumulative-x = start-x 777 | let alpha-front = 100% - layer-opacity * 100% 778 | for (j, ch) in channel-labels.enumerate() { 779 | let band-width = widths.at(j) 780 | let band-x = cumulative-x 781 | 782 | draw-band-front-face(band-x, y-offset, band-width, h, fill-color, bandfill-color, alpha-front, layer-show-relu) 783 | 784 | // Display channel label under each band 785 | let band-center-x = band-x + band-width / 2 786 | content((band-center-x, y-offset - 0.15), 787 | [#text(size: scaled-font(font-sizes.channel-number), str(ch))]) 788 | 789 | cumulative-x += band-width 790 | } 791 | 792 | // Draw front face outer edges 793 | line((start-x, y-offset), (start-x, y-offset + h), stroke: dyn-strokes.solid) 794 | line((start-x + total-width, y-offset), (start-x + total-width, y-offset + h), stroke: dyn-strokes.solid) 795 | line((start-x, y-offset + h), (start-x + total-width, y-offset + h), stroke: dyn-strokes.solid) 796 | line((start-x, y-offset), (start-x + total-width, y-offset), stroke: dyn-strokes.solid) 797 | 798 | // Draw top face segmented by band 799 | cumulative-x = start-x 800 | for (j, ch) in channel-labels.enumerate() { 801 | let band-width = widths.at(j) 802 | let band-x = cumulative-x 803 | 804 | draw-band-top-face(band-x, y-offset, band-width, h, ox, oy, fill-color, bandfill-color, layer-show-relu) 805 | 806 | cumulative-x += band-width 807 | } 808 | 809 | // Draw right face 810 | let right-face-color = if layer-show-relu { bandfill-color } else { fill-color } 811 | line((start-x + total-width, y-offset), (start-x + total-width + ox, y-offset + oy), 812 | (start-x + total-width + ox, y-offset + h + oy), (start-x + total-width, y-offset + h), 813 | close: true, 814 | fill: right-face-color.darken(darken-amounts.right).transparentize(opacity-values.right-face), 815 | stroke: dyn-strokes.solid) 816 | 817 | // Draw image on top of right face if provided 818 | if img != none { 819 | draw-isometric-image(start-x, y-offset, total-width, h, ox, oy, img) 820 | } 821 | 822 | // Draw all edges for band divisions 823 | cumulative-x = start-x 824 | for (j, ch) in channel-labels.enumerate() { 825 | let band-width = widths.at(j) 826 | let band-x = cumulative-x 827 | 828 | draw-band-separator-edges(band-x, y-offset, h, ox, oy, band-width, j == 0, fill-color) 829 | 830 | cumulative-x += band-width 831 | } 832 | 833 | // Draw outer edges (excluding right face edges which are already drawn) 834 | line((start-x, y-offset + h), (start-x + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 835 | line((start-x + ox, y-offset + h + oy), (start-x + total-width + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 836 | line((start-x + total-width, y-offset + h), (start-x + total-width + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 837 | 838 | prev-x = start-x + total-width 839 | prev-depth-offset = ox 840 | x = start-x + total-width 841 | let center-x = start-x + total-width / 2 842 | 843 | // Display label below channel numbers 844 | if label != none { 845 | content((center-x, y-offset - 0.5), 846 | [#text(size: scaled-font(font-sizes.layer-label), weight: "bold", label)]) 847 | } 848 | 849 | // Display xlabel if provided 850 | if xlabel != none { 851 | content((center-x, y-offset - 0.8), 852 | [#text(size: scaled-font(font-sizes.layer-label), xlabel)]) 853 | } 854 | 855 | // Display ylabel and zlabel if provided 856 | if ylabel-val != none { 857 | content((start-x - 0.4, y-offset + h/2), anchor: "east", 858 | [#text(size: scaled-font(font-sizes.layer-label), str(ylabel-val))]) 859 | } 860 | if zlabel-val != none { 861 | content((start-x + total-width + ox + 0.4, y-offset + h/2 + oy/2), anchor: "west", 862 | [#text(size: scaled-font(font-sizes.layer-label), str(zlabel-val))]) 863 | } 864 | 865 | // Display diagonal label if provided 866 | if diagonal-label != none { 867 | let diag-start-x = start-x + total-width 868 | let diag-start-y = y-offset 869 | let diag-mid-x = diag-start-x + ox / 2.5 870 | let diag-mid-y = diag-start-y + oy / 2.5 871 | content((diag-mid-x, diag-mid-y - 0.23), angle: depth-angle-deg, 872 | [#text(size: scaled-font(font-sizes.channel-number), str(diagonal-label))]) 873 | } 874 | 875 | // Track position if named 876 | if name != none { 877 | layer-positions.insert(name, ( 878 | x: start-x, y: y-offset, w: total-width, h: h, ox: ox, oy: oy, type: "custom", 879 | anchors: get-layer-anchors(start-x, y-offset, total-width, h, ox, oy), 880 | pool-offset: 0 881 | )) 882 | } 883 | 884 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 885 | prev-pool-width = 0 886 | } 887 | 888 | // Register legend entry for custom layers (only if legend parameter is provided) 889 | let custom-legend = l.at("legend", default: none) 890 | if custom-legend != none { 891 | // Use a unique key for each custom legend entry (legend text + color) 892 | let legend-key = "custom-" + str(custom-legend) + "-" + str(fill-color.to-hex()) 893 | if not legend-entries.any(e => e.key == legend-key) { 894 | legend-entries.push((key: legend-key, label: custom-legend, color: fill-color, bandfill: bandfill-color, show-relu: layer-show-relu, opacity: layer-opacity)) 895 | } 896 | } 897 | } 898 | 899 | // INPUT IMAGE - uses custom type with input-specific defaults 900 | else if l.type == "input" { 901 | // Re-route to custom handler with input defaults 902 | l.insert("type", "custom") 903 | if not l.keys().contains("width") { l.insert("width", 0) } 904 | if not l.keys().contains("fill") { l.insert("fill", colors.input) } 905 | if not l.keys().contains("opacity") { l.insert("opacity", 0.9) } 906 | if not l.keys().contains("input-style") { l.insert("input-style", true) } 907 | if not l.keys().contains("show-connection") { l.insert("show-connection", false) } 908 | 909 | // Fall through to process as custom (handled by previous if block) 910 | // But since we're in else-if, we need to inline the custom logic 911 | let h = l.at("height", default: 5) 912 | let d = l.at("depth", default: 5) 913 | l.insert("height", h) 914 | l.insert("depth", d) 915 | let w = l.at("width", default: 0) 916 | let label = l.at("label", default: none) 917 | let name = l.at("name", default: none) 918 | let fill-color = l.at("fill", default: colors.input) 919 | let layer-opacity = l.at("opacity", default: 0.9) 920 | let layer-show-connection = l.at("show-connection", default: false) 921 | let connection-label = l.at("connection-label", default: none) 922 | let channels = l.at("channels", default: none) 923 | let img = l.at("image", default: none) 924 | 925 | let (ox, oy) = get-depth-offsets(d) 926 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 927 | 928 | if img == "default" { 929 | img = image("bird.jpg") 930 | } 931 | 932 | // Special rendering for INPUT: draw image first, then highly transparent face on top 933 | if img != none { 934 | // Draw isometric image first 935 | draw-isometric-image(x, y-offset, w, h, ox, oy, img) 936 | 937 | // Then draw highly transparent right face on top 938 | let alpha-right = layer-opacity * 100% 939 | line((x + w, y-offset), (x + w + ox, y-offset + oy), 940 | (x + w + ox, y-offset + h + oy), (x + w, y-offset + h), 941 | close: true, 942 | fill: fill-color.darken(darken-amounts.right).transparentize(alpha-right), 943 | stroke: dynamic-color-strokes(fill-color).solid) 944 | } else { 945 | // No image: use standard box-3d rendering 946 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 947 | } 948 | 949 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 950 | 951 | if name != none { 952 | layer-positions.insert(name, ( 953 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "input", 954 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy), 955 | pool-offset: 0 956 | )) 957 | } 958 | 959 | if label != none { 960 | content((x + w/2, y-offset - 0.8), 961 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 962 | } 963 | 964 | prev-x = x + w 965 | prev-depth-offset = ox 966 | x += w 967 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 968 | prev-pool-width = 0 969 | 970 | // Register legend entry (check for legend parameter override) 971 | let layer-legend = l.at("legend", default: default-legend-labels.at("input")) 972 | if not legend-entries.any(e => e.key == "input") { 973 | legend-entries.push((key: "input", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 974 | } 975 | } 976 | 977 | // CONVOLUTIONAL BLOCK types - delegates to custom with conv-specific defaults 978 | else if l.type == "conv" or l.type == "convres"{ 979 | let fill-color = if l.type == "conv" { 980 | l.at("fill", default: colors.conv) 981 | } else if l.type == "convres" { 982 | l.at("fill", default: colors.convres) 983 | } 984 | let bandfill-color = if l.type == "conv" { 985 | l.at("bandfill", default: colors.at("conv-relu")) 986 | } else if l.type == "convres" { 987 | l.at("bandfill", default: colors.at("convres-relu")) 988 | } 989 | 990 | // Set up parameters for custom handler with conv defaults 991 | if not l.keys().contains("fill") { l.insert("fill", fill-color) } 992 | if not l.keys().contains("bandfill") { l.insert("bandfill", bandfill-color) } 993 | if not l.keys().contains("widths") { l.insert("widths", (1,)) } 994 | let channels = l.at("channels", default: none) 995 | let widths = l.at("widths", default: (1,)) 996 | let h = l.at("height", default: 5) 997 | let d = l.at("depth", default: 5) 998 | l.insert("height", h) 999 | l.insert("depth", d) 1000 | let label = l.at("label", default: none) 1001 | let xlabel = l.at("xlabel", default: none) 1002 | let name = l.at("name", default: none) 1003 | let layer-opacity = l.at("opacity", default: 1.0) 1004 | let ylabel-val = l.at("ylabel", default: none) 1005 | let zlabel-val = l.at("zlabel", default: none) 1006 | let layer-show-relu = l.at("show-relu", default: show-relu) 1007 | let layer-show-connection = l.at("show-connection", default: true) 1008 | let connection-label = l.at("connection-label", default: none) 1009 | let img = l.at("image", default: none) 1010 | 1011 | if img == "default" { 1012 | img = image("bird.jpg") 1013 | } 1014 | 1015 | // Use dynamic color strokes for fill-color and bandfill-color 1016 | let dyn-strokes = dynamic-color-strokes(fill-color) 1017 | let dyn-band-strokes = dynamic-color-strokes(bandfill-color) 1018 | 1019 | // Determine if we have a diagonal label (channels has one extra element) 1020 | let has-diagonal-label = channels != none and channels.len() == widths.len() + 1 1021 | let diagonal-label = if has-diagonal-label { channels.at(widths.len()) } else { none } 1022 | let channel-labels = if channels != none { 1023 | if has-diagonal-label { channels.slice(0, widths.len()) } else { channels } 1024 | } else { 1025 | (widths.map(w => "")) 1026 | } 1027 | 1028 | // Use actual widths values to determine band sizes 1029 | let (ox, oy) = get-depth-offsets(d) 1030 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1031 | let start-x = x 1032 | let total-width = widths.fold(0, (acc, w) => acc + w) 1033 | 1034 | // Draw front face as colored bands 1035 | let cumulative-x = start-x 1036 | let alpha-front = 100% - layer-opacity * 100% 1037 | for (j, ch) in channel-labels.enumerate() { 1038 | let band-width = widths.at(j) 1039 | let band-x = cumulative-x 1040 | 1041 | draw-band-front-face(band-x, y-offset, band-width, h, fill-color, bandfill-color, alpha-front, layer-show-relu) 1042 | 1043 | // Display channel label under each band 1044 | let band-center-x = band-x + band-width / 2 1045 | content((band-center-x, y-offset - 0.15), 1046 | [#text(size: scaled-font(font-sizes.channel-number), str(ch))]) 1047 | 1048 | cumulative-x += band-width 1049 | } 1050 | 1051 | // Draw front face outer edges (only the perimeter) 1052 | line((start-x, y-offset), (start-x, y-offset + h), stroke: dyn-strokes.solid) 1053 | line((start-x + total-width, y-offset), (start-x + total-width, y-offset + h), stroke: dyn-strokes.solid) 1054 | line((start-x, y-offset + h), (start-x + total-width, y-offset + h), stroke: dyn-strokes.solid) 1055 | line((start-x, y-offset), (start-x + total-width, y-offset), stroke: dyn-strokes.solid) 1056 | 1057 | // Draw top face segmented by band 1058 | cumulative-x = start-x 1059 | for (j, ch) in channel-labels.enumerate() { 1060 | let band-width = widths.at(j) 1061 | let band-x = cumulative-x 1062 | 1063 | draw-band-top-face(band-x, y-offset, band-width, h, ox, oy, fill-color, bandfill-color, layer-show-relu) 1064 | 1065 | cumulative-x += band-width 1066 | } 1067 | 1068 | // Draw right face 1069 | let right-face-color = if layer-show-relu { bandfill-color } else { fill-color } 1070 | line((start-x + total-width, y-offset), (start-x + total-width + ox, y-offset + oy), 1071 | (start-x + total-width + ox, y-offset + h + oy), (start-x + total-width, y-offset + h), 1072 | close: true, 1073 | fill: right-face-color.darken(darken-amounts.right).transparentize(opacity-values.right-face), 1074 | stroke: dyn-strokes.solid) 1075 | 1076 | // Draw image on top of right face if provided 1077 | if img != none { 1078 | draw-isometric-image(start-x, y-offset, total-width, h, ox, oy, img) 1079 | } 1080 | 1081 | // Draw all edges for band divisions (once each) 1082 | cumulative-x = start-x 1083 | for (j, ch) in channel-labels.enumerate() { 1084 | let band-width = widths.at(j) 1085 | let band-x = cumulative-x 1086 | 1087 | draw-band-separator-edges(band-x, y-offset, h, ox, oy, band-width, j == 0, fill-color) 1088 | 1089 | cumulative-x += band-width 1090 | } 1091 | 1092 | // Draw outer edges of the block (excluding right face edges which are already drawn) 1093 | line((start-x, y-offset + h), (start-x + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 1094 | line((start-x + ox, y-offset + h + oy), (start-x + total-width + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 1095 | line((start-x + total-width, y-offset + h), (start-x + total-width + ox, y-offset + h + oy), stroke: dyn-strokes.solid) 1096 | 1097 | prev-x = start-x + total-width 1098 | prev-depth-offset = ox 1099 | x = start-x + total-width 1100 | let center-x = start-x + total-width / 2 1101 | 1102 | // Display label below channel numbers 1103 | if label != none { 1104 | content((center-x, y-offset - 0.5), 1105 | [#text(size: scaled-font(font-sizes.layer-label), weight: "bold", label)]) 1106 | } 1107 | 1108 | // Display xlabel if provided 1109 | if xlabel != none { 1110 | content((center-x, y-offset - 0.8), 1111 | [#text(size: scaled-font(font-sizes.layer-label), xlabel)]) 1112 | } 1113 | 1114 | // Display ylabel and zlabel if provided 1115 | if ylabel-val != none { 1116 | content((start-x - 0.4, y-offset + h/2), anchor: "east", 1117 | [#text(size: scaled-font(font-sizes.layer-label), str(ylabel-val))]) 1118 | } 1119 | if zlabel-val != none { 1120 | content((start-x + total-width + ox + 0.4, y-offset + h/2 + oy/2), anchor: "west", 1121 | [#text(size: scaled-font(font-sizes.layer-label), str(zlabel-val))]) 1122 | } 1123 | 1124 | // Display diagonal label if provided (along bottom-right depth edge) 1125 | if diagonal-label != none { 1126 | let diag-start-x = start-x + total-width 1127 | let diag-start-y = y-offset 1128 | let diag-mid-x = diag-start-x + ox / 2.5 1129 | let diag-mid-y = diag-start-y + oy / 2.5 1130 | content((diag-mid-x, diag-mid-y - 0.23), angle: depth-angle-deg, 1131 | [#text(size: scaled-font(font-sizes.channel-number), str(diagonal-label))]) 1132 | } 1133 | 1134 | // Track position if named 1135 | if name != none { 1136 | layer-positions.insert(name, ( 1137 | x: start-x, y: y-offset, w: total-width, h: h, ox: ox, oy: oy, type: "conv", 1138 | anchors: get-layer-anchors(start-x, y-offset, total-width, h, ox, oy), 1139 | pool-offset: 0 // Will be updated if next layer is a pool 1140 | )) 1141 | } 1142 | 1143 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1144 | prev-pool-width = 0 1145 | 1146 | // Register legend entry 1147 | let layer-legend = l.at("legend", default: default-legend-labels.at(l.type)) 1148 | if not legend-entries.any(e => e.key == l.type) { 1149 | legend-entries.push((key: l.type, label: layer-legend, color: fill-color, bandfill: bandfill-color, show-relu: layer-show-relu, opacity: layer-opacity)) 1150 | } 1151 | } 1152 | 1153 | // POOLING LAYER - delegates to custom with pool-specific positioning 1154 | else if l.type == "pool" { 1155 | let h = l.at("height", default: 4) 1156 | let d = l.at("depth", default: 4) 1157 | l.insert("height", h) 1158 | l.insert("depth", d) 1159 | let w = 0.1 1160 | let name = l.at("name", default: none) 1161 | let fill-color = l.at("fill", default: colors.pool) 1162 | let layer-opacity = l.at("opacity", default: 0.75) 1163 | let layer-show-connection = l.at("show-connection", default: true) 1164 | let connection-label = l.at("connection-label", default: none) 1165 | let label = l.at("label", default: none) 1166 | let channels = l.at("channels", default: none) 1167 | let img = l.at("image", default: none) 1168 | let layer-offset = l.at("offset", default: none) 1169 | let (ox, oy) = get-depth-offsets(d) 1170 | let y-offset = prev-center-y - h / 2 - oy / 2 1171 | let pool-x = if layer-offset != none { x + layer-offset } else { x + prev-depth-offset / 2 - ox / 2 } 1172 | 1173 | if img == "default" { 1174 | img = image("bird.jpg") 1175 | } 1176 | 1177 | box-3d(pool-x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1178 | 1179 | draw-channels-labels(channels, pool-x + w/2, pool-x + w, y-offset, ox, oy) 1180 | 1181 | if label != none { 1182 | content((pool-x + w/2, y-offset - 0.5), 1183 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1184 | } 1185 | 1186 | if i > 0 { 1187 | let prev-layer = layers.at(i - 1) 1188 | let prev-name = prev-layer.at("name", default: none) 1189 | if prev-name != none and prev-name in layer-positions { 1190 | let prev-pos = layer-positions.at(prev-name) 1191 | layer-positions.insert(prev-name, ( 1192 | ..prev-pos, 1193 | pool-offset: w 1194 | )) 1195 | } 1196 | } 1197 | 1198 | if name != none { 1199 | layer-positions.insert(name, ( 1200 | x: pool-x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "pool", 1201 | anchors: get-layer-anchors(pool-x, y-offset, w, h, ox, oy), 1202 | pool-offset: 0 1203 | )) 1204 | } 1205 | 1206 | prev-x = pool-x + w 1207 | prev-depth-offset = ox 1208 | if layer-offset != none { 1209 | x += layer-offset + w 1210 | } else { 1211 | x = pool-x + w 1212 | } 1213 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1214 | prev-pool-width = 0 1215 | 1216 | // Register legend entry 1217 | let layer-legend = l.at("legend", default: default-legend-labels.at("pool")) 1218 | if not legend-entries.any(e => e.key == "pool") { 1219 | legend-entries.push((key: "pool", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1220 | } 1221 | } 1222 | 1223 | // UNPOOLING LAYER - delegates to custom with unpool-specific positioning 1224 | else if l.type == "unpool" { 1225 | let h = l.at("height", default: 4) 1226 | let d = l.at("depth", default: 4) 1227 | l.insert("height", h) 1228 | l.insert("depth", d) 1229 | let w = 0.1 1230 | let name = l.at("name", default: none) 1231 | let fill-color = l.at("fill", default: colors.unpool) 1232 | let layer-show-connection = l.at("show-connection", default: true) 1233 | let connection-label = l.at("connection-label", default: none) 1234 | let layer-opacity = l.at("opacity", default: 0.75) 1235 | let label = l.at("label", default: none) 1236 | let channels = l.at("channels", default: none) 1237 | let img = l.at("image", default: none) 1238 | let layer-offset = l.at("offset", default: none) 1239 | let (ox, oy) = get-depth-offsets(d) 1240 | let y-offset = prev-center-y - h / 2 - oy / 2 1241 | let unpool-x = if layer-offset != none { x + layer-offset } else { x + prev-depth-offset / 2 - ox / 2 } 1242 | 1243 | if img == "default" { 1244 | img = image("bird.jpg") 1245 | } 1246 | 1247 | box-3d(unpool-x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1248 | 1249 | // Display channels labels 1250 | draw-channels-labels(channels, unpool-x + w/2, unpool-x + w, y-offset, ox, oy) 1251 | 1252 | if label != none { 1253 | content((unpool-x + w/2, y-offset - 0.5), 1254 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1255 | } 1256 | 1257 | // Track position if named 1258 | if name != none { 1259 | layer-positions.insert(name, ( 1260 | x: unpool-x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "unpool", 1261 | anchors: get-layer-anchors(unpool-x, y-offset, w, h, ox, oy) 1262 | )) 1263 | } 1264 | 1265 | prev-x = unpool-x + w 1266 | prev-depth-offset = ox 1267 | if layer-offset != none { 1268 | x += layer-offset + w 1269 | } else { 1270 | x = unpool-x + w 1271 | } 1272 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1273 | prev-pool-width = 0 1274 | 1275 | // Register legend entry 1276 | let layer-legend = l.at("legend", default: default-legend-labels.at("unpool")) 1277 | if not legend-entries.any(e => e.key == "unpool") { 1278 | legend-entries.push((key: "unpool", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1279 | } 1280 | } 1281 | 1282 | // DECONVOLUTIONAL LAYER - delegates to custom with deconv-specific defaults 1283 | else if l.type == "deconv" { 1284 | let h = l.at("height", default: 5) 1285 | let d = l.at("depth", default: 5) 1286 | l.insert("height", h) 1287 | l.insert("depth", d) 1288 | let w = l.at("width", default: 0.3) 1289 | let label = l.at("label", default: "") 1290 | let name = l.at("name", default: none) 1291 | let fill-color = l.at("fill", default: colors.deconv) 1292 | let layer-opacity = l.at("opacity", default: 0.7) 1293 | let layer-show-connection = l.at("show-connection", default: true) 1294 | let connection-label = l.at("connection-label", default: none) 1295 | let channels = l.at("channels", default: none) 1296 | let img = l.at("image", default: none) 1297 | let (ox, oy) = get-depth-offsets(d) 1298 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1299 | 1300 | if img == "default" { 1301 | img = image("bird.jpg") 1302 | } 1303 | 1304 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1305 | 1306 | // Display channels labels 1307 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1308 | 1309 | if label != none { 1310 | content((x + w/2, y-offset - 0.5), 1311 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1312 | } 1313 | 1314 | // Track position if named 1315 | if name != none { 1316 | layer-positions.insert(name, ( 1317 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "deconv", 1318 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1319 | )) 1320 | } 1321 | 1322 | prev-x = x + w 1323 | prev-depth-offset = ox 1324 | x += w 1325 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1326 | prev-pool-width = 0 1327 | 1328 | // Register legend entry 1329 | let layer-legend = l.at("legend", default: default-legend-labels.at("deconv")) 1330 | if not legend-entries.any(e => e.key == "deconv") { 1331 | legend-entries.push((key: "deconv", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1332 | } 1333 | } 1334 | 1335 | // CONCATENATION LAYER - delegates to custom with concat-specific defaults 1336 | else if l.type == "concat" { 1337 | let h = l.at("height", default: 3) 1338 | let d = l.at("depth", default: 3) 1339 | l.insert("height", h) 1340 | l.insert("depth", d) 1341 | let w = l.at("width", default: 0.15) 1342 | let label = l.at("label", default: "") 1343 | let name = l.at("name", default: none) 1344 | let fill-color = l.at("fill", default: colors.concat) 1345 | let layer-opacity = l.at("opacity", default: 0.7) 1346 | let layer-show-connection = l.at("show-connection", default: true) 1347 | let connection-label = l.at("connection-label", default: none) 1348 | let channels = l.at("channels", default: none) 1349 | let img = l.at("image", default: none) 1350 | let (ox, oy) = get-depth-offsets(d) 1351 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1352 | 1353 | if img == "default" { 1354 | img = image("bird.jpg") 1355 | } 1356 | 1357 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1358 | 1359 | // Display channels labels 1360 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1361 | 1362 | if label != none { 1363 | content((x + w/2, y-offset - 0.5), 1364 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1365 | } 1366 | 1367 | // Track position if named 1368 | if name != none { 1369 | layer-positions.insert(name, ( 1370 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "concat", 1371 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1372 | )) 1373 | } 1374 | 1375 | prev-x = x + w 1376 | prev-depth-offset = ox 1377 | x += w 1378 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1379 | prev-pool-width = 0 1380 | 1381 | // Register legend entry 1382 | let layer-legend = l.at("legend", default: default-legend-labels.at("concat")) 1383 | if not legend-entries.any(e => e.key == "concat") { 1384 | legend-entries.push((key: "concat", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1385 | } 1386 | } 1387 | 1388 | // GLOBAL AVERAGE POOLING - delegates to custom with gap-specific defaults 1389 | else if l.type == "gap" { 1390 | let h = l.at("height", default: 1.5) 1391 | let d = l.at("depth", default: 1.5) 1392 | l.insert("height", h) 1393 | l.insert("depth", d) 1394 | let w = 0.3 1395 | let label = l.at("label", default: "") 1396 | let name = l.at("name", default: none) 1397 | let fill-color = l.at("fill", default: colors.gap) 1398 | let layer-opacity = l.at("opacity", default: 0.7) 1399 | let layer-show-connection = l.at("show-connection", default: true) 1400 | let connection-label = l.at("connection-label", default: none) 1401 | let channels = l.at("channels", default: none) 1402 | let img = l.at("image", default: none) 1403 | let (ox, oy) = get-depth-offsets(d) 1404 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1405 | 1406 | if img == "default" { 1407 | img = image("bird.jpg") 1408 | } 1409 | 1410 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1411 | 1412 | // Display channels labels 1413 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1414 | 1415 | if label != none { 1416 | content((x + w/2, y-offset - 0.5), 1417 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1418 | } 1419 | 1420 | // Track position if named 1421 | if name != none { 1422 | layer-positions.insert(name, ( 1423 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "gap", 1424 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1425 | )) 1426 | } 1427 | 1428 | prev-x = x + w 1429 | prev-depth-offset = ox 1430 | x += w 1431 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1432 | prev-pool-width = 0 1433 | 1434 | // Register legend entry 1435 | let layer-legend = l.at("legend", default: default-legend-labels.at("gap")) 1436 | if not legend-entries.any(e => e.key == "gap") { 1437 | legend-entries.push((key: "gap", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1438 | } 1439 | } 1440 | 1441 | // FULLY CONNECTED - delegates to custom with fc-specific defaults 1442 | else if l.type == "fc" { 1443 | let h = l.at("height", default: 3) 1444 | let d = l.at("depth", default: 0.4) 1445 | l.insert("height", h) 1446 | l.insert("depth", d) 1447 | let w = 0.2 1448 | let label = l.at("label", default: "") 1449 | let name = l.at("name", default: none) 1450 | let fill-color = l.at("fill", default: colors.fc) 1451 | let layer-opacity = l.at("opacity", default: 0.7) 1452 | let layer-show-connection = l.at("show-connection", default: true) 1453 | let connection-label = l.at("connection-label", default: none) 1454 | let channels = l.at("channels", default: none) 1455 | let img = l.at("image", default: none) 1456 | let (ox, oy) = get-depth-offsets(d) 1457 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1458 | 1459 | if img == "default" { 1460 | img = image("bird.jpg") 1461 | } 1462 | 1463 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1464 | 1465 | // Display channels labels 1466 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1467 | 1468 | if label != none { 1469 | content((x + w/2, y-offset - 0.5), 1470 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1471 | } 1472 | 1473 | // Track position if named 1474 | if name != none { 1475 | layer-positions.insert(name, ( 1476 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: "fc", 1477 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1478 | )) 1479 | } 1480 | 1481 | prev-x = x + w 1482 | prev-depth-offset = ox 1483 | x += w 1484 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1485 | prev-pool-width = 0 1486 | 1487 | // Register legend entry 1488 | let layer-legend = l.at("legend", default: default-legend-labels.at("fc")) 1489 | if not legend-entries.any(e => e.key == "fc") { 1490 | legend-entries.push((key: "fc", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1491 | } 1492 | } 1493 | 1494 | // SUM NODE - uses unique circle rendering (not box-based like custom) 1495 | else if l.type == "sum" { 1496 | let radius = l.at("radius", default: 0.4) 1497 | let symbol = l.at("symbol", default: "+") 1498 | let label = l.at("label", default: none) 1499 | let name = l.at("name", default: none) 1500 | let fill-color = l.at("fill", default: colors.sum) 1501 | let layer-show-connection = l.at("show-connection", default: true) 1502 | let connection-label = l.at("connection-label", default: none) 1503 | let layer-opacity = l.at("opacity", default: 1.0) 1504 | let channels = l.at("channels", default: none) 1505 | 1506 | // Center x accounts for depth offset of previous arrow 1507 | let center-x = x + radius + prev-depth-offset / 2 1508 | let center-y = arrow-axis-y 1509 | 1510 | let dyn-stroke = dynamic-color-strokes(fill-color) 1511 | dyn-stroke.solid.paint = dyn-stroke.solid.paint.darken(10%) // slightly darker stroke than for other layers 1512 | dyn-stroke.solid.thickness = dyn-stroke.solid.thickness * 1.5 // slightly thicker stroke than for other layers 1513 | fill-color = fill-color.transparentize((1-layer-opacity)*100%) 1514 | 1515 | circle((center-x, center-y), radius: radius, 1516 | fill: gradient.radial( 1517 | fill-color.lighten(50%), fill-color, fill-color.darken(30%), 1518 | center: (50%, 50%), radius: 50%, 1519 | focal-center: (35%, 35%), focal-radius: 5% 1520 | ), 1521 | stroke: dyn-stroke.solid) 1522 | 1523 | if symbol != none { 1524 | let symbole-size = scaled-font(font-sizes.label * 2.5) 1525 | content((center-x, center-y), 1526 | [#v(-0.185 * symbole-size)#text(size: symbole-size, weight: "bold", fill: dyn-stroke.solid.paint, symbol)]) 1527 | } 1528 | 1529 | // Display channels labels (below and optionally on diagonal) 1530 | if channels != none { 1531 | let (ox, oy) = get-depth-offsets(radius * 2) 1532 | draw-channels-labels(channels, center-x, center-x + radius, center-y - radius, ox, oy) 1533 | } 1534 | 1535 | // Display label below the sum node 1536 | if label != none { 1537 | content((center-x, center-y - 1.5 * radius), 1538 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1539 | } 1540 | 1541 | prev-x = center-x + radius 1542 | prev-depth-offset = 0 1543 | x += radius * 3 1544 | 1545 | if name != none { 1546 | let (ox, oy) = get-depth-offsets(radius * 2) 1547 | layer-positions.insert(name, ( 1548 | x: x - radius * 2, y: center-y - radius, w: radius * 2, h: radius * 2, ox: ox, oy: oy, 1549 | type: "sum", radius: radius, center-x: center-x, 1550 | anchors: get-layer-anchors(x - radius * 2, center-y - radius, radius * 2, radius * 2, ox, oy), 1551 | pool-offset: 0 1552 | )) 1553 | } 1554 | 1555 | prev-center-y = center-y 1556 | prev-pool-width = 0 1557 | 1558 | // Register legend entry 1559 | let layer-legend = l.at("legend", default: default-legend-labels.at("sum")) 1560 | if not legend-entries.any(e => e.key == "sum") { 1561 | legend-entries.push((key: "sum", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1562 | } 1563 | } 1564 | 1565 | // CONVOLUTIONAL SOFTMAX (Combined layer) 1566 | else if l.type == "convsoftmax" { 1567 | let h = l.at("height", default: 4) 1568 | let d = l.at("depth", default: 4) 1569 | l.insert("height", h) 1570 | l.insert("depth", d) 1571 | let w = l.at("width", default: 0.1) 1572 | let layer-show-connection = l.at("show-connection", default: true) 1573 | let connection-label = l.at("connection-label", default: none) 1574 | let label = l.at("label", default: "") 1575 | let name = l.at("name", default: none) 1576 | let fill-color = l.at("fill", default: colors.convsoftmax) 1577 | let layer-opacity = l.at("opacity", default: 0.5) 1578 | let channels = l.at("channels", default: none) 1579 | let img = l.at("image", default: none) 1580 | let (ox, oy) = get-depth-offsets(d) 1581 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1582 | 1583 | if img == "default" { 1584 | img = image("bird.jpg") 1585 | } 1586 | 1587 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1588 | 1589 | // Display channels labels 1590 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1591 | 1592 | if label != none { 1593 | content((x + w/2, y-offset - 0.5), 1594 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1595 | } 1596 | 1597 | // Track position if named 1598 | if name != none { 1599 | layer-positions.insert(name, ( 1600 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, 1601 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1602 | )) 1603 | } 1604 | 1605 | prev-x = x + w 1606 | prev-depth-offset = ox 1607 | x += w 1608 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1609 | prev-pool-width = 0 1610 | 1611 | // Register legend entry 1612 | let layer-legend = l.at("legend", default: default-legend-labels.at("convsoftmax")) 1613 | if not legend-entries.any(e => e.key == "convsoftmax") { 1614 | legend-entries.push((key: "convsoftmax", label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1615 | } 1616 | } 1617 | 1618 | // SOFTMAX / OUTPUT - delegates to custom with softmax/output-specific defaults 1619 | else if l.type == "softmax" or l.type == "output" { 1620 | let h = l.at("height", default: 3) 1621 | let d = l.at("depth", default: 0.4) 1622 | l.insert("height", h) 1623 | l.insert("depth", d) 1624 | let w = 0.2 1625 | let label = l.at("label", default: if l.type == "softmax" { "Softmax" } else { "Output" }) 1626 | let name = l.at("name", default: none) 1627 | let layer-show-connection = l.at("show-connection", default: true) 1628 | let connection-label = l.at("connection-label", default: none) 1629 | let classes = l.at("classes", default: none) 1630 | let channels = l.at("channels", default: none) 1631 | let fill-color = l.at("fill", default: if l.type == "softmax" { colors.softmax } else { colors.output }) 1632 | let layer-opacity = l.at("opacity", default: 0.5) 1633 | let img = l.at("image", default: none) 1634 | let (ox, oy) = get-depth-offsets(d) 1635 | let y-offset = get-y-offset-for-center-on-axis(h, d, arrow-axis-y) 1636 | 1637 | if img == "default" { 1638 | img = image("bird.jpg") 1639 | } 1640 | 1641 | box-3d(x, y-offset, w, h, d, fill-color, opacity: layer-opacity, show-left: true, show-right: true, image: img) 1642 | 1643 | // Display channels labels (preferred over classes) 1644 | if channels != none { 1645 | draw-channels-labels(channels, x + w/2, x + w, y-offset, ox, oy) 1646 | } else if classes != none { 1647 | content((x + w/2, y-offset - 0.3), 1648 | [#text(size: scaled-font(font-sizes.output-number), str(classes))]) 1649 | } 1650 | if label != none { 1651 | content((x + w/2, y-offset - 0.6), 1652 | [#text(size: scaled-font(font-sizes.label), weight: "bold", label)]) 1653 | } 1654 | 1655 | // Track position if named 1656 | if name != none { 1657 | layer-positions.insert(name, ( 1658 | x: x, y: y-offset, w: w, h: h, ox: ox, oy: oy, type: l.type, 1659 | anchors: get-layer-anchors(x, y-offset, w, h, ox, oy) 1660 | )) 1661 | } 1662 | 1663 | prev-x = x + w 1664 | prev-depth-offset = ox 1665 | x += w 1666 | prev-center-y = get-perspective-center-y(y-offset, h, oy) 1667 | prev-pool-width = 0 1668 | 1669 | // Register legend entry 1670 | let layer-legend = l.at("legend", default: default-legend-labels.at(l.type)) 1671 | if not legend-entries.any(e => e.key == l.type) { 1672 | legend-entries.push((key: l.type, label: layer-legend, color: fill-color, bandfill: fill-color, show-relu: false, opacity: layer-opacity)) 1673 | } 1674 | } 1675 | } 1676 | 1677 | // After all layers are drawn, calculate arrow segment midpoints for ALL named layer pairs 1678 | // This ensures skip connections between non-consecutive layers can find their anchor points 1679 | for (i, l) in layers.enumerate() { 1680 | let curr-name = l.at("name", default: none) 1681 | if curr-name != none and curr-name in layer-positions { 1682 | // Find the previous named layer (skip over unnamed layers like pool/unpool) 1683 | let prev-name = none 1684 | for j in range(i - 1, -1, step: -1) { 1685 | let candidate-name = layers.at(j).at("name", default: none) 1686 | if candidate-name != none and candidate-name in layer-positions { 1687 | prev-name = candidate-name 1688 | break 1689 | } 1690 | } 1691 | 1692 | // If we found a previous named layer, calculate the arrow segment 1693 | if prev-name != none { 1694 | let prev-pos = layer-positions.at(prev-name) 1695 | let curr-pos = layer-positions.at(curr-name) 1696 | 1697 | // Use true_east and add pool-offset if there's a pool after the previous layer 1698 | let pool-offset = prev-pos.at("pool-offset", default: 0) 1699 | let arrow-start = (prev-pos.anchors.true_east.at(0) + pool-offset, prev-pos.anchors.true_east.at(1)) 1700 | let arrow-end = curr-pos.anchors.true_west 1701 | 1702 | // Calculate midpoint of the arrow segment 1703 | let mid-x = (arrow-start.at(0) + arrow-end.at(0)) / 2 1704 | let mid-y = arrow-start.at(1) 1705 | 1706 | // Store for skip connections - these will override any stored during drawing 1707 | arrow-segments.insert(prev-name + "-out", ( 1708 | start: arrow-start, 1709 | mid: (mid-x, mid-y), 1710 | x: mid-x, 1711 | y: mid-y 1712 | )) 1713 | arrow-segments.insert(curr-name + "-in", ( 1714 | end: arrow-end, 1715 | mid: (mid-x, mid-y), 1716 | x: mid-x, 1717 | y: mid-y 1718 | )) 1719 | } 1720 | } 1721 | } 1722 | 1723 | for conn in connections { 1724 | let from-name = conn.at("from") 1725 | let to-name = conn.at("to") 1726 | let conn-type = conn.at("type", default: "skip") 1727 | let conn-mode = conn.at("mode", default: "flat") 1728 | let conn-pos = conn.at("pos", default: 1.25) 1729 | let conn-label = conn.at("label", default: none) 1730 | let conn-opacity = conn.at("opacity", default: 0.7) 1731 | let touch-layer = conn.at("touch-layer", default: false) 1732 | 1733 | if from-name in layer-positions and to-name in layer-positions { 1734 | let from-pos = layer-positions.at(from-name) 1735 | let to-pos = layer-positions.at(to-name) 1736 | 1737 | // Use arrow segment midpoints if available, otherwise fall back to layer edges 1738 | let from-anchor-key = from-name + "-out" 1739 | let to-anchor-key = to-name + "-in" 1740 | 1741 | // Check if the from layer has a pool attached but we're not departing from the pool itself 1742 | let from-has-pool = from-pos.at("pool-offset", default: 0) > 0 1743 | let from-type = from-pos.at("type", default: none) 1744 | let departing-from-layer-with-pool = from-has-pool and from-type != "pool" 1745 | 1746 | // Use true midpoint of arrow segment after from layer (uses stored start point) 1747 | let from-anchor = if departing-from-layer-with-pool { 1748 | // Special case: departing from a layer with attached pool (but not the pool itself) 1749 | // Use specific edges of the east side based on connection mode 1750 | let base-x = from-pos.x + from-pos.w 1751 | let base-y = from-pos.y 1752 | let h = from-pos.h 1753 | let ox = from-pos.ox 1754 | let oy = from-pos.oy 1755 | 1756 | if conn-mode == "air" { 1757 | // Middle of top diagonal edge of east side 1758 | (base-x + ox/2, base-y + h + oy/2) 1759 | } else if conn-mode == "depth" { 1760 | // Middle of left edge of east side 1761 | (base-x, base-y + h/2 + oy/2) 1762 | } else { 1763 | // "flat" - Middle of bottom edge of east side 1764 | (base-x + ox/2, base-y + oy/2) 1765 | } 1766 | } else if from-anchor-key in arrow-segments { 1767 | let seg = arrow-segments.at(from-anchor-key) 1768 | // Use the arrow's actual start point for x (depth-adjusted) 1769 | (seg.mid.at(0), seg.mid.at(1)) 1770 | } else { 1771 | from-pos.anchors.true_east 1772 | } 1773 | 1774 | // Determine target anchor point 1775 | let to-type = to-pos.at("type", default: none) 1776 | let to-anchor = if touch-layer { 1777 | // Special case: arrive at specific edge of west side of destination layer 1778 | let base-x = to-pos.x 1779 | let base-y = to-pos.y 1780 | let h = to-pos.h 1781 | let ox = to-pos.ox 1782 | let oy = to-pos.oy 1783 | 1784 | if conn-mode == "air" { 1785 | // Middle of top diagonal edge of west side 1786 | (base-x + ox/2, base-y + h + oy/2) 1787 | } else if conn-mode == "depth" { 1788 | // Middle of left edge of west side 1789 | (base-x, base-y + h/2 + oy/2) 1790 | } else { 1791 | // "flat" - Middle of bottom edge of west side 1792 | (base-x + ox/2, base-y + oy/2) 1793 | } 1794 | } else if to-type == "sum" { 1795 | // For sum layers, use the stored center-x (which already accounts for depth offset) 1796 | let center-x = to-pos.center-x 1797 | let center-y = to-pos.y + to-pos.radius 1798 | let center = (center-x, center-y) 1799 | let radius = to-pos.at("radius", default: 0.4) 1800 | if conn-mode == "flat" { 1801 | (center.at(0), center.at(1) - radius) 1802 | } else if conn-mode == "air" { 1803 | (center.at(0), center.at(1) + radius) 1804 | } else if conn-mode == "depth" { 1805 | let angle = 225 * calc.pi / 180 1806 | (center.at(0) + radius * calc.cos(angle), center.at(1) + radius * calc.sin(angle)) 1807 | } else { 1808 | (center.at(0), center.at(1) - radius) 1809 | } 1810 | } else if to-anchor-key in arrow-segments { 1811 | let seg = arrow-segments.at(to-anchor-key) 1812 | // Use the arrow's midpoint (both x and y) 1813 | seg.mid 1814 | } else { 1815 | to-pos.anchors.true_west 1816 | } 1817 | 1818 | if conn-type == "skip" { 1819 | let conn-layers = conn.at("layers", default: none) 1820 | 1821 | if conn-mode == "flat" { 1822 | let down-y = from-anchor.at(1) - conn-pos 1823 | let waypoint1 = (from-anchor.at(0), down-y) 1824 | let waypoint2 = (to-anchor.at(0), down-y) 1825 | 1826 | draw-connection-path(((from-anchor, waypoint1), (waypoint1, waypoint2), (waypoint2, to-anchor)), opacity: conn-opacity, layers: conn-layers, layer-positions-ref: layer-positions, show-relu: show-relu) 1827 | 1828 | if conn-label != none { 1829 | content(((waypoint1.at(0) + waypoint2.at(0)) / 2, down-y - 0.3), 1830 | [#text(size: scaled-font(font-sizes.layer-label), conn-label)]) 1831 | } 1832 | } else if conn-mode == "depth" { 1833 | let (ox, oy) = get-depth-offsets(conn-pos * 2.5) 1834 | let waypoint1 = (from-anchor.at(0) - ox, from-anchor.at(1) - oy) 1835 | // For sum circles, adjust waypoint2 x-coordinate to account for south-west arrival 1836 | let waypoint2-x = if to-type == "sum" { 1837 | // Compensate for the south-west arrival offset (radius * cos(225°)) 1838 | let radius = to-pos.at("radius", default: 0.4) 1839 | let angle = 225 * calc.pi / 180 1840 | to-anchor.at(0) - ox - radius * calc.cos(angle) 1841 | } else { 1842 | to-anchor.at(0) - ox 1843 | } 1844 | let waypoint2 = (waypoint2-x, from-anchor.at(1) - oy) 1845 | 1846 | draw-connection-path(((from-anchor, waypoint1), (waypoint1, waypoint2), (waypoint2, to-anchor)), opacity: conn-opacity, layers: conn-layers, layer-positions-ref: layer-positions, show-relu: show-relu) 1847 | 1848 | if conn-label != none { 1849 | content(((waypoint1.at(0) + waypoint2.at(0)) / 2, waypoint1.at(1) - 0.3), 1850 | [#text(size: scaled-font(font-sizes.layer-label), conn-label)]) 1851 | } 1852 | } else if conn-mode == "air" { 1853 | let up-y = arrow-axis-y + conn-pos 1854 | let down-y = from-anchor.at(1) - conn-pos 1855 | let waypoint1 = (from-anchor.at(0), up-y) 1856 | let waypoint2 = (to-anchor.at(0), up-y) 1857 | 1858 | draw-connection-path(((from-anchor, waypoint1), (waypoint1, waypoint2), (waypoint2, to-anchor)), opacity: conn-opacity, layers: conn-layers, layer-positions-ref: layer-positions, show-relu: show-relu) 1859 | 1860 | if conn-label != none { 1861 | content(((waypoint1.at(0) + waypoint2.at(0)) / 2, up-y + 0.28), 1862 | [#text(size: scaled-font(font-sizes.layer-label), conn-label)]) 1863 | } 1864 | } 1865 | } 1866 | } 1867 | } 1868 | 1869 | if show-legend { 1870 | // Position legend after the last layer, accounting for its width and depth 1871 | let legend-x = prev-x + prev-depth-offset + 1.0 1872 | let legend-item-height = 0.4 1873 | let legend-box-size = 0.3 1874 | 1875 | // Count total legend entries to calculate vertical centering 1876 | let entry-count = legend-entries.len() 1877 | 1878 | // Calculate total legend height: title + spacing + (entries * item-height) 1879 | let legend-total-height = 0.5 + entry-count * legend-item-height 1880 | 1881 | // Center legend vertically around arrow-axis-y 1882 | let legend-y = arrow-axis-y + legend-total-height / 2 1883 | 1884 | content((legend-x - 0.05, legend-y + 0.15), 1885 | anchor: "north-west", 1886 | [#text(size: scaled-font(font-sizes.legend-title), weight: "bold", legend-title)]) 1887 | 1888 | legend-y -= 0.6 1889 | 1890 | // Render all legend entries in order of appearance 1891 | for entry in legend-entries { 1892 | let item-stroke = dynamic-color-strokes(entry.color) 1893 | let alpha = 100% - entry.at("opacity", default: 1.0) * 100% 1894 | 1895 | if entry.at("show-relu", default: false) { 1896 | // Draw split rectangle: 2/3 fill color (left), 1/3 bandfill color (right) 1897 | let split-x = legend-x + legend-box-size * 2 / 3 1898 | rect((legend-x, legend-y), (split-x, legend-y + legend-box-size), 1899 | fill: entry.color.transparentize(alpha), stroke: none) 1900 | rect((split-x, legend-y), (legend-x + legend-box-size, legend-y + legend-box-size), 1901 | fill: entry.bandfill.transparentize(alpha), stroke: none) 1902 | // Draw outline 1903 | rect((legend-x, legend-y), (legend-x + legend-box-size, legend-y + legend-box-size), 1904 | fill: none, stroke: item-stroke.solid) 1905 | } else { 1906 | // Draw solid rectangle 1907 | rect((legend-x, legend-y), (legend-x + legend-box-size, legend-y + legend-box-size), 1908 | fill: entry.color.transparentize(alpha), stroke: item-stroke.solid) 1909 | } 1910 | 1911 | content((legend-x + legend-box-size + 0.2, legend-y - 0.013 + legend-box-size / 2), anchor: "west", 1912 | [#text(size: scaled-font(font-sizes.legend-item), entry.label)]) 1913 | 1914 | legend-y -= legend-item-height 1915 | } 1916 | } 1917 | })} --------------------------------------------------------------------------------