├── LICENSE ├── README.md ├── bgrep4.png ├── build ├── gpapp.js ├── gputils.js └── slider.js ├── distancematrix.js ├── gp.css ├── halfmoon9.ico ├── index.html ├── site.css └── src ├── gpapp.js ├── gputils.js └── slider.js /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tomi Peltola 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Gaussian process regression demo in Javascript 2 | ============================================== 3 | 4 | This repository contains code for a Gaussian process regression demonstration in Javascript. See http://www.tmpl.fi/gp/ for the live version. 5 | 6 | The code depends on the following javascript libraries, which are not included in this repository: 7 | 8 | * [React](http://facebook.github.io/react/) 9 | * [D3.js](http://d3js.org) 10 | * [Numeric.js](http://www.numericjs.com/) 11 | 12 | 13 | Simulation of continuous trajectories 14 | ------------------------------------- 15 | 16 | Continuous trajectories are simulated using Hamiltonian Monte Carlo (HMC) with partial momentum refreshment and analytically solved dynamics for the Gaussian posterior distribution. 17 | 18 | For an excellent HMC reference, see: Radford M. Neal, _MCMC using Hamiltonian dynamics_. [arXiv:1206.1901](http://arxiv.org/abs/1206.1901), 2012. 19 | 20 | Contact 21 | ------- 22 | 23 | E-mail: tomi.peltola@aalto.fi 24 | 25 | License 26 | ------- 27 | 28 | MIT. See `LICENSE`. 29 | -------------------------------------------------------------------------------- /bgrep4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/to-mi/gp-demo-js/21fac4e7d24847a0d347e8f3621f674e37d715a1/bgrep4.png -------------------------------------------------------------------------------- /build/gpapp.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | 3 | var GPApp = React.createClass({displayName: 'GPApp', 4 | getInitialState: function() { 5 | return { GPs: [new GP(0, [1,0.2], 1, [], [], [])], 6 | newGPParam: 1.0, 7 | newGPNoise: 0.2, 8 | newGPcf: 0, 9 | newGPavailableIDs: [10, 9, 8, 7, 6, 5, 4, 3, 2], 10 | alfa: 0.3, 11 | stepSize: 3.14, 12 | NSteps: 15, 13 | addTrPoints: false, 14 | trPointsX: [], 15 | trPointsY: [], 16 | dmTr: [], 17 | dmTeTr: [], 18 | samplingState: 0, // 0 = stopped, 1 = discrete, 2 = continuous 19 | oldSamplingState: 0, 20 | showSamples: true, 21 | showMeanAndVar: false 22 | } 23 | }, 24 | setAlfa: function(newVal) { this.setState({ alfa: newVal }); }, 25 | setStepSize: function(newVal) { this.setState({ stepSize: newVal }); }, 26 | setNSteps: function(newVal) { this.setState({ NSteps: newVal }); }, 27 | toggleAddTrPoints: function() { 28 | if (this.state.addTrPoints){ 29 | // added training points 30 | var dmTr = computeDistanceMatrix(this.state.trPointsX, this.state.trPointsX); 31 | var dmTeTr = computeDistanceMatrix(tePointsX, this.state.trPointsX); 32 | 33 | var newGPs = recomputeProjections(this.state.GPs, dmTr, dmTeTr, this.state.trPointsY); 34 | 35 | this.setState({ addTrPoints: !this.state.addTrPoints, GPs: newGPs, dmTr: dmTr, dmTeTr: dmTeTr, samplingState: this.state.oldSamplingState }); 36 | } else { 37 | // beginning to add training points 38 | this.setState({ addTrPoints: !this.state.addTrPoints, oldSamplingState: this.state.samplingState, samplingState: 0 }); 39 | } 40 | }, 41 | clearTrPoints: function() { this.setState({ trPointsX: [], trPointsY: [] }); }, 42 | toggleShowMeanAndVar: function() { if (!this.state.addTrPoints) this.setState({ showMeanAndVar: !this.state.showMeanAndVar }); }, 43 | toggleShowSamples: function() { 44 | if (!this.state.addTrPoints) { 45 | if (this.state.showSamples) { 46 | this.setState({ samplingState: 0, showSamples: false }); 47 | } else { 48 | this.setState({ samplingState: this.state.oldSamplingState, showSamples: true }); 49 | } 50 | } 51 | }, 52 | setNewGPParam: function(newVal) { this.setState({ newGPParam: newVal }); }, 53 | setNewGPNoise: function(newVal) { this.setState({ newGPNoise: newVal }); }, 54 | setNewGPcf: function(event) { this.setState({ newGPcf: event.target.value }); }, 55 | addGP: function() { 56 | if (this.state.newGPavailableIDs.length < 1) return; 57 | var id = this.state.newGPavailableIDs.pop(); 58 | var newGPs = this.state.GPs.concat([new GP(this.state.newGPcf, [this.state.newGPParam, this.state.newGPNoise], id, this.state.dmTr, this.state.dmTeTr, this.state.trPointsY)]); 59 | this.setState({ GPs: newGPs, newGPavailableIDs: this.state.newGPavailableIDs }); 60 | }, 61 | delGP: function(id) { 62 | return (function() { 63 | var newGPs = this.state.GPs; 64 | var delIdx = newGPs.findIndex(function (g) { return g.id == id; }); 65 | if (delIdx >= 0) { 66 | newGPs.splice(delIdx, 1); 67 | this.state.newGPavailableIDs.push(id); 68 | this.setState({ GPs: newGPs }); 69 | } 70 | }).bind(this); 71 | }, 72 | addTrPoint: function(x, y) { 73 | if (x >= -5 && x <= 5 && y >= -3 && y <= 3){ 74 | var newTrPointsX = this.state.trPointsX.concat([x]); 75 | var newTrPointsY = this.state.trPointsY.concat([y]); 76 | this.setState({ trPointsX: newTrPointsX, trPointsY: newTrPointsY }); 77 | } 78 | }, 79 | stopSampling: function() { this.setState({ samplingState: 0, oldSamplingState: 0 }); }, 80 | startDiscreteSampling: function() { this.setState({ samplingState: 1, oldSamplingState: 1 }); }, 81 | startContinuousSampling: function() { this.setState({ samplingState: 2, oldSamplingState: 2 }); }, 82 | render: function() { 83 | var sliderOptAlfa = { width: 200, height: 9, min: 0, max: 1 }; 84 | var sliderOptStepSize = { width: 200, height: 9, min: 0, max: 2*Math.PI }; 85 | var sliderOptNSteps = { width: 200, height: 9, min: 1, max: 100, step: 1 }; 86 | var sliderOptGPParam = { width: 200, height: 9, min: 0.01, max: 5 }; 87 | var sliderOptGPNoise = { width: 200, height: 9, min: 0, max: 2 }; 88 | var delGP = this.delGP; 89 | var gpoptions = cfs.map(function (c) { 90 | return (React.DOM.option( {value:c.id}, c.name)); 91 | }); 92 | return ( 93 | React.DOM.div( {id:"gp"}, 94 | GPAxis( {state:this.state, addTrPoint:this.addTrPoint} ), 95 | React.DOM.div( {id:"controls"}, 96 | React.DOM.input( {type:"checkbox", checked:this.state.showMeanAndVar, onChange:this.toggleShowMeanAndVar} ),"Show mean and credible intervals"+' '+ 97 | " ",React.DOM.input( {type:"checkbox", checked:this.state.showSamples, onChange:this.toggleShowSamples} ),"Show samples",React.DOM.br(null ), 98 | React.DOM.button( {onClick:this.startDiscreteSampling, disabled:this.state.samplingState === 1 || this.state.addTrPoints || !this.state.showSamples}, "sample independently"), 99 | React.DOM.button( {onClick:this.startContinuousSampling, disabled:this.state.samplingState === 2 || this.state.addTrPoints || !this.state.showSamples}, "sample continuous trajectories"), 100 | React.DOM.button( {onClick:this.stopSampling, disabled:this.state.samplingState === 0 || this.state.addTrPoints}, "stop sampling"), 101 | React.DOM.br(null ), 102 | this.state.addTrPoints ? React.DOM.span( {className:"info"}, "click on the figure to add an observation " ) : '', 103 | React.DOM.button( {onClick:this.toggleAddTrPoints}, this.state.addTrPoints ? "done" : "add observations"), 104 | this.state.addTrPoints ? React.DOM.button( {onClick:this.clearTrPoints}, "clear") : '' 105 | ), 106 | React.DOM.div( {id:"opts"}, 107 | React.DOM.h2(null, "Trajectory simulation settings"), 108 | React.DOM.table(null, 109 | React.DOM.tr(null, React.DOM.td(null, "Momentum refreshment"),React.DOM.td(null, Slider( {value:this.state.alfa, setValue:this.setAlfa, opt:sliderOptAlfa} ), " ", this.state.alfa.toFixed(2))), 110 | React.DOM.tr(null, React.DOM.td(null, "Path length"),React.DOM.td(null, Slider( {value:this.state.stepSize, setValue:this.setStepSize, opt:sliderOptStepSize} ), " ", this.state.stepSize.toFixed(2))), 111 | React.DOM.tr(null, React.DOM.td(null, "Number of steps in path"),React.DOM.td(null, Slider( {value:this.state.NSteps, setValue:this.setNSteps, opt:sliderOptNSteps} ), " ", this.state.NSteps)) 112 | ) 113 | ), 114 | React.DOM.div( {id:"gplist"}, 115 | React.DOM.div( {id:"addgp"}, 116 | React.DOM.h2(null, "Add new process"), 117 | React.DOM.table(null, 118 | React.DOM.tr(null, React.DOM.td(null, "Covariance function"),React.DOM.td(null, React.DOM.select( {value:this.state.newGPcf, onChange:this.setNewGPcf}, gpoptions))), 119 | React.DOM.tr(null, React.DOM.td(null, "Length scale"),React.DOM.td(null, Slider( {value:this.state.newGPParam, setValue:this.setNewGPParam, opt:sliderOptGPParam} ), " ", this.state.newGPParam.toFixed(2))), 120 | React.DOM.tr(null, React.DOM.td(null, "Noise"),React.DOM.td(null, Slider( {value:this.state.newGPNoise, setValue:this.setNewGPNoise, opt:sliderOptGPNoise} ), " ", this.state.newGPNoise.toFixed(2))) 121 | ), 122 | React.DOM.button( {onClick:this.addGP, disabled:this.state.newGPavailableIDs.length <= 0}, "add") 123 | ), 124 | React.DOM.h2(null, "Process list"), 125 | React.DOM.table(null, 126 | React.DOM.thead(null, 127 | React.DOM.tr(null, React.DOM.th(null, "id"),React.DOM.th(null, "covariance"),React.DOM.th(null, "length scale"),React.DOM.th(null, "noise"),React.DOM.th(null)) 128 | ), 129 | GPList( {GPs:this.state.GPs, delGP:this.delGP} ) 130 | ) 131 | ) 132 | ) 133 | ) 134 | } 135 | }); 136 | 137 | React.renderComponent( 138 | GPApp(null ), 139 | document.getElementById('gp-outer') 140 | ); 141 | -------------------------------------------------------------------------------- /build/gputils.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | var tePointsX = numeric.linspace(-5, 5, numeric.dim(distmatTe)[0]); 3 | var randn = d3.random.normal(); 4 | function randnArray(size){ 5 | var zs = new Array(size); 6 | for (var i = 0; i < size; i++) { 7 | zs[i] = randn(); 8 | } 9 | return zs; 10 | } 11 | 12 | // ids must be in order of the array 13 | var cfs = [ 14 | {'id': 0, 15 | 'name': 'Exponentiated quadratic', 16 | 'f': function(r, params) { 17 | return numeric.exp(numeric.mul(-0.5 / (params[0] * params[0]), numeric.pow(r, 2))); 18 | } 19 | }, 20 | {'id': 1, 21 | 'name': 'Exponential', 22 | 'f': function(r, params) { 23 | return numeric.exp(numeric.mul(-0.5 / params[0], r)); 24 | } 25 | }, 26 | {'id': 2, 27 | 'name': 'Matern 3/2', 28 | 'f': function(r, params) { 29 | var tmp = numeric.mul(Math.sqrt(3.0) / params[0], r); 30 | return numeric.mul(numeric.add(1.0, tmp), numeric.exp(numeric.neg(tmp))); 31 | } 32 | }, 33 | {'id': 3, 34 | 'name': 'Matern 5/2', 35 | 'f': function(r, params) { 36 | var tmp = numeric.mul(Math.sqrt(5.0) / params[0], r); 37 | var tmp2 = numeric.div(numeric.mul(tmp, tmp), 3.0); 38 | return numeric.mul(numeric.add(numeric.add(1, tmp), tmp2), numeric.exp(numeric.neg(tmp))); 39 | } 40 | }, 41 | {'id': 4, 42 | 'name': 'Rational quadratic (alpha=1)', 43 | 'f': function(r, params) { 44 | return numeric.pow(numeric.add(1.0, numeric.div(numeric.pow(r, 2), 2.0 * params[0] * params[0])), -1); 45 | } 46 | }, 47 | {'id': 5, 48 | 'name': 'Piecewise polynomial (q=0)', 49 | 'f': function(r, params) { 50 | var tmp = numeric.sub(1.0, numeric.div(r, params[0])); 51 | var dims = numeric.dim(tmp); 52 | for (var i = 0; i < dims[0]; i++){ 53 | for (var j = 0; j < dims[1]; j++){ 54 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0; 55 | } 56 | } 57 | return tmp; 58 | } 59 | }, 60 | {'id': 6, 61 | 'name': 'Piecewise polynomial (q=1)', 62 | 'f': function(r, params) { 63 | var tmp1 = numeric.div(r, params[0]); 64 | var tmp = numeric.sub(1.0, tmp1); 65 | var dims = numeric.dim(tmp); 66 | for (var i = 0; i < dims[0]; i++){ 67 | for (var j = 0; j < dims[1]; j++){ 68 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0; 69 | } 70 | } 71 | return numeric.mul(numeric.pow(tmp, 3), numeric.add(numeric.mul(3.0, tmp1), 1.0)); 72 | } 73 | }, 74 | {'id': 7, 75 | 'name': 'Periodic (period=pi)', 76 | 'f': function(r, params) { 77 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(r), 2))); 78 | } 79 | }, 80 | {'id': 8, 81 | 'name': 'Periodic (period=1)', 82 | 'f': function(r, params) { 83 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(numeric.mul(Math.PI, r)), 2))); 84 | } 85 | } 86 | ]; 87 | 88 | function GP(cf, params, id, dmTr, dmTeTr, trY) { 89 | var M = numeric.dim(distmatTe)[1]; 90 | 91 | this.z = randnArray(M); 92 | this.p = randnArray(M); 93 | this.cf = cf; 94 | this.params = params; 95 | this.id = id; 96 | 97 | this.Kte = cfs[this.cf].f(distmatTe, params); 98 | 99 | var tmp = computeProjection(this.Kte, this.cf, this.params, dmTr, dmTeTr, trY); 100 | this.proj = tmp.proj; 101 | this.mu = tmp.mu; 102 | this.sd95 = tmp.sd95; 103 | } 104 | 105 | 106 | function computeProjection(Kte, cf, params, dmTr, dmTeTr, trY) { 107 | var Mtr = numeric.dim(dmTr)[0]; 108 | var Mte = numeric.dim(distmatTe)[0]; 109 | 110 | if (Mtr > 0){ 111 | var Kxx_p_noise = cfs[cf].f(dmTr, params); 112 | for (var i = 0; i < Mtr; i++){ 113 | Kxx_p_noise[i][i] += params[1]; 114 | } 115 | 116 | var svd1 = numeric.svd(Kxx_p_noise); 117 | for (var i = 0; i < Mtr; i++){ 118 | if (svd1.S[i] > numeric.epsilon){ 119 | svd1.S[i] = 1.0/svd1.S[i]; 120 | } else { 121 | svd1.S[i] = 0.0; 122 | } 123 | } 124 | 125 | var tmp = numeric.dot(cfs[cf].f(dmTeTr, params), svd1.U); 126 | // there seems to be a bug in numeric.svd: svd1.U and transpose(svd1.V) are not always equal for a symmetric matrix 127 | var mu = numeric.dot(tmp, numeric.mul(svd1.S, numeric.dot(numeric.transpose(svd1.U), trY))); 128 | var cov = numeric.dot(tmp, numeric.diag(numeric.sqrt(svd1.S))); 129 | cov = numeric.dot(cov, numeric.transpose(cov)); 130 | cov = numeric.sub(Kte, cov); 131 | var svd2 = numeric.svd(cov); 132 | for (var i = 0; i < Mte; i++){ 133 | if (svd2.S[i] < numeric.epsilon){ 134 | svd2.S[i] = 0.0; 135 | } 136 | } 137 | var proj = numeric.dot(svd2.U, numeric.diag(numeric.sqrt(svd2.S))); 138 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(numeric.dot(proj, numeric.transpose(proj))))); 139 | } else { 140 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(Kte))); 141 | var svd = numeric.svd(Kte); 142 | var proj = numeric.dot(svd.U, numeric.diag(numeric.sqrt(svd.S))); 143 | var mu = numeric.rep([Mte], 0); 144 | } 145 | 146 | return { proj: proj, mu: mu, sd95: sd95 }; 147 | } 148 | 149 | function recomputeProjections(GPs, dmTr, dmTeTr, trY) { 150 | for (var gpi = 0; gpi < GPs.length; gpi++){ 151 | var gp = GPs[gpi]; 152 | var tmp = computeProjection(gp.Kte, gp.cf, gp.params, dmTr, dmTeTr, trY); 153 | gp.proj = tmp.proj; 154 | gp.mu = tmp.mu; 155 | gp.sd95 = tmp.sd95; 156 | GPs[gpi] = gp; 157 | } 158 | 159 | return GPs; 160 | } 161 | 162 | function computeDistanceMatrix(xdata1, xdata2) { 163 | var dm = numeric.rep([xdata1.length,xdata2.length], 0); 164 | for (var i = 0; i < xdata1.length; i++){ 165 | for (var j = 0; j < xdata2.length; j++){ 166 | var val = Math.abs(xdata2[j] - xdata1[i]); 167 | dm[i][j] = val; 168 | } 169 | } 170 | return dm; 171 | } 172 | 173 | var GPAxis = React.createClass({displayName: 'GPAxis', 174 | render: function() { 175 | return (React.DOM.svg(null)); 176 | }, 177 | shouldComponentUpdate: function() { return false; }, 178 | drawTrPoints: function(pointsX, pointsY) { 179 | var x = this.scales.x; 180 | var y = this.scales.y; 181 | var p = this.trPoints.selectAll("circle.trpoints") 182 | .data(d3.zip(pointsX, pointsY)) 183 | .attr("cx", function(d) { return x(d[0]); }) 184 | .attr("cy", function(d) { return y(d[1]); }); 185 | p.enter().append("circle") 186 | .attr("class", "trpoints") 187 | .attr("r", 2) 188 | .attr("cx", function(d) { return x(d[0]); }) 189 | .attr("cy", function(d) { return y(d[1]); }); 190 | p.exit().remove(); 191 | }, 192 | animationId: 0, 193 | componentWillReceiveProps: function(props) { 194 | // bind events 195 | if (props.state.addTrPoints) { 196 | d3.select(this.getDOMNode()).on("click", this.addTrPoint); 197 | } else { 198 | d3.select(this.getDOMNode()).on("click", null); 199 | } 200 | // redraw training points 201 | this.drawTrPoints(props.state.trPointsX, props.state.trPointsY); 202 | 203 | this.drawMeanAndVar(props); 204 | 205 | if (this.props.state.showSamples !== props.state.showSamples){ 206 | this.drawPaths(props); 207 | } 208 | 209 | if (this.props.state.samplingState !== props.state.samplingState){ 210 | clearInterval(this.animationId); 211 | if (props.state.samplingState === 1){ 212 | this.animationId = setInterval((function() { this.updateState(); this.drawPaths(); }).bind(this), 500); 213 | } else if (props.state.samplingState === 2){ 214 | this.animationId = setInterval((function() { this.contUpdateState(); this.drawPaths(); }).bind(this), 50); 215 | } 216 | } 217 | }, 218 | addTrPoint: function() { 219 | var mousePos = d3.mouse(this.getDOMNode()); 220 | var x = this.scales.x; 221 | var y = this.scales.y; 222 | 223 | // x is transformed to a point on a grid of 200 points between -5 and 5 224 | this.props.addTrPoint(Math.round((x.invert(mousePos[0]-50)+5)/10*199)/199*10-5, y.invert(mousePos[1]-50)); 225 | }, 226 | updateState: function() { 227 | var M = numeric.dim(distmatTe)[1]; 228 | for (var i = 0; i < this.props.state.GPs.length; i++){ 229 | var gp = this.props.state.GPs[i]; 230 | gp.z = randnArray(M); 231 | } 232 | }, 233 | stepState: 0, 234 | contUpdateState: function() { 235 | var M = numeric.dim(distmatTe)[1]; 236 | var alfa = 1.0-this.props.state.alfa; 237 | var n_steps = this.props.state.NSteps; 238 | var t_step = this.props.state.stepSize / n_steps; 239 | this.stepState = this.stepState % n_steps; 240 | 241 | for (var i = 0; i < this.props.state.GPs.length; i++){ 242 | var gp = this.props.state.GPs[i]; 243 | 244 | // refresh momentum: p = alfa * p + sqrt(1 - alfa^2) * randn(size(p)) 245 | if (this.stepState == (n_steps-1)) 246 | gp.p = numeric.add(numeric.mul(alfa, gp.p), numeric.mul(Math.sqrt(1 - alfa*alfa), randnArray(M))); 247 | 248 | var a = gp.p.slice(0), 249 | b = gp.z.slice(0), 250 | c = numeric.mul(-1, gp.z.slice(0)), 251 | d = gp.p.slice(0); 252 | 253 | gp.z = numeric.add(numeric.mul(a, Math.sin(t_step)), numeric.mul(b, Math.cos(t_step))); 254 | gp.p = numeric.add(numeric.mul(c, Math.sin(t_step)), numeric.mul(d, Math.cos(t_step))); 255 | } 256 | this.stepState = this.stepState + 1; 257 | }, 258 | drawMeanAndVar: function(props) { 259 | var gpline = this.gpline; 260 | if (props.state.showMeanAndVar){ 261 | var gps = props.state.GPs; 262 | } else { 263 | var gps = []; 264 | } 265 | 266 | var paths = this.meanLines.selectAll("path").data(gps, function (d) { return d.id; }) 267 | .attr("d", function (d) { 268 | var datay = d.mu; 269 | return gpline(d3.zip(tePointsX, datay)); 270 | }); 271 | paths.enter().append("path").attr("d", function (d) { 272 | var datay = d.mu; 273 | return gpline(d3.zip(tePointsX, datay)); 274 | }) 275 | .attr("class", function(d) { 276 | return "muline line line"+d.id; 277 | }); 278 | paths.exit().remove(); 279 | 280 | var pathsUp = this.upSd95Lines.selectAll("path").data(gps, function (d) { return d.id; }) 281 | .attr("d", function (d) { 282 | var datay = numeric.add(d.mu, d.sd95); 283 | return gpline(d3.zip(tePointsX, datay)); 284 | }); 285 | pathsUp.enter().append("path").attr("d", function (d) { 286 | var datay = numeric.add(d.mu, d.sd95); 287 | return gpline(d3.zip(tePointsX, datay)); 288 | }) 289 | .attr("class", function(d) { 290 | return "sdline line line"+d.id; 291 | }); 292 | pathsUp.exit().remove(); 293 | 294 | var pathsDown = this.downSd95Lines.selectAll("path").data(gps, function (d) { return d.id; }) 295 | .attr("d", function (d) { 296 | var datay = numeric.sub(d.mu, d.sd95); 297 | return gpline(d3.zip(tePointsX, datay)); 298 | }); 299 | pathsDown.enter().append("path").attr("d", function (d) { 300 | var datay = numeric.sub(d.mu, d.sd95); 301 | return gpline(d3.zip(tePointsX, datay)); 302 | }) 303 | .attr("class", function(d) { 304 | return "sdline line line"+d.id; 305 | }); 306 | pathsDown.exit().remove(); 307 | }, 308 | drawPaths: function(props) { 309 | if (!props) var props = this.props; 310 | var gpline = this.gpline; 311 | if (props.state.showSamples){ 312 | var gps = props.state.GPs; 313 | } else { 314 | var gps = []; 315 | } 316 | var paths = this.lines.selectAll("path").data(gps, function (d) { return d.id; }) 317 | .attr("d", function (d) { 318 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu); 319 | return gpline(d3.zip(tePointsX, datay)); 320 | }); 321 | paths.enter().append("path").attr("d", function (d) { 322 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu); 323 | return gpline(d3.zip(tePointsX, datay)); 324 | }) 325 | .attr("class", function(d) { 326 | return "line line"+d.id; 327 | }); 328 | paths.exit().remove(); 329 | }, 330 | scales: { x: null, y: null }, 331 | componentDidMount: function() { 332 | var svg = d3.select(this.getDOMNode()); 333 | var height = svg.attr("height"), 334 | width = svg.attr("width"); 335 | if (!height) { 336 | height = 300; 337 | svg.attr("height", height); 338 | } 339 | if (!width) { 340 | width = 500; 341 | svg.attr("width", width); 342 | } 343 | var margin = 50; 344 | svg = svg.append("g") 345 | .attr("transform", "translate("+margin+","+margin+")"); 346 | this.svg = svg; 347 | var fig_height = height - 2*margin, 348 | fig_width = width - 2*margin; 349 | 350 | // helper functions 351 | var x = d3.scale.linear().range([0, fig_width]).domain([-5, 5]); 352 | var y = d3.scale.linear().range([fig_height, 0]).domain([-3, 3]); 353 | this.scales.x = x; 354 | this.scales.y = y; 355 | var xAxis = d3.svg.axis() 356 | .scale(x) 357 | .orient("bottom"); 358 | var yAxis = d3.svg.axis() 359 | .scale(y) 360 | .orient("left"); 361 | this.gpline = d3.svg.line() 362 | .x(function(d) { return x(d[0]); }) 363 | .y(function(d) { return y(d[1]); }); 364 | 365 | // axes 366 | svg.append("g") 367 | .attr("class", "x axis") 368 | .attr("transform", "translate(0,"+fig_height+")") 369 | .call(xAxis); 370 | 371 | svg.append("g") 372 | .attr("class", "y axis") 373 | .call(yAxis); 374 | 375 | this.meanLines = svg.append("g"); 376 | this.upSd95Lines = svg.append("g"); 377 | this.downSd95Lines = svg.append("g"); 378 | this.lines = svg.append("g"); 379 | this.trPoints = svg.append("g"); 380 | this.drawTrPoints(this.props.state.trPointsX, this.props.state.trPointsY); 381 | this.drawPaths(); 382 | } 383 | }); 384 | 385 | 386 | var GPList = React.createClass({displayName: 'GPList', 387 | render: function() { 388 | var delGP = this.props.delGP; 389 | var gplist = this.props.GPs.map(function (gp) { 390 | return (React.DOM.tr( {key:gp.id}, 391 | React.DOM.td( {className:"tr"+gp.id}, gp.id),React.DOM.td(null, cfs[gp.cf].name),React.DOM.td(null, gp.params[0].toFixed(2)),React.DOM.td(null, gp.params[1].toFixed(2)),React.DOM.td(null, React.DOM.button( {onClick:delGP(gp.id)}, "remove")) 392 | )); 393 | }); 394 | return (React.DOM.tbody(null, gplist)); 395 | } 396 | }); 397 | -------------------------------------------------------------------------------- /build/slider.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | var Slider = React.createClass({displayName: 'Slider', 3 | render: function() { 4 | // Just insert the svg-element and render rest in componentDidMount. 5 | // Marker location is updated in componentWillReceiveProps using d3. 6 | return (React.DOM.svg(null)); 7 | }, 8 | shouldComponentUpdate: function() { return false; }, // Never re-render. 9 | componentDidMount: function() { 10 | var val = this.props.value; 11 | var setVal = this.props.setValue; 12 | var opt = this.props.opt; 13 | 14 | // set defaults for options if not given 15 | if (!opt.width) opt.width = 200; 16 | if (!opt.height) opt.height = 20; 17 | if (!opt.step) { 18 | opt.round = function(x) { return x; } 19 | } else { 20 | opt.round = function(x) { return Math.round(x / opt.step) * opt.step; } 21 | } 22 | if (!opt.min) opt.min = opt.round(0); 23 | if (!opt.max) opt.max = opt.round(1); 24 | if (opt.min > opt.max) { 25 | var tmp = opt.min; 26 | opt.min = opt.max; 27 | opt.max = tmp; 28 | } 29 | if (val > opt.max) setVal(opt.max); 30 | if (val < opt.min) setVal(opt.min); 31 | 32 | // calculate range 33 | var markerRadius = opt.height * 0.5; 34 | var x1 = markerRadius; 35 | var x2 = opt.width - markerRadius; 36 | 37 | // d3 helpers 38 | var scale = d3.scale.linear() 39 | .domain([opt.min, opt.max]) 40 | .range([x1, x2]); 41 | this.scale = scale; 42 | var setValFromMousePos = function(x) { 43 | setVal(opt.round(scale.invert(Math.max(x1, Math.min(x2, x))))); 44 | }; 45 | var dragmove = function() { 46 | setValFromMousePos(d3.event.x); 47 | }; 48 | var drag = d3.behavior.drag().on("drag", dragmove); 49 | 50 | // bind d3 events and insert background line and marker 51 | var svg = d3.select(this.getDOMNode()); 52 | svg.attr("class", "slider") 53 | .attr("width", opt.width) 54 | .attr("height", opt.height) 55 | .on("click", function () { setValFromMousePos(d3.mouse(this)[0]); }); 56 | svg.append("line") 57 | .attr("x1", x1) 58 | .attr("x2", x2) 59 | .attr("y1", "50%") 60 | .attr("y2", "50%") 61 | .attr("class", "sliderbg"); 62 | this.marker = svg.append("circle") 63 | .attr("cy", "50%") 64 | .attr("r", markerRadius) 65 | .attr("class", "slidermv") 66 | .datum(val) 67 | .attr("cx", function (d) { return scale(d); }) 68 | .call(drag); 69 | }, 70 | componentWillReceiveProps: function(props) { 71 | // update the marker location on receiving new props 72 | var scale = this.scale; 73 | this.marker.datum(props.value) 74 | .attr("cx", function (d) { return scale(d); }); 75 | } 76 | }); 77 | 78 | -------------------------------------------------------------------------------- /gp.css: -------------------------------------------------------------------------------- 1 | .slider circle.slidermv { 2 | fill: #c00; 3 | stroke: none; 4 | } 5 | .slider line.sliderbg { 6 | stroke: #000; 7 | fill: none; 8 | stroke-width: 2px; 9 | stroke-linecap: round; 10 | } 11 | .axis path, 12 | .axis line { 13 | fill: none; 14 | stroke: #000; 15 | shape-rendering: crispEdges; 16 | stroke-width: 1.5px; 17 | } 18 | .trpoints { 19 | fill: #000; 20 | } 21 | .line { 22 | fill: none; 23 | stroke-width: 1.5px; 24 | } 25 | .muline { 26 | //stroke-dasharray: 10,10; 27 | opacity: 0.5; 28 | } 29 | .sdline { 30 | //stroke-dasharray: 5,10; 31 | opacity: 0.5; 32 | } 33 | .line1 { stroke: #204a87; } 34 | .tr1 { border-left: 24px solid #204a87; } 35 | .line2 { stroke: #ce5c00; } 36 | .tr2 { border-left: 24px solid #ce5c00; } 37 | .line3 { stroke: #ad7fa8; } 38 | .tr3 { border-left: 24px solid #ad7fa8; } 39 | .line4 { stroke: #4e9a06; } 40 | .tr4 { border-left: 24px solid #4e9a06; } 41 | .line5 { stroke: #c4a000; } 42 | .tr5 { border-left: 24px solid #c4a000; } 43 | .line6 { stroke: #a40000; } 44 | .tr6 { border-left: 24px solid #a40000; } 45 | .line7 { stroke: #2e3436; } 46 | .tr7 { border-left: 24px solid #2e3436; } 47 | .line8 { stroke: #e9b96e; } 48 | .tr8 { border-left: 24px solid #e9b96e; } 49 | .line9 { stroke: #8f5902; } 50 | .tr9 { border-left: 24px solid #8f5902; } 51 | .line10 { stroke: #5c3566; } 52 | .tr10 { border-left: 24px solid #5c3566; } 53 | td { 54 | padding: 0 5px; 55 | } 56 | th { 57 | border-bottom: 2px solid #000; 58 | } 59 | button[disabled=disabled], button:disabled { 60 | text-decoration: line-through; 61 | } 62 | input, button, select { 63 | font: inherit; 64 | } 65 | .info { 66 | color: #c00; 67 | } 68 | -------------------------------------------------------------------------------- /halfmoon9.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/to-mi/gp-demo-js/21fac4e7d24847a0d347e8f3621f674e37d715a1/halfmoon9.ico -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Gaussian process regression demo 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |
21 |

home

22 |

Gaussian process regression demo

23 |
24 | 25 |
26 | 27 | 28 | -------------------------------------------------------------------------------- /site.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Lora', serif; 3 | font-size: 14px; 4 | margin: 0; 5 | padding: 0; 6 | background: repeat url('bgrep4.png'); 7 | color: #222; 8 | } 9 | h1 { 10 | font-family: 'Special Elite', cursive; 11 | font-size: 28px; 12 | color: #222; 13 | } 14 | h1#sitetitle { 15 | font-size: 10px; 16 | } 17 | a { 18 | color: #222; 19 | } 20 | a:hover { 21 | text-decoration: none; 22 | } 23 | #main { 24 | margin: auto; 25 | width: 802px; 26 | } 27 | #tmain { 28 | margin: 0; 29 | margin-top: 28px; 30 | 31 | padding-top: 14px; 32 | padding-bottom: 14px; 33 | } 34 | -------------------------------------------------------------------------------- /src/gpapp.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | 3 | var GPApp = React.createClass({ 4 | getInitialState: function() { 5 | return { GPs: [new GP(0, [1,0.2], 1, [], [], [])], 6 | newGPParam: 1.0, 7 | newGPNoise: 0.2, 8 | newGPcf: 0, 9 | newGPavailableIDs: [10, 9, 8, 7, 6, 5, 4, 3, 2], 10 | alfa: 0.3, 11 | stepSize: 3.14, 12 | NSteps: 15, 13 | addTrPoints: false, 14 | trPointsX: [], 15 | trPointsY: [], 16 | dmTr: [], 17 | dmTeTr: [], 18 | samplingState: 0, // 0 = stopped, 1 = discrete, 2 = continuous 19 | oldSamplingState: 0, 20 | showSamples: true, 21 | showMeanAndVar: false 22 | } 23 | }, 24 | setAlfa: function(newVal) { this.setState({ alfa: newVal }); }, 25 | setStepSize: function(newVal) { this.setState({ stepSize: newVal }); }, 26 | setNSteps: function(newVal) { this.setState({ NSteps: newVal }); }, 27 | toggleAddTrPoints: function() { 28 | if (this.state.addTrPoints){ 29 | // added training points 30 | var dmTr = computeDistanceMatrix(this.state.trPointsX, this.state.trPointsX); 31 | var dmTeTr = computeDistanceMatrix(tePointsX, this.state.trPointsX); 32 | 33 | var newGPs = recomputeProjections(this.state.GPs, dmTr, dmTeTr, this.state.trPointsY); 34 | 35 | this.setState({ addTrPoints: !this.state.addTrPoints, GPs: newGPs, dmTr: dmTr, dmTeTr: dmTeTr, samplingState: this.state.oldSamplingState }); 36 | } else { 37 | // beginning to add training points 38 | this.setState({ addTrPoints: !this.state.addTrPoints, oldSamplingState: this.state.samplingState, samplingState: 0 }); 39 | } 40 | }, 41 | clearTrPoints: function() { this.setState({ trPointsX: [], trPointsY: [] }); }, 42 | toggleShowMeanAndVar: function() { if (!this.state.addTrPoints) this.setState({ showMeanAndVar: !this.state.showMeanAndVar }); }, 43 | toggleShowSamples: function() { 44 | if (!this.state.addTrPoints) { 45 | if (this.state.showSamples) { 46 | this.setState({ samplingState: 0, showSamples: false }); 47 | } else { 48 | this.setState({ samplingState: this.state.oldSamplingState, showSamples: true }); 49 | } 50 | } 51 | }, 52 | setNewGPParam: function(newVal) { this.setState({ newGPParam: newVal }); }, 53 | setNewGPNoise: function(newVal) { this.setState({ newGPNoise: newVal }); }, 54 | setNewGPcf: function(event) { this.setState({ newGPcf: event.target.value }); }, 55 | addGP: function() { 56 | if (this.state.newGPavailableIDs.length < 1) return; 57 | var id = this.state.newGPavailableIDs.pop(); 58 | var newGPs = this.state.GPs.concat([new GP(this.state.newGPcf, [this.state.newGPParam, this.state.newGPNoise], id, this.state.dmTr, this.state.dmTeTr, this.state.trPointsY)]); 59 | this.setState({ GPs: newGPs, newGPavailableIDs: this.state.newGPavailableIDs }); 60 | }, 61 | delGP: function(id) { 62 | return (function() { 63 | var newGPs = this.state.GPs; 64 | var delIdx = newGPs.findIndex(function (g) { return g.id == id; }); 65 | if (delIdx >= 0) { 66 | newGPs.splice(delIdx, 1); 67 | this.state.newGPavailableIDs.push(id); 68 | this.setState({ GPs: newGPs }); 69 | } 70 | }).bind(this); 71 | }, 72 | addTrPoint: function(x, y) { 73 | if (x >= -5 && x <= 5 && y >= -3 && y <= 3){ 74 | var newTrPointsX = this.state.trPointsX.concat([x]); 75 | var newTrPointsY = this.state.trPointsY.concat([y]); 76 | this.setState({ trPointsX: newTrPointsX, trPointsY: newTrPointsY }); 77 | } 78 | }, 79 | stopSampling: function() { this.setState({ samplingState: 0, oldSamplingState: 0 }); }, 80 | startDiscreteSampling: function() { this.setState({ samplingState: 1, oldSamplingState: 1 }); }, 81 | startContinuousSampling: function() { this.setState({ samplingState: 2, oldSamplingState: 2 }); }, 82 | render: function() { 83 | var sliderOptAlfa = { width: 200, height: 9, min: 0, max: 1 }; 84 | var sliderOptStepSize = { width: 200, height: 9, min: 0, max: 2*Math.PI }; 85 | var sliderOptNSteps = { width: 200, height: 9, min: 1, max: 100, step: 1 }; 86 | var sliderOptGPParam = { width: 200, height: 9, min: 0.01, max: 5 }; 87 | var sliderOptGPNoise = { width: 200, height: 9, min: 0, max: 2 }; 88 | var delGP = this.delGP; 89 | var gpoptions = cfs.map(function (c) { 90 | return (); 91 | }); 92 | return ( 93 |
94 | 95 |
96 | Show mean and credible intervals 97 |  Show samples
98 | 99 | 100 | 101 |
102 | {this.state.addTrPoints ? click on the figure to add an observation : ''} 103 | 104 | {this.state.addTrPoints ? : ''} 105 |
106 |
107 |

Trajectory simulation settings

108 | 109 | 110 | 111 | 112 |
Momentum refreshment {this.state.alfa.toFixed(2)}
Path length {this.state.stepSize.toFixed(2)}
Number of steps in path {this.state.NSteps}
113 |
114 |
115 |
116 |

Add new process

117 | 118 | 119 | 120 | 121 |
Covariance function
Length scale {this.state.newGPParam.toFixed(2)}
Noise {this.state.newGPNoise.toFixed(2)}
122 | 123 |
124 |

Process list

125 | 126 | 127 | 128 | 129 | 130 |
idcovariancelength scalenoise
131 |
132 |
133 | ) 134 | } 135 | }); 136 | 137 | React.renderComponent( 138 | , 139 | document.getElementById('gp-outer') 140 | ); 141 | -------------------------------------------------------------------------------- /src/gputils.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | var tePointsX = numeric.linspace(-5, 5, numeric.dim(distmatTe)[0]); 3 | var randn = d3.random.normal(); 4 | function randnArray(size){ 5 | var zs = new Array(size); 6 | for (var i = 0; i < size; i++) { 7 | zs[i] = randn(); 8 | } 9 | return zs; 10 | } 11 | 12 | // ids must be in order of the array 13 | var cfs = [ 14 | {'id': 0, 15 | 'name': 'Exponentiated quadratic', 16 | 'f': function(r, params) { 17 | return numeric.exp(numeric.mul(-0.5 / (params[0] * params[0]), numeric.pow(r, 2))); 18 | } 19 | }, 20 | {'id': 1, 21 | 'name': 'Exponential', 22 | 'f': function(r, params) { 23 | return numeric.exp(numeric.mul(-0.5 / params[0], r)); 24 | } 25 | }, 26 | {'id': 2, 27 | 'name': 'Matern 3/2', 28 | 'f': function(r, params) { 29 | var tmp = numeric.mul(Math.sqrt(3.0) / params[0], r); 30 | return numeric.mul(numeric.add(1.0, tmp), numeric.exp(numeric.neg(tmp))); 31 | } 32 | }, 33 | {'id': 3, 34 | 'name': 'Matern 5/2', 35 | 'f': function(r, params) { 36 | var tmp = numeric.mul(Math.sqrt(5.0) / params[0], r); 37 | var tmp2 = numeric.div(numeric.mul(tmp, tmp), 3.0); 38 | return numeric.mul(numeric.add(numeric.add(1, tmp), tmp2), numeric.exp(numeric.neg(tmp))); 39 | } 40 | }, 41 | {'id': 4, 42 | 'name': 'Rational quadratic (alpha=1)', 43 | 'f': function(r, params) { 44 | return numeric.pow(numeric.add(1.0, numeric.div(numeric.pow(r, 2), 2.0 * params[0] * params[0])), -1); 45 | } 46 | }, 47 | {'id': 5, 48 | 'name': 'Piecewise polynomial (q=0)', 49 | 'f': function(r, params) { 50 | var tmp = numeric.sub(1.0, numeric.div(r, params[0])); 51 | var dims = numeric.dim(tmp); 52 | for (var i = 0; i < dims[0]; i++){ 53 | for (var j = 0; j < dims[1]; j++){ 54 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0; 55 | } 56 | } 57 | return tmp; 58 | } 59 | }, 60 | {'id': 6, 61 | 'name': 'Piecewise polynomial (q=1)', 62 | 'f': function(r, params) { 63 | var tmp1 = numeric.div(r, params[0]); 64 | var tmp = numeric.sub(1.0, tmp1); 65 | var dims = numeric.dim(tmp); 66 | for (var i = 0; i < dims[0]; i++){ 67 | for (var j = 0; j < dims[1]; j++){ 68 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0; 69 | } 70 | } 71 | return numeric.mul(numeric.pow(tmp, 3), numeric.add(numeric.mul(3.0, tmp1), 1.0)); 72 | } 73 | }, 74 | {'id': 7, 75 | 'name': 'Periodic (period=pi)', 76 | 'f': function(r, params) { 77 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(r), 2))); 78 | } 79 | }, 80 | {'id': 8, 81 | 'name': 'Periodic (period=1)', 82 | 'f': function(r, params) { 83 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(numeric.mul(Math.PI, r)), 2))); 84 | } 85 | } 86 | ]; 87 | 88 | function GP(cf, params, id, dmTr, dmTeTr, trY) { 89 | var M = numeric.dim(distmatTe)[1]; 90 | 91 | this.z = randnArray(M); 92 | this.p = randnArray(M); 93 | this.cf = cf; 94 | this.params = params; 95 | this.id = id; 96 | 97 | this.Kte = cfs[this.cf].f(distmatTe, params); 98 | 99 | var tmp = computeProjection(this.Kte, this.cf, this.params, dmTr, dmTeTr, trY); 100 | this.proj = tmp.proj; 101 | this.mu = tmp.mu; 102 | this.sd95 = tmp.sd95; 103 | } 104 | 105 | 106 | function computeProjection(Kte, cf, params, dmTr, dmTeTr, trY) { 107 | var Mtr = numeric.dim(dmTr)[0]; 108 | var Mte = numeric.dim(distmatTe)[0]; 109 | 110 | if (Mtr > 0){ 111 | var Kxx_p_noise = cfs[cf].f(dmTr, params); 112 | for (var i = 0; i < Mtr; i++){ 113 | Kxx_p_noise[i][i] += params[1]; 114 | } 115 | 116 | var svd1 = numeric.svd(Kxx_p_noise); 117 | for (var i = 0; i < Mtr; i++){ 118 | if (svd1.S[i] > numeric.epsilon){ 119 | svd1.S[i] = 1.0/svd1.S[i]; 120 | } else { 121 | svd1.S[i] = 0.0; 122 | } 123 | } 124 | 125 | var tmp = numeric.dot(cfs[cf].f(dmTeTr, params), svd1.U); 126 | // there seems to be a bug in numeric.svd: svd1.U and transpose(svd1.V) are not always equal for a symmetric matrix 127 | var mu = numeric.dot(tmp, numeric.mul(svd1.S, numeric.dot(numeric.transpose(svd1.U), trY))); 128 | var cov = numeric.dot(tmp, numeric.diag(numeric.sqrt(svd1.S))); 129 | cov = numeric.dot(cov, numeric.transpose(cov)); 130 | cov = numeric.sub(Kte, cov); 131 | var svd2 = numeric.svd(cov); 132 | for (var i = 0; i < Mte; i++){ 133 | if (svd2.S[i] < numeric.epsilon){ 134 | svd2.S[i] = 0.0; 135 | } 136 | } 137 | var proj = numeric.dot(svd2.U, numeric.diag(numeric.sqrt(svd2.S))); 138 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(numeric.dot(proj, numeric.transpose(proj))))); 139 | } else { 140 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(Kte))); 141 | var svd = numeric.svd(Kte); 142 | var proj = numeric.dot(svd.U, numeric.diag(numeric.sqrt(svd.S))); 143 | var mu = numeric.rep([Mte], 0); 144 | } 145 | 146 | return { proj: proj, mu: mu, sd95: sd95 }; 147 | } 148 | 149 | function recomputeProjections(GPs, dmTr, dmTeTr, trY) { 150 | for (var gpi = 0; gpi < GPs.length; gpi++){ 151 | var gp = GPs[gpi]; 152 | var tmp = computeProjection(gp.Kte, gp.cf, gp.params, dmTr, dmTeTr, trY); 153 | gp.proj = tmp.proj; 154 | gp.mu = tmp.mu; 155 | gp.sd95 = tmp.sd95; 156 | GPs[gpi] = gp; 157 | } 158 | 159 | return GPs; 160 | } 161 | 162 | function computeDistanceMatrix(xdata1, xdata2) { 163 | var dm = numeric.rep([xdata1.length,xdata2.length], 0); 164 | for (var i = 0; i < xdata1.length; i++){ 165 | for (var j = 0; j < xdata2.length; j++){ 166 | var val = Math.abs(xdata2[j] - xdata1[i]); 167 | dm[i][j] = val; 168 | } 169 | } 170 | return dm; 171 | } 172 | 173 | var GPAxis = React.createClass({ 174 | render: function() { 175 | return (); 176 | }, 177 | shouldComponentUpdate: function() { return false; }, 178 | drawTrPoints: function(pointsX, pointsY) { 179 | var x = this.scales.x; 180 | var y = this.scales.y; 181 | var p = this.trPoints.selectAll("circle.trpoints") 182 | .data(d3.zip(pointsX, pointsY)) 183 | .attr("cx", function(d) { return x(d[0]); }) 184 | .attr("cy", function(d) { return y(d[1]); }); 185 | p.enter().append("circle") 186 | .attr("class", "trpoints") 187 | .attr("r", 2) 188 | .attr("cx", function(d) { return x(d[0]); }) 189 | .attr("cy", function(d) { return y(d[1]); }); 190 | p.exit().remove(); 191 | }, 192 | animationId: 0, 193 | componentWillReceiveProps: function(props) { 194 | // bind events 195 | if (props.state.addTrPoints) { 196 | d3.select(this.getDOMNode()).on("click", this.addTrPoint); 197 | } else { 198 | d3.select(this.getDOMNode()).on("click", null); 199 | } 200 | // redraw training points 201 | this.drawTrPoints(props.state.trPointsX, props.state.trPointsY); 202 | 203 | this.drawMeanAndVar(props); 204 | 205 | if (this.props.state.showSamples !== props.state.showSamples){ 206 | this.drawPaths(props); 207 | } 208 | 209 | if (this.props.state.samplingState !== props.state.samplingState){ 210 | clearInterval(this.animationId); 211 | if (props.state.samplingState === 1){ 212 | this.animationId = setInterval((function() { this.updateState(); this.drawPaths(); }).bind(this), 500); 213 | } else if (props.state.samplingState === 2){ 214 | this.animationId = setInterval((function() { this.contUpdateState(); this.drawPaths(); }).bind(this), 50); 215 | } 216 | } 217 | }, 218 | addTrPoint: function() { 219 | var mousePos = d3.mouse(this.getDOMNode()); 220 | var x = this.scales.x; 221 | var y = this.scales.y; 222 | 223 | // x is transformed to a point on a grid of 200 points between -5 and 5 224 | this.props.addTrPoint(Math.round((x.invert(mousePos[0]-50)+5)/10*199)/199*10-5, y.invert(mousePos[1]-50)); 225 | }, 226 | updateState: function() { 227 | var M = numeric.dim(distmatTe)[1]; 228 | for (var i = 0; i < this.props.state.GPs.length; i++){ 229 | var gp = this.props.state.GPs[i]; 230 | gp.z = randnArray(M); 231 | } 232 | }, 233 | stepState: 0, 234 | contUpdateState: function() { 235 | var M = numeric.dim(distmatTe)[1]; 236 | var alfa = 1.0-this.props.state.alfa; 237 | var n_steps = this.props.state.NSteps; 238 | var t_step = this.props.state.stepSize / n_steps; 239 | this.stepState = this.stepState % n_steps; 240 | 241 | for (var i = 0; i < this.props.state.GPs.length; i++){ 242 | var gp = this.props.state.GPs[i]; 243 | 244 | // refresh momentum: p = alfa * p + sqrt(1 - alfa^2) * randn(size(p)) 245 | if (this.stepState == (n_steps-1)) 246 | gp.p = numeric.add(numeric.mul(alfa, gp.p), numeric.mul(Math.sqrt(1 - alfa*alfa), randnArray(M))); 247 | 248 | var a = gp.p.slice(0), 249 | b = gp.z.slice(0), 250 | c = numeric.mul(-1, gp.z.slice(0)), 251 | d = gp.p.slice(0); 252 | 253 | gp.z = numeric.add(numeric.mul(a, Math.sin(t_step)), numeric.mul(b, Math.cos(t_step))); 254 | gp.p = numeric.add(numeric.mul(c, Math.sin(t_step)), numeric.mul(d, Math.cos(t_step))); 255 | } 256 | this.stepState = this.stepState + 1; 257 | }, 258 | drawMeanAndVar: function(props) { 259 | var gpline = this.gpline; 260 | if (props.state.showMeanAndVar){ 261 | var gps = props.state.GPs; 262 | } else { 263 | var gps = []; 264 | } 265 | 266 | var paths = this.meanLines.selectAll("path").data(gps, function (d) { return d.id; }) 267 | .attr("d", function (d) { 268 | var datay = d.mu; 269 | return gpline(d3.zip(tePointsX, datay)); 270 | }); 271 | paths.enter().append("path").attr("d", function (d) { 272 | var datay = d.mu; 273 | return gpline(d3.zip(tePointsX, datay)); 274 | }) 275 | .attr("class", function(d) { 276 | return "muline line line"+d.id; 277 | }); 278 | paths.exit().remove(); 279 | 280 | var pathsUp = this.upSd95Lines.selectAll("path").data(gps, function (d) { return d.id; }) 281 | .attr("d", function (d) { 282 | var datay = numeric.add(d.mu, d.sd95); 283 | return gpline(d3.zip(tePointsX, datay)); 284 | }); 285 | pathsUp.enter().append("path").attr("d", function (d) { 286 | var datay = numeric.add(d.mu, d.sd95); 287 | return gpline(d3.zip(tePointsX, datay)); 288 | }) 289 | .attr("class", function(d) { 290 | return "sdline line line"+d.id; 291 | }); 292 | pathsUp.exit().remove(); 293 | 294 | var pathsDown = this.downSd95Lines.selectAll("path").data(gps, function (d) { return d.id; }) 295 | .attr("d", function (d) { 296 | var datay = numeric.sub(d.mu, d.sd95); 297 | return gpline(d3.zip(tePointsX, datay)); 298 | }); 299 | pathsDown.enter().append("path").attr("d", function (d) { 300 | var datay = numeric.sub(d.mu, d.sd95); 301 | return gpline(d3.zip(tePointsX, datay)); 302 | }) 303 | .attr("class", function(d) { 304 | return "sdline line line"+d.id; 305 | }); 306 | pathsDown.exit().remove(); 307 | }, 308 | drawPaths: function(props) { 309 | if (!props) var props = this.props; 310 | var gpline = this.gpline; 311 | if (props.state.showSamples){ 312 | var gps = props.state.GPs; 313 | } else { 314 | var gps = []; 315 | } 316 | var paths = this.lines.selectAll("path").data(gps, function (d) { return d.id; }) 317 | .attr("d", function (d) { 318 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu); 319 | return gpline(d3.zip(tePointsX, datay)); 320 | }); 321 | paths.enter().append("path").attr("d", function (d) { 322 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu); 323 | return gpline(d3.zip(tePointsX, datay)); 324 | }) 325 | .attr("class", function(d) { 326 | return "line line"+d.id; 327 | }); 328 | paths.exit().remove(); 329 | }, 330 | scales: { x: null, y: null }, 331 | componentDidMount: function() { 332 | var svg = d3.select(this.getDOMNode()); 333 | var height = svg.attr("height"), 334 | width = svg.attr("width"); 335 | if (!height) { 336 | height = 300; 337 | svg.attr("height", height); 338 | } 339 | if (!width) { 340 | width = 500; 341 | svg.attr("width", width); 342 | } 343 | var margin = 50; 344 | svg = svg.append("g") 345 | .attr("transform", "translate("+margin+","+margin+")"); 346 | this.svg = svg; 347 | var fig_height = height - 2*margin, 348 | fig_width = width - 2*margin; 349 | 350 | // helper functions 351 | var x = d3.scale.linear().range([0, fig_width]).domain([-5, 5]); 352 | var y = d3.scale.linear().range([fig_height, 0]).domain([-3, 3]); 353 | this.scales.x = x; 354 | this.scales.y = y; 355 | var xAxis = d3.svg.axis() 356 | .scale(x) 357 | .orient("bottom"); 358 | var yAxis = d3.svg.axis() 359 | .scale(y) 360 | .orient("left"); 361 | this.gpline = d3.svg.line() 362 | .x(function(d) { return x(d[0]); }) 363 | .y(function(d) { return y(d[1]); }); 364 | 365 | // axes 366 | svg.append("g") 367 | .attr("class", "x axis") 368 | .attr("transform", "translate(0,"+fig_height+")") 369 | .call(xAxis); 370 | 371 | svg.append("g") 372 | .attr("class", "y axis") 373 | .call(yAxis); 374 | 375 | this.meanLines = svg.append("g"); 376 | this.upSd95Lines = svg.append("g"); 377 | this.downSd95Lines = svg.append("g"); 378 | this.lines = svg.append("g"); 379 | this.trPoints = svg.append("g"); 380 | this.drawTrPoints(this.props.state.trPointsX, this.props.state.trPointsY); 381 | this.drawPaths(); 382 | } 383 | }); 384 | 385 | 386 | var GPList = React.createClass({ 387 | render: function() { 388 | var delGP = this.props.delGP; 389 | var gplist = this.props.GPs.map(function (gp) { 390 | return ( 391 | {gp.id}{cfs[gp.cf].name}{gp.params[0].toFixed(2)}{gp.params[1].toFixed(2)} 392 | ); 393 | }); 394 | return ({gplist}); 395 | } 396 | }); 397 | -------------------------------------------------------------------------------- /src/slider.js: -------------------------------------------------------------------------------- 1 | /** @jsx React.DOM */ 2 | var Slider = React.createClass({ 3 | render: function() { 4 | // Just insert the svg-element and render rest in componentDidMount. 5 | // Marker location is updated in componentWillReceiveProps using d3. 6 | return (); 7 | }, 8 | shouldComponentUpdate: function() { return false; }, // Never re-render. 9 | componentDidMount: function() { 10 | var val = this.props.value; 11 | var setVal = this.props.setValue; 12 | var opt = this.props.opt; 13 | 14 | // set defaults for options if not given 15 | if (!opt.width) opt.width = 200; 16 | if (!opt.height) opt.height = 20; 17 | if (!opt.step) { 18 | opt.round = function(x) { return x; } 19 | } else { 20 | opt.round = function(x) { return Math.round(x / opt.step) * opt.step; } 21 | } 22 | if (!opt.min) opt.min = opt.round(0); 23 | if (!opt.max) opt.max = opt.round(1); 24 | if (opt.min > opt.max) { 25 | var tmp = opt.min; 26 | opt.min = opt.max; 27 | opt.max = tmp; 28 | } 29 | if (val > opt.max) setVal(opt.max); 30 | if (val < opt.min) setVal(opt.min); 31 | 32 | // calculate range 33 | var markerRadius = opt.height * 0.5; 34 | var x1 = markerRadius; 35 | var x2 = opt.width - markerRadius; 36 | 37 | // d3 helpers 38 | var scale = d3.scale.linear() 39 | .domain([opt.min, opt.max]) 40 | .range([x1, x2]); 41 | this.scale = scale; 42 | var setValFromMousePos = function(x) { 43 | setVal(opt.round(scale.invert(Math.max(x1, Math.min(x2, x))))); 44 | }; 45 | var dragmove = function() { 46 | setValFromMousePos(d3.event.x); 47 | }; 48 | var drag = d3.behavior.drag().on("drag", dragmove); 49 | 50 | // bind d3 events and insert background line and marker 51 | var svg = d3.select(this.getDOMNode()); 52 | svg.attr("class", "slider") 53 | .attr("width", opt.width) 54 | .attr("height", opt.height) 55 | .on("click", function () { setValFromMousePos(d3.mouse(this)[0]); }); 56 | svg.append("line") 57 | .attr("x1", x1) 58 | .attr("x2", x2) 59 | .attr("y1", "50%") 60 | .attr("y2", "50%") 61 | .attr("class", "sliderbg"); 62 | this.marker = svg.append("circle") 63 | .attr("cy", "50%") 64 | .attr("r", markerRadius) 65 | .attr("class", "slidermv") 66 | .datum(val) 67 | .attr("cx", function (d) { return scale(d); }) 68 | .call(drag); 69 | }, 70 | componentWillReceiveProps: function(props) { 71 | // update the marker location on receiving new props 72 | var scale = this.scale; 73 | this.marker.datum(props.value) 74 | .attr("cx", function (d) { return scale(d); }); 75 | } 76 | }); 77 | 78 | --------------------------------------------------------------------------------