├── .gitignore ├── README.md ├── package.json ├── public ├── favicon.ico ├── index.html └── manifest.json ├── src ├── index.css └── index.js └── yarn.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | 6 | # testing 7 | /coverage 8 | 9 | # production 10 | /build 11 | 12 | # misc 13 | .DS_Store 14 | .env.local 15 | .env.development.local 16 | .env.test.local 17 | .env.production.local 18 | 19 | npm-debug.log* 20 | yarn-debug.log* 21 | yarn-error.log* 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Convolution Visualizer 2 | 3 | Live at [Convolution Visualizer](https://ezyang.github.io/convolution-visualizer/index.html). 4 | 5 | Made with the help of our fine friends at [React](https://reactjs.org/) 6 | and [D3.js](https://d3js.org/). 7 | 8 | ### Things to do 9 | 10 | Want to play around with the code? Clone this repository and run `yarn 11 | start` to start a development instance. The main code lives in 12 | `src/index.js`. This [React manual](https://github.com/facebookincubator/create-react-app/blob/master/packages/react-scripts/template/README.md) may be of interest. 13 | 14 | Here are some project ideas: 15 | 16 | * Tweak the CSS so that the weight and output matrices 17 | are displayed to the right of the input if there is space. 18 | * Add a slider for adjusting speed of the animation. 19 | * Add a slider which specifies the animation timestep you 20 | are on; this way, you can run the animation forward and 21 | backward by dragging the slider. 22 | * Add output size and output padding sliders. When these 23 | sliders are adjusted, you recompute the input size using 24 | the transposed convolution formula. 25 | * Add an onClick handler, which pins your selection at 26 | the current mouse collection until another click 27 | occurs (disabling the hover behavior.) 28 | * Add a mode which, when enabled, labels cells with variables and 29 | renders the mathematical formula to compute the output 30 | cell you are moused over. 31 | * Render code for PyTorch (or your favorite framework) which performs the 32 | selected convolution. 33 | * Add more exotic convolution types like circular convolution. 34 | * Add a "true" convolution mode, where the weights are flipped 35 | before multiplication. 36 | * Support bigger input sizes than 16 (decreasing the size of 37 | the squares when inputs are large), and optimize the code so that it 38 | still runs quickly in these cases. 39 | * Support assymmetric inputs/kernels/strides/dilations. 40 | 41 | Bigger projects: 42 | 43 | * Create an in-browser canvas application, which convolves 44 | an input image against a displayed filter. Bonus points 45 | if your canvas supports painting capabilities. 46 | * Design a visualization which demonstrates the principles 47 | of group convolution, allowing you to slide from standard 48 | to depthwise convolution. 49 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "convolution-visualizer", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "d3-scale-chromatic": "^1.2.0", 7 | "d3v4": "^4.2.2", 8 | "gh-pages": "^1.1.0", 9 | "react": "^16.2.0", 10 | "react-dom": "^16.2.0", 11 | "react-scripts": "1.1.1" 12 | }, 13 | "homepage": "https://ezyang.github.io/convolution-visualizer", 14 | "scripts": { 15 | "start": "react-scripts start", 16 | "build": "react-scripts build", 17 | "test": "react-scripts test --env=jsdom", 18 | "eject": "react-scripts eject", 19 | "predeploy": "yarn build", 20 | "deploy": "gh-pages -d build" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezyang/convolution-visualizer/35ded63e432e2629ba5c6a64aa09df5e74c9a8d9/public/favicon.ico -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 21 | Convolution Visualizer 22 | 23 | 24 | 27 | Fork me on GitHub 28 |
29 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Convolution Visualizer", 3 | "name": "Convolution Visualizer", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | } 10 | ], 11 | "start_url": "./index.html", 12 | "display": "standalone", 13 | "theme_color": "#000000", 14 | "background_color": "#ffffff" 15 | } 16 | -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | h1 { margin-bottom: 0 } 2 | 3 | .author { margin-left: 2em; } 4 | 5 | p { 6 | max-width: 80em; 7 | } 8 | 9 | body { 10 | font: 14px "Century Gothic", Futura, sans-serif; 11 | margin: 20px; 12 | margin-right:40px; 13 | } 14 | 15 | .form { 16 | margin-bottom: 1em; 17 | float: left; 18 | } 19 | 20 | .viewport { 21 | margin-left: 16em; 22 | } 23 | 24 | .grid-container { 25 | margin-bottom: 1em; 26 | } 27 | 28 | table { 29 | border-collapse: collapse; 30 | } 31 | 32 | td { 33 | background: #fff; 34 | border: 1px solid #999; 35 | height: 34px; 36 | width: 34px; 37 | } 38 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import * as d3 from 'd3v4'; 4 | import './index.css'; 5 | 6 | /** 7 | * An HTML5 range slider and associated raw text input. 8 | * 9 | * Properties: 10 | * - min: The minimum allowed value for the slider range 11 | * - max: The maximum allowed value for the slider range 12 | * - value: The current value of the slider 13 | * - disabled: Whether or not to disable the slider. A slider 14 | * is automatically disabled when min == max. 15 | * - onChange: Callback when the value of this slider changes. 16 | */ 17 | function Slider(props) { 18 | const max = parseInt(props.max, 10); 19 | const min = parseInt(props.min, 10); 20 | const maxLength = max ? Math.ceil(Math.log10(max)) : 1; 21 | const disabled = props.disabled || min >= max; 22 | return ( 23 | 24 | 28 | 34 | 35 | ); 36 | } 37 | 38 | /** 39 | * Create a 1-dimensional array of size 'length', where the 'i'th entry 40 | * is initialized to 'f(i)', or 'undefined' if 'f' is not passed. 41 | */ 42 | function array1d(length, f) { 43 | return Array.from({length: length}, f ? ((v, i) => f(i)) : undefined); 44 | } 45 | 46 | /** 47 | * Create a 2-dimensional array of size 'height' x 'width', where the 'i','j' entry 48 | * is initialized to 'f(i, j)', or 'undefined' if 'f' is not passed. 49 | */ 50 | function array2d(height, width, f) { 51 | return Array.from({length: height}, (v, i) => Array.from({length: width}, f ? ((w, j) => f(i, j)) : undefined)); 52 | } 53 | 54 | /** 55 | * The classic convolution output size formula for a single dimension. 56 | * 57 | * The derivation for many special cases is worked out in: 58 | * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html 59 | */ 60 | function computeOutputSize(input_size, weight_size, padding, dilation, stride) { 61 | return Math.floor((input_size + 2 * padding - dilation * (weight_size - 1) - 1) / stride + 1); 62 | } 63 | 64 | /** 65 | * Test if a set of parameters is valid. 66 | */ 67 | function paramsOK(input_h, input_w, weight_h, weight_w, padding, dilation, stride_h, stride_w) { 68 | const output_h = computeOutputSize(input_h, weight_h, padding, dilation, stride_h); 69 | const output_w = computeOutputSize(input_w, weight_w, padding, dilation, stride_w); 70 | return output_h > 0 && output_w > 0; 71 | } 72 | 73 | 74 | // We use the next two functions (maxWhile and minWhile) to 75 | // inefficiently compute the bounds for various parameters 76 | // given fixed values for other parameters. 77 | 78 | /** 79 | * Given a predicate 'pred' and a starting integer 'start', 80 | * find the largest integer i >= start such that 'pred(i)' 81 | * is true OR end, whichever is smaller. 82 | */ 83 | function maxWhile(start, end, pred) { 84 | for (let i = start; i <= end; i++) { 85 | if (pred(i)) continue; 86 | return i - 1; 87 | } 88 | return end; 89 | } 90 | 91 | /** 92 | * Given a predicate 'pred' and a starting integer 'start', 93 | * find the smallest integer i <= start such that 'pred(i)' 94 | * is true OR end, whichever is larger. 95 | */ 96 | function minWhile(start, end, pred) { 97 | for (let i = start; i >= end; i--) { 98 | if (pred(i)) continue; 99 | return i + 1; 100 | } 101 | return end; 102 | } 103 | 104 | /** 105 | * Return the color at 0 <= p <= 1 for the RGB linear interpolation 106 | * between color (0) and white (1). 107 | */ 108 | function whiten(color, p) { 109 | return d3.interpolateRgb(color, "white")(p) 110 | } 111 | 112 | /** 113 | * Top-level component for the entire visualization. This component 114 | * controls top level parameters like input sizes, but not the mouse 115 | * interaction with the actual visualized grids. 116 | */ 117 | class App extends React.Component { 118 | constructor(props) { 119 | super(props); 120 | this.state = { 121 | input_height: 5, 122 | input_width: 5, 123 | weight_height: 3, 124 | weight_width: 3, 125 | padding: 0, 126 | dilation: 1, 127 | stride_height: 1, 128 | stride_width: 1, 129 | // State to control the UI mode 130 | inputShape: 'square', 131 | kernelShape: 'square', 132 | strideShape: 'square', 133 | }; 134 | } 135 | 136 | // React controlled components clobber saved browser state, so 137 | // instead we manually save/load our state from localStorage. 138 | 139 | componentDidMount() { 140 | const state = localStorage.getItem("state"); 141 | if (state) { 142 | this.setState(JSON.parse(state)); 143 | } 144 | } 145 | 146 | componentDidUpdate() { 147 | localStorage.setItem("state", JSON.stringify(this.state)); 148 | } 149 | 150 | // A smarter handler for dimension changes that respects the current shape mode. 151 | handleDimensionChange = (type, dimension) => (e) => { 152 | const r = parseInt(e.target.value, 10); 153 | if (isNaN(r)) return; 154 | 155 | // TODO: transposed convolution 156 | // FIX: Correctly map the 'type' string to its corresponding state key 157 | let shapeKey; 158 | if (type === 'input') shapeKey = 'inputShape'; 159 | else if (type === 'weight') shapeKey = 'kernelShape'; 160 | else if (type === 'stride') shapeKey = 'strideShape'; 161 | 162 | const shape = this.state[shapeKey]; 163 | 164 | if (shape === 'square') { 165 | // In square mode, the slider controls both height and width 166 | this.setState({ 167 | [`${type}_height`]: r, 168 | [`${type}_width`]: r, 169 | }); 170 | } else { 171 | // In rectangular mode, sliders are independent 172 | this.setState({ 173 | [`${type}_${dimension}`]: r 174 | }); 175 | } 176 | }; 177 | 178 | // Handles the user switching between "Square" and "Rectangular" 179 | handleShapeChange = (type) => (e) => { 180 | const newShape = e.target.value; 181 | const key = `${type}Shape`; 182 | 183 | if (newShape === 'square') { 184 | // When switching back to square, make width equal to height 185 | const height = this.state[`${type}_height`]; 186 | this.setState({ 187 | [key]: newShape, 188 | [`${type}_width`]: height, 189 | }); 190 | } else { 191 | this.setState({ [key]: newShape }); 192 | } 193 | }; 194 | 195 | render() { 196 | const { input_height, input_width, weight_height, weight_width, padding, dilation, stride_height, stride_width, inputShape, kernelShape, strideShape } = this.state; 197 | 198 | const padded_input_height = input_height + padding * 2; 199 | const padded_input_width = input_width + padding * 2; 200 | 201 | const output_height = computeOutputSize(input_height, weight_height, padding, dilation, stride_height); 202 | const output_width = computeOutputSize(input_width, weight_width, padding, dilation, stride_width); 203 | 204 | const output = array2d(output_height, output_width, (i, j) => array2d(weight_height, weight_width)); 205 | 206 | for (let h_out = 0; h_out < output_height; h_out++) { 207 | for (let w_out = 0; w_out < output_width; w_out++) { 208 | for (let h_kern = 0; h_kern < weight_height; h_kern++) { 209 | for (let w_kern = 0; w_kern < weight_width; w_kern++) { 210 | const h_im = h_out * stride_height + h_kern * dilation; 211 | const w_im = w_out * stride_width + w_kern * dilation; 212 | output[h_out][w_out][h_kern][w_kern] = h_im * padded_input_width + w_im; 213 | } 214 | } 215 | } 216 | } 217 | 218 | // Make an extended params dictionary with our new computed values 219 | // to pass to the inner component. 220 | const params = Object.assign({ 221 | padded_input_height: padded_input_height, 222 | padded_input_width: padded_input_width, 223 | output_height: output_height, 224 | output_width: output_width, 225 | output: output, 226 | }, this.state); 227 | 228 | const onChange = (state_key) => (e) => { 229 | const r = parseInt(e.target.value, 10); 230 | // Text inputs can sometimes temporarily be in invalid states. 231 | // If it's not a valid number, refuse to set it. 232 | if (!isNaN(r)) { 233 | this.setState({[state_key]: r}); 234 | } 235 | }; 236 | 237 | // An arbitrary constant I found aesthetically pleasing. 238 | const max_input_size = 16; 239 | 240 | return ( 241 |
242 |

Convolution Visualizer

243 |
Edward Z. Yang
244 |

245 | This interactive visualization demonstrates how various convolution parameters 246 | affect shapes and data dependencies between the input, weight and 247 | output matrices. Hovering over an input/output will highlight the 248 | corresponding output/input, while hovering over an weight 249 | will highlight which inputs were multiplied into that weight to 250 | compute an output. (Strictly speaking, the operation visualized 251 | here is a correlation, not a convolution, as a true 252 | convolution flips its weights before performing a correlation. 253 | However, most deep learning frameworks still call these convolutions, 254 | and in the end it's all the same to gradient descent.) 255 |

256 |
257 |
258 | Input size: 259 | 263 | 264 | {inputShape === 'square' && ( 265 |
266 | paramsOK(x, x, weight_height, weight_width, padding, dilation, stride_height, stride_width))} 268 | max={max_input_size} 269 | value={input_height} 270 | onChange={this.handleDimensionChange('input', 'height')} 271 | /> 272 |
273 | )} 274 | 275 | {inputShape === 'rectangular' && ( 276 | 277 |
278 |
279 | paramsOK(x, input_width, weight_height, weight_width, padding, dilation, stride_height, stride_width))} 281 | max={max_input_size} 282 | value={input_height} 283 | onChange={this.handleDimensionChange('input', 'height')} 284 | /> 285 |
286 |
287 |
288 | paramsOK(input_height, x, weight_height, weight_width, padding, dilation, stride_height, stride_width))} 290 | max={max_input_size} 291 | value={input_width} 292 | onChange={this.handleDimensionChange('input', 'width')} 293 | /> 294 |
295 |
296 | )} 297 |
298 |
299 | Kernel size: 300 | 304 | 305 | {kernelShape === 'square' && ( 306 |
307 | paramsOK(input_height, input_width, x, x, padding, dilation, stride_height, stride_width))} 310 | value={weight_height} 311 | onChange={this.handleDimensionChange('weight', 'height')} 312 | /> 313 |
314 | )} 315 | 316 | {kernelShape === 'rectangular' && ( 317 | 318 |
319 |
320 | paramsOK(input_height, input_width, x, weight_width, padding, dilation, stride_height, stride_width))} 323 | value={weight_height} 324 | onChange={this.handleDimensionChange('weight', 'height')} 325 | /> 326 |
327 |
328 |
329 | paramsOK(input_height, input_width, weight_height, x, padding, dilation, stride_height, stride_width))} 332 | value={weight_width} 333 | onChange={this.handleDimensionChange('weight', 'width')} 334 | /> 335 |
336 |
337 | )} 338 |
339 |
340 | Padding: 341 | paramsOK(input_height, input_width, weight_height, weight_width, x, dilation, stride_height, stride_width))} 342 | max={maxWhile(0, 100, (x) => paramsOK(input_height, input_width, weight_height, weight_width, x, dilation, stride_height, stride_width))} 343 | value={padding} 344 | onChange={onChange("padding")} 345 | /> 346 |
347 |
348 | Stride: 349 | 353 | 354 | {strideShape === 'square' && ( 355 |
356 | paramsOK(input_height, input_width, weight_height, weight_width, padding, dilation, x, x))} 359 | value={stride_height} 360 | onChange={this.handleDimensionChange('stride', 'height')} 361 | /> 362 |
363 | )} 364 | 365 | {strideShape === 'rectangular' && ( 366 | 367 |
368 |
369 | paramsOK(input_height, input_width, weight_height, weight_width, padding, dilation, x, stride_width))} 372 | value={stride_height} 373 | onChange={this.handleDimensionChange('stride', 'height')} 374 | /> 375 |
376 |
377 |
378 | paramsOK(input_height, input_width, weight_height, weight_width, padding, dilation, stride_height, x))} 381 | value={stride_width} 382 | onChange={this.handleDimensionChange('stride', 'width')} 383 | /> 384 |
385 |
386 | )} 387 |
388 |
389 | Dilation: 390 | paramsOK(input_height, input_width, weight_height, weight_width, padding, x, stride_height, stride_width))} 392 | value={dilation} 393 | onChange={onChange("dilation")} 394 | disabled={weight_height === 1 && weight_width === 1} 395 | /> 396 |
397 |
398 | 399 |
400 | ); 401 | } 402 | } 403 | 404 | /** 405 | * The viewport into the actual meat of the visualization, the 406 | * matrices. This component controls the state for hovering 407 | * and the animation. 408 | */ 409 | class Viewport extends React.Component { 410 | constructor(props) { 411 | super(props); 412 | this.state = { 413 | // Which matrix are we hovering over? 414 | hoverOver: undefined, 415 | // Which coordinate are we hovering over? Origin 416 | // is the top-left corner. 417 | hoverH: undefined, 418 | hoverW: undefined, 419 | // What is our animation timestep? A monotonically 420 | // increasing integer. 421 | counter: 0 422 | }; 423 | } 424 | 425 | // Arrange for counter to increment by one after a fixed 426 | // time interval: 427 | 428 | tick() { 429 | this.setState({counter: this.state.counter + 1}); 430 | } 431 | componentDidMount() { 432 | this.interval = setInterval(this.tick.bind(this), 1000); // 1 second 433 | } 434 | componentWillUnmount() { 435 | clearInterval(this.interval); 436 | } 437 | 438 | render() { 439 | const { input_height, input_width, padded_input_height, padded_input_width, 440 | weight_height, weight_width, output_height, output_width, 441 | output, padding, stride_height, stride_width } = this.props; 442 | 443 | let hoverOver = this.state.hoverOver; 444 | let hoverH = this.state.hoverH; 445 | let hoverW = this.state.hoverW; 446 | 447 | // The primary heavy lifting of the render() function is to 448 | // define colorizer functions for each matrix, such that 449 | // 450 | // colorizer(i, j) = color of the cell at i, j 451 | // 452 | let inputColorizer = undefined; 453 | let weightColorizer = undefined; 454 | let outputColorizer = undefined; 455 | 456 | // After colorizing an input cell, apply darkening if the cell falls 457 | // within the padding. This function is responsible for rendering 458 | // the dark padding border; if you replace this with a passthrough 459 | // to f no dark padding border will be rendered. 460 | function inputColorizerWrapper(f) { 461 | return (i, j) => { 462 | let r = f(i, j); 463 | if (typeof r === "undefined") { 464 | r = d3.color("white"); 465 | } else { 466 | r = d3.color(r); 467 | } 468 | if (i < padding || i >= input_height + padding || j < padding || j >= input_width + padding) { 469 | r = r.darker(2.5); 470 | } 471 | return r; 472 | }; 473 | } 474 | 475 | // Given the animation timestep, determine the output coordinates 476 | // of our animated stencil. 477 | const flat_animated = this.state.counter % (output_height * output_width); 478 | const animatedH = Math.floor(flat_animated / output_width); 479 | const animatedW = flat_animated % output_width; 480 | 481 | // If the user is not hovering over any matrix, render "as if" 482 | // they were hovering over the animated output coordinate. 483 | if (!hoverOver) { 484 | hoverOver = "output"; 485 | hoverH = animatedH; 486 | hoverW = animatedW; 487 | } 488 | 489 | // If the user is hovering over the input matrix, render "as if' 490 | // they were hovering over the output coordinate, such that the 491 | // top-left corner of the stencil is attached to the cursor. 492 | if (hoverOver === "input") { 493 | hoverOver = "output"; 494 | hoverH = Math.min(Math.floor(hoverH / stride_height), output_height - 1); 495 | hoverW = Math.min(Math.floor(hoverW / stride_width), output_width - 1); 496 | } 497 | 498 | // Generate the color interpolator for generating the kernels. 499 | // This particular scale was found via experimentation with various 500 | // start/endpoints and different interpolation schemes. For more 501 | // documentation on these D3 functions, see: 502 | // 503 | // - https://github.com/d3/d3-interpolate 504 | // - https://github.com/d3/d3-color 505 | // 506 | // Some notes on what I was going for, from an aesthetic perspective: 507 | // 508 | // - The most important constraint is that all colors produced by the 509 | // interpolator need to be saturated enough so they are not confused 510 | // with the "animation" shadow. 511 | // - I wanted the interpolation to be smooth, despite this being a 512 | // discrete setting where an ordinal color scheme could be 513 | // employed. (Also I couldn't get the color schemes to work lol.) 514 | // 515 | // If you are a visualization expert and have a pet 2D color 516 | // interpolation scheme, please try swapping it in here and seeing 517 | // how it goes. 518 | const xScale = d3.scaleSequential(d3.interpolateLab('#d7191c', '#2c7bb6')) 519 | .domain([-1, weight_height]); 520 | 521 | // The yScale (Red->Green) is driven by the column index `j`. 522 | const yScale = d3.scaleSequential(d3.interpolateLab('#d7191c', d3.color('#1a9641').brighter(1))) 523 | .domain([-1, weight_width]); 524 | 525 | const max_dim = Math.max(weight_height, weight_width); 526 | 527 | function xyScale(i, j) { // i for height index, j for width index 528 | // Get the end-point colors for this specific cell's gradient 529 | const color1 = xScale(i); 530 | const color2 = yScale(j); 531 | 532 | // The interpolation factor determines the mix between color1 and color2 533 | const factor = (max_dim > 1) ? (j - i) / (max_dim - 1) : 0.5; 534 | 535 | // We need to normalize the factor to be in the [0, 1] range for the interpolator. 536 | // The original factor is roughly in [-1, 1], so this mapping works. 537 | const normalizedFactor = (factor + 1) / 2; 538 | 539 | return d3.color(d3.interpolateLab(color1, color2)(normalizedFactor)); 540 | } 541 | 542 | // Given an output coordinate 'hoverH, hoverW', compute a mapping 543 | // from inputs to the weight coordinates which multiplied with 544 | // that input. 545 | // 546 | // Result: 547 | // r[flat_input_index] = [weight_height, weight_width] 548 | function compute_input_multiplies_with_weight(hoverH, hoverW) { 549 | const input_multiplies_with_weight = array1d(padded_input_height * padded_input_width); 550 | if (hoverH >= 0 && hoverH < output_height && hoverW >= 0 && hoverW < output_width) { 551 | for (let h_weight = 0; h_weight < weight_height; h_weight++) { 552 | for (let w_weight = 0; w_weight < weight_width; w_weight++) { 553 | const flat_input = output[hoverH][hoverW][h_weight][w_weight]; 554 | if (typeof flat_input === "undefined") continue; 555 | input_multiplies_with_weight[flat_input] = [h_weight, w_weight]; 556 | } 557 | } 558 | } 559 | return input_multiplies_with_weight; 560 | } 561 | 562 | // The user is hovering over the output matrix (or the input matrix) 563 | if (hoverOver === "output") { 564 | outputColorizer = (i, j) => { 565 | const base = d3.color('#666') 566 | // If this output is selected, display it as dark grey 567 | if (hoverH === i && hoverW === j) { 568 | return base; 569 | } 570 | 571 | // Otherwise, if the output is animated, display it as a lighter 572 | // gray 573 | if (animatedH === i && animatedW === j) { 574 | return whiten(base, 0.8); 575 | } 576 | }; 577 | 578 | const input_multiplies_with_weight = compute_input_multiplies_with_weight(hoverH, hoverW); 579 | const animated_input_multiplies_with_weight = compute_input_multiplies_with_weight(animatedH, animatedW); 580 | 581 | inputColorizer = inputColorizerWrapper((i, j) => { 582 | // If this input was used to compute the selected output, render 583 | // it the same color as the corresponding entry in the weight 584 | // matrix which it was multiplied against. 585 | const r = input_multiplies_with_weight[i * padded_input_width + j]; 586 | if (r) { 587 | return xyScale(r[0], r[1]); 588 | } 589 | 590 | // Otherwise, if the input was used to compute the animated 591 | // output, render it as a lighter version of the weight color it was 592 | // multiplied against. 593 | const s = animated_input_multiplies_with_weight[i * padded_input_width + j]; 594 | if (s) { 595 | return whiten(xyScale(s[0], s[1]), 0.8); 596 | } 597 | }); 598 | 599 | // The weight matrix displays the full 2D color scale 600 | weightColorizer = (i, j) => { 601 | return xyScale(i, j); 602 | }; 603 | 604 | // The user is hovering over the weight matrix 605 | } else if (hoverOver === "weight") { 606 | 607 | weightColorizer = (i, j) => { 608 | // If this weight is selected, render its color 609 | if (hoverH === i && hoverW === j) { 610 | return xyScale(hoverH, hoverW); 611 | } 612 | }; 613 | 614 | // Compute a mapping from flat input index to output coordinates which 615 | // this input multiplied with the selected weight to produce. 616 | const input_produces_output = array1d(padded_input_height * padded_input_width); 617 | for (let h_out = 0; h_out < output_height; h_out++) { 618 | for (let w_out = 0; w_out < output_width; w_out++) { 619 | const flat_input = output[h_out][w_out][hoverH][hoverW]; 620 | if (typeof flat_input === "undefined") continue; 621 | input_produces_output[flat_input] = [h_out, w_out]; 622 | } 623 | } 624 | 625 | const animated_input_multiplies_with_weight = compute_input_multiplies_with_weight(animatedH, animatedW); 626 | 627 | inputColorizer = inputColorizerWrapper((i, j) => { 628 | // We are only rendering inputs which multiplied against a given 629 | // weight, so render all inputs the same color as the selected 630 | // weight. 631 | const color = xyScale(hoverH, hoverW); 632 | 633 | // If this input cell was multiplied by the selected weight to 634 | // produce the animated output, darken it. This shows the 635 | // current animation step's "contribution" to the colored 636 | // inputs. 637 | const s = animated_input_multiplies_with_weight[i * padded_input_width + j]; 638 | if (s) { 639 | if (s[0] === hoverH && s[1] === hoverW) { 640 | return color.darker(1); 641 | } 642 | } 643 | 644 | // If this input cell was multiplied by the selected weight to 645 | // produce *some* output, render it as the weight's color. 646 | const r = input_produces_output[i * padded_input_width + j]; 647 | if (r) { 648 | // BUT, if the input cell is part of the current animation 649 | // stencil, lighten it so that we can still see the stencil. 650 | if (s) { 651 | return whiten(color, 0.2); 652 | } 653 | return color; 654 | } 655 | 656 | // If this input cell is part of the animated stencil (and 657 | // it is not part of the solid block of color), render a shadow 658 | // of the stencil so we can still see it. 659 | if (s) { 660 | return whiten(xyScale(s[0], s[1]), 0.8); 661 | } 662 | }); 663 | 664 | // The output matrix is a solid color of the selected weight. 665 | outputColorizer = (i, j) => { 666 | const color = xyScale(hoverH, hoverW); 667 | // If the output is the animated one, darken it, so we can 668 | // see the animation. 669 | if (i === animatedH && j === animatedW) { 670 | return color.darker(1); 671 | } 672 | return color; 673 | }; 674 | } 675 | 676 | return ( 677 |
678 |
679 | Input ({input_height} × {input_width}): 680 | { 683 | this.setState({hoverOver: "input", hoverH: i, hoverW: j}); 684 | }} 685 | onMouseLeave={(e, i, j) => { 686 | this.setState({hoverOver: undefined, hoverH: undefined, hoverW: undefined}); 687 | }} 688 | /> 689 |
690 |
691 | Weight ({weight_height} × {weight_width}): 692 | { 695 | this.setState({hoverOver: "weight", hoverH: i, hoverW: j}); 696 | }} 697 | onMouseLeave={(e, i, j) => { 698 | this.setState({hoverOver: undefined, hoverH: undefined, hoverW: undefined}); 699 | }} 700 | /> 701 |
702 |
703 | Output ({output_height} × {output_width}): 704 | { 707 | this.setState({hoverOver: "output", hoverH: i, hoverW: j}); 708 | }} 709 | onMouseLeave={(e, i, j) => { 710 | this.setState({hoverOver: undefined, hoverH: undefined, hoverW: undefined}); 711 | }} 712 | /> 713 |
714 |
715 | ); 716 | } 717 | } 718 | 719 | /** 720 | * A rectangular matrix grid which we render our matrix animations. 721 | * 722 | * Properties: 723 | * - height: The height of the matrix 724 | * - width: The width of the matrix 725 | * - colorizer: A function f(i, j), returning the color of the i,j cell 726 | * - onMouseEnter: A callback invoked f(event, i, j) when the i,j cell is 727 | * entered by a mouse. 728 | * - onMouseLeave: A callback invoked f(event, i, j) when the i,j cell is 729 | * left by a mouse. 730 | */ 731 | function Grid(props) { 732 | const height = parseInt(props.height, 10) || 0; 733 | const width = parseInt(props.width, 10) || 0; 734 | 735 | if (height <= 0 || width <= 0) { 736 | return
(empty)
; 737 | } 738 | 739 | const grid = array2d(height, width); 740 | const xgrid = grid.map((row, i) => { 741 | const xrow = row.map((e, j) => { 742 | // Use of colorizer this way means we force recompute of all tiles 743 | const color = props.colorizer ? props.colorizer(i, j) : undefined; 744 | return props.onMouseEnter(e, i, j)) : undefined} 748 | onMouseLeave={props.onMouseLeave ? 749 | ((e) => props.onMouseLeave(e, i, j)) : undefined} /> 750 | }); 751 | return {xrow}; 752 | }); 753 | return {xgrid}
; 754 | } 755 | 756 | // ======================================== 757 | 758 | ReactDOM.render( 759 | , 760 | document.getElementById('root') 761 | ); 762 | --------------------------------------------------------------------------------