├── .gitignore ├── README.md ├── package.json ├── public ├── favicon.ico ├── index.html └── manifest.json ├── src ├── index.css └── index.js └── yarn.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | 6 | # testing 7 | /coverage 8 | 9 | # production 10 | /build 11 | 12 | # misc 13 | .DS_Store 14 | .env.local 15 | .env.development.local 16 | .env.test.local 17 | .env.production.local 18 | 19 | npm-debug.log* 20 | yarn-debug.log* 21 | yarn-error.log* 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Stride Visualizer 2 | 3 | Live at [Stride Visualizer](https://ezyang.github.io/stride-visualizer/index.html). 4 | 5 | Made with the help of our fine friends at [React](https://reactjs.org/) 6 | and [D3.js](https://d3js.org/). 7 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stride-visualizer", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "d3-scale-chromatic": "^1.2.0", 7 | "d3v4": "^4.2.2", 8 | "gh-pages": "^1.1.0", 9 | "react": "^16.2.0", 10 | "react-dom": "^16.2.0", 11 | "react-scripts": "1.1.1" 12 | }, 13 | "homepage": "https://ezyang.github.io/stride-visualizer", 14 | "scripts": { 15 | "start": "react-scripts start", 16 | "build": "react-scripts build", 17 | "test": "react-scripts test --env=jsdom", 18 | "eject": "react-scripts eject", 19 | "predeploy": "yarn build", 20 | "deploy": "gh-pages -d build" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ezyang/stride-visualizer/03cdfe0cf062a283f0cd3837ca67aea8214cad17/public/favicon.ico -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 21 | Stride Visualizer 22 | 23 | 24 | 27 | Fork me on GitHub 28 |
29 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Stride Visualizer", 3 | "name": "Stride Visualizer", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | } 10 | ], 11 | "start_url": "./index.html", 12 | "display": "standalone", 13 | "theme_color": "#000000", 14 | "background_color": "#ffffff" 15 | } 16 | -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | h1 { margin-bottom: 0 } 2 | 3 | .author { margin-left: 2em; } 4 | 5 | p { 6 | max-width: 80em; 7 | } 8 | 9 | body { 10 | font: 14px "Century Gothic", Futura, sans-serif; 11 | margin: 20px; 12 | margin-right:40px; 13 | } 14 | 15 | .form { 16 | margin-bottom: 1em; 17 | float: left; 18 | } 19 | 20 | .viewport { 21 | margin-left: 30em; 22 | } 23 | 24 | .grid-container { 25 | margin-bottom: 1em; 26 | } 27 | 28 | table { 29 | border-collapse: collapse; 30 | } 31 | 32 | td { 33 | background: #fff; 34 | border: 1px solid #999; 35 | height: 34px; 36 | width: 34px; 37 | } 38 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import * as d3 from 'd3v4'; 4 | import './index.css'; 5 | 6 | /** 7 | * An HTML5 range slider and associated raw text input. 8 | * 9 | * Properties: 10 | * - min: The minimum allowed value for the slider range 11 | * - max: The maximum allowed value for the slider range 12 | * - value: The current value of the slider 13 | * - disabled: Whether or not to disable the slider. A slider 14 | * is automatically disabled when min == max. 15 | * - onChange: Callback when the value of this slider changes. 16 | */ 17 | function Slider(props) { 18 | const max = parseInt(props.max, 10); 19 | const min = parseInt(props.min, 10); 20 | const maxLength = max ? Math.ceil(Math.log10(Math.abs(max))) : 1; 21 | const disabled = props.disabled || min >= max; 22 | return ( 23 | 24 | 28 | 34 | 35 | ); 36 | } 37 | 38 | /** 39 | * Create a 1-dimensional array of size 'length', where the 'i'th entry 40 | * is initialized to 'f(i)', or 'undefined' if 'f' is not passed. 41 | */ 42 | function array1d(length, f) { 43 | return Array.from({length: length}, f ? ((v, i) => f(i)) : undefined); 44 | } 45 | 46 | /** 47 | * Create a 2-dimensional array of size 'height' x 'width', where the 'i','j' entry 48 | * is initialized to 'f(i, j)', or 'undefined' if 'f' is not passed. 49 | */ 50 | function array2d(height, width, f) { 51 | return Array.from({length: height}, (v, i) => Array.from({length: width}, f ? ((w, j) => f(i, j)) : undefined)); 52 | } 53 | 54 | function array3d(depth, height, width, f) { 55 | return Array.from({length: depth}, (v, i) => 56 | Array.from({length: height}, (v, j) => 57 | Array.from({length: width}, 58 | f ? ((w, k) => f(i, j)) : undefined))); 59 | } 60 | 61 | // We use the next two functions (maxWhile and minWhile) to 62 | // inefficiently compute the bounds for various parameters 63 | // given fixed values for other parameters. 64 | 65 | /** 66 | * Given a predicate 'pred' and a starting integer 'start', 67 | * find the largest integer i >= start such that 'pred(i)' 68 | * is true OR end, whichever is smaller. 69 | */ 70 | function maxWhile(start, end, pred) { 71 | for (let i = start; i <= end; i++) { 72 | if (pred(i)) continue; 73 | return i - 1; 74 | } 75 | return end; 76 | } 77 | 78 | /** 79 | * Given a predicate 'pred' and a starting integer 'start', 80 | * find the smallest integer i <= start such that 'pred(i)' 81 | * is true OR end, whichever is larger. 82 | */ 83 | function minWhile(start, end, pred) { 84 | for (let i = start; i >= end; i--) { 85 | if (pred(i)) continue; 86 | return i + 1; 87 | } 88 | return end; 89 | } 90 | 91 | function watermarks(view_height, view_width, stride_height, stride_width) { 92 | // NB: both of these watermarks are INCLUSIVE 93 | // For example, if all strides are 0, we get [0, 0], which is true, we 94 | // will access the memory at 0. 95 | // NB: this does the RIGHT THING when height/width is zero. Then high 96 | // watermark is negative while low watermark is zero, meaning the 97 | // empty range, which is precisely correct. 98 | 99 | let high_watermark = 0; 100 | if (stride_height > 0) high_watermark += (view_height - 1) * stride_height; 101 | if (stride_width > 0) high_watermark += (view_width - 1) * stride_width; 102 | 103 | let low_watermark = 0; 104 | if (stride_height < 0) low_watermark += (view_height - 1) * stride_height; 105 | if (stride_width < 0) low_watermark += (view_width - 1) * stride_width; 106 | 107 | return [low_watermark, high_watermark]; 108 | } 109 | 110 | function paramsOK(storage_height, storage_width, storage_offset, view_height, view_width, stride_height, stride_width) { 111 | const wms = watermarks(view_height, view_width, stride_height, stride_width); 112 | 113 | if (wms[1] < wms[0]) return true; 114 | 115 | const storage_size = storage_height * storage_width; 116 | return wms[0] + storage_offset >= 0 && wms[1] + storage_offset < storage_size; 117 | } 118 | 119 | /** 120 | * Top-level component for the entire visualization. This component 121 | * controls top level parameters like input sizes, but not the mouse 122 | * interaction with the actual visualized grids. 123 | */ 124 | class App extends React.Component { 125 | constructor(props) { 126 | super(props); 127 | this.state = { 128 | storage_height: 4, 129 | storage_width: 4, 130 | storage_offset: 0, 131 | view_height: 4, 132 | view_width: 4, 133 | stride_height: 4, 134 | stride_width: 1, 135 | }; 136 | } 137 | 138 | // React controlled components clobber saved browser state, so 139 | // instead we manually save/load our state from localStorage. 140 | 141 | componentDidMount() { 142 | const state = localStorage.getItem("stride-visualizer"); 143 | if (state) { 144 | this.setState(JSON.parse(state)); 145 | } 146 | } 147 | 148 | componentDidUpdate() { 149 | localStorage.setItem("stride-visualizer", JSON.stringify(this.state)); 150 | } 151 | 152 | render() { 153 | const storage_height = this.state.storage_height; 154 | const storage_width = this.state.storage_width; 155 | const storage_offset = this.state.storage_offset; 156 | const view_height = this.state.view_height; 157 | const view_width = this.state.view_width; 158 | const stride_height = this.state.stride_height; 159 | const stride_width = this.state.stride_width; 160 | 161 | const onChange = (state_key) => { 162 | return (e) => { 163 | const r = parseInt(e.target.value, 10); 164 | // Text inputs can sometimes temporarily be in invalid states. 165 | // If it's not a valid number, refuse to set it. 166 | if (typeof r !== "undefined") { 167 | this.setState({[state_key]: r}); 168 | } 169 | }; 170 | }; 171 | 172 | const max_storage = 64; 173 | const max_size = 8; 174 | const max_stride = 8; 175 | 176 | return ( 177 |
178 |

Stride Visualizer

179 |
Edward Z. Yang
180 |

181 | Strides specify a factor by which an index is multiplied when computing its 182 | index into an array. Strides are surprisingly versatile and can be used 183 | to program a large number of access patterns: 184 |

185 | 193 |
194 |
195 | Storage size: 196 | paramsOK(x, storage_width, storage_offset, view_height, view_width, stride_height, stride_width))} 197 | max={max_storage} 198 | value={storage_height} 199 | onChange={onChange("storage_height")} 200 | /> 201 | paramsOK(storage_height, x, storage_offset, view_height, view_width, stride_height, stride_width))} 202 | max={max_storage} 203 | value={storage_width} 204 | onChange={onChange("storage_width")} 205 | /> 206 |
207 |
208 | Storage offset: 209 | { /* These formulas don't handle the size = 0 boundary case correctly */ } 210 | 215 |
216 |
217 | View size: 218 | paramsOK(storage_height, storage_width, storage_offset, x, view_width, stride_height, stride_width))} 220 | value={view_height} 221 | onChange={onChange("view_height")} 222 | /> 223 | paramsOK(storage_height, storage_width, storage_offset, view_height, x, stride_height, stride_width))} 225 | value={view_width} 226 | onChange={onChange("view_width")} 227 | /> 228 |
229 |
230 | View stride: 231 | paramsOK(storage_height, storage_width, storage_offset, view_height, view_width, x, stride_width))} 232 | max={maxWhile(0, max_stride, (x) => paramsOK(storage_height, storage_width, storage_offset, view_height, view_width, x, stride_width))} 233 | value={stride_height} 234 | onChange={onChange("stride_height")} 235 | /> 236 | paramsOK(storage_height, storage_width, storage_offset, view_height, view_width, stride_height, x))} 237 | max={maxWhile(0, max_stride, (x) => paramsOK(storage_height, storage_width, storage_offset, view_height, view_width, stride_height, x))} 238 | value={stride_width} 239 | onChange={onChange("stride_width")} 240 | /> 241 |
242 |
243 | 244 |
245 | ); 246 | } 247 | } 248 | 249 | /** 250 | * The viewport into the actual meat of the visualization, the 251 | * tensors. This component controls the state for hovering 252 | * and the animation. 253 | */ 254 | class Viewport extends React.Component { 255 | constructor(props) { 256 | super(props); 257 | this.state = { 258 | // Which matrix are we hovering over? 259 | hoverOver: undefined, 260 | // Which coordinate are we hovering over? Origin 261 | // is the top-left corner. 262 | hoverH: undefined, 263 | hoverW: undefined, 264 | // What is our animation timestep? A monotonically 265 | // increasing integer. 266 | counter: 0 267 | }; 268 | } 269 | 270 | // Arrange for counter to increment by one after a fixed 271 | // time interval: 272 | 273 | tick() { 274 | this.setState({counter: this.state.counter + 1}); 275 | } 276 | componentDidMount() { 277 | this.interval = setInterval(this.tick.bind(this), 500); // 0.5 second 278 | } 279 | componentWillUnmount() { 280 | clearInterval(this.interval); 281 | } 282 | 283 | render() { 284 | const storage_height = this.props.storage_height; 285 | const storage_width = this.props.storage_width; 286 | const storage_offset = this.props.storage_offset; 287 | const view_height = this.props.view_height; 288 | const view_width = this.props.view_width; 289 | const stride_height = this.props.stride_height; 290 | const stride_width = this.props.stride_width; 291 | 292 | let hoverOver = this.state.hoverOver; 293 | let hoverH = this.state.hoverH; 294 | let hoverW = this.state.hoverW; 295 | 296 | // The primary heavy lifting of the render() function is to 297 | // define colorizer functions for each matrix, such that 298 | // 299 | // colorizer(i, j) = color of the cell at i, j 300 | // 301 | let storageColorizer = undefined; 302 | let viewColorizer = undefined; 303 | 304 | // Given the animation timestep, determine the output coordinates 305 | // of our animated stencil. 306 | const animatedH = this.state.counter % view_height; 307 | 308 | // Don't have a good thing for this yet 309 | if (hoverOver === "storage") hoverOver = false; 310 | 311 | // If the user is not hovering over any matrix, render "as if" 312 | // they were hovering over the animated output coordinate. 313 | if (!hoverOver) { 314 | hoverOver = "output"; 315 | hoverH = animatedH; 316 | hoverW = undefined; 317 | } 318 | 319 | const scale = d3.scaleSequential(d3.interpolateLab('#d7191c', '#2c7bb6')).domain([0, view_width]) 320 | 321 | /* 322 | // The easy colorizers 323 | storageColorizer = (i, j) => { 324 | return xyScale(i, j); 325 | }; 326 | 327 | viewColorizer = (i, j) => { 328 | const loc = storage_offset + i * stride_height + j * stride_width; 329 | return xyScale(Math.floor(loc / storage_width), loc % storage_width); 330 | }; 331 | */ 332 | 333 | if (hoverOver === "output" || true) { 334 | storageColorizer = (i, j) => { 335 | const flat = i * storage_width + j; 336 | for (let k = 0; k < view_width; k++) { 337 | if (hoverH * stride_height + k * stride_width + storage_offset === flat) return scale(k); 338 | } 339 | return "white"; 340 | } 341 | viewColorizer = (i, j) => { 342 | if (hoverH !== i) return "white"; 343 | return scale(stride_width ? j : 0); 344 | }; 345 | } 346 | 347 | // The user is hovering over the output matrix (or the input matrix) 348 | /* 349 | if (hoverOver === "output") { 350 | outputColorizer = (i, j) => { 351 | const base = d3.color('#666') 352 | // If this output is selected, display it as dark grey 353 | if (hoverH === i && hoverW === j) { 354 | return base; 355 | } 356 | 357 | // Otherwise, if the output is animated, display it as a lighter 358 | // gray 359 | if (animatedH === i && animatedW === j) { 360 | return whiten(base, 0.8); 361 | } 362 | }; 363 | 364 | const input_multiplies_with_weight = compute_input_multiplies_with_weight(hoverH, hoverW); 365 | const animated_input_multiplies_with_weight = compute_input_multiplies_with_weight(animatedH, animatedW); 366 | 367 | inputColorizer = inputColorizerWrapper((i, j) => { 368 | // If this input was used to compute the selected output, render 369 | // it the same color as the corresponding entry in the weight 370 | // matrix which it was multiplied against. 371 | const r = input_multiplies_with_weight[i * padded_input_size + j]; 372 | if (r) { 373 | return xyScale(r[0], r[1]); 374 | } 375 | 376 | // Otherwise, if the input was used to compute the animated 377 | // output, render it as a lighter version of the weight color it was 378 | // multiplied against. 379 | const s = animated_input_multiplies_with_weight[i * padded_input_size + j]; 380 | if (s) { 381 | return whiten(xyScale(s[0], s[1]), 0.8); 382 | } 383 | }); 384 | 385 | // The weight matrix displays the full 2D color scale 386 | weightColorizer = (i, j) => { 387 | return xyScale(i, j); 388 | }; 389 | 390 | // The user is hovering over the weight matrix 391 | } else if (hoverOver === "weight") { 392 | 393 | weightColorizer = (i, j) => { 394 | // If this weight is selected, render its color 395 | if (hoverH === i && hoverW === j) { 396 | return xyScale(hoverH, hoverW); 397 | } 398 | }; 399 | 400 | // Compute a mapping from flat input index to output coordinates which 401 | // this input multiplied with the selected weight to produce. 402 | const input_produces_output = array1d(padded_input_size * padded_input_size); 403 | for (let h_out = 0; h_out < output_size; h_out++) { 404 | for (let w_out = 0; w_out < output_size; w_out++) { 405 | const flat_input = output[h_out][w_out][hoverH][hoverW]; 406 | if (typeof flat_input === "undefined") continue; 407 | input_produces_output[flat_input] = [h_out, w_out]; 408 | } 409 | } 410 | 411 | const animated_input_multiplies_with_weight = compute_input_multiplies_with_weight(animatedH, animatedW); 412 | 413 | inputColorizer = inputColorizerWrapper((i, j) => { 414 | // We are only rendering inputs which multiplied against a given 415 | // weight, so render all inputs the same color as the selected 416 | // weight. 417 | const color = xyScale(hoverH, hoverW); 418 | 419 | // If this input cell was multiplied by the selected weight to 420 | // produce the animated output, darken it. This shows the 421 | // current animation step's "contribution" to the colored 422 | // inputs. 423 | const s = animated_input_multiplies_with_weight[i * padded_input_size + j]; 424 | if (s) { 425 | if (s[0] === hoverH && s[1] === hoverW) { 426 | return color.darker(1); 427 | } 428 | } 429 | 430 | // If this input cell was multiplied by the selected weight to 431 | // produce *some* output, render it as the weight's color. 432 | const r = input_produces_output[i * padded_input_size + j]; 433 | if (r) { 434 | // BUT, if the input cell is part of the current animation 435 | // stencil, lighten it so that we can still see the stencil. 436 | if (s) { 437 | return whiten(color, 0.2); 438 | } 439 | return color; 440 | } 441 | 442 | // If this input cell is part of the animated stencil (and 443 | // it is not part of the solid block of color), render a shadow 444 | // of the stencil so we can still see it. 445 | if (s) { 446 | return whiten(xyScale(s[0], s[1]), 0.8); 447 | } 448 | }); 449 | 450 | // The output matrix is a solid color of the selected weight. 451 | outputColorizer = (i, j) => { 452 | const color = xyScale(hoverH, hoverW); 453 | // If the output is the animated one, darken it, so we can 454 | // see the animation. 455 | if (i === animatedH && j === animatedW) { 456 | return color.darker(1); 457 | } 458 | return color; 459 | }; 460 | } 461 | */ 462 | 463 | return ( 464 |
465 |
466 | Storage ({storage_height} × {storage_width}): 467 | { 470 | this.setState({hoverOver: "storage", hoverH: i, hoverW: j}); 471 | }} 472 | onMouseLeave={(e, i, j) => { 473 | this.setState({hoverOver: undefined, hoverH: undefined, hoverW: undefined}); 474 | }} 475 | /> 476 |
477 |
478 | View ({view_height} × {view_width}): 479 | { 482 | this.setState({hoverOver: "view", hoverH: i, hoverW: j}); 483 | }} 484 | onMouseLeave={(e, i, j) => { 485 | this.setState({hoverOver: undefined, hoverH: undefined, hoverW: undefined}); 486 | }} 487 | /> 488 |
489 |
490 | ); 491 | } 492 | } 493 | 494 | /** 495 | * A matrix grid which we render our matrix animations. 496 | * 497 | * Properties: 498 | * - height: height of the matrix 499 | * - width: widht of the matrix 500 | * - colorizer: A function f(i, j), returning the color of the i,j cell 501 | * - onMouseEnter: A callback invoked f(event, i, j) when the i,j cell is 502 | * entered by a mouse. 503 | * - onMouseLeave: A callback invoked f(event, i, j) when the i,j cell is 504 | * left by a mouse. 505 | */ 506 | function Grid(props) { 507 | const height = parseInt(props.height, 10); 508 | const width = parseInt(props.width, 10); 509 | const grid = array2d(height, width); 510 | const xgrid = grid.map((row, i) => { 511 | const xrow = row.map((e, j) => { 512 | // Use of colorizer this way means we force recompute of all tiles 513 | const color = props.colorizer ? props.colorizer(i, j) : undefined; 514 | return props.onMouseEnter(e, i, j)) : undefined} 518 | onMouseLeave={props.onMouseLeave ? 519 | ((e) => props.onMouseLeave(e, i, j)) : undefined} /> 520 | }); 521 | return {xrow}; 522 | }); 523 | return {xgrid}
; 524 | } 525 | 526 | // ======================================== 527 | 528 | ReactDOM.render( 529 | , 530 | document.getElementById('root') 531 | ); 532 | --------------------------------------------------------------------------------