├── LICENSE ├── README.md ├── addition ├── index.html ├── init-add-big-feature.js ├── init-add-connections.js ├── init.js ├── style.css └── util-add.js ├── attribution_graph ├── cg.css ├── gridsnap │ ├── gridsnap.css │ └── init-gridsnap.js ├── init-cg-button-container.js ├── init-cg-clerp-list.js ├── init-cg-feature-detail.js ├── init-cg-feature-scatter.js ├── init-cg-link-graph.js ├── init-cg-node-connections.js ├── init-cg-subgraph.js ├── init-cg.js └── util-cg.js ├── feature_examples ├── feature-examples.css ├── init-feature-examples-list.js ├── init-feature-examples-logits.js └── init-feature-examples.js ├── index.html ├── prettier.config.js ├── style.css └── util.js /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Anthropic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # attribution-graphs-frontend 2 | 3 | Snapshot of the frontend code in [On the Biology of a Large Language Model](https://transformer-circuits.pub/2025/attribution-graphs/biology.html) and [Circuit Tracing: Revealing Computational Graphs in Language Models](https://transformer-circuits.pub/2025/attribution-graphs/methods.html). 4 | 5 | To run: 6 | 7 | ``` 8 | git clone git@github.com:anthropics/attribution-graphs-frontend.git 9 | cd attribution-graphs-frontend 10 | npx hot-server 11 | ``` 12 | -------------------------------------------------------------------------------- /addition/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Global Addition Weights 9 | 10 |
11 | 12 |
13 |
14 |
15 |
16 |
17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /addition/init-add-big-feature.js: -------------------------------------------------------------------------------- 1 | window.initBigAddFeature = function({visState, renderAll, allFeatures}){ 2 | var sel = d3.select('.add-big-feature').html('') 3 | 4 | var numSel = sel.append('div.feature-num.section-title') 5 | var heatmapSel = sel.append('div.heatmap.operand') 6 | 7 | sel.append('div.section-title').text('Token Predictions') 8 | var logitSel = sel.append('div.logit') 9 | 10 | sel.append('div.section-title').text('Encoder UMAP') 11 | utilAdd.drawUmap(sel.append('div.umap'), {allFeatures, visState, renderAll}) 12 | 13 | sel.append('div.section-title').text('Decoder UMAP') 14 | utilAdd.drawUmap(sel.append('div.umap'), {allFeatures, visState, renderAll, type: 'dec'}) 15 | 16 | // sel.append('div.section-title').text('Joint UMAP') 17 | // utilAdd.drawUmap(sel.append('div.umap'), {allFeatures, visState, renderAll, type: 'joint'}) 18 | 19 | sel.append('div.link').st({marginTop: 20}) 20 | .append('a').html('← Circuit Tracing § Addition Case Study') 21 | .at({href: '../../methods.html#graphs-addition'}) 22 | 23 | 24 | renderAll.clickIdx.fns.push(async () => { 25 | heatmapSel.html('') 26 | 27 | var idx = visState.clickIdx 28 | numSel.text(d => `L${allFeatures.idxLookup[idx].layer}#${d3.format('07d')(idx)}`) 29 | 30 | utilAdd.drawHeatmap(heatmapSel, idx, {isBig: true}) 31 | utilAdd.drawLogits(logitSel, idx, {isBig: true}) 32 | }) 33 | 34 | } 35 | 36 | window.init?.() 37 | -------------------------------------------------------------------------------- /addition/init-add-connections.js: -------------------------------------------------------------------------------- 1 | window.initAddConnections = function({visState, renderAll, allFeatures, type='inputs'}){ 2 | var sel = d3.select(`.add-connections.${type}`).html('') 3 | 4 | var isInputs = type == 'inputs' 5 | 6 | var headerSel = sel.append('div.sticky.connection-header') 7 | var headerLeftSel = headerSel.append('div.section-title.sticky') 8 | .text(isInputs ? 'Negative Input Features' : 'Positive Output Features') 9 | var headerRightSel = headerSel.append('div.section-title.sticky') 10 | .text(isInputs ? 'Positive Input Features' : 'Negative Output Features') 11 | 12 | var featureContainerSel = sel.append('div.feature-container') 13 | 14 | renderAll.clickIdx.fns.push(() => { 15 | var features = allFeatures.idxLookup[visState.clickIdx][type] 16 | .map(d => ({...d, layer: allFeatures.idxLookup[d.idx].layer})) 17 | 18 | var featureWidth = 100 19 | var gap = 20 20 | 21 | var availableWidth = (window.innerWidth - 320 - 90)/2 // middle col and gaps 22 | var nCols = 2 * Math.max(1, Math.floor(availableWidth/(2*(featureWidth + gap)))) // even number directly, min 1 per side 23 | var nCols = Math.max(4, nCols) 24 | 25 | d3.select('.add-main').st({minWidth: d3.select('.add-main').st({minWidth: ''}).node().offsetWidth}) 26 | 27 | d3.nestBy(features, d => d.strength > 0).forEach(signGroup => { 28 | var isPos = signGroup[0].strength > 0 29 | d3.sort(signGroup, d => -Math.abs(d.strength)).forEach((d, i) => { 30 | d.j = Math.floor(i/(nCols/2)) 31 | d.i = (isPos ^ isInputs ? 0 : nCols/2) + i%(nCols/2) 32 | }) 33 | }) 34 | 35 | sel.st({ 36 | width: featureWidth*nCols + gap*(nCols - 1) + 3, 37 | height: (d3.max(features, d => d.j) + 1)*(featureWidth + gap) 38 | }) 39 | 40 | headerLeftSel.st({ 41 | width: (featureWidth * nCols/2 + gap * (nCols/2 - 1)) + 'px', 42 | display: 'inline-block' 43 | }) 44 | headerRightSel.st({ 45 | width: (featureWidth * nCols/2 + gap * (nCols/2 - 1)) + 'px', 46 | display: 'inline-block', 47 | marginLeft: gap + 'px' 48 | }) 49 | 50 | var featureSel = featureContainerSel.html('').appendMany('div.feature', features) 51 | .call(utilAdd.attachFeatureEvents, {visState, renderAll}) 52 | .translate(d => [d.i*(featureWidth + gap), d.j*(featureWidth + gap)]) 53 | 54 | featureSel.append('div').each(function(d){ 55 | utilAdd.drawHeatmap(d3.select(this), d.idx, {isDelay: d.j > 6}) 56 | }) 57 | 58 | var featureLabelSel = featureSel.append('div.feature-label') 59 | featureLabelSel.append('span') 60 | .text(d => `L${allFeatures.idxLookup[d.idx].layer}#${d3.format('07d')(d.idx)}`) 61 | featureLabelSel.append('span.strength') 62 | .st({ 63 | background: d => utilAdd.color(d.strength), 64 | color: d => Math.abs(d.strength) < 0.6 ? '#000' : '#fff', 65 | }) 66 | .text(d => d3.format('+.2f')(d.strength)) 67 | }) 68 | 69 | renderAll.hoverIdx.fns.push(() => { 70 | sel.selectAll('.feature').classed('hovered', d => d.idx == visState.hoverIdx) 71 | }) 72 | } 73 | 74 | 75 | 76 | window.init?.() 77 | -------------------------------------------------------------------------------- /addition/init.js: -------------------------------------------------------------------------------- 1 | window.init = async function () { 2 | var features_enriched = await util.getFile('/data/addition/features_enriched.json') 3 | 4 | var allFeatures = window.allFeatures = features_enriched.features 5 | allFeatures.idxLookup = Object.fromEntries(allFeatures.map(d => [d.idx, d])) 6 | console.log(allFeatures) 7 | 8 | window.visState = window.visState || { 9 | clickIdx: util.params.get('clickIdx') || 17574692, 10 | hoverIdx: null 11 | } 12 | if (visState.clickIdx == 'undefined') visState.clickIdx = 17574692 13 | 14 | var renderAll = util.initRenderAll(['clickIdx', 'hoverIdx']) 15 | util.attachRenderAllHistory(renderAll, ['hoverIdx']) 16 | 17 | 18 | initBigAddFeature({visState, renderAll, allFeatures}) 19 | initAddConnections({visState, renderAll, allFeatures, type: 'inputs'}) 20 | initAddConnections({visState, renderAll, allFeatures, type: 'outputs'}) 21 | 22 | renderAll.clickIdx() 23 | } 24 | 25 | 26 | window.init() 27 | -------------------------------------------------------------------------------- /addition/style.css: -------------------------------------------------------------------------------- 1 | body{ 2 | margin: 0px; 3 | min-width: 1400px; 4 | } 5 | 6 | .link a{ 7 | color: #000; 8 | /* text-decoration: none; */ 9 | font-size: 12px; 10 | } 11 | 12 | .add-main { 13 | margin-top: 20px; 14 | display: flex; 15 | justify-content: center; 16 | gap: 30px; 17 | position: relative; 18 | height: 100vh; 19 | overflow-y: auto; 20 | min-width: 800px; 21 | 22 | 23 | .section-title{ 24 | margin-bottom: 10px; 25 | border-bottom: 1px solid #ddd; 26 | } 27 | 28 | .sticky{ 29 | top: 0; 30 | position: sticky; 31 | z-index: 100; 32 | background: #fff; 33 | } 34 | 35 | .add-big-feature{ 36 | position: sticky; 37 | top: 0; 38 | } 39 | 40 | .connection-header{ 41 | background: rgba(0,0,0,0); 42 | } 43 | .connection-header > div{ 44 | position: relative; 45 | left: 2px; 46 | top: -3px; 47 | padding-top: 3px; 48 | outline: 1px solid #fff; 49 | } 50 | 51 | .feature-container { 52 | position: relative; 53 | font-size: 8px; 54 | 55 | .feature { 56 | position: absolute; 57 | width: 100px; 58 | cursor: pointer; 59 | 60 | .feature-label { 61 | color: #777; 62 | text-align: center; 63 | display: flex; 64 | justify-content: space-between; 65 | position: relative; 66 | top: -2px; 67 | left: 2px; 68 | 69 | .strength { 70 | padding: 1px 3px; 71 | position: relative; 72 | right: -1px; 73 | } 74 | } 75 | } 76 | 77 | .hovered .feature-label { 78 | text-decoration: underline; 79 | color: #000; 80 | } 81 | } 82 | 83 | .umap{ 84 | circle{ 85 | cursor: pointer; 86 | } 87 | 88 | circle.hover{ 89 | stroke-width: 1px; 90 | animation: throb 0.5s ease-in-out infinite alternate; 91 | } 92 | circle.hover.unselected{ 93 | stroke: #000; 94 | } 95 | } 96 | 97 | } 98 | 99 | @keyframes throb { 100 | from { stroke-width: 2px; } 101 | to { stroke-width: 2.4px; } 102 | } 103 | 104 | .operand { 105 | .domain{ 106 | display: none; 107 | } 108 | 109 | .tick { 110 | text { 111 | font-size: 10px; 112 | fill: #777; 113 | } 114 | 115 | path { 116 | stroke: #eee; 117 | stroke-width: 1px; 118 | } 119 | } 120 | } 121 | 122 | .tooltip .section-title { 123 | margin-bottom: 10px; 124 | } 125 | -------------------------------------------------------------------------------- /addition/util-add.js: -------------------------------------------------------------------------------- 1 | window.utilAdd = (function(){ 2 | async function drawHeatmap(sel, id, {isBig, isDelay, s}){ 3 | s = s ?? (isBig ? 3 : 1) 4 | 5 | var margin = isBig ? {right: 0, top: 2, bottom: 40} : {top: 0, left: 2, bottom: 2, right: 0} 6 | var c = d3.conventions({ 7 | sel: sel.html('').classed('operand', 1), 8 | margin, 9 | width: s*100, 10 | height: s*100, 11 | layers: 'sc', 12 | }) 13 | 14 | // add axis 15 | c.x.domain([0, 100]) 16 | c.y.domain([0, 100]) 17 | 18 | var tickValues = d3.range(0, 110, isBig ? 10 : 20) 19 | var tickFormat = isBig ? d => d : d => '' 20 | c.xAxis.tickValues(tickValues).tickFormat(tickFormat).tickPadding(-2) 21 | c.yAxis.tickValues(tickValues).tickFormat(tickFormat).tickPadding(-2) 22 | 23 | c.drawAxis() 24 | c.svg.selectAll('.tick').selectAll('line').remove() 25 | c.svg.selectAll('.x .tick').append('path').at({d: `M 0 0 V ${-c.height}`}) 26 | c.svg.selectAll('.y .tick').append('path').at({d: `M 0 0 H ${c.width}`}) 27 | 28 | // load and draw data 29 | if (isDelay) await util.sleep(32) 30 | var gridData = await util.getFile(`/data/addition/heatmap/${id}.npy`) 31 | 32 | var maxVal = d3.max(gridData.data) 33 | maxVal = .15 34 | var colorScale = d3.scaleSequential(d3.interpolateOranges).domain([0, 1.4*maxVal]).clamp(1) 35 | 36 | for (var i = 0; i < 100*100; i++){ 37 | var v = gridData.data[i] 38 | if (v == 0) continue 39 | 40 | var row = Math.ceil(100 - i/100 - 1) 41 | var col = i % 100 42 | 43 | c.layers[1].fillStyle = colorScale(v) 44 | c.layers[1].fillRect(col*s, row*s, s, s) 45 | } 46 | } 47 | 48 | async function drawLogits(sel, id, {isBig, isDelay, s}){ 49 | s = s ?? (isBig ? 3 : 1) 50 | 51 | var margin = isBig ? {right: 0, top: 0, bottom: 40} : {top: 0, left: 2, bottom: 2, right: 0} 52 | var c = d3.conventions({ 53 | sel: sel.html('').classed('operand', 1).st({marginTop: -2}), 54 | margin, 55 | width: s*100, 56 | height: s*10, 57 | layers: 'sc', 58 | }) 59 | 60 | // add axis 61 | c.x.domain([0, 100]) 62 | c.y.domain([0, 10]) 63 | 64 | var tickValues = d3.range(0, 100, isBig ? 10 : 20) 65 | var tickFormat = isBig ? d => d : d => '' 66 | c.xAxis.tickValues(tickValues).tickFormat(d => '_' + (d ? d : '00')).tickPadding(-2) 67 | c.yAxis.tickValues([0, 4, 8]).tickFormat(d => d + '_ _').tickPadding(-2) 68 | 69 | c.drawAxis() 70 | c.svg.selectAll('.tick').selectAll('line').remove() 71 | c.svg.selectAll('.x .tick').append('path').at({d: `M 0 5 V ${0}`}) 72 | c.svg.selectAll('.x .tick').select('text').translate(5, 1) 73 | c.svg.selectAll('.y .tick').append('path').at({d: `M -5 0 H ${0}`}) 74 | c.svg.selectAll('.y .tick').select('text').translate(-5, 0) 75 | 76 | // load and draw data 77 | if (isDelay) await util.sleep(32) 78 | var gridData = await util.getFile(`/data/addition/effects/${id}.npy`) 79 | 80 | var mean = d3.mean(gridData.data) 81 | values = gridData.data.map(d => d - mean) 82 | 83 | var maxVal = d3.max(values) 84 | var colorScale = d3.scaleDiverging(d3.interpolatePRGn).domain([maxVal, 0, -maxVal]).clamp(1) 85 | 86 | for (var i = 0; i < 100*10; i++){ 87 | var v = values[i] 88 | var row = Math.ceil(10 - i/100 - 1) 89 | var col = i % 100 90 | 91 | 92 | c.layers[1].fillStyle = colorScale(v) 93 | c.layers[1].fillRect(col*s, row*s, s, s) 94 | } 95 | } 96 | 97 | 98 | function drawUmap(sel, {allFeatures, visState, renderAll, type='enc'}){ 99 | var c = d3.conventions({ 100 | sel: sel.html(''), 101 | width: 300, 102 | height: 200, 103 | margin: {left: 0, top: 0, right: 0, bottom: 0}, 104 | }) 105 | 106 | var points = allFeatures 107 | .map(d => ({ 108 | idx: d.idx, 109 | x: d['umap_' + type][0], 110 | y: d['umap_' + type][1], 111 | d, 112 | })) 113 | .filter(d => d.d.inputs.length + d.d.outputs.length > 0) 114 | 115 | c.x.domain(d3.extent(points, d => d.x)) 116 | c.y.domain(d3.extent(points, d => d.y)) 117 | 118 | var pointSel = c.svg.appendMany('circle', points) 119 | .translate(d => [c.x(d.x), c.y(d.y)]) 120 | .at({r: 2, fill: '#000', fillOpacity: .2, stroke: '#000'}) 121 | .call(attachFeatureEvents, {visState, renderAll}) 122 | 123 | renderAll.hoverIdx.fns.push(() => { 124 | pointSel.classed('hover', d => d.idx == visState.hoverIdx) 125 | }) 126 | 127 | renderAll.clickIdx.fns.push(() => { 128 | var clickFeature = allFeatures.idxLookup[visState.clickIdx] 129 | var idx2strength = {} 130 | clickFeature.inputs.forEach(d => idx2strength[d.idx] = d.strength) 131 | clickFeature.outputs.forEach(d => idx2strength[d.idx] = d.strength) 132 | 133 | points.forEach(d => { 134 | d.isClicked = d.idx == visState.clickIdx 135 | d.strength = d.isClicked ? 9999 : idx2strength[d.idx] || 0 136 | d.fill = d.isClicked ? '#000' : d.strength ? utilAdd.color(d.strength) : '#fff' 137 | }) 138 | 139 | pointSel.at({ 140 | fill: d => d.fill, 141 | r: d => d.strength ? 4 : 1, 142 | fillOpacity: d => d.idx == visState.clickIdx || d.fill != '#fff' ? 1 : 0, 143 | stroke: d => { 144 | if (d.idx == visState.clickIdx) return '#000' 145 | if (d.fill == '#fff') return 'rgba(0,0,0,0.2)' 146 | return d3.rgb(d.fill).darker(3) 147 | } 148 | }).classed('unselected', d => d.fill == '#fff') 149 | }) 150 | } 151 | 152 | function attachFeatureEvents(sel, {visState, renderAll}){ 153 | sel 154 | .call(d3.attachTooltip) 155 | .on('mouseover', (e, d) => { 156 | d = d.d || d 157 | 158 | visState.hoverIdx = d.idx 159 | renderAll.hoverIdx() 160 | 161 | var ttSel = d3.select('.tooltip').html('').st({padding: 20, paddingBottom: 0, paddingTop: 10}) 162 | 163 | ttSel.append('div.section-title') 164 | .text(`L${d.layer}#${d3.format('07d')(d.idx)}`) 165 | utilAdd.drawHeatmap(ttSel.append('div'), d.idx, {isBig: true, s: 2}) 166 | 167 | ttSel.append('div.section-title').text('Token Predictions') 168 | utilAdd.drawLogits(ttSel.append('div'), d.idx, {isBig: true, s: 2}) 169 | }) 170 | .on('mouseleave', () => { 171 | visState.hoverIdx = null 172 | renderAll.hoverIdx() 173 | }) 174 | .on('click', (e, d) => { 175 | visState.clickIdx = d.idx 176 | renderAll.clickIdx() 177 | }) 178 | } 179 | 180 | return { 181 | drawHeatmap, 182 | drawLogits, 183 | drawUmap, 184 | color: d3.scaleDiverging(d3.interpolatePRGn).domain([-1, 0, 1]), 185 | attachFeatureEvents 186 | } 187 | })() 188 | 189 | window.init?.() 190 | -------------------------------------------------------------------------------- /attribution_graph/cg.css: -------------------------------------------------------------------------------- 1 | 2 | .cg{ 3 | .feature-detail{ 4 | margin-top: 10px; 5 | } 6 | 7 | 8 | line-height: normal !important; 9 | 10 | text { 11 | cursor: default; 12 | /* text-shadow: 0 1px 0 #fff, 1px 0 0 #fff, 0 -1px 0 #fff, -1px 0 0 #fff; */ 13 | user-select: none; 14 | } 15 | 16 | svg { 17 | overflow: visible; 18 | } 19 | 20 | .h3 { 21 | margin: 0 0 6px 0; 22 | font-size: 15px; 23 | font-weight: 600; 24 | } 25 | 26 | .h4 { 27 | margin: 0 0 0 0; 28 | font-size: 13px; 29 | font-weight: 500; 30 | } 31 | 32 | .axis { 33 | .domain { display: none; } 34 | 35 | text { fill: #777; } 36 | } 37 | .prompt-token text { 38 | fill: #777; 39 | text-shadow: 0 1px 0 #fff, 1px 0 0 #fff, 0 -1px 0 #fff, -1px 0 0 #fff; 40 | } 41 | 42 | .gridsnap-container { 43 | .grid-item { 44 | outline: 0px solid #fff; 45 | } 46 | } 47 | 48 | 49 | .link-graph, .feature-scatter, .feature-umap, .feature-diff-scatter, .link-diff-scatter { 50 | canvas{ 51 | pointer-events: none;; 52 | } 53 | overflow: visible; 54 | .node { 55 | cursor: pointer; 56 | } 57 | 58 | 59 | .node.clicked { 60 | stroke-width: 1.5px; 61 | stroke: #f0f; 62 | } 63 | 64 | .node.pinned { 65 | stroke: #000; 66 | stroke-width: 1.8px; 67 | } 68 | 69 | .node.hidden { 70 | opacity: .5; 71 | r: .5; 72 | } 73 | 74 | .node.pinned.clicked { 75 | stroke-width: 2.3px; 76 | stroke: #f0f; 77 | /* stroke: url(#pinned-clicked-gradient); */ 78 | } 79 | } 80 | 81 | .link-graph .node.hidden { 82 | opacity: 1 !important; 83 | font-size: 3.5px; 84 | } 85 | 86 | /* This is here for the diff view */ 87 | .feature-detail { 88 | height: 100%; 89 | display: flex; 90 | flex-direction: column; 91 | overflow: hidden; 92 | 93 | .h4 { 94 | padding-bottom: 5px; 95 | border-bottom: solid 1px #ccc; 96 | } 97 | 98 | .no-selected-feature{ 99 | margin: 0 0 0 0; 100 | font-size: 13px; 101 | font-weight: 500; 102 | } 103 | 104 | /* Header section */ 105 | .feature-header { 106 | flex-shrink: 0; 107 | text-overflow: ellipsis; 108 | white-space: nowrap; 109 | overflow-x: hidden; 110 | padding: 5px 0 0 0; 111 | } 112 | 113 | .header-top-row { 114 | display: flex; 115 | align-items: center; 116 | column-gap: 10px; 117 | height: 20px; 118 | } 119 | 120 | .feature-title { 121 | flex: 0 0 auto; 122 | font-size: 13px; 123 | font-weight: 500; 124 | cursor: pointer; 125 | margin-left: 1px; 126 | padding: 1px 1px; 127 | 128 | /* disables feature link */ 129 | a { 130 | pointer-events: none; 131 | text-decoration: none; 132 | } 133 | } 134 | .feature-title.hovered{ 135 | text-decoration: underline; 136 | text-decoration-color: #f0f; 137 | } 138 | .feature-title.pinned{ 139 | outline: 1px solid #000; 140 | } 141 | 142 | .feature-link { 143 | flex: 0 0 auto; 144 | font-size: 13px; 145 | margin-left: 5px; 146 | cursor: pointer; 147 | a { 148 | color: inherit; 149 | text-decoration: none; 150 | } 151 | &:hover { 152 | text-decoration: underline; 153 | } 154 | } 155 | 156 | /* CLERP sections */ 157 | .pp-clerp { 158 | flex: 1 1 auto; 159 | white-space: nowrap; 160 | overflow: hidden; 161 | text-overflow: ellipsis; 162 | text-align: right; 163 | } 164 | 165 | .pclerp { 166 | display: flex; 167 | padding: 0 0 6px 0; 168 | overflow: hidden; 169 | text-overflow: ellipsis; 170 | align-items: baseline; 171 | 172 | > div { 173 | flex: 1 0 0; 174 | overflow: hidden; 175 | text-overflow: ellipsis; 176 | } 177 | } 178 | 179 | .edit-clerp-button { 180 | flex: 0 0 auto; 181 | overflow: hidden; 182 | } 183 | 184 | button{ 185 | font-size: 10px; 186 | padding: 0px 3px; 187 | } 188 | 189 | .clerp-edit { 190 | padding: 10px 0; 191 | flex-direction: column; 192 | gap: 8px; 193 | font-size: 12px; 194 | display: flex; 195 | } 196 | 197 | 198 | /* Examples section */ 199 | .feature-examples-container { 200 | flex: 1; 201 | overflow-y: auto; 202 | } 203 | 204 | /* Logits section */ 205 | .logits-container { 206 | padding-top: 6px; 207 | margin-bottom: 10px; 208 | font-size: 11px; 209 | display: flex; 210 | flex: 0 0 auto; 211 | flex-direction: column; 212 | overflow-y: hidden; 213 | 214 | .effects { 215 | margin-top: 6px; 216 | } 217 | 218 | .sign { 219 | display: flex; 220 | align-items: center; 221 | gap: 3px; 222 | overflow: hidden; 223 | } 224 | 225 | .label { 226 | flex: 0 0 80px; 227 | color: hsl(0 0 0 / 0.5); 228 | } 229 | .rows { 230 | flex: 1 1 0; 231 | display: flex; 232 | gap: 3px; 233 | overflow: hidden; 234 | } 235 | .row { 236 | display: flex; 237 | flex-direction: column; 238 | align-items: flex-start; 239 | margin-top: 1px; 240 | padding-bottom: 1px; 241 | } 242 | 243 | .key { 244 | font-family: monospace; 245 | flex-grow: 1; 246 | color: hsl(0 0 0 / 0.5); 247 | background: hsl(0 0 0 / 0.08); 248 | border-radius: 4px; 249 | padding: 1px 3px; 250 | height: 14px; 251 | } 252 | 253 | .value { 254 | display: none; 255 | font-size: 9px; 256 | color: hsl(0 0 0 / 0.4) 257 | } 258 | } 259 | } 260 | 261 | 262 | .metadata { 263 | overflow-y: auto; 264 | font-family: monospace; 265 | font-size: 10px; 266 | color: #777; 267 | 268 | .row { 269 | padding-left: 15px; 270 | position: relative; 271 | } 272 | 273 | .toggle { 274 | position: absolute; 275 | left: 2px; 276 | cursor: pointer; 277 | user-select: none; 278 | } 279 | 280 | .children { 281 | padding-left: 15px; 282 | } 283 | } 284 | 285 | 286 | .button-container { 287 | display: flex; 288 | padding-left: 2px; 289 | overflow: hidden; 290 | 291 | .link-type-buttons, .toggle-buttons { 292 | display: flex; 293 | margin-right: 10px; 294 | 295 | div { 296 | margin: 0; 297 | padding: 8px; 298 | border: 1px solid #ccc; 299 | background: #fff; 300 | cursor: pointer; 301 | margin-left: -1px; 302 | user-select: none; 303 | display: flex; 304 | align-items: center; 305 | text-align: center; 306 | font-size: 12px; 307 | 308 | &:hover{ 309 | background: #eee; 310 | } 311 | 312 | &.active{ 313 | background: #000; 314 | color: #fff; 315 | border-color: #000; 316 | } 317 | } 318 | } 319 | 320 | .toggle-button:not(:last-child) { 321 | border-right: none; 322 | } 323 | 324 | .toggle-button:hover { 325 | background-color: #e0e0e0; 326 | } 327 | } 328 | 329 | 330 | 331 | 332 | .clerp-list{ 333 | overflow-y: auto; 334 | overflow-x: scroll; 335 | padding: 0px; 336 | 337 | .feature-indicator { 338 | width: 10px; 339 | height: 10px; 340 | border-radius: 50%; 341 | margin-right: 5px; 342 | display: inline-block; 343 | } 344 | } 345 | 346 | .clerp-list { 347 | font-size: 12px; 348 | } 349 | 350 | .feature-scatter, .feature-umap{ 351 | overflow: hidden; 352 | display: flex; 353 | flex-direction: column; 354 | height: 100%; 355 | 356 | .chart-container { 357 | flex-grow: 1; 358 | width: 100%; 359 | height: 100%; 360 | 361 | > div { 362 | width: 100%; 363 | height: 100%; 364 | } 365 | } 366 | select{ 367 | font-family: system-ui; 368 | font-size: 12px; 369 | max-width: 100%; 370 | } 371 | } 372 | 373 | .feature-umap{ 374 | .button { 375 | display: inline-block; 376 | cursor: pointer; 377 | user-select: none; 378 | opacity: .4; 379 | margin-right: 20px; 380 | font-size: 12px; 381 | } 382 | 383 | .button.active { 384 | opacity: 1; 385 | } 386 | 387 | > div:first-child { 388 | display: flex; 389 | flex-direction: row; 390 | justify-content: center; 391 | } 392 | } 393 | 394 | .subgraph { 395 | overflow: hidden; 396 | /* outline: 1px solid #eee; */ 397 | 398 | marker{ 399 | orient: auto; 400 | } 401 | 402 | svg{ 403 | z-index: -999999999; 404 | } 405 | 406 | .weight-label{ 407 | /* text-shadow: #fff .2px 0 .2px, #fff -.2px 0 .2px, #fff 0 .2px .2px, #fff 0 -.2px .2px; */ 408 | font-size: 9px; 409 | font-family: system-ui; 410 | text-anchor: middle; 411 | alignment-baseline: bottom; 412 | dy: -.5em; 413 | opacity: .8; 414 | fill: #999; 415 | } 416 | .layer-label{ 417 | color: #999; 418 | opacity: .8; 419 | margin-right: 2px; 420 | } 421 | 422 | .supernode-container { 423 | position: absolute; 424 | /* border: 1px solid #ddd; */ 425 | font-size: 9px; 426 | line-height: 1em; 427 | cursor: pointer; 428 | pointer-events: none; 429 | height: 1px; 430 | width: 1px; 431 | 432 | .node-text-container{ 433 | /* overflow: hidden; */ 434 | text-align: center; 435 | position: relative; 436 | top: 22px; 437 | text-shadow: 0 1px 0 #fff, 1px 0 0 #fff, 0 -1px 0 #fff, -1px 0 0 #fff, 1px 1px 0 #fff; 438 | pointer-events: all; 439 | color: #777; 440 | /* font-size: 12px; */ 441 | } 442 | 443 | .temp-edit { 444 | width: 100px; 445 | font: inherit; 446 | border: none; 447 | padding: 0; 448 | margin: 0; 449 | background: none; 450 | outline: none; 451 | display: inline; 452 | } 453 | 454 | .clicked-weight { 455 | position: absolute; 456 | right: 0px; 457 | padding: 1px 1px; 458 | font-size: 8px; 459 | overflow: visible; 460 | } 461 | .clicked-weight.source { 462 | bottom: -12px; 463 | } 464 | .clicked-weight.target { 465 | top: -12px; 466 | } 467 | 468 | .member-circles { 469 | pointer-events: all; 470 | position: absolute; 471 | top: 0; 472 | left: 50%; 473 | transform: translate(-50%); 474 | border: 1px solid #ddd; 475 | border-radius: 4px; 476 | padding: 4px; 477 | display: flex; 478 | background: #E5E3D7; 479 | outline: 1px solid #6F6D5E; 480 | } 481 | 482 | .member-circle { 483 | width: 8px; 484 | height: 8px; 485 | border-radius: 50%; 486 | background: #fff; 487 | display: inline-block; 488 | border: 1px solid #000; 489 | outline: .5px solid #000; 490 | z-index: 100 !important; 491 | } 492 | 493 | .not-clt-feature.member-circle{ 494 | border-radius: 1.5px; 495 | } 496 | 497 | .member-circle:first-child { 498 | margin-left: 0; 499 | } 500 | } 501 | 502 | .ungroup-btn { 503 | display: none; 504 | } 505 | 506 | &.is-grouping .ungroup-btn { 507 | display: block; 508 | } 509 | 510 | 511 | .hovered.node .member-circles, .hovered.member-circle { 512 | border-color: rgba(0,0,0,0); 513 | outline: 1.5px dotted #f0f; 514 | z-index: 100; 515 | } 516 | 517 | .clicked.node .member-circles, .clicked.member-circle { 518 | border-color: rgba(0,0,0,0); 519 | outline: 1.5px solid #f0f; 520 | } 521 | 522 | .grouping-selected .member-circles{ 523 | outline: 2px solid #0ff; 524 | } 525 | 526 | .checkbox-container{ 527 | user-select: none; 528 | position: absolute; 529 | bottom: 0px; 530 | color: #999; 531 | 532 | input{ 533 | accent-color: #999; 534 | } 535 | 536 | label:not(:first-child){ 537 | margin-left: 10px; 538 | display: inline-block; 539 | } 540 | } 541 | 542 | .is-supernode { 543 | padding-bottom: 15px; 544 | } 545 | 546 | 547 | .ungroup-btn { 548 | position: absolute; 549 | left: 0px; 550 | bottom: 0px; 551 | cursor: pointer; 552 | color: #0bb; 553 | z-index: 100; 554 | font-size: 12px; 555 | padding: 2px; 556 | 557 | &:hover { 558 | opacity: 1; 559 | color: #000; 560 | background: #0ff; 561 | } 562 | } 563 | 564 | } 565 | 566 | .feature-row { 567 | display: flex; 568 | align-items: center; 569 | gap: 4px; 570 | font-size: 11px; 571 | overflow: hidden; 572 | padding: 1px 3px; 573 | margin: 1px 0; 574 | border-radius: 8px; 575 | cursor: pointer; 576 | 577 | svg { 578 | width: 10px; 579 | height: 10px; 580 | margin-right: 4px; 581 | } 582 | .default-icon { 583 | fill: none; 584 | stroke: grey; 585 | stroke-width: .7; 586 | display: block; 587 | } 588 | 589 | .ctx-offset { 590 | flex: 1 0 5px; 591 | opacity: 0.5; 592 | } 593 | 594 | .layer { 595 | flex: 1 0 25px; 596 | opacity: 0.5; 597 | } 598 | 599 | .label { 600 | flex: 1 1 100%; 601 | overflow: hidden; 602 | white-space: nowrap; 603 | text-overflow: ellipsis; 604 | } 605 | .weight { 606 | flex: 1 0 35px; 607 | font-variant-numeric: tabular-nums; 608 | } 609 | &.hovered { 610 | outline: 3px dotted #f0f; 611 | z-index: 99; 612 | z-index: 100; 613 | position: relative; 614 | } 615 | &.clicked { 616 | outline: 1px solid #f0f; 617 | z-index: 999; 618 | } 619 | &.pinned .default-icon { 620 | stroke: #000 !important; 621 | stroke-width: 1.8px; 622 | } 623 | &.clicked circle { 624 | stroke: magenta; 625 | stroke-width: 1px; 626 | } 627 | 628 | &.hidden { 629 | opacity: .5; 630 | font-size: 7px; 631 | padding-top: 0px; 632 | padding-bottom: 0px; 633 | } 634 | 635 | } 636 | 637 | .feature-examples { 638 | 639 | .example-quantile { 640 | margin-bottom: 5px !important; 641 | } 642 | & .example-quantile { 643 | margin-top: 10px; 644 | 645 | .quantile-title { 646 | font-weight: 500 !important; 647 | font-size: 13px; 648 | /* margin-bottom: 5px; */ 649 | width: 100%; 650 | padding-bottom: 5px; 651 | border-bottom: solid 1px #eee; 652 | display: block; 653 | } 654 | } 655 | 656 | .example-2-col > div{ 657 | margin-bottom: 20px;; 658 | } 659 | 660 | .feature-example-logits{ 661 | .ctx-container{ 662 | border-bottom: solid 0px #fff !important; 663 | margin-bottom: -4px; 664 | } 665 | } 666 | } 667 | 668 | .node-connections { 669 | .header-top-row{ 670 | user-select: none; 671 | } 672 | .feature-icon { 673 | color: #fff; 674 | -webkit-text-stroke: 1.5px #f0f; 675 | text-align: center; 676 | vertical-align: middle; 677 | margin-right: 5px; 678 | position: relative; 679 | top: -2px; 680 | font-size: 10px; 681 | } 682 | .feature-title{ 683 | font-weight: 500 !important; 684 | font-size: 13px; 685 | margin-bottom: 10px; 686 | white-space: nowrap; 687 | overflow: hidden; 688 | text-overflow: ellipsis; 689 | } 690 | .pinned .feature-icon{ 691 | -webkit-text-stroke: 2.5px #f0f; 692 | } 693 | 694 | .connections { 695 | flex: 1 0 auto; 696 | display: flex; 697 | overflow-y: hidden; 698 | overflow-x: visible; 699 | gap: 200px; 700 | 701 | .features { 702 | display: flex; 703 | flex: 1 1 auto; 704 | flex-direction: column; 705 | min-width: 20px; 706 | width: 100%; 707 | } 708 | 709 | 710 | .effects { 711 | font-size: 11px; 712 | flex: 1 0 0; 713 | display: flex; 714 | flex-direction: column; 715 | overflow-y: scroll; 716 | overflow-x: visible !important; 717 | /* border-bottom: solid 1px #eee; */ 718 | padding-top: 8px; 719 | } 720 | 721 | .h4 { 722 | padding-bottom: 5px; 723 | border-bottom: solid 1px #eee; 724 | } 725 | 726 | .h4 label { 727 | display: flex; 728 | align-items: center; 729 | gap: 6px; 730 | } 731 | 732 | input { 733 | margin: 0; 734 | } 735 | } 736 | .sum-table table { 737 | width: 100%; 738 | border-collapse: collapse; 739 | font-weight: 300; 740 | color: rgb(85, 85, 85); 741 | line-height: 1em; 742 | margin-top: 4px; 743 | border-width: 0px !important; 744 | margin-bottom: 0px; 745 | line-height: 1em; 746 | 747 | th, td { 748 | text-align: center; 749 | font-size: 11px !important; 750 | border-bottom-color: #fff !important; 751 | line-height: 1em; 752 | } 753 | 754 | th:not(:first-child), td:not(:first-child){ 755 | padding-left: 0px !important; 756 | padding: 0px 10px !important; 757 | } 758 | 759 | th { 760 | font-weight: 400; 761 | } 762 | 763 | tr{ 764 | border-color: #fff !important; 765 | } 766 | 767 | tr.data td{ 768 | font-variant-numeric: tabular-nums; 769 | } 770 | } 771 | } 772 | } 773 | -------------------------------------------------------------------------------- /attribution_graph/gridsnap/gridsnap.css: -------------------------------------------------------------------------------- 1 | 2 | .gridsnap-container { 3 | position: relative; 4 | width: 100%; 5 | height: 100%; 6 | 7 | .grid-item { 8 | position: absolute; 9 | outline: 1px solid #eee; 10 | } 11 | 12 | .dragging { 13 | opacity: .5; 14 | z-index: 100; 15 | } 16 | 17 | .resize-handle, 18 | .move-handle { 19 | position: absolute; 20 | bottom: 0; 21 | width: 20px; 22 | height: 20px; 23 | font-size: 12px; 24 | user-select: none; 25 | display: flex; 26 | align-items: center; 27 | justify-content: center; 28 | background: #000; 29 | color: #fff; 30 | display: none; 31 | pointer-events: none; 32 | z-index: 100000000; 33 | 34 | .grid-item:hover & { 35 | display: flex; 36 | pointer-events: auto; 37 | } 38 | } 39 | 40 | .resize-handle { 41 | right: 0; 42 | cursor: se-resize; 43 | } 44 | 45 | .move-handle { 46 | right: 20px; 47 | cursor: move; 48 | } 49 | 50 | .preview { 51 | position: absolute; 52 | background: yellow; 53 | opacity: .3; 54 | pointer-events: none; 55 | z-index: 10; 56 | } 57 | 58 | .grid-contents{ 59 | width: 100% !important; 60 | height: 100% !important; 61 | } 62 | } 63 | 64 | .gridsnap:not(.is-edit-mode) { 65 | .resize-handle, .move-handle { 66 | display: none !important; 67 | } 68 | } 69 | 70 | 71 | 72 | 73 | .gridsnap-container.dragging{ 74 | .grid-item{ 75 | outline: 3px solid #000; 76 | } 77 | .grid-contents{ 78 | pointer-events: none; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /attribution_graph/gridsnap/init-gridsnap.js: -------------------------------------------------------------------------------- 1 | window.initGridsnap = function({ 2 | gridData = [], 3 | maxX = null, 4 | isFullScreenY = true, 5 | pad = 10, 6 | gridSizeY = 80, 7 | isFillContainer, 8 | sel = d3.select('.gridsnap'), 9 | repositionFn = null, 10 | serializedGrid = '', 11 | } = {}){ 12 | var gridsnap = {gridData, serializeGrid, deserializeGrid} 13 | 14 | gridData.forEach((d, i) => { 15 | d.next = {...d.cur} 16 | d.class = d.class === undefined ? i : d.class 17 | }) 18 | 19 | var maxX = maxX || d3.max(gridData, d => d.cur.x + d.cur.w) 20 | function calcgridSizeX(){ 21 | return (sel.node().offsetWidth)/maxX 22 | } 23 | var gridSizeX = calcgridSizeX() 24 | 25 | // TODO: bubble events first 26 | function calcGridSizeY() { 27 | if (!isFillContainer) return gridSizeY 28 | return sel.node().offsetHeight/(d3.max(gridData, d => d.cur.y + d.cur.h) || 1) 29 | } 30 | 31 | gridSizeY = calcGridSizeY() 32 | 33 | var resizeKey = 'resize.gridsnap' + serializedGrid 34 | d3.select(window).on(resizeKey, util.throttle(() => { 35 | var newX = calcgridSizeX() 36 | var newY = calcGridSizeY() 37 | if (newX == gridSizeX && newY == gridSizeY) return 38 | 39 | gridSizeX = newX 40 | gridSizeY = newY 41 | renderPositions() 42 | 43 | 44 | gridItemSel.each(d => d.resizeFn?.()) 45 | }, 500)) 46 | 47 | var gridsnapSel = sel.html('').append('div.gridsnap-container') 48 | 49 | var gridItemSel = gridsnapSel.appendMany('div.grid-item', gridData) 50 | .each(function(d){ d.sel = d3.select(this).append('div.grid-contents').classed(d.class, 1) }) 51 | 52 | gridItemSel.append('div.move-handle') 53 | .text('✣') 54 | .call(makeDragFn(false)) 55 | gridItemSel.append('div.resize-handle') 56 | .text('↘') 57 | .call(makeDragFn(true)) 58 | 59 | var previewSel = gridsnapSel.append('div.preview.grid-item') 60 | 61 | function makeDragFn(isResize) { 62 | return d3.drag() 63 | .subject((ev, d) => ({ 64 | x: (d.cur.x + (isResize ? d.cur.w : 0))*gridSizeX, 65 | y: (d.cur.y + (isResize ? d.cur.h : 0))*gridSizeY 66 | })) 67 | .container(function(){ return this.parentNode.parentNode }) 68 | .on('start', function(ev, d){ 69 | gridData.forEach(d => d.next = {...d.cur}) 70 | d.dragStart = {...d.cur} 71 | 72 | gridsnapSel.classed('dragging', 1) 73 | d3.select(this.parentNode).classed('dragging', 1) 74 | previewSel.st({'display': ''}) 75 | }) 76 | .on('end', (ev, d) => { 77 | gridData.forEach(d => d.cur = {...d.next}) 78 | 79 | gridsnapSel.classed('dragging', 0) 80 | gridItemSel.classed('dragging', 0) 81 | previewSel.st({'display': 'none'}) 82 | renderPositions() 83 | 84 | if (isResize) d.resizeFn?.() 85 | }) 86 | .on('drag', isResize ? resize : drag) 87 | 88 | function drag(ev, d) { 89 | d.cur.x = ev.x/gridSizeX 90 | d.cur.y = ev.y/gridSizeY 91 | 92 | pushGrid(d) 93 | renderPositions(d) 94 | } 95 | 96 | function resize(ev, d) { 97 | d.cur.x = d.dragStart.x 98 | d.cur.y = d.dragStart.y 99 | 100 | d.cur.w = ev.x/gridSizeX - d.dragStart.x 101 | d.cur.h = ev.y/gridSizeY - d.dragStart.y 102 | if (d.cur.w < 0) { 103 | d.cur.x += d.cur.w 104 | d.cur.w = d.dragStart.w - d.cur.w 105 | } 106 | if (d.cur.h < 0) { 107 | d.cur.y += d.cur.h 108 | d.cur.h = d.dragStart.h - d.cur.h 109 | } 110 | 111 | pushGrid(d) 112 | renderPositions(d) 113 | } 114 | } 115 | 116 | 117 | function pushGrid(active) { 118 | if (active) active.next = snapToGrid(active.cur) 119 | 120 | var sortedGridData = d3.sort(gridData, d => d != active) 121 | sortedGridData = d3.sort(sortedGridData, d => d == active ? d.cur.y : d.next.y) 122 | 123 | var topArray = d3.range(maxX).map(d => 0) 124 | sortedGridData.forEach(d => { 125 | var {x, y, w, h} = d.next 126 | d.next.y = d3.max(d3.range(w), i => topArray[x + i]) 127 | d3.range(w).forEach(i => topArray[x + i] = d.next.y + h) 128 | }) 129 | 130 | function snapToGrid({x, y, w, h}) { 131 | var rv = {x: Math.max(0, Math.round(x)), y: Math.max(0, Math.round(y)), w: Math.max(1, Math.round(w)), h: Math.max(1, Math.round(h))} 132 | if (rv.x + rv.w > maxX) rv.x = Math.max(0, maxX - rv.w) 133 | return rv 134 | } 135 | } 136 | 137 | function renderPositions(active){ 138 | gridItemSel.call(renderGridItem, 'next') 139 | 140 | if (active){ 141 | gridItemSel.filter(d => d == active).call(renderGridItem, 'cur') 142 | previewSel.datum(active).call(renderGridItem, 'next') 143 | } else{ 144 | if (!isFillContainer){ 145 | var maxY = Math.max(maxY, d3.max(gridData, d => d.next.y + d.next.h)) 146 | gridsnapSel.st({height: Math.max(isFullScreenY ? window.innerHeight : 0, maxY*gridSizeY + pad) + 'px'}) 147 | } 148 | } 149 | 150 | repositionFn?.(gridsnap) 151 | 152 | function renderGridItem(itemSel, key) { 153 | itemSel 154 | .translate(d => [d[key].x*gridSizeX + pad/2, d[key].y*gridSizeY + pad/2].map(Math.round)) 155 | .st({ 156 | width: d => Math.round(Math.max(0, d[key].w*gridSizeX - pad)), // negative sizes bug out 157 | height: d => Math.round(Math.max(0, d[key].h*gridSizeY - pad)) 158 | }) 159 | } 160 | } 161 | 162 | function serializeGrid(){ 163 | return gridData.map(d => { 164 | var {x, y, w, h} = d.cur 165 | return `${d.class}${[x, y, w, h].map(d3.format('02d')).join('')}` 166 | }).join('_') 167 | } 168 | 169 | function deserializeGrid(serializedGrid){ 170 | serializedGrid?.split('_').forEach(str => { 171 | var match = str.match(/^([\w-]+)(\d{8})$/) 172 | if (!match) return 173 | var [_, className, coords] = match 174 | 175 | var [x, y, w, h] = d3.range(4).map(i => +coords.substr(i*2, 2)) 176 | var gridItem = gridData.find(d => d.class == className) 177 | if (gridItem) gridItem.next = {x, y, w, h} 178 | }) 179 | 180 | pushGrid() 181 | gridData.forEach(d => d.cur = {...d.next}) 182 | renderPositions() 183 | } 184 | 185 | deserializeGrid(serializedGrid) 186 | 187 | return gridsnap 188 | } 189 | 190 | window.init?.() 191 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-button-container.js: -------------------------------------------------------------------------------- 1 | window.initCgButtonContainer = function({visState, renderAll, cgSel}){ 2 | var buttonContainer = cgSel.select('.button-container').html('') 3 | .st({marginBottom: '10px'}) 4 | 5 | var linkTypeSel= buttonContainer.append('div.link-type-buttons') 6 | .appendMany('div', ['input', 'output', 'either', 'both']) 7 | .text(d => d[0].toUpperCase() + d.slice(1).toLowerCase()) 8 | .on('click', (ev, d) => { 9 | visState.linkType = d 10 | renderAll.linkType() 11 | }) 12 | 13 | renderAll.linkType.fns.push(() => { 14 | linkTypeSel.classed('active', d => d === visState.linkType) 15 | }) 16 | var showAllSel = buttonContainer.append('div.toggle-buttons') 17 | .append('div').text('Show all links') 18 | .on('click', () => { 19 | visState.isShowAllLinks = visState.isShowAllLinks ? '' : '1' 20 | renderAll.isShowAllLinks() 21 | }) 22 | 23 | renderAll.isShowAllLinks.fns.push(() => { 24 | showAllSel.classed('active', visState.isShowAllLinks) 25 | }) 26 | 27 | var clearButtonsSel = buttonContainer.append('div.toggle-buttons') 28 | .appendMany('div', ['Clear pinned', 'Clear clicked']) 29 | .text(d => d) 30 | .on('click', (ev, d) => { 31 | if (d == 'Clear pinned') { 32 | visState.pinnedIds = [] 33 | renderAll.pinnedIds() 34 | } else { 35 | visState.clickedId = '' 36 | renderAll.clickedId() 37 | } 38 | }) 39 | 40 | cgSel.on('keydown.esc-check', ev => { 41 | if (ev.key == 'Escape') { 42 | visState.clickedId = '' 43 | renderAll.clickedId() 44 | } 45 | }) 46 | 47 | var resetGridSel = buttonContainer.append('div.toggle-buttons') 48 | .append('div').text('Reset grid') 49 | .on('click', () => { 50 | util.params.set('gridsnap', '') // TODO: this won't work with baked in features 51 | window.location.reload() 52 | }) 53 | 54 | var onSyncValue = visState.isSyncEnabled || '1' 55 | var syncButtonSel = buttonContainer.append('div.toggle-buttons') 56 | .append('div').text('Enable sync') 57 | .on('click', () => { 58 | visState.isSyncEnabled = visState.isSyncEnabled ? '' : onSyncValue 59 | renderAll.isSyncEnabled() 60 | }) 61 | 62 | renderAll.isSyncEnabled.fns.push(() => { 63 | syncButtonSel.classed('active', visState.isSyncEnabled) 64 | }); 65 | 66 | } 67 | 68 | window.init?.() 69 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-clerp-list.js: -------------------------------------------------------------------------------- 1 | window.initCgClerpList = function({visState, renderAll, data, cgSel}){ 2 | let itemSel, weightValSel; 3 | let clerpListSel = cgSel.select('.clerp-list'); 4 | 5 | const tokenValues = data.metadata.prompt_tokens; 6 | const featureById = d3.group(data.features, d => d.featureId); 7 | 8 | function render() { 9 | let nodesByTokenByLayer = d3.group(data.nodes, d => d.ctx_idx, d => d.streamIdx); 10 | 11 | const finalData = Array.from(nodesByTokenByLayer.entries()) 12 | .sort((a, b) => a[0] - b[0]) 13 | .map(d => { 14 | const layers = Array.from(d[1]) 15 | return { 16 | token: tokenValues[d[0]], 17 | values: layers.sort((a, b) => a[0] - b[0]), 18 | } 19 | }) 20 | 21 | clerpListSel.html('') 22 | .st({ padding: '2px' }); 23 | 24 | const featuresSel = clerpListSel.append('div.features') 25 | .st({columns: '220px', columnFill: 'auto', height: '100%'}); 26 | 27 | const tokenSel = featuresSel.appendMany('div.token', finalData) 28 | .st({ 29 | position: 'relative', 30 | borderTop: 'solid 1px hsl(0 0 0 / 0.4)', 31 | }) 32 | .at({ title: d => d.token }); 33 | 34 | const tokenLabelSel = tokenSel.append('div') 35 | .st({ 36 | fontSize: 11, 37 | color: 'hsl(0 0 0 /0.4)', 38 | fontWeight: '400', 39 | pointerEvents: 'none', 40 | padding: '2px', 41 | zIndex: 1e6, 42 | textWrap: 'nowrap', 43 | overflow: 'hidden', 44 | textOverflow: 'ellipsis', 45 | textAlign: 'center', 46 | marginTop: 5, 47 | }); 48 | 49 | tokenLabelSel.append('span').text('“'); 50 | tokenLabelSel.append('span') 51 | .text(d => util.ppToken(d.token)) 52 | .st({ 53 | background: 'hsl(55deg 0% 85% / 0.6)', 54 | borderRadius: 4, 55 | padding: '0 2px', 56 | color: 'black', 57 | fontWeight: '700', 58 | }); 59 | tokenLabelSel.append('span').text('”'); 60 | 61 | const layerSel = tokenSel 62 | .appendMany('div.layer', d => d.values) 63 | .st({ position: 'relative' }); 64 | 65 | const nodeSel = layerSel.appendMany('div.node', d => d[1].entries().map(d => d[1])); 66 | 67 | itemSel = nodeSel.append('div.feature-row') 68 | .classed('clicked', e => e.nodeId == visState.clickedId) 69 | .classed('pinned', d => visState.pinnedIds.includes(d.nodeId)); 70 | 71 | utilCg.renderFeatureRow(itemSel, visState, renderAll); 72 | 73 | } 74 | renderAll.hClerpUpdate.fns.push(render); 75 | renderAll.clickedId.fns.push(render); 76 | renderAll.hoveredId.fns.push(() => itemSel?.classed('hovered', e => e.featureId == visState.hoveredId)); 77 | renderAll.pinnedIds.fns.push(() => itemSel?.classed('pinned', d => visState.pinnedIds.includes(d.nodeId))); 78 | 79 | renderAll.clickedId.fns.push(() => { 80 | if (!itemSel || visState.isDev) return 81 | 82 | var hNode = itemSel.filter(d => d.featureId == visState.clickedId).node() 83 | if (!hNode) return 84 | var cNode = clerpListSel.node() 85 | 86 | var scrollTop = hNode.offsetTop - cNode.offsetTop - cNode.clientHeight / 2 + hNode.clientHeight / 2 87 | scrollTop = d3.clamp(0, scrollTop, cNode.scrollHeight - cNode.clientHeight) 88 | if (scrollTop < cNode.scrollTop - cNode.clientHeight / 2 || scrollTop > cNode.scrollTop + cNode.clientHeight / 2) { 89 | cNode.scrollTop = scrollTop; 90 | } 91 | }); 92 | 93 | } 94 | 95 | window.init?.() 96 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-feature-detail.js: -------------------------------------------------------------------------------- 1 | window.initCgFeatureDetail = async function({visState, renderAll, data, cgSel}){ 2 | var sel = cgSel.select('.feature-detail').html('') 3 | if (!sel.node()) return 4 | 5 | // var headerSel = sel.append('div.feature-header') 6 | var logitsSel = sel.append('div.logits-container') 7 | var examplesSel = sel.append('div.feature-examples-container') 8 | var featureExamples = await window.initFeatureExamples({ 9 | containerSel: examplesSel, 10 | showLogits: true, 11 | // showLogits: !data.nodes.some(d => d.top_logit_effects) // we show logits ourselves frozen above the feature vis, don't also show it inside 12 | }) 13 | 14 | let editOpen = false; 15 | 16 | // throttle to prevent lag when mousing over 17 | var renderFeatureExamples = util.throttleDebounce(featureExamples.renderFeature, 100) 18 | 19 | function renderFeatureDetail() { 20 | logitsSel.html('').st({display:''}) 21 | 22 | // display hovered then clicked nodes, with fallbacks for supernode 23 | var d = null 24 | // var d = data.nodes.find(e => e.nodeId === visState.hoveredNodeId) 25 | if (!d) d = data.nodes.find(e => e.nodeId === visState.clickedId) 26 | if (!d){ 27 | var featureId = visState.hoveredId 28 | if (!featureId || featureId.includes('supernode')){ 29 | // headerSel.html('') 30 | // .append('div.no-selected-feature').text("Click or hover to see a feature's examples") 31 | examplesSel.st({opacity: 0}) 32 | return 33 | } 34 | return 35 | } 36 | 37 | var label = d.isTmpFeature ? d.featureId : 38 | visState.isHideLayer ? `#F${d.featureIndex}` : 39 | `${utilCg.layerLocationLabel(d.layer, d.probe_location_idx)}/${d.featureIndex}` 40 | 41 | if (d.isError || d.feature_type == 'embedding' || d.feature_type == 'logit'){ 42 | if (d.isError) addLogits(d) 43 | if (d.feature_type=='logit') addEmbeddings(d) 44 | 45 | // headerSel.html('').append('div.header-top-row').append('div.feature-title') 46 | // .text(d.ppClerp) 47 | examplesSel.st({opacity: 0}) 48 | } else if (d.feature_type == 'cross layer transcoder') { 49 | addLogits(d) 50 | addEmbeddings(d) 51 | // var headerTopRowSel = headerSel.html('').append('div.header-top-row') 52 | // headerTopRowSel.append('div.feature-title') 53 | // .html(`Feature ${label}`) 54 | 55 | // headerTopRowSel.append('div.pp-clerp') 56 | // .text(d.ppClerp) 57 | // .at({title: d.ppClerp}) 58 | 59 | if (visState.isEditMode && false){ 60 | headerTopRowSel.append('button.edit-clerp-button') 61 | .text('Edit') 62 | .on('click', toggleEdit) 63 | 64 | function toggleEdit() { 65 | editOpen = !editOpen; 66 | hClerpEditSel.st({display: editOpen ? 'flex' : 'none'}) 67 | if (editOpen) { 68 | headerSel.select('input').node()?.focus(); 69 | } 70 | } 71 | 72 | const hClerpEditSel = headerSel.append('div.clerp-edit') 73 | .st({ display: editOpen ? 'flex' : 'none' }); 74 | 75 | const hClerpSel = hClerpEditSel.append('div') 76 | .st({ display: 'flex' }); 77 | hClerpSel.append('div') 78 | .st({flex: '0 0 50px'}) 79 | .text(`🧑💻`); 80 | hClerpSel.append('input').data([d]) 81 | .at({ value: d.localClerp }) 82 | .st({flex: '1 0 0', whiteSpace: 'normal', fontSize: 12}) 83 | .on('change', ev => renderAll.hClerpUpdate([d, ev.target.value])) 84 | 85 | const rClerpSel = hClerpEditSel.append('div') 86 | .st({ display: 'flex' }); 87 | rClerpSel.append('div') 88 | .st({flex: '0 0 50px'}) 89 | .text(`🧑☁️`); 90 | rClerpSel.append('div') 91 | .text(d.remoteClerp) 92 | .st({flex: '1 0', whiteSpace: 'normal'}) 93 | 94 | const clerpSel = hClerpEditSel.append('div') 95 | .st({ display: 'flex' }); 96 | clerpSel.append('div') 97 | .st({flex: '0 0 50px'}) 98 | .text(`🤖💬`); 99 | clerpSel.append('div') 100 | .text(d.clerp) 101 | .st({ flex: '1 0', whiteSpace: 'normal' }) 102 | } 103 | 104 | featureExamples.loadFeature(data.metadata.scan, d.featureIndex) 105 | renderFeatureExamples(data.metadata.scan, d.featureIndex) 106 | examplesSel.st({opacity: 1}) 107 | } else { 108 | headerSel.html(`${label}`) 109 | logitsSel.html('No logit data') 110 | examplesSel.st({opacity: 0}) 111 | } 112 | 113 | // add pinned/click state and toggle to feature-title 114 | // headerSel.select('div.feature-title') 115 | // .classed('pinned', d.nodeId && visState.pinnedIds.includes(d.nodeId)) 116 | // .classed('hovered', visState.clickedId == d.nodeId) 117 | // .on('click', ev => { 118 | // utilCg.clickFeature(visState, renderAll, d, ev.metaKey || ev.ctrlKey) 119 | 120 | // if (visState.clickedId) return 121 | // // double render to toggle on hoveredId, could expose more of utilCg.clickFeature to prevent 122 | // visState.hoveredId = d.featureId 123 | // renderAll.hoveredId() 124 | // }) 125 | 126 | } 127 | 128 | function addLogits(d) { 129 | return 130 | if (!d || !d.top_logit_effects) return logitsSel.html('').st({display: 'none'}) 131 | // Add logit effects section 132 | let logitRowContainerSel = logitsSel.st({display: ''}) 133 | .append('div.effects') 134 | .appendMany('div.sign', [d.top_logit_effects, d.bottom_logit_effects].filter(d => d)) 135 | logitRowContainerSel.append('div.label').text((d, i) => i ? 'Bottom Outputs' : 'Top Outputs') 136 | logitRowContainerSel.append('div.rows') 137 | .appendMany('div.row', d => d) 138 | .append('span.key').text(d => d) 139 | } 140 | function addEmbeddings(d) { 141 | return 142 | // Add embedding effects section 143 | if (d.top_embedding_effects || d.bottom_embedding_effects) { 144 | let embeddingRowContainerSel = logitsSel 145 | .append('div.effects') 146 | .appendMany('div.sign', [d.top_embedding_effects, d.bottom_embedding_effects].filter(d => d)) 147 | embeddingRowContainerSel.append('div.label').text((d, i) => i ? 'Bottom Inputs' : 'Top Inputs') 148 | embeddingRowContainerSel.append('div.rows').appendMany('div.row', d => d) 149 | .append('span.key').text(d => d) 150 | } 151 | } 152 | 153 | renderAll.hClerpUpdate.fns.push(renderFeatureDetail) 154 | renderAll.clickedId.fns.push(renderFeatureDetail) 155 | renderAll.hoveredId.fns.push(renderFeatureDetail) 156 | renderAll.pinnedIds.fns.push(renderFeatureDetail) 157 | 158 | renderFeatureDetail() 159 | } 160 | 161 | window.init?.() 162 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-feature-scatter.js: -------------------------------------------------------------------------------- 1 | // possible improvements: 2 | // - persist to url in diff view 3 | // - support multiple scatter plots 4 | // - color by x 5 | // - facet by x 6 | // - brush to filter 7 | // - scatter plot of links 8 | 9 | window.initCgFeatureScatter = function({visState, renderAll, data, cgSel}){ 10 | var nodes = data.nodes//.filter(d => !d.isLogit) 11 | 12 | var numericCols = Object.entries(nodes[0]) 13 | .filter(([k, v]) => typeof v != 'object' && typeof v != 'function' && !utilCg.keysToSkip.has(k) && isFinite(v)) 14 | .map(([k]) => k) 15 | 16 | var xKey = util.params.get('feature_scatter_x') || 'ctx_idx' 17 | var yKey = util.params.get('feature_scatter_y') || 'target_influence' 18 | function addSelect(isX){ 19 | var options = isX ? ['Distribution'].concat(numericCols) : numericCols 20 | selectSel.append('select').st({marginRight: 10}) 21 | .on('change', function(){ 22 | isX ? xKey = options[this.selectedIndex] : yKey = options[this.selectedIndex] 23 | isX ? util.params.set('feature_scatter_x', xKey) : util.params.set('feature_scatter_y', yKey) 24 | renderScales() 25 | }) 26 | .appendMany('option', options) 27 | .text(d => d) 28 | .at({value: d => d}) 29 | .filter(d => isX && d == xKey || !isX && d == yKey).at({selected: 'selected'}) 30 | } 31 | 32 | var sel = cgSel.select('.feature-scatter').html('') 33 | var selectSel = sel.append('div.select-container').st({marginLeft: 35}) 34 | var chartSel = sel.append('div.chart-container') 35 | addSelect(1) 36 | addSelect(0) 37 | 38 | function renderScales(){ 39 | var c = d3.conventions({ 40 | sel: chartSel.html('').append('div'), 41 | margin: {left: 35, bottom: 30, top: 2, right: 6}, 42 | }) 43 | 44 | if (xKey == 'Distribution'){ 45 | d3.sort(d3.sort(nodes, d => +d[yKey]), d => d.feature_type) 46 | .forEach((d, i) => d.Distribution = i/nodes.length) 47 | } 48 | 49 | c.x.domain(d3.extent(nodes, d => +d[xKey])).nice() 50 | c.y.domain(d3.extent(nodes, d => +d[yKey])).nice() 51 | 52 | c.yAxis.ticks(3) 53 | c.xAxis.ticks(5) 54 | c.drawAxis() 55 | util.ggPlot(c) 56 | util.addAxisLabel(c, xKey + ' →', yKey + ' →', '', 0, 5) 57 | 58 | var nodeSel = c.svg.appendMany('text.node', nodes) 59 | .translate(d => [c.x(d[xKey]) ?? -2, c.y(d[yKey]) ?? c.height + 2]) 60 | .text(d => utilCg.featureTypeToText(d.feature_type)) 61 | .at({ 62 | fontSize: 7, 63 | stroke: '#000', 64 | strokeWidth: .2, 65 | textAnchor: 'middle', 66 | dominantBaseline: 'central', 67 | fill: 'rgba(0,0,0,.1)' 68 | }) 69 | .call(utilCg.addFeatureTooltip) 70 | .call(utilCg.addFeatureEvents(visState, renderAll)) 71 | 72 | // TODO: add hover circle? 73 | utilCg.updateFeatureStyles(nodeSel) 74 | renderAll.hoveredId.fns['featureScatter'] = () => utilCg.updateFeatureStyles(nodeSel) 75 | renderAll.clickedId.fns['featureScatter'] = () => utilCg.updateFeatureStyles(nodeSel) 76 | renderAll.pinnedIds.fns['featureScatter'] = () => utilCg.updateFeatureStyles(nodeSel) 77 | renderAll.hiddenIds.fns['featureScatter'] = () => utilCg.updateFeatureStyles(nodeSel) 78 | } 79 | 80 | // TODO: awkward, maybe gridsnap/widget inits need to be restructured? 81 | if (!sel.datum().resizeFn) renderScales() 82 | sel.datum().resizeFn = renderScales 83 | } 84 | 85 | window.init?.() 86 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-link-graph.js: -------------------------------------------------------------------------------- 1 | window.initCgLinkGraph = function({visState, renderAll, data, cgSel}){ 2 | var {nodes, links, metadata} = data 3 | 4 | var c = d3.conventions({ 5 | sel: cgSel.select('.link-graph').html(''), 6 | margin: {left: visState.isHideLayer ? 0 : 30, bottom: 85}, 7 | layers: 'sccccs', 8 | }) 9 | 10 | c.svgBot = c.layers[0] 11 | var allCtx = { 12 | allLinks: c.layers[1], 13 | pinnedLinks: c.layers[2], 14 | bgLinks: c.layers[3], 15 | clickedLinks: c.layers[4] 16 | } 17 | c.svg = c.layers[5] 18 | 19 | // Count max number of nodes at each context to create a polylinear x scale 20 | var earliestCtxWithNodes = d3.min(nodes, d => d.ctx_idx) 21 | var cumsum = 0 22 | var ctxCounts = d3.range(d3.max(nodes, d => d.ctx_idx) + 1).map(ctx_idx => { 23 | if (ctx_idx >= earliestCtxWithNodes) { 24 | var group = nodes.filter(d => d.ctx_idx === ctx_idx) 25 | var maxCount = d3.max([1, d3.max(d3.nestBy(group, d => d.streamIdx), e => e.length)]) 26 | cumsum += maxCount 27 | } 28 | return {ctx_idx, maxCount, cumsum} 29 | }) 30 | 31 | var xDomain = [-1].concat(ctxCounts.map(d => d.ctx_idx)) 32 | var xRange = [0].concat(ctxCounts.map(d => d.cumsum * c.width / cumsum)) 33 | c.x = d3.scaleLinear().domain(xDomain.map(d => d + 1)).range(xRange) 34 | 35 | var yNumTicks= visState.isHideLayer ? data.byStream.length : 19 36 | c.y = d3.scaleBand(d3.range(yNumTicks), [c.height, 0]) 37 | 38 | c.yAxis = d3.axisLeft(c.y) 39 | .tickValues(d3.range(yNumTicks)) 40 | .tickFormat(i => { 41 | if (i % 2) return 42 | 43 | return i == 18 ? 'Lgt' : i == 0 ? 'Emb' : 'L' + i 44 | var label = data.byStream[i][0].layerLocationLabel 45 | var layer = +label.replace('L', '') 46 | return isFinite(layer) && layer % 2 ? '' : label 47 | }) 48 | 49 | c.svgBot.append('rect').at({width: c.width, height: c.height, fill: '#F5F4EE'}) 50 | c.svgBot.append('g').appendMany('rect', [0, yNumTicks - 1]) 51 | .at({width: c.width, height: c.y.bandwidth(), y: c.y, fill: '#F0EEE7'}) 52 | 53 | c.svgBot.append('g').appendMany('path', d3.range(-1, yNumTicks - 1)) 54 | .translate(d => [0, c.y(d + 1)]) 55 | .at({d: `M0,0H${c.width}`, stroke: 'white', strokeWidth: .5}) 56 | 57 | c.drawAxis(c.svgBot) 58 | c.svgBot.select('.x').remove() 59 | c.svgBot.selectAll('.y line').remove() 60 | if (visState.isHideLayer) c.svgBot.select('.y').remove() 61 | 62 | // Spread nodes across each context 63 | // d.width is the total amount of px space in each column 64 | ctxCounts.forEach((d, i) => d.width = c.x(d.ctx_idx + 1) - c.x(ctxCounts[i].ctx_idx)) 65 | 66 | // if default to 8px padding right, if pad right to center singletons 67 | var padR = Math.min(8, d3.min(ctxCounts.slice(1), d => d.width/2)) + 0 68 | 69 | // find the tightest spacing between nodes and use for all ctx (but don't go below 20) 70 | ctxCounts.forEach(d => d.minS = (d.width - padR)/d.maxCount) 71 | var overallS = Math.max(20, d3.min(ctxCounts, d => d.minS)) 72 | 73 | // apply to nodes 74 | d3.nestBy(nodes, d => [d.ctx_idx, d.streamIdx].join('-')).forEach(ctxLayer => { 75 | var ctxWidth = c.x(ctxLayer[0].ctx_idx + 1) - c.x(ctxLayer[0].ctx_idx) - padR 76 | var s = Math.min(overallS, ctxWidth/ctxLayer.length) 77 | 78 | // sorting by pinned stacks all the links on top of each other 79 | // ctxLayer = d3.sort(ctxLayer, d => visState.pinnedIds.includes(d.nodeId) ? -1 : 1) 80 | ctxLayer = d3.sort(ctxLayer, d => -d.logitPct) 81 | ctxLayer.forEach((d, i) => { 82 | d.xOffset = d.feature_type === 'logit' ? ctxWidth - (padR/2 + i*s) : ctxWidth - (padR/2 + i*s) 83 | d.yOffset = 0 84 | }) 85 | }) 86 | nodes.forEach(d => d.pos = [ 87 | c.x(d.ctx_idx) + d.xOffset, 88 | c.y(d.streamIdx) + c.y.bandwidth()/2 + d.yOffset 89 | ]) 90 | 91 | 92 | // hover poitns 93 | var maxHoverDistance = 30 94 | c.sel 95 | .on('mousemove', (ev) => { 96 | if (ev.shiftKey) return 97 | var [mouseX, mouseY] = d3.pointer(ev) 98 | var [closestNode, closestDistance] = findClosestPoint(mouseX - c.margin.left, mouseY - c.margin.top, nodes) 99 | if (closestDistance > maxHoverDistance) { 100 | utilCg.unHoverFeature(visState, renderAll) 101 | utilCg.hideTooltip() 102 | } else if (visState.hoveredId !== closestNode) { 103 | utilCg.hoverFeature(visState, renderAll, closestNode) 104 | utilCg.showTooltip(ev, closestNode) 105 | } 106 | }) 107 | .on('mouseleave', (ev) => { 108 | if (ev.shiftKey) return 109 | utilCg.unHoverFeature(visState, renderAll) 110 | utilCg.hideTooltip() 111 | }) 112 | .on('click', (ev) => { 113 | var [mouseX, mouseY] = d3.pointer(ev) 114 | var [closestNode, closestDistance] = findClosestPoint(mouseX - c.margin.left, mouseY - c.margin.top, nodes) 115 | if (closestDistance > maxHoverDistance) { 116 | visState.clickedId = null 117 | visState.clickedCtxIdx = null 118 | renderAll.clickedId() 119 | } else { 120 | utilCg.clickFeature(visState, renderAll, closestNode, ev.metaKey || ev.ctrlKey) 121 | } 122 | }) 123 | 124 | function findClosestPoint(mouseX, mouseY, points) { 125 | if (points.length === 0) return null 126 | 127 | let closestPoint = points[0] 128 | let closestDistance = distance(mouseX, mouseY, closestPoint.pos[0], closestPoint.pos[1]) 129 | 130 | for (let i = 1; i < points.length; i++){ 131 | var point = points[i] 132 | var dist = distance(mouseX, mouseY, point.pos[0], point.pos[1]) 133 | if (dist < closestDistance){ 134 | closestPoint = point 135 | closestDistance = dist 136 | } 137 | } 138 | return [closestPoint, closestDistance] 139 | 140 | function distance(x1, y1, x2, y2) { 141 | return Math.sqrt(Math.pow(x2 - x1, 2) + Math.pow(y2 - y1, 2)) 142 | } 143 | } 144 | 145 | // set up dom 146 | var nodeSel = c.svg.appendMany('text.node', nodes) 147 | .translate(d => d.pos) 148 | .text(d => utilCg.featureTypeToText(d.feature_type)) 149 | .at({ 150 | fontSize: 9, 151 | fill: d => d.nodeColor, 152 | stroke: '#000', 153 | strokeWidth: .5, 154 | textAnchor: 'middle', 155 | dominantBaseline: 'central', 156 | }) 157 | // .call(utilCg.addFeatureTooltip) 158 | // .call(utilCg.addFeatureEvents(visState, renderAll, ev => ev.shiftKey)) 159 | 160 | var hoverSel = c.svg.appendMany('circle', nodes) 161 | .translate(d => d.pos) 162 | .at({r: 6, cy: .5, stroke: '#f0f', strokeWidth: 2, strokeDasharray: '2 2', fill: 'none', display: 'xnone', pointEvents: 'none'}) 163 | 164 | links.forEach(d => { 165 | var [x1, y1] = d.sourceNode.pos 166 | var [x2, y2] = d.targetNode.pos 167 | d.pathStr = `M${x1},${y1}L${x2},${y2}` 168 | }) 169 | 170 | 171 | function drawLinks(links, ctx, strokeWidthOffset=0, colorOverride){ 172 | ctx.clearRect(-c.margin.left, -c.margin.top, c.totalWidth, c.totalHeight) 173 | d3.sort(links, d => d.strokeWidth).forEach(d => { 174 | ctx.beginPath() 175 | ctx.moveTo(d.sourceNode.pos[0], d.sourceNode.pos[1]) 176 | ctx.lineTo(d.targetNode.pos[0], d.targetNode.pos[1]) 177 | ctx.strokeStyle = colorOverride || d.color 178 | ctx.lineWidth = d.strokeWidth + strokeWidthOffset 179 | ctx.stroke() 180 | }) 181 | } 182 | 183 | function filterLinks(featureIds){ 184 | var filteredLinks = [] 185 | 186 | featureIds.forEach(nodeId => { 187 | nodes.filter(n => n.nodeId == nodeId).forEach(node => { 188 | if (visState.linkType == 'input' || visState.linkType == 'either') { 189 | Array.prototype.push.apply(filteredLinks, node.sourceLinks) 190 | } 191 | if (visState.linkType == 'output' || visState.linkType == 'either') { 192 | Array.prototype.push.apply(filteredLinks, node.targetLinks) 193 | } 194 | if (visState.linkType == 'both') { 195 | Array.prototype.push.apply(filteredLinks, node.sourceLinks.filter( 196 | link => visState.pinnedIds.includes(link.sourceNode.nodeId) 197 | )) 198 | Array.prototype.push.apply(filteredLinks, node.targetLinks.filter( 199 | link => visState.pinnedIds.includes(link.targetNode.nodeId) 200 | )) 201 | } 202 | }) 203 | }) 204 | 205 | return filteredLinks 206 | } 207 | 208 | drawLinks(links, allCtx.allLinks, 0, 'rgba(0,0,0,.05)') 209 | // renderAll.isShowAllLinks.fns['linkGraph'] = () => c.sel.select('canvas').st({display: visState.isShowAllLinks ? '' : 'none'}) 210 | 211 | function renderPinnedIds(){ 212 | drawLinks(visState.clickedId ? [] : filterLinks(visState.pinnedIds), allCtx.pinnedLinks) 213 | nodeSel.classed('pinned', d => visState.pinnedIds.includes(d.nodeId)) 214 | } 215 | renderAll.pinnedIds.fns['linkGraph'] = renderPinnedIds 216 | 217 | function renderHiddenIds(){ 218 | var hiddenIdSet = new Set(visState.hiddenIds) 219 | nodeSel.classed('hidden', d => hiddenIdSet.has(d.featureId)) 220 | } 221 | renderAll.hiddenIds.fns['linkGraph'] = renderHiddenIds 222 | 223 | function renderClicked(){ 224 | var clickedLinks = [] 225 | // if (visState.clickedId) { 226 | // clickedLinks = links.filter(link => 227 | // link.sourceNode.nodeId === visState.clickedId || 228 | // link.targetNode.nodeId === visState.clickedId 229 | // ) 230 | // } 231 | 232 | drawLinks(clickedLinks, allCtx.bgLinks, .05, '#000') 233 | drawLinks(clickedLinks, allCtx.clickedLinks) 234 | nodeSel.classed('clicked', e => e.nodeId === visState.clickedId) 235 | 236 | drawLinks(visState.clickedId ? [] : filterLinks(visState.pinnedIds), allCtx.pinnedLinks) 237 | 238 | 239 | nodeSel 240 | .at({fill: '#fff'}) 241 | .filter(d => d.tmpClickedLink?.tmpColor) 242 | .at({fill: d => d.tmpClickedLink.tmpColor}) 243 | .raise() 244 | } 245 | 246 | renderAll.clickedId.fns['linkGraph'] = renderClicked 247 | renderAll.linkType.fns['linkGraph'] = () => { 248 | renderPinnedIds() 249 | renderClicked() 250 | } 251 | renderAll.hoveredId.fns['linkGraph'] = () => { 252 | hoverSel.st({display: e => e.featureId == visState.hoveredId ? '' : 'none'}) 253 | } 254 | 255 | // Add x axis text/lines 256 | var promptTicks = data.metadata.prompt_tokens.slice(earliestCtxWithNodes).map((token, i) =>{ 257 | var ctx_idx = i + earliestCtxWithNodes 258 | var mNodes = nodes.filter(d => d.ctx_idx == ctx_idx) 259 | var hasEmbed = mNodes.some(d => d.feature_type == 'embedding') 260 | return {token, ctx_idx, mNodes, hasEmbed} 261 | }) 262 | 263 | var xTickSel = c.svgBot.appendMany('g.prompt-token', promptTicks) 264 | .translate(d => [c.x(d.ctx_idx + 1), c.height]) 265 | 266 | xTickSel.append('path').at({d: `M0,0v${-c.height}`, stroke: '#fff',strokeWidth: 1}) 267 | xTickSel.filter(d => d.hasEmbed).append('path').at({ 268 | stroke: '#B0AEA6', 269 | d: `M-${padR + 3.5},${-c.y.bandwidth()/2 + 6}V${8}`, 270 | }) 271 | 272 | xTickSel.filter(d => d.hasEmbed).append('g').translate([-12, 8]) 273 | .append('text').text(d => d.token) 274 | .at({ 275 | x: -5, 276 | y: 2, 277 | textAnchor: 'end', 278 | transform: 'rotate(-45)', 279 | dominantBaseline: 'middle', 280 | fontSize: 12, 281 | // fontSize: (d, i) => c.x(i+1) - c.x(i) < 15 ? 9 : 14, 282 | }) 283 | 284 | var logitTickSel = c.svgBot.append('g.axis').appendMany('g', nodes.filter(d => d.feature_type == 'logit')) 285 | .translate(d => d.pos) 286 | logitTickSel.append('path').at({ 287 | stroke: '#B0AEA6', 288 | d: `M0,${-6}V${-c.y.bandwidth()/2 - 6}`, 289 | }) 290 | logitTickSel.append('g').translate([-5, -c.y.bandwidth()/2 - 8]) 291 | .append('text').text(d => d.logitToken) 292 | .at({ 293 | x: 5, 294 | y: 2, 295 | textAnchor: 'start', 296 | transform: 'rotate(-45)', 297 | dominantBaseline: 'middle', 298 | fontSize: 12, 299 | // fontSize: (d, i) => c.x(i+1) - c.x(i) < 15 ? 9 : 14, 300 | }) 301 | 302 | 303 | utilCg.addPinnedClickedGradient(c.svg) 304 | } 305 | 306 | window.init?.() 307 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-node-connections.js: -------------------------------------------------------------------------------- 1 | window.initCgNodeConnections = function({visState, renderAll, data, cgSel}){ 2 | 3 | var nodeConnectionsSel = cgSel.select('.node-connections') 4 | var headerSel = null 5 | var clickedNode = null 6 | var featureSel = null 7 | 8 | function render() { 9 | nodeConnectionsSel.html('').st({display: 'flex', flexDirection: 'column'}) 10 | clickedNode = data.nodes.find(e => e.nodeId === visState.clickedId) 11 | headerSel = nodeConnectionsSel.append('div.header-top-row.section-title').st({marginBottom: 20}).datum(clickedNode) 12 | 13 | if (!clickedNode) return headerSel.text('Click a feature on the left for details') 14 | 15 | addHeaderRow(headerSel) 16 | 17 | var types = [ 18 | { id: 'input', title: 'Input Features'}, 19 | { id: 'output', title: 'Output Features' } 20 | ] 21 | types.forEach(type => type.sections = ['Positive', 'Negative'].map(title => { 22 | var nodes = data.nodes.filter(d => { 23 | var weight = type.id === 'input' ? d.tmpClickedSourceLink?.weight : d.tmpClickedTargetLink?.weight 24 | return title == 'Positive' ? weight > 0 : weight < 0 25 | }) 26 | nodes = d3.sort(nodes, d => -(type.id === 'input' ? d.tmpClickedSourceLink?.pctInput : d.tmpClickedTargetLink?.pctInput)) 27 | return {title, nodes} 28 | })) 29 | 30 | var typesSel = nodeConnectionsSel.append('div.connections') 31 | .st({flex: '1 0 auto', display: 'flex', overflow: 'hidden', gap: '20px'}) 32 | .appendMany('div.features', types) 33 | .classed('output', d => d.id === 'output') 34 | .classed('input', d => d.id === 'input') 35 | 36 | typesSel.append('div.section-title').text(d => d.title) 37 | .st({marginBottom: 0}) 38 | // addInputSumTable(typesSel.filter(d => d.id == 'input').append('div.sum-table').append('div.section')) 39 | 40 | var featuresContainerSel = typesSel.append('div.effects') 41 | 42 | var sectionSel = featuresContainerSel.appendMany('div.section', d => d.sections) 43 | featureSel = sectionSel.appendMany('div.feature-row', d => d.nodes) 44 | .classed('clicked', e => e.featureId == visState.clickedId) 45 | 46 | classPinned() 47 | classHidden() 48 | 49 | typesSel.each(function(type) { 50 | d3.select(this).selectAll('.feature-row') 51 | .call(utilCg.renderFeatureRow, visState, renderAll, type.id === 'input' ? 'tmpClickedSourceLink' : 'tmpClickedTargetLink') 52 | }) 53 | } 54 | 55 | function addHeaderRow(headerSel){ 56 | if (!clickedNode) return 57 | 58 | headerSel.append('text') 59 | .text(clickedNode.feature_type == 'cross layer transcoder' ? 'F#' + d3.format('08')(clickedNode.feature) : ' ') 60 | .st({display: 'inline-block', marginRight: 5, 'font-variant-numeric': 'tabular-nums', width: 82}) 61 | headerSel.append('span.feature-icon').text(utilCg.featureTypeToText(clickedNode.feature_type)) 62 | headerSel.append('span.feature-title').text(clickedNode.ppClerp) 63 | 64 | // add cmd-click toggle to title 65 | headerSel.on('click', ev => { 66 | utilCg.clickFeature(visState, renderAll, clickedNode, ev.metaKey || ev.ctrlKey) 67 | 68 | if (visState.clickedId) return 69 | // double render to toggle on hoveredId, could expose more of utilCg.clickFeature to prevent 70 | visState.hoveredId = clickedNode.featureId 71 | renderAll.hoveredId() 72 | }) 73 | } 74 | 75 | function addInputSumTable(sumSel){ 76 | return // TODO: turn back on? 77 | 78 | var clickedNode = data.nodes.idToNode[visState.clickedId] 79 | if (!clickedNode) return 80 | var inputSum = d3.sum(clickedNode.sourceLinks, d => Math.abs(d.weight)) 81 | 82 | var tableSel = sumSel.append('table') 83 | tableSel.appendMany('th', ['% Feat', '% Err', '% Emb']).text(d => d) 84 | 85 | var trSel = tableSel.appendMany('tr.data', [ 86 | {str: '←', links: clickedNode.sourceLinks.filter(d => d.sourceNode.ctx_idx != clickedNode.ctx_idx)}, 87 | {str: '↓', links: clickedNode.sourceLinks.filter(d => d.sourceNode.ctx_idx == clickedNode.ctx_idx)} 88 | ]) 89 | 90 | trSel.append('td').text(d => d.str).at({title: d => d.str == '↓' ? 'cur token' : 'prev token'}) 91 | trSel.appendMany('td', d => { 92 | var rv = [ 93 | d.links.filter(e => !e.sourceNode.isError && e.sourceNode.feature_type != 'embedding'), 94 | d.links.filter(e => e.sourceNode.isError), 95 | d.links.filter(e => e.sourceNode.feature_type == 'embedding'), 96 | ] 97 | 98 | if (rv.flat().length != d.links.length) console.error("Non-feature/error/embedding node present") 99 | return rv 100 | }) 101 | .html(links => { 102 | var pos = d3.sum(links.filter(d => d.weight > 0), d => d.weight)/inputSum 103 | var neg = d3.sum(links.filter(d => d.weight < 0), d => d.weight)/inputSum 104 | 105 | var outStr = `+${d3.format('.2f')(pos)}   −${d3.format('.2f')(neg)}` 106 | return outStr.replaceAll('0.', '.').replaceAll('−−', '−') 107 | }) 108 | } 109 | 110 | renderAll.hClerpUpdate.fns.push(render) 111 | renderAll.hoveredId.fns['nodeconnections'] = () => featureSel?.classed('hovered', e => e.featureId == visState.hoveredId) 112 | renderAll.pinnedIds.fns['nodeconnections'] = classPinned 113 | renderAll.hiddenIds.fns['nodeconnections'] = classHidden 114 | renderAll.clickedId.fns['nodeconnections'] = render 115 | 116 | 117 | function classPinned(){ 118 | var pinnedIdSet = new Set(visState.pinnedIds) 119 | featureSel?.classed('pinned', d => pinnedIdSet.has(d.nodeId)) 120 | headerSel?.classed('pinned', d => d && d.nodeId && visState.pinnedIds.includes(d.nodeId)) 121 | .select('span.feature-title').text(d => d.ppClerp) 122 | } 123 | 124 | function classHidden(){ 125 | var hiddenIdSet = new Set(visState.hiddenIds) 126 | featureSel?.classed('hidden', d => hiddenIdSet.has(d.featureId)) 127 | } 128 | } 129 | 130 | window.init?.() 131 | -------------------------------------------------------------------------------- /attribution_graph/init-cg-subgraph.js: -------------------------------------------------------------------------------- 1 | window.initCgSubgraph = function ({visState, renderAll, data, cgSel}) { 2 | var subgraphSel = cgSel.select('.subgraph') 3 | subgraphSel.datum().resizeFn = renderSubgraph 4 | 5 | var nodeIdToNode = {} 6 | var sgNodes = [] 7 | var sgLinks = [] 8 | 9 | let nodeSel = null 10 | let memberNodeSel = null 11 | let simulation = null 12 | 13 | var nodeWidth = 75 14 | var nodeHeight = 25 15 | 16 | function supernodesToUrl() { 17 | // util.params.set('supernodes', JSON.stringify(subgraphState.supernodes)) 18 | } 19 | 20 | var subgraphState = visState.subgraph = visState.subgraph || { 21 | sticky: true, 22 | dagrefy: true, 23 | supernodes: visState.supernodes || [], 24 | activeGrouping: { 25 | isActive: false, 26 | selectedNodeIds: new Set(), 27 | } 28 | } 29 | 30 | d3.select('body') 31 | .on('keydown.grouping' + data.metadata.slug, ev => { 32 | if (ev.repeat) return 33 | if (!visState.isEditMode || ev.key != 'g') return 34 | subgraphState.activeGrouping.isActive = true 35 | styleNodes() 36 | 37 | subgraphSel.classed('is-grouping', true) 38 | }) 39 | .on('keyup.grouping' + data.metadata.slug, ev => { 40 | if (!visState.isEditMode || ev.key != 'g') return 41 | if (subgraphState.activeGrouping.selectedNodeIds.size > 1){ 42 | var allSelectedIds = [] 43 | var prevSupernodeLabel = '' 44 | subgraphState.activeGrouping.selectedNodeIds.forEach(id => { 45 | var node = nodeIdToNode[id] 46 | if (!node?.memberNodeIds) return allSelectedIds.push(id) 47 | prevSupernodeLabel = node.ppClerp 48 | 49 | // if a supernode is selected, remove the previous super node 50 | subgraphState.supernodes = subgraphState.supernodes.filter(([label, ...nodeIds]) => 51 | !nodeIds.every(d => node.memberNodeIds.includes(d)) 52 | ) 53 | // and adds its member nodes to selection 54 | node.memberNodeIds.forEach(id => allSelectedIds.push(id)) 55 | }) 56 | 57 | var label = prevSupernodeLabel || allSelectedIds 58 | .map(id => nodeIdToNode[id]?.ppClerp) 59 | .find(d => d) || 'supernode' 60 | subgraphState.supernodes.push([label, ...new Set(allSelectedIds)]) 61 | supernodesToUrl() 62 | } 63 | subgraphState.activeGrouping.isActive = false 64 | subgraphState.activeGrouping.selectedNodeIds.clear() 65 | renderSubgraph() 66 | 67 | subgraphSel.classed('is-grouping', false) 68 | }) 69 | 70 | let {nodes, links} = data 71 | 72 | function renderSubgraph() { 73 | var c = d3.conventions({ 74 | sel: subgraphSel.html(''), 75 | margin: {top: 26, bottom: 5, left: visState.isHideLayer ? 0 : 30}, 76 | layers: 'sd', 77 | }) 78 | // subgraphSel.st({borderTop: '1px solid #eee'}) 79 | 80 | c.svg.append('text.section-title').text('Subgraph').translate(-16, 1) 81 | c.svg.append('g.border-path').append('path') 82 | .at({stroke: '#eee', d: 'M 0 -10 H ' + c.width}) 83 | 84 | 85 | var [svg, div] = c.layers 86 | 87 | // // set up arrowheads 88 | // svg.appendMany('marker', [{id: 'mid-negative', color: '#40004b'},{id: 'mid-positive', color: '#00441b'}]) 89 | // .at({id: d => d.id, orient: 'auto', refX: .1, refY: 1}) // marker-height/marker-width? 90 | // .append('path') 91 | // .at({d: 'M0,0 V2 L1,1 Z', fill: d => d.color}) 92 | 93 | 94 | // pick out the subgraph and do supernode surgery 95 | nodes.forEach(d => d.supernodeId = null) 96 | var pinnedIds = visState.pinnedIds.slice(0, 200) // max of 200 nodes 97 | var pinnedNodes = nodes.filter(d => pinnedIds.includes(d.nodeId)) 98 | 99 | // create supernodes and mark their children 100 | nodeIdToNode = Object.fromEntries(pinnedNodes.map(d => [d.nodeId, d])) 101 | var supernodes = subgraphState.supernodes 102 | .map(([label, ...nodeIds], i) => { 103 | var nodeId = nodeIdToNode[label] ? `supernode-${i}` : label 104 | var memberNodes = nodeIds.map(id => nodeIdToNode[id]).filter(d => d) 105 | memberNodes.forEach(d => d.supernodeId = nodeId) 106 | 107 | var rv = { 108 | nodeId, 109 | featureId: `supernode-${i}`, 110 | ppClerp: label, 111 | layer: d3.mean(memberNodes, d => +d.layer), 112 | ctx_idx: d3.mean(memberNodes, d => d.ctx_idx), 113 | ppLayer: d3.extent(memberNodes, d => +d.layer).join('—'), 114 | streamIdx: d3.mean(memberNodes, d => d.streamIdx), 115 | memberNodeIds: nodeIds, 116 | memberNodes, 117 | isSuperNode: true, 118 | } 119 | nodeIdToNode[rv.nodeId] = rv 120 | 121 | return rv 122 | }) 123 | .filter(d => d.memberNodes.length) 124 | 125 | // update clerps — fragile hack if hClerpUpdate changes 126 | nodes.forEach(d => d.ppClerp = d.clerp) 127 | supernodes.forEach(({ppClerp, memberNodes}) => { 128 | if (memberNodes.length == 1 && ppClerp == memberNodes[0].ppClerp) return 129 | 130 | memberNodes.forEach(d => { 131 | d.ppClerp = `[${ppClerp}]` + (ppClerp != d.ppClerp ? ' ' + d.ppClerp : '') 132 | }) 133 | }) 134 | 135 | // inputAbsSumExternalSn: the abs sum of input links from outside the supernode 136 | pinnedNodes.forEach(d => { 137 | d.inputAbsSumExternalSn = d3.sum(d.sourceLinks, e => { 138 | if (!e.sourceNode.supernodeId) return Math.abs(e.weight) 139 | return e.sourceNode.supernodeId == d.supernodeId ? 0 : Math.abs(e.weight) 140 | }) 141 | d.sgSnInputWeighting = d.inputAbsSumExternalSn/d.inputAbsSum 142 | }) 143 | 144 | // subgraph plots pinnedNodes not in a supernode and supernodes 145 | sgNodes = pinnedNodes.filter(d => !d.supernodeId).concat(supernodes) 146 | sgNodes.forEach(d => { 147 | // for supernodes, sum up values from member nodes 148 | if (d.isSuperNode) { 149 | d.inputAbsSum = d3.sum(d.memberNodes, e => e.inputAbsSum) 150 | d.inputAbsSumExternalSn = d3.sum(d.memberNodes, e => e.inputAbsSumExternalSn) 151 | } else { 152 | d.memberNodes = [d] 153 | } 154 | 155 | var sum = d3.sum(d.memberNodes, e => e.sgSnInputWeighting) 156 | d.memberNodes.forEach(e => e.sgSnInputWeighting = e.sgSnInputWeighting/sum) 157 | }) 158 | 159 | // select subgraph links 160 | sgLinks = links 161 | .filter(d => nodeIdToNode[d.sourceNode.nodeId] && nodeIdToNode[d.targetNode.nodeId]) 162 | .map(d => ({ 163 | source: d.sourceNode.nodeId, 164 | target: d.targetNode.nodeId, 165 | weight: d.weight, 166 | color: d.pctInputColor, 167 | ogLink: d, 168 | })) 169 | 170 | // then remap source/target to supernodes 171 | sgLinks.forEach(link => { 172 | if (nodeIdToNode[link.source]?.supernodeId) link.source = nodeIdToNode[link.source].supernodeId 173 | if (nodeIdToNode[link.target]?.supernodeId) link.target = nodeIdToNode[link.target].supernodeId 174 | }) 175 | 176 | // finally combine parallel links and remove self-links 177 | sgLinks = d3.nestBy(sgLinks, d => d.source + '-' + d.target) 178 | .map(links => { 179 | var weight = d3.sum(links, link => { 180 | var {inputAbsSumExternalSn, sgSnInputWeighting} = link.ogLink.targetNode 181 | return link.weight/inputAbsSumExternalSn*sgSnInputWeighting 182 | }) 183 | 184 | return { 185 | source: links[0].source, 186 | target: links[0].target, 187 | weight, 188 | color: utilCg.pctInputColorFn(weight), 189 | pctInput: weight, 190 | pctInputColor: utilCg.pctInputColorFn(weight), 191 | ogLinks: links 192 | } 193 | }) 194 | .filter(d => d.source !== d.target) 195 | sgLinks = d3.sort(sgLinks, d => Math.abs(d.weight)) 196 | 197 | let xScale = d3.scaleLinear() 198 | .domain(d3.extent(sgNodes.map(d => d.ctx_idx))) 199 | .range([0, c.width*3/4]) 200 | let yScale = d3.scaleLinear() 201 | .domain(d3.extent(sgNodes.map(d => d.streamIdx)).toReversed()) 202 | .range([0, c.height - nodeHeight]) 203 | 204 | // d3.force is impure, need copy 205 | // Also want to persist these positions across node changes 206 | const existingNodes = window.selForceNodes && new Map(window.selForceNodes.map(n => [n.node.nodeId, n])) 207 | window.selForceNodes = sgNodes.map(node => { 208 | const existing = existingNodes?.get(node.nodeId) 209 | return { 210 | x: existing ? existing.x : xScale(node.ctx_idx), 211 | y: existing ? existing.y : yScale(node.streamIdx), 212 | fx: existing?.fx, 213 | fy: existing?.fy, 214 | nodeId: node.nodeId, // for addFeatureEvents 215 | featureId: node.featureId, // for addFeatureEvents 216 | node, 217 | sortedSlug: d3.sort(node.memberNodes.map(d => d.featureIndex).join(' ')), 218 | } 219 | }) 220 | 221 | 222 | var selForceNodes = window.selForceNodes = d3.sort(window.selForceNodes, d => d.sortedSlug) 223 | window._exportSubgraphPos = function(){ 224 | return selForceNodes.map(d => [d.x/c.width*1000, d.y/c.height*1000].map(Math.round)).flat().join(' ') 225 | } 226 | 227 | if (simulation) simulation.stop() 228 | simulation = d3.forceSimulation(selForceNodes) 229 | .force('link', d3.forceLink(sgLinks).id(d => d.node.nodeId)) 230 | .force('charge', d3.forceManyBody()) 231 | .force('collide', d3.forceCollide(Math.sqrt(nodeHeight ** 2 + nodeWidth ** 2) / 2)) 232 | .force('container', forceContainer([[-10, 0], [c.width - nodeHeight, c.height - nodeHeight]])) 233 | .force('x', d3.forceX(d => xScale(d.node.ctx_idx)).strength(.1)) 234 | .force('y', d3.forceY(d => yScale(d.node.streamIdx)).strength(2)) 235 | 236 | var svgPaths = svg.appendMany('path.link-path', sgLinks).at({ 237 | fill: 'none', 238 | markerMid: d => d.weight > 0 ? 'url(#mid-positive)' : 'url(#mid-negative)', 239 | strokeWidth: d => Math.abs(d.weight)*15, 240 | stroke: d => d.color, 241 | opacity: 0.8, 242 | strokeLinecap: 'round', 243 | }) 244 | 245 | var edgeLabels = svg.appendMany('text.weight-label', sgLinks) 246 | // .text(d => d3.format('+.2f')(d.weight)) 247 | 248 | simulation.on('tick', renderForce) 249 | 250 | var drag = d3.drag() 251 | .on('drag', (ev) => { 252 | // Only when actually dragging, mark as no longer dagre positioned and restart sim 253 | ev.subject.dagrePositioned = false 254 | if (!ev.active) simulation.alphaTarget(0.3).restart() 255 | ev.subject.fx = ev.subject.x = ev.x 256 | ev.subject.fy = ev.subject.y = ev.y 257 | renderForce() 258 | }) 259 | .on('end', (ev) => { 260 | if (!ev.active) simulation.alphaTarget(0) 261 | if (!subgraphState.sticky && !ev.subject.dagrePositioned){ 262 | ev.subject.fx = null 263 | ev.subject.fy = null 264 | } 265 | }) 266 | 267 | nodeSel = div 268 | .appendMany('div.supernode-container', selForceNodes) 269 | .translate(d => [d.x, d.y]) 270 | .st({width: nodeWidth, height: nodeHeight}) 271 | .call(utilCg.addFeatureEvents(visState, renderAll, ev => ev.shiftKey)) 272 | .on('click.group', (ev, d) => { 273 | var {isActive, selectedNodeIds} = subgraphState.activeGrouping 274 | if (!isActive) return 275 | 276 | // If it's a child node, use its parent supernode's ID instead 277 | var nodeId = d.supernodeId || d.nodeId 278 | selectedNodeIds.has(nodeId) ? selectedNodeIds.delete(nodeId) : selectedNodeIds.add(nodeId) 279 | 280 | styleNodes() 281 | ev.stopPropagation() 282 | ev.preventDefault() 283 | }) 284 | .call(drag) 285 | 286 | selForceNodes.forEach(d => { 287 | if (!d.node.memberNodes) d.node.memberNodes = [d.node] 288 | }) 289 | 290 | var supernodeSel = nodeSel//.filter(d => d.node.isSuperNode) 291 | .classed('is-supernode', true) 292 | .st({height: nodeHeight + 12}) 293 | 294 | memberNodeSel = supernodeSel.append('div.member-circles') 295 | .st({ 296 | width: d => d.node.memberNodes.length <= 4 ? 'auto' : 'calc(32px + 12px)', 297 | gap: d => d.node.memberNodes.length <= 4 ? 4 : 0, 298 | }) 299 | .appendMany('div.member-circle', d => d.node.memberNodes) 300 | .classed('not-clt-feature', d => d.feature_type != 'cross layer transcoder') 301 | .st({marginLeft: function(d, i) { 302 | var n = this.parentNode.childNodes.length 303 | return n <= 4 ? 0 : i == 0 ? 0 : -((n - 4)*8)/(n - 1) 304 | }}) 305 | .call(utilCg.addFeatureEvents(visState, renderAll, ev => ev.shiftKey)) 306 | .on('click.stop-parent', ev => { 307 | if (!subgraphState.activeGrouping.isActive) ev.stopPropagation() 308 | }) 309 | .on('mouseover.stop-parent', ev => ev.stopPropagation()) 310 | .at({title: d => d.ppClerp}) 311 | 312 | if (visState.isEditMode) { 313 | // TODO: enable 314 | supernodeSel.select('.member-circles') 315 | .filter(d => d.node.isSuperNode) 316 | .append('div.ungroup-btn') 317 | .text('×').st({top: 2, left: -15, position: 'absolute'}) 318 | .on('click', (ev, d) => { 319 | ev.stopPropagation() 320 | 321 | subgraphState.supernodes = subgraphState.supernodes.filter(([label, ...nodeIds]) => 322 | !nodeIds.every(id => d.node.memberNodeIds.includes(id)) 323 | ) 324 | supernodesToUrl() 325 | renderSubgraph() 326 | }) 327 | } 328 | 329 | var nodeTextSel = nodeSel.append('div.node-text-container') 330 | nodeTextSel.append('span') 331 | .text(d => d.node.ppClerp) 332 | .on('click', (ev, d) => { 333 | if (!visState.isEditMode) return 334 | if (!d.node.isSuperNode) return 335 | // TODO: enable? 336 | return 337 | ev.stopPropagation() 338 | 339 | var spanSel = d3.select(ev.target).st({display: 'none'}) 340 | var input = d3.select(spanSel.node().parentNode).append('input') 341 | .at({class: 'temp-edit', value: spanSel.text()}) 342 | .on('blur', save) 343 | .on('keydown', ev => { 344 | if (ev.key === 'Enter'){ 345 | save() 346 | input.node().blur() 347 | } 348 | ev.stopPropagation() 349 | }) 350 | 351 | input.node().focus() 352 | 353 | function save(){ 354 | var idx = subgraphState.supernodes.findIndex(([label, ...nodeIds]) => 355 | nodeIds.every(id => d.node.memberNodeIds.includes(id)) 356 | ) 357 | if (idx >= 0){ 358 | subgraphState.supernodes[idx][0] = input.node().value || 'supernode' 359 | supernodesToUrl() 360 | renderSubgraph() 361 | } 362 | } 363 | }) 364 | 365 | 366 | nodeTextSel.each(function(d) { 367 | d.textHeight = this.getBoundingClientRect().height || -8 368 | }) 369 | 370 | nodeSel.append('div.clicked-weight.source') 371 | nodeSel.append('div.clicked-weight.target') 372 | styleNodes() 373 | 374 | 375 | var checkboxes = Object.entries({ 376 | sticky: () => { 377 | // simulation.alphaTarget(0.3).restart() 378 | if (!subgraphState.sticky) unsticky() 379 | }, 380 | dagrefy: () => { 381 | subgraphState.dagrefy ? dagrefy() : selForceNodes.forEach(d => d.dagrePositioned = null) 382 | }, 383 | }).map(([key, fn]) => ({key, fn})) 384 | 385 | 386 | if (visState.isEditMode && false) { 387 | div.append('div.checkbox-container').translate([-c.margin.left, c.margin.bottom]) 388 | .appendMany('label', checkboxes).append('input') 389 | .at({type: 'checkbox'}) 390 | .property('checked', d => subgraphState[d.key]) 391 | .on('change', function(ev, d){ 392 | subgraphState[d.key] = this.checked 393 | d.fn() 394 | }) 395 | .parent().append('span').text(d => d.key) 396 | } 397 | 398 | checkboxes.forEach(d => d.fn()) 399 | 400 | function unsticky(){ 401 | selForceNodes.forEach(d => (d.fx = d.fy = null)) 402 | simulation.alphaTarget(0.3).restart() 403 | if (subgraphState.dagrefy) { 404 | subgraphState.dagrefy = false 405 | d3.select('.checkbox-container').selectAll('input').filter(d => d.key == 'dagrefy').property('checked', 0) 406 | checkboxes.find(d => d.key == 'dagrefy').fn() 407 | } 408 | } 409 | 410 | function dagrefy(){ 411 | if (visState.sg_pos){ 412 | var nums = visState.sg_pos.split(' ').map(d => +d) 413 | selForceNodes.forEach((d, i) => { 414 | d.fx = d.x = nums[i*2 + 0]/1000*c.width 415 | d.fy = d.y = nums[i*2 + 1]/1000*c.height 416 | }) 417 | 418 | nodeSel.translate(d => [d.x, d.y]) 419 | styleNodes() 420 | renderEdges() 421 | 422 | visState.og_sg_pos = visState.sg_pos 423 | delete visState.sg_pos 424 | } 425 | if (visState.og_sg_pos) return 426 | 427 | 428 | var g = new window.dagre.graphlib.Graph() 429 | g.setGraph({rankdir: 'BT', nodesep: 20, ranksep: 20}) 430 | g.setDefaultEdgeLabel(() => ({})) 431 | 432 | sgLinks.forEach(d =>{ 433 | if (Math.abs(d.weight) < .003) return 434 | // set width and height to make dagre return x and y for edges 435 | g.setEdge(d.source.nodeId, d.target.nodeId, {width: 1, height: 1, labelpos: 'c', weight: Math.abs(d.weight)}) 436 | }) 437 | sgNodes.forEach(d => { 438 | g.setNode(d.nodeId, {width: nodeWidth, height: nodeHeight}) 439 | }) 440 | 441 | window.dagre.layout(g) 442 | 443 | // rescale to fit container 444 | var xs = d3.scaleLinear([0, g.graph().width], [0, Math.min(c.width, g.graph().width)]) 445 | var ys = d3.scaleLinear([0, g.graph().height], [0, Math.min(c.height, g.graph().height)]) 446 | 447 | // flip to make ctx_idx left to right and streamIdx bottom to top 448 | var w0 = d3.mean(selForceNodes, d => g.node(d.nodeId).x*d.node.ctx_idx) 449 | var w1 = d3.mean(selForceNodes, d => -g.node(d.nodeId).x*d.node.ctx_idx) 450 | if (w0 < w1) xs.range(xs.range().reverse()) 451 | 452 | var w0 = d3.mean(selForceNodes, d => g.node(d.nodeId).y*d.node.streamIdx) 453 | var w1 = d3.mean(selForceNodes, d => -g.node(d.nodeId).y*d.node.streamIdx) 454 | if (w0 < w1) ys.range(ys.range().reverse()) 455 | 456 | for (var node of window.selForceNodes) { 457 | var pos = g.node(node.nodeId) 458 | node.fx = node.x = xs(pos.x) - nodeWidth/2 459 | node.fy = node.y = ys(pos.y) - nodeHeight/2 460 | node.dagrePositioned = true 461 | } 462 | 463 | // var curveFactory = d3.line(d => d.x, d => d.y).curve(d3.curveBasis) 464 | // svgPaths.at({d: d => { 465 | // var points = g.edge(d.source.nodeId, d.target.nodeId)?.points 466 | // if (!points) return '' 467 | // return curveFactory(points.map(p => ({x: xs(p.x), y: ys(p.y)}))) 468 | // }}) 469 | renderEdges() 470 | 471 | // edgeLabels.translate(d => { 472 | // var pos = g.edge(d.source.nodeId, d.target.nodeId) 473 | // if (!pos) return [-100, -100] 474 | // return [xs(pos.x), ys(pos.y)] 475 | // }) 476 | styleNodes() 477 | } 478 | 479 | function renderForce(){ 480 | nodeSel.translate(d => [d.x, d.y]) 481 | 482 | renderEdges() 483 | 484 | edgeLabels 485 | .filter(d => !(d.source.dagrePositioned && d.target.dagrePositioned)) 486 | .translate(d => [ 487 | (d.source.x + d.target.x) / 2 + nodeWidth / 2, 488 | (d.source.y + d.target.y) / 2 + nodeHeight / 2 489 | ]) 490 | } 491 | 492 | function renderEdges(){ 493 | 494 | // TODO: use actual strokeWidth to spread 495 | d3.nestBy(sgLinks, d => d.source).forEach(links => { 496 | // if (links[0].source.nodeId == '6_12890134_-0') debugger 497 | var numSlots = links[0].source.node.memberNodes.length 498 | var totalWidth = (Math.min(4, numSlots))*8 499 | d3.sort(links, d => Math.atan2(d.target.y - d.source.y, d.target.x - d.source.x)) 500 | .forEach((d, i) => d.sourceOffsetX = (i - links.length/2)*totalWidth/links.length) 501 | }) 502 | 503 | d3.nestBy(sgLinks, d => d.target).forEach(links => { 504 | var numSlots = links[0].target.node.memberNodes.length 505 | var totalWidth = (Math.min(4, numSlots) + 1)*3 506 | d3.sort(links, d => -Math.atan2(d.source.y - d.target.y, d.source.x - d.target.x)) 507 | .forEach((d, i) => d.targetOffsetX = (i - links.length/2)*totalWidth/links.length) 508 | }) 509 | 510 | svgPaths.at({ 511 | d: d => { 512 | var x0 = d.source.x + nodeWidth/2 + d.sourceOffsetX 513 | var y0 = d.source.y 514 | var x1 = d.target.x + nodeWidth/2 + d.targetOffsetX 515 | var y1 = d.target.y + d.target.textHeight + 28 516 | 517 | return `M${x0},${y0} L${x1},${y1}` 518 | }, 519 | }) 520 | } 521 | } 522 | 523 | 524 | function styleNodes() { 525 | if (!nodeSel) return 526 | 527 | nodeSel 528 | .classed('clicked', d => d.nodeId == visState.clickedId) 529 | .classed('hovered', d => d.featureId == visState.hoveredId) 530 | .st({zIndex: d => Math.round(d.x*20 + d.y) + 1000}) 531 | .classed('grouping-selected', d => subgraphState.activeGrouping.selectedNodeIds.has(d.nodeId)) 532 | 533 | memberNodeSel 534 | .classed('clicked', d => d.nodeId == visState.clickedId) 535 | .classed('hovered', d => d.featureId == visState.hoveredId) 536 | .st({ 537 | background: d => d.tmpClickedLink?.pctInputColor, 538 | color: d => utilCg.bgColorToTextColor(d.tmpClickedLink?.pctInputColor) 539 | }) 540 | // .at({title: d => d3.format('.1%')(d.tmpClickedLink?.pctInput)}) 541 | 542 | 543 | 544 | // style clicked links using supernode adjusted graph when possible 545 | sgNodes.forEach(d => { 546 | d.tmpClickedSgSource = d.tmpClickedLink?.sourceNode == d ? d.tmpClickedLink : null 547 | d.tmpClickedSgTarget = d.tmpClickedLink?.targetNode == d ? d.tmpClickedLink : null 548 | }) 549 | 550 | if (visState.clickedId) { 551 | sgLinks.forEach(d => { 552 | if (d.source.nodeId == visState.clickedId) nodeIdToNode[d.target.nodeId].tmpClickedSgTarget = d 553 | if (d.target.nodeId == visState.clickedId) nodeIdToNode[d.source.nodeId].tmpClickedSgSource = d 554 | }) 555 | } 556 | 557 | // nodeSel.selectAll('.clicked-weight.source') 558 | // .st({display: d => d.node.tmpClickedSgSource ? '' : 'none'}) 559 | // .filter(d => d.node.tmpClickedSgSource) 560 | // .text(d => d3.format('.1%')(d.node.tmpClickedSgSource.pctInput)) 561 | // .st({ 562 | // background: d => d.node.tmpClickedSgSource.pctInputColor, 563 | // color: d => utilCg.bgColorToTextColor(d.node.tmpClickedSgSource.pctInputColor) 564 | // }) 565 | 566 | // nodeSel.selectAll('.clicked-weight.target') 567 | // .st({display: d => d.node.tmpClickedSgTarget ? '' : 'none'}) 568 | // .filter(d => d.node.tmpClickedSgTarget) 569 | // .text(d => d3.format('.1%')(d.node.tmpClickedSgTarget.pctInput)) 570 | // .st({ 571 | // background: d => d.node.tmpClickedSgTarget.pctInputColor, 572 | // color: d => utilCg.bgColorToTextColor(d.node.tmpClickedSgTarget.pctInputColor) 573 | // }) 574 | } 575 | 576 | renderAll.hClerpUpdate.fns['subgraph'] = renderSubgraph 577 | renderAll.pinnedIds.fns['subgraph'] = renderSubgraph 578 | renderAll.clickedId.fns['subgraph'] = styleNodes 579 | renderAll.hoveredId.fns['subgraph'] = styleNodes 580 | 581 | // https://github.com/1wheel/d3-force-container/blob/master/src/force-container.js 582 | function forceContainer(bbox) { 583 | var nodes, strength = 1 584 | 585 | function force(alpha) { 586 | var i, 587 | n = nodes.length, 588 | node, 589 | x = 0, 590 | y = 0 591 | 592 | for (i = 0; i < n; ++i) { 593 | node = nodes[i], x = node.x, y = node.y 594 | 595 | if (x < bbox[0][0]) node.vx += (bbox[0][0] - x)*alpha 596 | if (y < bbox[0][1]) node.vy += (bbox[0][1] - y)*alpha 597 | if (x > bbox[1][0]) node.vx += (bbox[1][0] - x)*alpha 598 | if (y > bbox[1][1]) node.vy += (bbox[1][1] - y)*alpha 599 | } 600 | } 601 | 602 | force.initialize = function(_){ 603 | nodes = _ 604 | } 605 | 606 | force.bbox = function(_){ 607 | return arguments.length ? (bbox = +_, force) : bbox 608 | } 609 | force.strength = function(_){ 610 | return arguments.length ? (strength = +_, force) : strength 611 | } 612 | 613 | return force 614 | } 615 | } 616 | 617 | window.init?.() 618 | -------------------------------------------------------------------------------- /attribution_graph/init-cg.js: -------------------------------------------------------------------------------- 1 | window.initCg = async function (sel, slug, {clickedId, clickedIdCb, isModal, isGridsnap} = {}){ 2 | var data = await util.getFile(`/graph_data/${slug}.json`) 3 | console.log(data) 4 | 5 | var visState = { 6 | pinnedIds: [], 7 | hiddenIds: [], 8 | hoveredId: null, 9 | hoveredNodeId: null, 10 | hoveredCtxIdx: null, 11 | clickedId: null, 12 | clickedCtxIdx: null, 13 | linkType: 'both', 14 | isShowAllLinks: '', 15 | isSyncEnabled: '', 16 | subgraph: null, 17 | isEditMode: 1, 18 | isHideLayer: data.metadata.scan == util.scanSlugToName.h35 || data.metadata.scan == util.scanSlugToName.moc, 19 | sg_pos: '', 20 | isModal: true, 21 | isGridsnap, 22 | ...data.qParams 23 | } 24 | 25 | if (visState.clickedId?.includes('supernode')) delete visState.clickedId 26 | if (clickedId && clickedId != 'null' && !clickedId.includes('supernode-')) visState.clickedId = clickedId 27 | if (!visState.clickedId || visState.clickedId == 'null' || visState.clickedId == 'undefined') visState.clickedId = data.nodes.find(d => d.isLogit)?.nodeId 28 | 29 | if (visState.pinnedIds.replace) visState.pinnedIds = visState.pinnedIds.split(',') 30 | if (visState.hiddenIds.replace) visState.hiddenIds = visState.hiddenIds.split(',') 31 | 32 | await utilCg.formatData(data, visState) 33 | 34 | var renderAll = util.initRenderAll(['hClerpUpdate', 'clickedId', 'hiddenIds', 'pinnedIds', 'linkType', 'isShowAllLinks', 'features', 'isSyncEnabled', 'shouldSortByWeight', 'hoveredId']) 35 | 36 | function colorNodes() { 37 | data.nodes.forEach(d => d.nodeColor = '#fff') 38 | } 39 | colorNodes() 40 | 41 | // global link color — the color scale skips #fff so links are visible 42 | // TODO: weight by input sum instead 43 | function colorLinks() { 44 | var absMax = d3.max(data.links, d => d.absWeight) 45 | var _linearAbsScale = d3.scaleLinear().domain([-absMax, absMax]) 46 | var _linearPctScale = d3.scaleLinear().domain([-.4, .4]) 47 | var _linearTScale = d3.scaleLinear().domain([0, .5, .5, 1]).range([0, .5 - .001, .5 + .001, 1]) 48 | 49 | var widthScale = d3.scaleSqrt().domain([0, 1]).range([.00001, 3]) 50 | 51 | utilCg.pctInputColorFn = d => d3.interpolatePRGn(_linearTScale(_linearPctScale(d))) 52 | 53 | data.links.forEach(d => { 54 | // d.color = d3.interpolatePRGn(_linearTScale(_linearAbsScale(d.weight))) 55 | d.strokeWidth = widthScale(Math.abs(d.pctInput)) 56 | d.pctInputColor = utilCg.pctInputColorFn(d.pctInput) 57 | d.color = d3.interpolatePRGn(_linearTScale(_linearPctScale(d.pctInput))) 58 | }) 59 | } 60 | colorLinks() 61 | 62 | renderAll.hClerpUpdate.fns.push(() => utilCg.hClerpUpdateFn(null, data)) 63 | 64 | renderAll.hoveredId.fns.push(() => { 65 | // use hovered node if possible, otherwise use last occurence of feature 66 | var targetCtxIdx = visState.hoveredCtxIdx ?? 999 67 | var hoveredNodes = data.nodes.filter(n => n.featureId == visState.hoveredId) 68 | var node = d3.sort(hoveredNodes, d => Math.abs(d.ctx_idx - targetCtxIdx))[0] 69 | visState.hoveredNodeId = node?.nodeId 70 | }) 71 | 72 | // set tmpClickedLink w/ strength of all the links connected the clickedNode 73 | renderAll.clickedId.fns.push(() => { 74 | clickedIdCb?.(visState.clickedId) 75 | 76 | var node = data.nodes.idToNode[visState.clickedId] 77 | if (!node){ 78 | // for a clicked supernode, sum over memberNode links to make tmpClickedLink 79 | if (visState.clickedId?.startsWith('supernode-')) { 80 | node = { 81 | nodeId: visState.clickedId, 82 | memberNodes: visState.subgraph.supernodes[+visState.clickedId.split('-')[1]] 83 | .slice(1) 84 | .map(id => data.nodes.idToNode[id]) 85 | } 86 | node.memberSet = new Set(node.memberNodes.map(d => d.nodeId)) 87 | 88 | function combineLinks(links, isSrc) { 89 | return d3.nestBy(links, d => isSrc ? d.sourceNode.nodeId : d.targetNode.nodeId) 90 | .map(links => ({ 91 | source: isSrc ? links[0].sourceNode.nodeId : visState.clickedId, 92 | target: isSrc ? visState.clickedId : links[0].targetNode.nodeId, 93 | sourceNode: isSrc ? links[0].sourceNode : node, 94 | targetNode: isSrc ? node : links[0].targetNode, 95 | weight: d3.sum(links, d => d.weight), 96 | absWeight: Math.abs(d3.sum(links, d => d.weight)) 97 | })) 98 | } 99 | 100 | node.sourceLinks = combineLinks(node.memberNodes.flatMap(d => d.sourceLinks), true) 101 | node.targetLinks = combineLinks(node.memberNodes.flatMap(d => d.targetLinks), false) 102 | } else { 103 | return data.nodes.forEach(d => { 104 | d.tmpClickedLink = null 105 | d.tmpClickedSourceLink = null 106 | d.tmpClickedTargetLink = null 107 | }) 108 | } 109 | } 110 | 111 | var connectedLinks = [...node.sourceLinks, ...node.targetLinks] 112 | var max = d3.max(connectedLinks, d => d.absWeight) 113 | var colorScale = d3.scaleSequential(d3.interpolatePRGn).domain([-max*1.1, max*1.1]) 114 | 115 | // allowing supernode links means each node can have a both tmpClickedSourceLink and tmpClickedTargetLink 116 | // currently we render bidirectional links where possible, falling back to the target side links otherwises 117 | var nodeIdToSourceLink = {} 118 | var nodeIdToTargetLink = {} 119 | var featureIdToLink = {} 120 | connectedLinks.forEach(link => { 121 | if (link.sourceNode === node) { 122 | nodeIdToTargetLink[link.targetNode.nodeId] = link 123 | featureIdToLink[link.targetNode.featureId] = link 124 | link.tmpClickedCtxOffset = link.targetNode.ctx_idx - node.ctx_idx 125 | } 126 | if (link.targetNode === node) { 127 | nodeIdToSourceLink[link.sourceNode.nodeId] = link 128 | featureIdToLink[link.sourceNode.featureId] = link 129 | link.tmpClickedCtxOffset = link.sourceNode.ctx_idx - node.ctx_idx 130 | } 131 | // link.tmpColor = colorScale(link.pctInput) 132 | link.tmpColor = link.pctInputColor 133 | }) 134 | 135 | data.nodes.forEach(d => { 136 | var link = nodeIdToSourceLink[d.nodeId] || nodeIdToTargetLink[d.nodeId] 137 | d.tmpClickedLink = link 138 | d.tmpClickedSourceLink = nodeIdToSourceLink[d.nodeId] 139 | d.tmpClickedTargetLink = nodeIdToTargetLink[d.nodeId] 140 | }) 141 | 142 | data.features.forEach(d => { 143 | var link = featureIdToLink[d.featureId] 144 | d.tmpClickedLink = link 145 | }) 146 | }) 147 | 148 | function initGridsnap() { 149 | var gridData = [ 150 | // {cur: {x: 0, y: 0, w: 6, h: 1}, class: 'button-container'}, 151 | {cur: {x: 0, y: 8, w: 8, h: 8}, class: 'subgraph'}, 152 | {cur: {x: 8, y: 1, w: 6, h: 6}, class: 'node-connections'}, 153 | {cur: {x: 8, y: 6, w: 6, h: 10}, class: 'feature-detail'}, 154 | {cur: {x: 0, y: 0, w: 8, h: 8}, class: 'link-graph', resizeFn: makeResizeFn(initCgLinkGraph)}, 155 | // {cur: {x: 0, y: 18, w: 6, h: 7}, class: 'clerp-list'}, 156 | // {cur: {x: 6, y: 30, w: 4, h: 7}, class: 'feature-scatter'}, 157 | // {cur: {x: 0, y: 30, w: 3, h: 8}, class: 'metadata'}, 158 | ].filter(d => d) 159 | 160 | var initFns = [ 161 | // initCgButtonContainer, 162 | initCgSubgraph, 163 | initCgLinkGraph, 164 | initCgNodeConnections, 165 | initCgFeatureDetail, 166 | // initCgClerpList, 167 | // initCgFeatureScatter, 168 | ].filter(d => d) 169 | 170 | var gridsnapSel = sel.html('').append('div.gridsnap.cg') 171 | .classed('is-edit-mode', visState.isGridsnap) 172 | if (visState.isModal) gridsnapSel.st({width: '100%', height: '100%'}) 173 | 174 | 175 | window.initGridsnap({ 176 | gridData, 177 | gridSizeY: 50, 178 | pad: 10, 179 | sel: gridsnapSel, 180 | isFullScreenY: false, 181 | isFillContainer: visState.isModal, 182 | serializedGrid: data.qParams.gridsnap 183 | }) 184 | 185 | initFns.forEach(fn => fn({visState, renderAll, data, cgSel: sel})) 186 | 187 | function makeResizeFn(fn){ 188 | return () => { 189 | fn({visState, renderAll, data, cgSel: sel.select('.gridsnap.cg')}) 190 | Object.values(renderAll).forEach(d => d()) 191 | } 192 | } 193 | } 194 | 195 | initGridsnap() 196 | renderAll.hClerpUpdate() 197 | renderAll.isShowAllLinks() 198 | renderAll.linkType() 199 | renderAll.clickedId() 200 | renderAll.pinnedIds() 201 | renderAll.features() 202 | renderAll.isSyncEnabled() 203 | renderAll.hoveredId() 204 | } 205 | 206 | window.init?.() 207 | -------------------------------------------------------------------------------- /attribution_graph/util-cg.js: -------------------------------------------------------------------------------- 1 | window.utilCg = (function(){ 2 | function clerpUUID(d){ 3 | return '🤖' + d.featureIndex 4 | } 5 | 6 | function parseClerpUUID(str){ 7 | var [featureIndex] = str.split('🤖') 8 | return {featureIndex} 9 | } 10 | 11 | function loadDatapath(urlStr){ 12 | try { 13 | var url = new URL(urlStr) 14 | urlStr = url.searchParams.get('datapath') ?? urlStr 15 | } catch {} 16 | urlStr = urlStr?.replace('index.html', 'data.json').split('?')[0] || 'data.json' 17 | 18 | try { 19 | return util.getFile(urlStr) 20 | } catch (exc) { 21 | d3.select('body') 22 | .html(`Couldn't load data from ${urlStr}: ${exc}. Maybe you need to specify a ?datapath= argument?`) 23 | .st({color: '#c00', fontSize: '150%', padding: '1em', whiteSpace: 'pre-wrap'}) 24 | throw exc 25 | } 26 | } 27 | 28 | function saveHClerpsToLocalStorage(hClerps) { 29 | const key = 'local-clerp' 30 | const hClerpArray = Array.from(hClerps.entries()).filter(d => d[1]) 31 | localStorage.setItem(key, JSON.stringify(hClerpArray)); 32 | } 33 | 34 | function getHClerpsFromLocalStorage() { 35 | const key = 'local-clerp' 36 | // We want to set on load here so that any page load will fix the key. 37 | if (localStorage.getItem(key) === null) localStorage.setItem(key, '[]') 38 | const hClerpArray = JSON.parse(localStorage.getItem(key)) 39 | .filter(d => d[0] != clerpUUID({})) 40 | return new Map(hClerpArray) 41 | } 42 | 43 | async function deDupHClerps() { 44 | const remoteClerps = [] 45 | let remoteMap = new Map(remoteClerps.map(d => { 46 | let key = clerpUUID(d); 47 | let clerp = d.clerp; 48 | return [key, clerp]; 49 | })); 50 | 51 | let localClerps = getHClerpsFromLocalStorage() 52 | let featureLookup = {} 53 | data.features.forEach(d => featureLookup[clerpUUID(d)] = d) 54 | 55 | // update feature data with current spreadsheet 56 | // (why is this behind the "copy" button?) 57 | Array.from(remoteMap).forEach(([key, value]) => { 58 | if (featureLookup[key]) featureLookup[key].remoteClerp = value 59 | }) 60 | 61 | const deDupArray = Array.from(localClerps) 62 | .filter(([key, localClerp]) => { 63 | let remote = remoteMap.get(key); 64 | 65 | // keep only local clerps that are different from remote 66 | if (!remote) return true 67 | // gdoc to local storage mangles quotes, don't force strict equality 68 | function slugify(d){ return d ? d.replace(/['"]/g, '').trim() : ''} 69 | if (slugify(remote) != slugify(localClerp)) return true 70 | 71 | // if local changes are on remote, delete localClerp and set remoteClerp 72 | localClerps.delete(key) 73 | if (featureLookup[key]) featureLookup[key].localClerp = '' 74 | }) 75 | 76 | // copy feature.remoteClerp to node.remoteClerp 77 | data.nodes?.forEach(node => { 78 | var feature = data.features.idToFeature[node.featureId] 79 | node.remoteClerp = feature.remoteClerp 80 | node.localClerp = feature.localClerp 81 | }) 82 | 83 | saveHClerpsToLocalStorage(new Map(deDupArray)) 84 | return new Map(deDupArray); 85 | } 86 | 87 | function tabifyHClerps(hClerps) { 88 | return [] 89 | } 90 | 91 | function hClerpUpdateFn(params, data){ 92 | const localClerps = getHClerpsFromLocalStorage(); 93 | if (params) { 94 | const [node, hClerp] = params; 95 | localClerps.set(clerpUUID(node), hClerp) 96 | saveHClerpsToLocalStorage(localClerps); 97 | } 98 | 99 | data.features.forEach(node => { 100 | node.localClerp = localClerps.get(clerpUUID(node)) 101 | node.ppClerp = node.localClerp || node.remoteClerp || node.clerp; 102 | }) 103 | 104 | data.nodes?.forEach(node => { 105 | var feature = data.features.idToFeature[node.featureId] 106 | if (!feature) return 107 | node.localClerp = feature.localClerp 108 | node.ppClerp = feature.ppClerp 109 | }) 110 | } 111 | 112 | // Adds virtual logit node showing A-B logit difference based on url param logitDiff=⍽tokenA⍽__vs__⍽tokenB⍽ 113 | function addVirtualDiff(data){ 114 | // Filter out any previous virtual nodes/links 115 | var nodes = data.nodes.filter(d => !d.isJsVirtual) 116 | var links = data.links.filter(d => !d.isJsVirtual) 117 | nodes.forEach(d => d.logitToken = d.clerp?.split(`"`)[1]?.split(`" k(p=`)[0]) 118 | 119 | var [logitAStr, logitBStr] = util.params.get('logitDiff')?.split('__vs__') || [] 120 | if (!logitAStr || !logitBStr) return {nodes, links} 121 | var logitANode = nodes.find(d => d.logitToken == logitAStr) 122 | var logitBNode = nodes.find(d => d.logitToken == logitBStr) 123 | if (!logitANode || !logitBNode) return {nodes, links} 124 | 125 | var virtualId = `virtual-diff-${logitAStr}-vs-${logitBStr}` 126 | var diffNode = { 127 | ...logitANode, 128 | node_id: virtualId, 129 | jsNodeId: virtualId, 130 | feature: virtualId, 131 | isJsVirtual: true, 132 | logitToken: `${logitAStr} - ${logitBStr}`, 133 | clerp: `Logit diff: ${logitAStr} - ${logitBStr}`, 134 | } 135 | nodes.push(diffNode) 136 | 137 | var targetLinks = links.filter(d => d.target == logitANode.node_id || d.target == logitBNode.node_id) 138 | d3.nestBy(targetLinks, d => d.source).map(sourceLinks => { 139 | var linkA = sourceLinks.find(d => d.target == logitANode.node_id) 140 | var linkB = sourceLinks.find(d => d.target == logitBNode.node_id) 141 | 142 | links.push({ 143 | source: sourceLinks[0].source, 144 | target: diffNode.node_id, 145 | weight: (linkA?.weight || 0) - (linkB?.weight || 0), 146 | isJsVirtual: true 147 | }) 148 | }) 149 | 150 | return {nodes, links} 151 | } 152 | 153 | // Decorates and mutates data.json 154 | // - Adds pointers between node and links 155 | // - Deletes very common features 156 | // - Adds data.features and data.byStream 157 | async function formatData(data, visState){ 158 | var {metadata} = data 159 | var {nodes, links} = addVirtualDiff(data) 160 | 161 | var py_node_id_to_node = {} 162 | var idToNode = {} 163 | var maxLayer = d3.max(nodes.filter(d => d.feature_type != 'logit'), d => +d.layer) 164 | nodes.forEach((d, i) => { 165 | // To make hover state work across prompts, drop ctx from node id 166 | d.featureId = `${d.layer}_${d.feature}` 167 | 168 | d.active_feature_idx = d.feature 169 | d.nodeIndex = i 170 | 171 | if (d.feature_type == 'logit') d.layer = maxLayer + 1 172 | 173 | // TODO: does this handle error nodes correctly? 174 | if (d.feature_type == 'unexplored node' && !d.layer != 'E'){ 175 | d.feature_type = 'cross layer transcoder' 176 | } 177 | 178 | // count from end to align last token on diff prompts 179 | d.ctx_from_end = data.metadata.prompt_tokens.length - d.ctx_idx 180 | 181 | // add clerp to embed and error nodes 182 | if (d.feature_type.includes('error')){ 183 | d.isError = true 184 | 185 | if (!d.featureId.includes('__err_idx_')) d.featureId = d.featureId + '__err_idx_' + d.ctx_from_end 186 | 187 | if (d.feature_type == 'mlp reconstruction error'){ 188 | d.clerp = `Err: mlp “${util.ppToken(data.metadata.prompt_tokens[d.ctx_idx])}"` 189 | } 190 | } else if (d.feature_type == 'embedding'){ 191 | d.clerp = `Emb: “${util.ppToken(data.metadata.prompt_tokens[d.ctx_idx])}"` 192 | } 193 | 194 | d.url = d.vis_link 195 | d.isFeature = true 196 | 197 | d.targetLinks = [] 198 | d.sourceLinks = [] 199 | 200 | // TODO: switch to featureIndex in graphgen 201 | d.featureIndex = d.feature 202 | 203 | d.nodeId = d.jsNodeId 204 | if (d.feature_type == 'logit' && d.clerp) d.logitPct= +d.clerp.split('(p=')[1].split(')')[0] 205 | idToNode[d.nodeId] = d 206 | py_node_id_to_node[d.node_id] = d 207 | }) 208 | 209 | // delete features that occur in than 2/3 of tokens 210 | // TODO: more principled way of filtering them out — maybe by feature density? 211 | var deletedFeatures = [] 212 | var byFeatureId = d3.nestBy(nodes, d => d.featureId) 213 | byFeatureId.forEach(feature => { 214 | if (feature.length > metadata.prompt_tokens.length*2/3){ 215 | deletedFeatures.push(feature) 216 | feature.forEach(d => { 217 | delete idToNode[d.nodeId] 218 | delete py_node_id_to_node[d.node_id] 219 | }) 220 | } 221 | }) 222 | if (deletedFeatures.length) console.log({deletedFeatures}) 223 | nodes = nodes.filter(d => idToNode[d.nodeId]) 224 | nodes = d3.sort(nodes, d => +d.layer) 225 | 226 | links = links.filter(d => py_node_id_to_node[d.source] && py_node_id_to_node[d.target]) 227 | 228 | // connect links to nodes 229 | links.forEach(link => { 230 | link.sourceNode = py_node_id_to_node[link.source] 231 | link.targetNode = py_node_id_to_node[link.target] 232 | 233 | link.linkId = link.sourceNode.nodeId + '__' + link.targetNode.nodeId 234 | 235 | link.sourceNode.targetLinks.push(link) 236 | link.targetNode.sourceLinks.push(link) 237 | link.absWeight = Math.abs(link.weight) 238 | }) 239 | links = d3.sort(links, d => d.absWeight) 240 | 241 | 242 | nodes.forEach(d => { 243 | d.inputAbsSum = d3.sum(d.sourceLinks, e => Math.abs(e.weight)) 244 | d.sourceLinks.forEach(e => e.pctInput = e.weight/d.inputAbsSum) 245 | d.inputError = d3.sum(d.sourceLinks.filter(e => e.sourceNode.isError), e => Math.abs(e.weight)) 246 | d.pctInputError = d.inputError/d.inputAbsSum 247 | }) 248 | 249 | // convert layer/probe_location_idx to a streamIdx used to position nodes 250 | var byStream = d3.nestBy(nodes, d => [d.layer, d.probe_location_idx] + '') 251 | byStream = d3.sort(byStream, d => d[0].probe_location_idx) 252 | byStream = d3.sort(byStream, d => d[0].layer == 'E' ? -1 : +d[0].layer) 253 | byStream.forEach((stream, streamIdx) => { 254 | stream.forEach(d => { 255 | d.streamIdx = streamIdx 256 | d.layerLocationLabel = layerLocationLabel(d.layer, d.probe_location_idx) 257 | 258 | if (!visState.isHideLayer) d.streamIdx = isFinite(d.layer) ? +d.layer : 0 259 | }) 260 | }) 261 | 262 | // add target_logit_effect__ columns for each logit 263 | var logitNodeMap = new Map(nodes.filter(d => d.isLogit).map(d => [d.node_id, d.logitToken])) 264 | nodes.forEach(node => { 265 | node.targetLinks.forEach(link => { 266 | if (!logitNodeMap.has(link.target)) return 267 | node[`target_logit_effect__${logitNodeMap.get(link.target)}`] = link.weight 268 | }) 269 | }) 270 | 271 | // add ppClerp 272 | await Promise.all(nodes.map(async d => { 273 | if (!d.clerp) d.clerp = '' 274 | d.remoteClerp = '' 275 | })) 276 | 277 | // condense nodes into features, using last occurence of feature if necessary to point to a node 278 | var features = d3.nestBy(nodes.filter(d => d.isFeature), d => d.featureId) 279 | .map(d => ({ 280 | featureId: d[0].featureId, 281 | feature_type: d[0].feature_type, 282 | clerp: d[0].clerp, 283 | remoteClerp: d[0].remoteClerp, 284 | layer: d[0].layer, 285 | streamIdx: d[0].streamIdx, 286 | probe_location_idx: d[0].probe_location_idx, 287 | featureIndex: d[0].featureIndex, 288 | top_logit_effects: d[0].top_logit_effects, 289 | bottom_logit_effects: d[0].bottom_logit_effects, 290 | top_embedding_effects: d[0].top_embedding_effects, 291 | bottom_embedding_effects: d[0].bottom_embedding_effects, 292 | url: d[0].url, 293 | lastNodeId: d.at(-1).nodeId, 294 | isLogit: d[0].isLogit, 295 | isError: d[0].isError, 296 | feature_type: d[0].feature_type, 297 | })) 298 | 299 | nodes.idToNode = idToNode 300 | features.idToFeature = Object.fromEntries(features.map(d => [d.featureId, d])) 301 | links.idToLink = Object.fromEntries(links.map(d => [d.linkId, d])) 302 | 303 | Object.assign(data, {nodes, features, links, byStream}) 304 | } 305 | 306 | function initBcSync({visState, renderAll}){ 307 | var bcStateSync = window.bcSync = window.bcSync || new BroadcastChannel('state-sync') 308 | 309 | function broadcastState(){ 310 | if (!visState.isSyncEnabled) return 311 | bcStateSync.postMessage({ 312 | pinnedIds: visState.pinnedIds, 313 | hiddenIds: visState.hiddenIds, 314 | clickedId: visState.clickedId, 315 | hoveredId: visState.hoveredId, 316 | pageUUID: visState.pageUUID, 317 | isSyncEnabled: visState.isSyncEnabled, 318 | }) 319 | } 320 | 321 | renderAll.pinnedIds.fns.push(ev => { if (!ev?.skipBroadcast) broadcastState() }) 322 | renderAll.hiddenIds.fns.push(ev => { if (!ev?.skipBroadcast) broadcastState() }) 323 | renderAll.clickedId.fns.push(ev => { if (!ev?.skipBroadcast) broadcastState() }) 324 | renderAll.hoveredId.fns.push(ev => { if (!ev?.skipBroadcast) broadcastState() }) 325 | 326 | bcStateSync.onmessage = ev => { 327 | if (!visState.isSyncEnabled) return 328 | if (visState.isSyncEnabled != ev.data.isSyncEnabled) return 329 | if (ev.data.pageUUID == visState.pageUUID) return 330 | 331 | if (JSON.stringify(visState.pinnedIds) != JSON.stringify(ev.data.pinnedIds)){ 332 | visState.pinnedIds = ev.data.pinnedIds 333 | renderAll.pinnedIds({skipBroadcast: true}) 334 | } 335 | 336 | if (JSON.stringify(visState.hiddenIds) != JSON.stringify(ev.data.hiddenIds)){ 337 | visState.hiddenIds = ev.data.hiddenIds 338 | renderAll.hiddenIds({skipBroadcast: true}) 339 | } 340 | 341 | if (visState.clickedId != ev.data.clickedId){ 342 | visState.clickedId = ev.data.clickedId 343 | renderAll.clickedId({skipBroadcast: true}) 344 | } 345 | 346 | if (visState.hoveredId != ev.data.hoveredId){ 347 | visState.hoveredId = ev.data.hoveredId 348 | renderAll.hoveredId({skipBroadcast: true}) 349 | } 350 | } 351 | } 352 | 353 | function addFeatureEvents(visState, renderAll) { 354 | return function(selection) { 355 | selection 356 | .on('mouseover', (ev, d) => { 357 | if (ev.shiftKey) return 358 | if (visState.subgraph?.activeGrouping.isActive) return 359 | ev.preventDefault() 360 | hoverFeature(visState, renderAll, d) 361 | }) 362 | .on('mouseout', (ev, d) => { 363 | if (ev.shiftKey) return 364 | if (visState.subgraph?.activeGrouping.isActive) return 365 | ev.preventDefault() 366 | unHoverFeature(visState, renderAll) 367 | }) 368 | .on('click', (ev, d) => { 369 | if (visState.subgraph?.activeGrouping.isActive) return 370 | clickFeature(visState, renderAll, d, ev.metaKey || ev.ctrlKey) 371 | }) 372 | } 373 | } 374 | 375 | function hoverFeature(visState, renderAll, d) { 376 | if (d.nodeId.includes('supernode-')) return 377 | 378 | if (visState.hoveredId != d.featureId) { 379 | visState.hoveredId = d.featureId 380 | visState.hoveredCtxIdx = d.ctx_idx 381 | renderAll.hoveredId() 382 | } 383 | } 384 | 385 | function unHoverFeature(visState, renderAll) { 386 | if (visState.hoveredId) { 387 | visState.hoveredId = null 388 | visState.hoveredCtxIdx = null 389 | setTimeout(() => { 390 | if (!visState.hoveredId) renderAll.hoveredId() 391 | }) 392 | } 393 | } 394 | function togglePinned(visState, renderAll, d) { 395 | var index = visState.pinnedIds.indexOf(d.nodeId) 396 | if (index == -1) { 397 | visState.pinnedIds.push(d.nodeId) 398 | } else { 399 | visState.pinnedIds.splice(index, 1) 400 | } 401 | renderAll.pinnedIds() 402 | } 403 | 404 | function clickFeature(visState, renderAll, d, metaKey){ 405 | console.log(d) 406 | if (d.nodeId.includes('supernode-')) return 407 | 408 | if (metaKey && visState.isEditMode) { 409 | togglePinned(visState, renderAll, d) 410 | } else { 411 | if (visState.clickedId == d.nodeId) { 412 | visState.clickedId = null 413 | visState.clickedCtxIdx = null 414 | } else { 415 | visState.clickedId = d.nodeId 416 | visState.clickedCtxIdx = d.ctx_idx 417 | } 418 | visState.hoveredId = null 419 | visState.hoveredCtxIdx = null 420 | renderAll.clickedId() 421 | } 422 | } 423 | 424 | function showTooltip(ev, d) { 425 | let tooltipSel = d3.select('.tooltip'), 426 | x = ev.clientX, 427 | y = ev.clientY, 428 | bb = tooltipSel.node().getBoundingClientRect(), 429 | left = d3.clamp(20, (x-bb.width/2), window.innerWidth - bb.width - 20), 430 | top = innerHeight > y + 20 + bb.height ? y + 20 : y - bb.height - 20; 431 | 432 | let tooltipHtml = !ev.metaKey ? (d.ppClerp || `F#${d.feature}`) : Object.keys(d) 433 | .filter(str => typeof d[str] != 'object' && typeof d[str] != 'function' && !keysToSkip.has(str)) 434 | .map(str => { 435 | var val = d[str] 436 | if (typeof val == 'number' && !Number.isInteger(val)) val = val.toFixed(6) 437 | return `
${str}: ${val}
` 438 | }) 439 | .join('') 440 | 441 | tooltipSel 442 | .style('left', left +'px') 443 | .style('top', top + 'px') 444 | .html(tooltipHtml) 445 | .classed('tooltip-hidden', false) 446 | } 447 | 448 | function addFeatureTooltip(selection){ 449 | selection 450 | .call(d3.attachTooltip, d3.select('.tooltip'), []) 451 | .on('mouseover.tt', (ev, d) => { 452 | var tooltipHtml = !ev.metaKey ? d.ppClerp : Object.keys(d) 453 | .filter(str => typeof d[str] != 'object' && typeof d[str] != 'function' && !keysToSkip.has(str)) 454 | .map(str => { 455 | var val = d[str] 456 | if (typeof val == 'number' && !Number.isInteger(val)) val = val.toFixed(6) 457 | return `
${str}: ${val}
` 458 | }) 459 | .join('') 460 | 461 | d3.select('.tooltip').html(tooltipHtml) 462 | }) 463 | } 464 | 465 | function hideTooltip() { 466 | d3.select('.tooltip').classed('tooltip-hidden', true); 467 | } 468 | 469 | function updateFeatureStyles(nodeSel){ 470 | nodeSel.call(classAndRaise('hovered', e => e.featureId == visState.hoveredId)) 471 | 472 | var pinnedIdSet = new Set(visState.pinnedIds) 473 | nodeSel.call(classAndRaise('pinned', d => pinnedIdSet.has(d.nodeId))) 474 | 475 | var hiddenIdSet = new Set(visState.hiddenIds) 476 | nodeSel.call(classAndRaise('hidden', d => hiddenIdSet.has(d.featureId))) 477 | 478 | if (nodeSel.datum().nodeId){ 479 | nodeSel.call(classAndRaise('clicked', e => e.nodeId === visState.clickedId)) 480 | } else { 481 | nodeSel.call(classAndRaise('clicked', d => d.featureId == visState.clickedId)) 482 | } 483 | } 484 | 485 | function classAndRaise(className, filterFn) { 486 | return sel => { 487 | sel 488 | .classed(className, 0) 489 | .filter(filterFn) 490 | .classed(className, 1) 491 | .raise() 492 | } 493 | } 494 | 495 | var keysToSkip = new Set([ 496 | 'node_id', 'jsNodeId', 'nodeId', 'layerLocationLabel', 'remoteClerp', 'localClerp', 497 | 'tmpClickedTargetLink', 'tmpClickedLink', 'tmpClickedSourceLink', 498 | 'pos', 'xOffset', 'yOffset', 'sourceLinks', 'targetLinks', 'url', 'vis_link', 'run_idx', 499 | 'featureId', 'active_feature_idx', 'nodeIndex', 'isFeature', 'Distribution', 500 | 'clerp', 'ppClerp', 'is_target_logit', 'token_prob', 'reverse_ctx_idx', 'ctx_from_end', 'feature', 'logitToken', 501 | 'featureIndex', 'streamIdx', 'nodeColor', 'umap_enc_x', 'umap_enc_y', 'umap_dec_x', 'umap_dec_y', 'umap_concat_x', 'umap_concat_y', 502 | ]) 503 | 504 | 505 | function layerLocationLabel(layer, location) { 506 | if (layer == 'E') return 'Emb' 507 | if (layer == 'E1') return 'Lgt' 508 | if (location === -1) return 'logit' 509 | 510 | // TODO: is stream probe_location_idx no longer be saved out? 511 | // NOTE: For now, location is literally ProbePointLocation 512 | return `L${layer}` 513 | } 514 | 515 | var memoize = fn => { 516 | var cache = new Map() 517 | return (...args) => { 518 | var key = JSON.stringify(args) 519 | if (cache.has(key)) return cache.get(key) 520 | var result = fn(...args) 521 | cache.set(key, result) 522 | return result 523 | } 524 | } 525 | 526 | var bgColorToTextColor = memoize((backgroundColor, light='#fff', dark='#000') => { 527 | if (!backgroundColor) return '' 528 | var hsl = d3.hsl(backgroundColor) 529 | return hsl.l > 0.55 ? dark : light 530 | }) 531 | 532 | // gradient for hover && pinned state 533 | function addPinnedClickedGradient(svg){ 534 | svg.append('defs').html(` 535 | 536 | 537 | 538 | 539 | 540 | 541 | `) 542 | } 543 | 544 | function renderFeatureRow(sel, visState, renderAll, linkKey='tmpClickedLink'){ 545 | sel.st({ 546 | background: d => d[linkKey]?.tmpColor, 547 | color: d => bgColorToTextColor(d[linkKey]?.tmpColor, '#eee', '#555'), 548 | }) 549 | 550 | // add events in a timeout to avoid connection clicks leading to an instant hover 551 | setTimeout(() => sel.call(addFeatureEvents(visState, renderAll)), 16) 552 | 553 | let featureIconSel = sel.append('svg') 554 | .at({width: 10, height: 10}) 555 | 556 | let featureIcon = featureIconSel.append('g') 557 | 558 | featureIcon.append('g.default-icon').append('text') 559 | .text(d => featureTypeToText(d.feature_type)) 560 | .at({ 561 | fontSize: 9, 562 | textAnchor: 'middle', 563 | dominantBaseline: 'central', 564 | dx: 5, 565 | dy: 4, 566 | }) 567 | .at({fill: d => d[linkKey]?.tmpColor}) 568 | 569 | 570 | sel.append('div.label') 571 | .text(d => d.ppClerp) 572 | .at({ title: d => d.ppClerp }) 573 | 574 | sel 575 | .filter(d => d[linkKey] && d[linkKey].tmpClickedCtxOffset != 0) 576 | .append('div.ctx-offset') 577 | .text(d => d[linkKey].tmpClickedCtxOffset < 0 ? '←' : '→') 578 | 579 | if (!visState.isHideLayer){ 580 | sel.append('div.layer') 581 | .text(d => d.layerLocationLabel ?? layerLocationLabel(d.layer, d.probe_location_idx)); 582 | } 583 | 584 | sel.append('div.weight') 585 | .text(d => d[linkKey] ? d3.format('+.3f')(d[linkKey].pctInput) : '') 586 | } 587 | 588 | function featureTypeToText(type){ 589 | if (type == 'logit') return '■' 590 | if (type == 'embedding') return '■' 591 | if (type === 'mlp reconstruction error') return '◆' 592 | return '●' 593 | 594 | } 595 | 596 | 597 | return { 598 | loadDatapath, 599 | formatData, 600 | initBcSync, 601 | addFeatureEvents, 602 | hoverFeature, 603 | unHoverFeature, 604 | clickFeature, 605 | togglePinned, 606 | layerLocationLabel, 607 | keysToSkip, 608 | addFeatureTooltip, 609 | showTooltip, 610 | hideTooltip, 611 | updateFeatureStyles, 612 | memoize, 613 | bgColorToTextColor, 614 | addPinnedClickedGradient, 615 | renderFeatureRow, 616 | saveHClerpsToLocalStorage, 617 | getHClerpsFromLocalStorage, 618 | hClerpUpdateFn, 619 | deDupHClerps, 620 | tabifyHClerps, 621 | featureTypeToText, 622 | } 623 | })() 624 | 625 | window.init?.() 626 | -------------------------------------------------------------------------------- /feature_examples/feature-examples.css: -------------------------------------------------------------------------------- 1 | .feature-examples { 2 | display: flex; 3 | flex-direction: column; 4 | position: relative; 5 | 6 | .example-2-col { 7 | flex: 1; 8 | overflow: visible; 9 | overflow-x: hidden; 10 | position: relative; 11 | z-index: 100; 12 | } 13 | 14 | .chart-row{ 15 | display: grid; 16 | grid-template-columns: 2fr 1fr; 17 | &:has(> :only-child) { 18 | grid-template-columns: 1fr; 19 | } 20 | 21 | gap: 20px; 22 | line-height: 0px; 23 | } 24 | 25 | .feature-example-logits{ 26 | color: hsl(0 0 0 / 0.5); 27 | font-size: 10px; 28 | font-family: sans-serif; 29 | line-height: 14px; 30 | position: sticky; 31 | top: 0px; 32 | background: #fff; 33 | z-index: 100000; 34 | margin-bottom: -4px; 35 | outline: 2px solid #fff; 36 | 37 | .section-title{ 38 | margin-bottom: 3px; 39 | } 40 | 41 | .token { 42 | display: inline-block; 43 | margin-left: 0.2em; 44 | background-color: #eee; 45 | color: #666; 46 | padding: 1px 3px !important; 47 | border-radius: 4px; 48 | font-family: monospace; 49 | font-size: 11px; 50 | white-space: pre; 51 | overflow-y: hidden; 52 | vertical-align: middle; 53 | max-height: 13px; 54 | margin-right: 2px; 55 | } 56 | 57 | .logit-label { 58 | display: inline-block; 59 | margin-right: 7px; 60 | color: #aaa; 61 | padding-right: 3px; 62 | } 63 | } 64 | 65 | .example-2-col:not(:last-child) { 66 | margin-right: 10px; 67 | } 68 | 69 | .example-quantile { 70 | margin: 0; 71 | margin-top: 20px; 72 | } 73 | 74 | .train_token_ind{ 75 | outline: 1px solid #000; 76 | } 77 | 78 | .ctx-container { 79 | font-family: monospace; 80 | font-size: 11px; 81 | line-height: 11px; 82 | padding: 6px 0px 4px; 83 | border-bottom: 1px solid #eee; 84 | border-left: 2px solid rgba(0, 0, 0, 0); 85 | 86 | overflow: visible; 87 | white-space: nowrap; 88 | position: relative; 89 | a, .dataset-source{ 90 | margin-right: 7px; 91 | color: #aaa; 92 | display: inline-block; 93 | } 94 | 95 | .dataset-source{ 96 | padding-right: 3px; 97 | border-right: 1px solid #eee 98 | } 99 | } 100 | 101 | .is-repeated-datapoint{ 102 | opacity: .2; 103 | } 104 | .is-repeated-datapoint:before { 105 | content: '⟳ '; 106 | } 107 | 108 | .ctx-container .token { 109 | border-radius: 4px; 110 | padding: 1px; 111 | cursor: default; 112 | white-space: pre; 113 | } 114 | 115 | .ctx-container .token.active { 116 | outline: 2px solid #000; 117 | overflow: visible; 118 | } 119 | 120 | svg{ 121 | overflow: visible; 122 | } 123 | 124 | &.is-stale-output { 125 | td span, .text-wrapper, .token { 126 | opacity: 0; 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /feature_examples/init-feature-examples-list.js: -------------------------------------------------------------------------------- 1 | window.initFeatureExamplesList = function({renderAll, visState, sel}){ 2 | var sel = sel.select('.feature-example-list') 3 | renderAll.feature.fns.push(async () => { 4 | if (visState.feature.isDead) return sel.html(`Feature ${visState.feature.featureIndex} failed to load`) 5 | 6 | // // Put quantiles into cols to fill white space. 7 | // var cols = d3.range(Math.max(1, Math.floor(sel.node().offsetWidth/800))).map(d => []) 8 | // cols.forEach(col => col.y = 0) 9 | // visState.feature.examples_quantiles.forEach((d, i) => { 10 | // var col = cols[d3.minIndex(cols, d => d.y)] 11 | // col.push(d) 12 | // col.y += d.examples.length + 2 // quantile header/whitespace is about 2× bigger than an example 13 | // if (!i) col.y += 6 14 | // }) 15 | // 16 | var cols = [visState.feature.examples_quantiles] 17 | sel.html('') 18 | .appendMany('div.example-2-col', cols) 19 | .appendMany('div', d => d) 20 | .each(drawQuantile) 21 | 22 | var firstColSel = sel.select('.example-2-col').append('div').lower() 23 | }) 24 | 25 | function drawQuantile(quantile){ 26 | var sel = d3.select(this) 27 | 28 | var quintileSel = sel.append('div.example-quantile') 29 | .append('span.quantile-title').text(quantile.quantile_name + ' ') 30 | 31 | sel.appendMany('div.ctx-container', quantile.examples).each(drawExample) 32 | } 33 | 34 | function maybeHexEscapedToBytes(token) { // -> number[] 35 | let ret = []; 36 | while (token.length) { 37 | if (/^\\x[0-9a-f]{2}/.exec(token)) { 38 | ret.push(parseInt(token.slice(2, 4), 16)); 39 | token = token.slice(4); 40 | } else { 41 | ret.push(...new TextEncoder().encode(token[0])); 42 | token = token.slice(1); 43 | } 44 | } 45 | return ret; 46 | } 47 | function mergeHexEscapedMax(tokens, acts) { 48 | // -> {token: string, act: number, minIndex: int, maxIndex: int}[] 49 | let ret = []; 50 | let i = 0; 51 | while (i < tokens.length) { 52 | let maxAct = acts[i]; 53 | let pushedMerge = false; 54 | if (/\\x[0-9a-f]{2}/.exec(tokens[i])) { 55 | let buf = maybeHexEscapedToBytes(tokens[i]); 56 | for (let j = i + 1; j < Math.min(i + 5, tokens.length); j++) { 57 | maxAct = Math.max(maxAct, acts[j]); 58 | buf.push(...maybeHexEscapedToBytes(tokens[j])); 59 | try { 60 | let text = new TextDecoder("utf-8", { fatal: true }).decode( 61 | new Uint8Array(buf), 62 | ); 63 | ret.push({ 64 | token: text, 65 | act: maxAct, 66 | minIndex: i, 67 | maxIndex: j, 68 | }); 69 | i = j + 1; 70 | pushedMerge = true; 71 | break; 72 | } catch (e) { 73 | continue; 74 | } 75 | } 76 | } 77 | if (!pushedMerge) { 78 | ret.push({ 79 | token: tokens[i], 80 | act: acts[i], 81 | minIndex: i, 82 | maxIndex: i, 83 | }); 84 | i++; 85 | } 86 | } 87 | return ret; 88 | } 89 | 90 | function drawExample(exp){ 91 | var sel = d3.select(this).append('div') 92 | .st({opacity: exp.is_repeated_datapoint ? .4 : 1}) 93 | var textSel = sel.append('div.text-wrapper') 94 | 95 | var tokenData = mergeHexEscapedMax(exp.tokens, exp.tokens_acts_list); 96 | var tokenSel = textSel.appendMany('span.token', tokenData) 97 | .text(d => d.token) 98 | // .at({title: d => `${d.token} (${d.act})` }) 99 | 100 | tokenSel 101 | .filter(d => d.act) 102 | .st({background: d => visState.feature.colorScale(d.act)}) 103 | 104 | var centerNode = tokenSel 105 | .filter(d => d.minIndex <= exp.train_token_ind && exp.train_token_ind <= d.maxIndex) 106 | .classed('train_token_ind', 1) 107 | .node() 108 | 109 | if (!centerNode) return 110 | var leftOffset = (sel.node().offsetWidth - centerNode.offsetWidth)/2 - centerNode.offsetLeft 111 | textSel.translate([leftOffset, 0]) 112 | } 113 | } 114 | 115 | // window.initFeatureExample?.() 116 | window.init?.() 117 | -------------------------------------------------------------------------------- /feature_examples/init-feature-examples-logits.js: -------------------------------------------------------------------------------- 1 | window.initFeatureExamplesLogits = function({renderAll, visState, sel}) { 2 | renderAll.feature.fns.push(async () => { 3 | let { top_logits, bottom_logits } = visState.feature; 4 | if (!top_logits?.length && !bottom_logits?.length) return; 5 | 6 | var containerSel = sel.select('.feature-example-logits').html('') 7 | containerSel.append('div.section-title').text('Token Predictions') 8 | // containerSel.parent().st({position: 'sticky', top: 0}) 9 | 10 | for (let [rowName, logits] of [['Top', top_logits], ['Bottom', bottom_logits]]) { 11 | if (!logits?.length) continue; 12 | 13 | var row = containerSel.append('div.ctx-container') 14 | row.append('div.logit-label').text(rowName) 15 | 16 | row.appendMany('span.token', logits).text(d => d) 17 | } 18 | }) 19 | } 20 | 21 | window.init?.() 22 | -------------------------------------------------------------------------------- /feature_examples/init-feature-examples.js: -------------------------------------------------------------------------------- 1 | window.initFeatureExamples = function({containerSel, showLogits=true, showExamples=true, hideStaleOutputs=false}){ 2 | var visState = { 3 | isDev: 0, 4 | showLogits, 5 | showExamples, 6 | hideStaleOutputs, 7 | 8 | activeToken: null, 9 | feature: null, 10 | featureIndex: -1, 11 | 12 | chartRowTop: 16, 13 | chartRowHeight: 82, 14 | } 15 | 16 | // set up dom and render fns 17 | var sel = containerSel.html('').append('div.feature-examples') 18 | if (visState.showLogits) sel.append('div.feature-example-logits') 19 | if (visState.showExamples) sel.append('div.feature-example-list') 20 | var renderAll = util.initRenderAll(['feature']) 21 | 22 | 23 | if (visState.showLogits) window.initFeatureExamplesLogits({renderAll, visState, sel}) 24 | if (visState.showExamples) window.initFeatureExamplesList({renderAll, visState, sel}) 25 | 26 | return {loadFeature, renderFeature} 27 | 28 | async function renderFeature(scan, featureIndex){ 29 | if (visState.hideStaleOutputs) sel.classed('is-stale-output', 1) 30 | 31 | // load feature data and exit early if featureIndex has changed 32 | visState.featureIndex = featureIndex 33 | var feature = await loadFeature(scan, featureIndex) 34 | if (feature.featureIndex == visState.featureIndex){ 35 | visState.feature = feature 36 | renderAll.feature() 37 | if (visState.hideStaleOutputs) sel.classed('is-stale-output', 0) 38 | } 39 | 40 | return feature 41 | } 42 | 43 | async function loadFeature(scan, featureIndex){ 44 | try { 45 | var feature = await util.getFile(`/features/${scan}/${featureIndex}.json`) 46 | } catch { 47 | var feature = {isDead: true, statistics: {}} 48 | } 49 | 50 | feature.colorScale = d3.scaleSequential(d3.interpolateOranges) 51 | .domain([0, 1.4]).clamp(1) 52 | 53 | feature.featureIndex = featureIndex 54 | feature.scan = scan 55 | 56 | return feature 57 | } 58 | } 59 | 60 | window.init?.() 61 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Attribution Graphs 6 | 7 |
8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 84 | -------------------------------------------------------------------------------- /prettier.config.js: -------------------------------------------------------------------------------- 1 | const config = { 2 | requirePragma: true, 3 | }; 4 | 5 | module.exports = config; 6 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | html { 2 | margin: 0px; 3 | background: #fff; 4 | font-family: system-ui; 5 | } 6 | 7 | body { 8 | margin: 10px; 9 | } 10 | 11 | .tooltip { 12 | top: -1000px; 13 | position: fixed; 14 | padding: 5px; 15 | background: rgba(255, 255, 255, 1); 16 | border: 1px solid lightgray; 17 | pointer-events: none; 18 | z-index: 10000; 19 | font-size: 10px; 20 | line-height: 1.3em; 21 | } 22 | 23 | .tooltip-hidden { 24 | opacity: 0; 25 | transition: all .3s; 26 | transition-delay: .1s; 27 | } 28 | 29 | 30 | text{ 31 | cursor: default; 32 | } 33 | 34 | svg{ 35 | overflow: visible; 36 | max-width: 100%; 37 | height: auto; 38 | } 39 | 40 | .cg-div{ 41 | background: #FAFAFA; 42 | height: 800px; 43 | border-top: 1px solid #EEE; 44 | border-bottom: 1px solid #EEE; 45 | padding: 40px 10px; 46 | margin: 60px -10px; 47 | width: calc(100% + 20px); 48 | grid-column: 1 / -1; 49 | position: relative; 50 | box-sizing: border-box; 51 | overflow-x: hidden; 52 | 53 | > .gridsnap{ 54 | max-width: 1800px; 55 | margin: 0 auto; 56 | } 57 | 58 | .feature-example-logits{ 59 | background: #FAFAFA !important; 60 | outline-color: #FAFAFA !important; 61 | } 62 | } 63 | 64 | d-article{ 65 | overflow: visible; 66 | } 67 | 68 | 69 | .vis-link{ 70 | margin: 0px auto; 71 | font-size: 12px; 72 | margin-top: -15px; 73 | margin-bottom: 1em; 74 | } 75 | 76 | .full_graph_link{ 77 | cursor: pointer; 78 | text{ 79 | cursor: pointer; 80 | } 81 | 82 | rect{ 83 | transition: all 200ms; 84 | } 85 | 86 | &:hover rect{ 87 | fill-opacity: 0.2; 88 | } 89 | } 90 | 91 | .graph-prompt-select { 92 | max-width: 40%; 93 | max-height: 4.8em; 94 | white-space: normal; 95 | display: flex; 96 | align-items: center; 97 | text-overflow: ellipsis; 98 | overflow: hidden; 99 | padding: 8px; 100 | padding-right: 25px; 101 | border-radius: 5px; 102 | 103 | background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 20 20' stroke-width='.5px' fill='none' %3E%3Cpath stroke='%23000' d='M8 10l2 2 2-2'/%3E%3C/svg%3E"); 104 | background-position: right 0px center; 105 | background-repeat: no-repeat; 106 | background-size: 40px; 107 | 108 | /* Removes the dropdown arrow */ 109 | appearance: none; 110 | 111 | option { 112 | display: -webkit-box; 113 | -webkit-line-clamp: 3; 114 | -webkit-box-orient: vertical; 115 | overflow: hidden; 116 | } 117 | } 118 | 119 | .section-title, .quantile-title{ 120 | font-weight: 500 !important; 121 | font-size: 13px; 122 | margin-bottom: 5px; 123 | font-family: system-ui; 124 | color: #000; 125 | padding-bottom: 5px; 126 | border-bottom: solid 1px #eee; 127 | display: block; 128 | line-height: 14px; 129 | } 130 | 131 | .quantile-title{ 132 | margin-bottom: 0px !important; 133 | } 134 | 135 | 136 | 137 | .link a{ 138 | color: #000; 139 | /* text-decoration: none; */ 140 | font-size: 12px; 141 | } 142 | 143 | .nav { 144 | display: flex; 145 | gap: 10px; 146 | flex-wrap: wrap; 147 | margin-top: 20px; 148 | 149 | } 150 | 151 | 152 | .nav-button { 153 | padding: 8px 16px; 154 | border: 1px solid #ccc; 155 | border-radius: 4px; 156 | cursor: pointer; 157 | background: white; 158 | font-size: 14px; 159 | } 160 | 161 | .nav-button:hover { 162 | border-color: #000; 163 | } 164 | 165 | .nav-button.active { 166 | background: #000; 167 | color: white; 168 | border-color: #000; 169 | } 170 | 171 | 172 | .cg{ 173 | height: calc(100vh - 150px); 174 | } 175 | -------------------------------------------------------------------------------- /util.js: -------------------------------------------------------------------------------- 1 | window.util = (function () { 2 | var params = (function(){ 3 | var rv = {} 4 | 5 | rv.get = key => { 6 | var url = new URL(window.location) 7 | var searchParams = new URLSearchParams(url.search) 8 | 9 | var str = searchParams.get(key) 10 | return str && decodeURIComponent(str) 11 | } 12 | 13 | rv.getAll = () => { 14 | var url = new URL(window.location) 15 | var searchParams = new URLSearchParams(url.search) 16 | 17 | var values = {} 18 | for (const [key, value] of searchParams.entries()) { 19 | values[key] = decodeURIComponent(value) 20 | } 21 | return values 22 | } 23 | 24 | rv.set = (key, value) => { 25 | var url = new URL(window.location) 26 | var searchParams = new URLSearchParams(url.search) 27 | 28 | if (value === null) { 29 | searchParams.delete(key) 30 | } else { 31 | searchParams.set(key, encodeURIComponent(value)) 32 | } 33 | 34 | url.search = searchParams.toString() 35 | history.replaceState(null, '', url) 36 | } 37 | 38 | return rv 39 | })() 40 | 41 | async function getFile(path) { 42 | // Cache storage 43 | var __datacache = window.__datacache = window.__datacache || {} 44 | 45 | // If path starts with /, treat as relative to static_js directory 46 | if (path.startsWith('/')) { 47 | path = 'https://transformer-circuits.pub/2025/attribution-graphs' + path 48 | } 49 | 50 | // Return cached result if available 51 | if (!__datacache[path]) __datacache[path] = __fetch() 52 | return __datacache[path] 53 | 54 | async function __fetch() { 55 | var res = await fetch(path, {cache: 'force-cache'}) 56 | if (res.status == 500) { 57 | var resText = await res.text() 58 | console.log(resText, res) 59 | throw '500 error' 60 | } 61 | 62 | var type = path.replaceAll('..', '').split('.').at(-1) 63 | if (type == 'csv') { 64 | return d3.csvParse(await res.text()) 65 | } else if (type == 'npy') { 66 | return npyjs.parse(await res.arrayBuffer()) 67 | } else if (type == 'json') { 68 | return await res.json() 69 | } else if (type == 'jsonl') { 70 | var text = await res.text() 71 | return text.split(/\r?\n/).filter(d => d).map(line => JSON.parse(line)) 72 | } else { 73 | return await res.text() 74 | } 75 | } 76 | } 77 | 78 | 79 | function addAxisLabel(c, xText, yText, title='', xOffset=0, yOffset=0, titleOffset=0){ 80 | c.svg.select('.x').append('g') 81 | .translate([c.width/2, xOffset + 25]) 82 | .append('text.axis-label') 83 | .text(xText) 84 | .at({textAnchor: 'middle', fill: '#000'}) 85 | 86 | c.svg.select('.y') 87 | .append('g') 88 | .translate([yOffset -30, c.height/2]) 89 | .append('text.axis-label') 90 | .text(yText) 91 | .at({textAnchor: 'middle', fill: '#000', transform: 'rotate(-90)'}) 92 | 93 | c.svg 94 | .append('g.axis').at({fontFamily: 'sans-serif'}) 95 | .translate([c.width/2, titleOffset -10]) 96 | .append('text.axis-label.axis-title') 97 | .text(title) 98 | .at({textAnchor: 'middle', fill: '#000'}) 99 | } 100 | 101 | function ggPlot(c){ 102 | c.svg.append('rect.bg-rect') 103 | .at({width: c.width, height: c.height, fill: c.isBlack ? '#000' : '#EAECED'}).lower() 104 | c.svg.selectAll('.domain').remove() 105 | 106 | c.svg.selectAll('.x text').at({y: 4}) 107 | c.svg.selectAll('.x .tick') 108 | .selectAppend('path').at({d: 'M 0 0 V -' + c.height, stroke: c.isBlack ? '#444' : '#fff', strokeWidth: 1}) 109 | 110 | c.svg.selectAll('.y text').at({x: -3}) 111 | c.svg.selectAll('.y .tick') 112 | .selectAppend('path').at({d: 'M 0 0 H ' + c.width, stroke: c.isBlack? '#444' : '#fff', strokeWidth: 1}) 113 | 114 | ggPlotUpdate(c) 115 | } 116 | 117 | function ggPlotUpdate(c){ 118 | c.svg.selectAll('.tick').selectAll('line').remove() 119 | 120 | c.svg.selectAll('.x text').at({y: 4}) 121 | c.svg.selectAll('.x .tick') 122 | .selectAppend('path').at({d: 'M 0 0 V -' + c.height, stroke: c.isBlack ? '#444' : '#fff', strokeWidth: 1}) 123 | 124 | c.svg.selectAll('.y text').at({x: -3}) 125 | c.svg.selectAll('.y .tick') 126 | .selectAppend('path').at({d: 'M 0 0 H ' + c.width, stroke: c.isBlack? '#444' : '#fff', strokeWidth: 1}) 127 | } 128 | 129 | function initRenderAll(fnLabels){ 130 | var rv = {} 131 | fnLabels.forEach(label => { 132 | rv[label] = (ev) => Object.values(rv[label].fns).forEach(d => d(ev)) 133 | rv[label].fns = [] 134 | }) 135 | 136 | return rv 137 | } 138 | 139 | function attachRenderAllHistory(renderAll, skipKeys=['hoverId', 'hoverIdx']) { 140 | // Add state pushing to each render function 141 | Object.keys(renderAll).forEach(key => { 142 | renderAll[key].fns.push(() => { 143 | if (skipKeys.includes(key)) return 144 | var simpleVisState = {...visState} 145 | skipKeys.forEach(key => delete simpleVisState[key]) 146 | 147 | var url = new URL(window.location) 148 | if (visState[key] == url.searchParams.get(key)) return 149 | url.searchParams.set(key, simpleVisState[key]) 150 | history.pushState(simpleVisState, '', url) 151 | }) 152 | }) 153 | 154 | // Handle back/forward navigation 155 | d3.select(window).on('popstate.updateState', ev => { 156 | if (!ev.state) return 157 | ev.preventDefault() 158 | Object.keys(renderAll).forEach(key => { 159 | if (skipKeys.includes(key)) return 160 | if (visState[key] == ev.state[key]) return 161 | visState[key] = ev.state[key] 162 | renderAll[key]() 163 | }) 164 | }) 165 | } 166 | 167 | function throttle(fn, delay){ 168 | var lastCall = 0 169 | return (...args) => { 170 | if (Date.now() - lastCall < delay) return 171 | lastCall = Date.now() 172 | fn(...args) 173 | } 174 | } 175 | 176 | function debounce(fn, delay) { 177 | var timeout 178 | return (...args) => { 179 | clearTimeout(timeout) 180 | timeout = setTimeout(() => fn(...args), delay) 181 | } 182 | } 183 | 184 | function throttleDebounce(fn, delay) { 185 | var lastCall = 0 186 | var timeoutId 187 | 188 | return function (...args) { 189 | clearTimeout(timeoutId) 190 | var remainingDelay = delay - (Date.now() - lastCall) 191 | if (remainingDelay <= 0) { 192 | lastCall = Date.now() 193 | fn.apply(this, args) 194 | } else { 195 | timeoutId = setTimeout(() => { 196 | lastCall = Date.now() 197 | fn.apply(this, args) 198 | }, remainingDelay) 199 | } 200 | } 201 | } 202 | 203 | function sleep(ms) { 204 | return new Promise(resolve => setTimeout(resolve, ms)) 205 | } 206 | 207 | function cache(fn){ 208 | var cache = {} 209 | return function(...args){ 210 | var key = JSON.stringify(args) 211 | if (!(key in cache)) cache[key] = fn.apply(this, args) 212 | return cache[key] 213 | } 214 | } 215 | var featureExamplesTooltipSel 216 | var featureExamples 217 | var featureQueue = [] 218 | function attachFeatureExamplesTooltip(sel, getFeatureParams, getNearby){ 219 | if (!featureExamplesTooltipSel){ 220 | featureExamplesTooltipSel = d3.select('body') 221 | .selectAppend('div.tooltip.feature-examples-tooltip.tooltip-hidden') 222 | .on('mouseover', mousemove) 223 | .on('mousemove', mousemove) 224 | .on('mouseleave', mouseout) 225 | 226 | // Add touch event handler to body 227 | d3.select('body').on('click.feature-tooltip', ev => { 228 | // Don't trigger if touch is on tooltip or tooltipped element 229 | if (ev.target.closest('.feature-examples-tooltip') || ev.target.closest('.feature-examples-tooltipped')) return 230 | mouseout() 231 | }) 232 | 233 | d3.select(window).on('scroll.feature-examples-tooltip', () => { 234 | if (featureExamplesTooltipSel.isFading || featureExamplesTooltipSel.isFaded) return 235 | mouseout() 236 | }) 237 | 238 | featureExamplesTooltipSel.append('div.feature-nav') 239 | featureExamples = window.initFeatureExamples({ 240 | containerSel: featureExamplesTooltipSel.append('div'), 241 | hideStaleOutputs: true, 242 | }) 243 | 244 | if (window.__feature_tooltip_queue_timer) __feature_tooltip_queue_timer.stop() 245 | window.__feature_tooltip_queue_timer = d3.timer(() => { 246 | if (!featureQueue.length) return 247 | var feature = featureQueue.pop() 248 | featureExamples.loadFeature(feature.scan, feature.featureIndex) 249 | }, 250) 250 | } 251 | 252 | sel 253 | .on('mousemove.feature-examples-tooltip', mousemove) 254 | .on('mouseleave.feature-examples-tooltip', mouseout) 255 | .on('mouseenter.feature-examples-tooltip', function(ev, d){ 256 | setTimeout(mousemove, 0) 257 | 258 | // skip moving if we're just bouncing in and out of the current feature 259 | if (featureExamplesTooltipSel.cur == d && !featureExamplesTooltipSel.classed('tooltip-hidden')) return 260 | featureExamplesTooltipSel.cur = d 261 | 262 | featureExamplesTooltipSel.node().scrollTop = -200 263 | 264 | featureExamplesTooltipSel.isFaded = false 265 | featureExamplesTooltipSel.classed('tooltip-hidden', 0) 266 | 267 | // requires either featureIndex or featureIndices 268 | var {scan, featureIndex, featureIndices} = getFeatureParams(d) 269 | featureIndices = featureIndices ?? [featureIndex] 270 | featureIndex = featureIndex ?? featureIndices[0] 271 | 272 | var buttonSel = featureExamplesTooltipSel.select('.feature-nav').html('') 273 | .appendMany('div.button', featureIndices) 274 | .text((_, i) => 'Feature ' + (i + 1)) 275 | .classed('active', idx => idx == featureIndex) 276 | .on('click', (ev, idx) => { 277 | featureExamples.renderFeature(scan, idx) 278 | buttonSel.classed('active', idx2 => idx2 == idx) 279 | }) 280 | 281 | featureExamples.renderFeature(scan, featureIndex) 282 | 283 | d3.selectAll('.feature-examples-tooltipped').classed('feature-examples-tooltipped', 0) 284 | d3.select(this).classed('feature-examples-tooltipped', 1) 285 | 286 | var snBB = this.getBoundingClientRect() 287 | var ttBB = featureExamplesTooltipSel.node().getBoundingClientRect() 288 | var left = d3.clamp(20, (ev.clientX-ttBB.width/2), window.innerWidth - ttBB.width - 20) 289 | var top = snBB.top > innerHeight - snBB.bottom ? 290 | snBB.top - ttBB.height - 10 : 291 | snBB.bottom + 10 292 | featureExamplesTooltipSel.st({left, top, pointerEvents: 'all'}) 293 | 294 | getNearby?.(d).forEach(e => featureQueue.push(e)) 295 | }) 296 | 297 | function mousemove(){ 298 | if (window.__ttfade) window.__ttfade.stop() 299 | featureExamplesTooltipSel.isFading = false 300 | featureExamplesTooltipSel.isFaded = false 301 | } 302 | 303 | function mouseout(){ 304 | if (featureExamplesTooltipSel.isFading) return 305 | 306 | if (window.__ttfade) window.__ttfade.stop() 307 | featureExamplesTooltipSel.isFading = true 308 | window.__ttfade = d3.timeout(() => { 309 | featureExamplesTooltipSel.classed('tooltip-hidden', 1).st({pointerEvents: 'none'}) 310 | d3.selectAll('.feature-examples-tooltipped').classed('feature-examples-tooltipped', 0) 311 | featureExamplesTooltipSel.isFading = false 312 | featureExamplesTooltipSel.isFaded = true 313 | }, 250) 314 | } 315 | } 316 | 317 | async function initGraphSelect(sel, cgSlug){ 318 | var {graphs} = await util.getFile('/data/graph-metadata.json') 319 | 320 | var selectSel = sel.html('').append('select.graph-prompt-select') 321 | .on('change', function() { 322 | cgSlug = this.value 323 | // visState.clickedId = undefined 324 | util.params.set('slug', this.value) 325 | render() 326 | }) 327 | 328 | var cgSel = sel.append('div.cg-container') 329 | 330 | selectSel.appendMany('option', graphs) 331 | .text(d => { 332 | var scanName = util.nameToPrettyPrint[d.scan] || d.scan 333 | var prefix = d.title_prefix ? d.title_prefix + ' ' : '' 334 | return prefix + scanName + ' — ' + d.prompt 335 | }) 336 | .attr('value', d => d.slug) 337 | .property('selected', d => d.slug === cgSlug) 338 | 339 | function render() { 340 | initCg(cgSel.html(''), cgSlug, { 341 | isModal: true, 342 | // clickedId: visState.clickedId, 343 | // clickedIdCb: id => util.params.set('clickedId', id) 344 | }) 345 | 346 | var m = graphs.find(g => g.slug == cgSlug) 347 | if (!m) return 348 | selectSel.at({title: m.prompt}) 349 | } 350 | render() 351 | } 352 | 353 | function attachCgLinkEvents(sel, cgSlug, figmaSlug){ 354 | sel 355 | .on('mouseover', () => util.getFile(`/graph_data/${cgSlug}.json`)) 356 | .on('click', (ev) => { 357 | ev.preventDefault() 358 | 359 | if (window.innerWidth < 900 || window.innerHeight < 500) { 360 | return window.open(`./static_js/attribution_graphs/index.html?slug=${cgSlug}`, '_blank') 361 | } 362 | 363 | d3.select('body').classed('modal-open', true) 364 | var contentSel = d3.select('modal').classed('is-active', 1) 365 | .select('.modal-content').html('') 366 | 367 | util.initGraphSelect(contentSel, cgSlug) 368 | 369 | util.params.set('slug', cgSlug) 370 | if (figmaSlug) history.replaceState(null, '', '#' + figmaSlug) 371 | }) 372 | } 373 | 374 | // TODO: tidy 375 | function ppToken(d){ 376 | return d 377 | } 378 | 379 | function ppClerp(d){ 380 | return d 381 | } 382 | 383 | 384 | var scanSlugToName = { 385 | 'h35': 'jackl-circuits-runs-1-4-sofa-v3_0', 386 | '18l': 'jackl-circuits-runs-1-1-druid-cp_0', 387 | 'moc': 'jackl-circuits-runs-12-19-valet-m_0' 388 | } 389 | 390 | var nameToPrettyPrint = { 391 | 'jackl-circuits-runs-1-4-sofa-v3_0': 'Haiku', 392 | 'jackl-circuits-runs-1-1-druid-cp_0': '18L', 393 | 'jackl-circuits-runs-12-19-valet-m_0': 'Model Organism', 394 | 'jackl-circuits-runs-1-12-rune-cp3_0': '18L PLTs', 395 | } 396 | 397 | 398 | return { 399 | scanSlugToName, 400 | nameToPrettyPrint, 401 | params, 402 | getFile, 403 | addAxisLabel, 404 | ggPlot, 405 | ggPlotUpdate, 406 | initRenderAll, 407 | attachRenderAllHistory, 408 | throttle, 409 | debounce, 410 | throttleDebounce, 411 | sleep, 412 | cache, 413 | initGraphSelect, 414 | attachCgLinkEvents, 415 | ppToken, 416 | ppClerp, 417 | attachFeatureExamplesTooltip, 418 | } 419 | })() 420 | 421 | window.init?.() 422 | --------------------------------------------------------------------------------