├── LICENSE
├── README.md
├── bgrep4.png
├── build
├── gpapp.js
├── gputils.js
└── slider.js
├── distancematrix.js
├── gp.css
├── halfmoon9.ico
├── index.html
├── site.css
└── src
├── gpapp.js
├── gputils.js
└── slider.js
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Tomi Peltola
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Gaussian process regression demo in Javascript
2 | ==============================================
3 |
4 | This repository contains code for a Gaussian process regression demonstration in Javascript. See http://www.tmpl.fi/gp/ for the live version.
5 |
6 | The code depends on the following javascript libraries, which are not included in this repository:
7 |
8 | * [React](http://facebook.github.io/react/)
9 | * [D3.js](http://d3js.org)
10 | * [Numeric.js](http://www.numericjs.com/)
11 |
12 |
13 | Simulation of continuous trajectories
14 | -------------------------------------
15 |
16 | Continuous trajectories are simulated using Hamiltonian Monte Carlo (HMC) with partial momentum refreshment and analytically solved dynamics for the Gaussian posterior distribution.
17 |
18 | For an excellent HMC reference, see: Radford M. Neal, _MCMC using Hamiltonian dynamics_. [arXiv:1206.1901](http://arxiv.org/abs/1206.1901), 2012.
19 |
20 | Contact
21 | -------
22 |
23 | E-mail: tomi.peltola@aalto.fi
24 |
25 | License
26 | -------
27 |
28 | MIT. See `LICENSE`.
29 |
--------------------------------------------------------------------------------
/bgrep4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/to-mi/gp-demo-js/21fac4e7d24847a0d347e8f3621f674e37d715a1/bgrep4.png
--------------------------------------------------------------------------------
/build/gpapp.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 |
3 | var GPApp = React.createClass({displayName: 'GPApp',
4 | getInitialState: function() {
5 | return { GPs: [new GP(0, [1,0.2], 1, [], [], [])],
6 | newGPParam: 1.0,
7 | newGPNoise: 0.2,
8 | newGPcf: 0,
9 | newGPavailableIDs: [10, 9, 8, 7, 6, 5, 4, 3, 2],
10 | alfa: 0.3,
11 | stepSize: 3.14,
12 | NSteps: 15,
13 | addTrPoints: false,
14 | trPointsX: [],
15 | trPointsY: [],
16 | dmTr: [],
17 | dmTeTr: [],
18 | samplingState: 0, // 0 = stopped, 1 = discrete, 2 = continuous
19 | oldSamplingState: 0,
20 | showSamples: true,
21 | showMeanAndVar: false
22 | }
23 | },
24 | setAlfa: function(newVal) { this.setState({ alfa: newVal }); },
25 | setStepSize: function(newVal) { this.setState({ stepSize: newVal }); },
26 | setNSteps: function(newVal) { this.setState({ NSteps: newVal }); },
27 | toggleAddTrPoints: function() {
28 | if (this.state.addTrPoints){
29 | // added training points
30 | var dmTr = computeDistanceMatrix(this.state.trPointsX, this.state.trPointsX);
31 | var dmTeTr = computeDistanceMatrix(tePointsX, this.state.trPointsX);
32 |
33 | var newGPs = recomputeProjections(this.state.GPs, dmTr, dmTeTr, this.state.trPointsY);
34 |
35 | this.setState({ addTrPoints: !this.state.addTrPoints, GPs: newGPs, dmTr: dmTr, dmTeTr: dmTeTr, samplingState: this.state.oldSamplingState });
36 | } else {
37 | // beginning to add training points
38 | this.setState({ addTrPoints: !this.state.addTrPoints, oldSamplingState: this.state.samplingState, samplingState: 0 });
39 | }
40 | },
41 | clearTrPoints: function() { this.setState({ trPointsX: [], trPointsY: [] }); },
42 | toggleShowMeanAndVar: function() { if (!this.state.addTrPoints) this.setState({ showMeanAndVar: !this.state.showMeanAndVar }); },
43 | toggleShowSamples: function() {
44 | if (!this.state.addTrPoints) {
45 | if (this.state.showSamples) {
46 | this.setState({ samplingState: 0, showSamples: false });
47 | } else {
48 | this.setState({ samplingState: this.state.oldSamplingState, showSamples: true });
49 | }
50 | }
51 | },
52 | setNewGPParam: function(newVal) { this.setState({ newGPParam: newVal }); },
53 | setNewGPNoise: function(newVal) { this.setState({ newGPNoise: newVal }); },
54 | setNewGPcf: function(event) { this.setState({ newGPcf: event.target.value }); },
55 | addGP: function() {
56 | if (this.state.newGPavailableIDs.length < 1) return;
57 | var id = this.state.newGPavailableIDs.pop();
58 | var newGPs = this.state.GPs.concat([new GP(this.state.newGPcf, [this.state.newGPParam, this.state.newGPNoise], id, this.state.dmTr, this.state.dmTeTr, this.state.trPointsY)]);
59 | this.setState({ GPs: newGPs, newGPavailableIDs: this.state.newGPavailableIDs });
60 | },
61 | delGP: function(id) {
62 | return (function() {
63 | var newGPs = this.state.GPs;
64 | var delIdx = newGPs.findIndex(function (g) { return g.id == id; });
65 | if (delIdx >= 0) {
66 | newGPs.splice(delIdx, 1);
67 | this.state.newGPavailableIDs.push(id);
68 | this.setState({ GPs: newGPs });
69 | }
70 | }).bind(this);
71 | },
72 | addTrPoint: function(x, y) {
73 | if (x >= -5 && x <= 5 && y >= -3 && y <= 3){
74 | var newTrPointsX = this.state.trPointsX.concat([x]);
75 | var newTrPointsY = this.state.trPointsY.concat([y]);
76 | this.setState({ trPointsX: newTrPointsX, trPointsY: newTrPointsY });
77 | }
78 | },
79 | stopSampling: function() { this.setState({ samplingState: 0, oldSamplingState: 0 }); },
80 | startDiscreteSampling: function() { this.setState({ samplingState: 1, oldSamplingState: 1 }); },
81 | startContinuousSampling: function() { this.setState({ samplingState: 2, oldSamplingState: 2 }); },
82 | render: function() {
83 | var sliderOptAlfa = { width: 200, height: 9, min: 0, max: 1 };
84 | var sliderOptStepSize = { width: 200, height: 9, min: 0, max: 2*Math.PI };
85 | var sliderOptNSteps = { width: 200, height: 9, min: 1, max: 100, step: 1 };
86 | var sliderOptGPParam = { width: 200, height: 9, min: 0.01, max: 5 };
87 | var sliderOptGPNoise = { width: 200, height: 9, min: 0, max: 2 };
88 | var delGP = this.delGP;
89 | var gpoptions = cfs.map(function (c) {
90 | return (React.DOM.option( {value:c.id}, c.name));
91 | });
92 | return (
93 | React.DOM.div( {id:"gp"},
94 | GPAxis( {state:this.state, addTrPoint:this.addTrPoint} ),
95 | React.DOM.div( {id:"controls"},
96 | React.DOM.input( {type:"checkbox", checked:this.state.showMeanAndVar, onChange:this.toggleShowMeanAndVar} ),"Show mean and credible intervals"+' '+
97 | " ",React.DOM.input( {type:"checkbox", checked:this.state.showSamples, onChange:this.toggleShowSamples} ),"Show samples",React.DOM.br(null ),
98 | React.DOM.button( {onClick:this.startDiscreteSampling, disabled:this.state.samplingState === 1 || this.state.addTrPoints || !this.state.showSamples}, "sample independently"),
99 | React.DOM.button( {onClick:this.startContinuousSampling, disabled:this.state.samplingState === 2 || this.state.addTrPoints || !this.state.showSamples}, "sample continuous trajectories"),
100 | React.DOM.button( {onClick:this.stopSampling, disabled:this.state.samplingState === 0 || this.state.addTrPoints}, "stop sampling"),
101 | React.DOM.br(null ),
102 | this.state.addTrPoints ? React.DOM.span( {className:"info"}, "click on the figure to add an observation " ) : '',
103 | React.DOM.button( {onClick:this.toggleAddTrPoints}, this.state.addTrPoints ? "done" : "add observations"),
104 | this.state.addTrPoints ? React.DOM.button( {onClick:this.clearTrPoints}, "clear") : ''
105 | ),
106 | React.DOM.div( {id:"opts"},
107 | React.DOM.h2(null, "Trajectory simulation settings"),
108 | React.DOM.table(null,
109 | React.DOM.tr(null, React.DOM.td(null, "Momentum refreshment"),React.DOM.td(null, Slider( {value:this.state.alfa, setValue:this.setAlfa, opt:sliderOptAlfa} ), " ", this.state.alfa.toFixed(2))),
110 | React.DOM.tr(null, React.DOM.td(null, "Path length"),React.DOM.td(null, Slider( {value:this.state.stepSize, setValue:this.setStepSize, opt:sliderOptStepSize} ), " ", this.state.stepSize.toFixed(2))),
111 | React.DOM.tr(null, React.DOM.td(null, "Number of steps in path"),React.DOM.td(null, Slider( {value:this.state.NSteps, setValue:this.setNSteps, opt:sliderOptNSteps} ), " ", this.state.NSteps))
112 | )
113 | ),
114 | React.DOM.div( {id:"gplist"},
115 | React.DOM.div( {id:"addgp"},
116 | React.DOM.h2(null, "Add new process"),
117 | React.DOM.table(null,
118 | React.DOM.tr(null, React.DOM.td(null, "Covariance function"),React.DOM.td(null, React.DOM.select( {value:this.state.newGPcf, onChange:this.setNewGPcf}, gpoptions))),
119 | React.DOM.tr(null, React.DOM.td(null, "Length scale"),React.DOM.td(null, Slider( {value:this.state.newGPParam, setValue:this.setNewGPParam, opt:sliderOptGPParam} ), " ", this.state.newGPParam.toFixed(2))),
120 | React.DOM.tr(null, React.DOM.td(null, "Noise"),React.DOM.td(null, Slider( {value:this.state.newGPNoise, setValue:this.setNewGPNoise, opt:sliderOptGPNoise} ), " ", this.state.newGPNoise.toFixed(2)))
121 | ),
122 | React.DOM.button( {onClick:this.addGP, disabled:this.state.newGPavailableIDs.length <= 0}, "add")
123 | ),
124 | React.DOM.h2(null, "Process list"),
125 | React.DOM.table(null,
126 | React.DOM.thead(null,
127 | React.DOM.tr(null, React.DOM.th(null, "id"),React.DOM.th(null, "covariance"),React.DOM.th(null, "length scale"),React.DOM.th(null, "noise"),React.DOM.th(null))
128 | ),
129 | GPList( {GPs:this.state.GPs, delGP:this.delGP} )
130 | )
131 | )
132 | )
133 | )
134 | }
135 | });
136 |
137 | React.renderComponent(
138 | GPApp(null ),
139 | document.getElementById('gp-outer')
140 | );
141 |
--------------------------------------------------------------------------------
/build/gputils.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 | var tePointsX = numeric.linspace(-5, 5, numeric.dim(distmatTe)[0]);
3 | var randn = d3.random.normal();
4 | function randnArray(size){
5 | var zs = new Array(size);
6 | for (var i = 0; i < size; i++) {
7 | zs[i] = randn();
8 | }
9 | return zs;
10 | }
11 |
12 | // ids must be in order of the array
13 | var cfs = [
14 | {'id': 0,
15 | 'name': 'Exponentiated quadratic',
16 | 'f': function(r, params) {
17 | return numeric.exp(numeric.mul(-0.5 / (params[0] * params[0]), numeric.pow(r, 2)));
18 | }
19 | },
20 | {'id': 1,
21 | 'name': 'Exponential',
22 | 'f': function(r, params) {
23 | return numeric.exp(numeric.mul(-0.5 / params[0], r));
24 | }
25 | },
26 | {'id': 2,
27 | 'name': 'Matern 3/2',
28 | 'f': function(r, params) {
29 | var tmp = numeric.mul(Math.sqrt(3.0) / params[0], r);
30 | return numeric.mul(numeric.add(1.0, tmp), numeric.exp(numeric.neg(tmp)));
31 | }
32 | },
33 | {'id': 3,
34 | 'name': 'Matern 5/2',
35 | 'f': function(r, params) {
36 | var tmp = numeric.mul(Math.sqrt(5.0) / params[0], r);
37 | var tmp2 = numeric.div(numeric.mul(tmp, tmp), 3.0);
38 | return numeric.mul(numeric.add(numeric.add(1, tmp), tmp2), numeric.exp(numeric.neg(tmp)));
39 | }
40 | },
41 | {'id': 4,
42 | 'name': 'Rational quadratic (alpha=1)',
43 | 'f': function(r, params) {
44 | return numeric.pow(numeric.add(1.0, numeric.div(numeric.pow(r, 2), 2.0 * params[0] * params[0])), -1);
45 | }
46 | },
47 | {'id': 5,
48 | 'name': 'Piecewise polynomial (q=0)',
49 | 'f': function(r, params) {
50 | var tmp = numeric.sub(1.0, numeric.div(r, params[0]));
51 | var dims = numeric.dim(tmp);
52 | for (var i = 0; i < dims[0]; i++){
53 | for (var j = 0; j < dims[1]; j++){
54 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0;
55 | }
56 | }
57 | return tmp;
58 | }
59 | },
60 | {'id': 6,
61 | 'name': 'Piecewise polynomial (q=1)',
62 | 'f': function(r, params) {
63 | var tmp1 = numeric.div(r, params[0]);
64 | var tmp = numeric.sub(1.0, tmp1);
65 | var dims = numeric.dim(tmp);
66 | for (var i = 0; i < dims[0]; i++){
67 | for (var j = 0; j < dims[1]; j++){
68 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0;
69 | }
70 | }
71 | return numeric.mul(numeric.pow(tmp, 3), numeric.add(numeric.mul(3.0, tmp1), 1.0));
72 | }
73 | },
74 | {'id': 7,
75 | 'name': 'Periodic (period=pi)',
76 | 'f': function(r, params) {
77 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(r), 2)));
78 | }
79 | },
80 | {'id': 8,
81 | 'name': 'Periodic (period=1)',
82 | 'f': function(r, params) {
83 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(numeric.mul(Math.PI, r)), 2)));
84 | }
85 | }
86 | ];
87 |
88 | function GP(cf, params, id, dmTr, dmTeTr, trY) {
89 | var M = numeric.dim(distmatTe)[1];
90 |
91 | this.z = randnArray(M);
92 | this.p = randnArray(M);
93 | this.cf = cf;
94 | this.params = params;
95 | this.id = id;
96 |
97 | this.Kte = cfs[this.cf].f(distmatTe, params);
98 |
99 | var tmp = computeProjection(this.Kte, this.cf, this.params, dmTr, dmTeTr, trY);
100 | this.proj = tmp.proj;
101 | this.mu = tmp.mu;
102 | this.sd95 = tmp.sd95;
103 | }
104 |
105 |
106 | function computeProjection(Kte, cf, params, dmTr, dmTeTr, trY) {
107 | var Mtr = numeric.dim(dmTr)[0];
108 | var Mte = numeric.dim(distmatTe)[0];
109 |
110 | if (Mtr > 0){
111 | var Kxx_p_noise = cfs[cf].f(dmTr, params);
112 | for (var i = 0; i < Mtr; i++){
113 | Kxx_p_noise[i][i] += params[1];
114 | }
115 |
116 | var svd1 = numeric.svd(Kxx_p_noise);
117 | for (var i = 0; i < Mtr; i++){
118 | if (svd1.S[i] > numeric.epsilon){
119 | svd1.S[i] = 1.0/svd1.S[i];
120 | } else {
121 | svd1.S[i] = 0.0;
122 | }
123 | }
124 |
125 | var tmp = numeric.dot(cfs[cf].f(dmTeTr, params), svd1.U);
126 | // there seems to be a bug in numeric.svd: svd1.U and transpose(svd1.V) are not always equal for a symmetric matrix
127 | var mu = numeric.dot(tmp, numeric.mul(svd1.S, numeric.dot(numeric.transpose(svd1.U), trY)));
128 | var cov = numeric.dot(tmp, numeric.diag(numeric.sqrt(svd1.S)));
129 | cov = numeric.dot(cov, numeric.transpose(cov));
130 | cov = numeric.sub(Kte, cov);
131 | var svd2 = numeric.svd(cov);
132 | for (var i = 0; i < Mte; i++){
133 | if (svd2.S[i] < numeric.epsilon){
134 | svd2.S[i] = 0.0;
135 | }
136 | }
137 | var proj = numeric.dot(svd2.U, numeric.diag(numeric.sqrt(svd2.S)));
138 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(numeric.dot(proj, numeric.transpose(proj)))));
139 | } else {
140 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(Kte)));
141 | var svd = numeric.svd(Kte);
142 | var proj = numeric.dot(svd.U, numeric.diag(numeric.sqrt(svd.S)));
143 | var mu = numeric.rep([Mte], 0);
144 | }
145 |
146 | return { proj: proj, mu: mu, sd95: sd95 };
147 | }
148 |
149 | function recomputeProjections(GPs, dmTr, dmTeTr, trY) {
150 | for (var gpi = 0; gpi < GPs.length; gpi++){
151 | var gp = GPs[gpi];
152 | var tmp = computeProjection(gp.Kte, gp.cf, gp.params, dmTr, dmTeTr, trY);
153 | gp.proj = tmp.proj;
154 | gp.mu = tmp.mu;
155 | gp.sd95 = tmp.sd95;
156 | GPs[gpi] = gp;
157 | }
158 |
159 | return GPs;
160 | }
161 |
162 | function computeDistanceMatrix(xdata1, xdata2) {
163 | var dm = numeric.rep([xdata1.length,xdata2.length], 0);
164 | for (var i = 0; i < xdata1.length; i++){
165 | for (var j = 0; j < xdata2.length; j++){
166 | var val = Math.abs(xdata2[j] - xdata1[i]);
167 | dm[i][j] = val;
168 | }
169 | }
170 | return dm;
171 | }
172 |
173 | var GPAxis = React.createClass({displayName: 'GPAxis',
174 | render: function() {
175 | return (React.DOM.svg(null));
176 | },
177 | shouldComponentUpdate: function() { return false; },
178 | drawTrPoints: function(pointsX, pointsY) {
179 | var x = this.scales.x;
180 | var y = this.scales.y;
181 | var p = this.trPoints.selectAll("circle.trpoints")
182 | .data(d3.zip(pointsX, pointsY))
183 | .attr("cx", function(d) { return x(d[0]); })
184 | .attr("cy", function(d) { return y(d[1]); });
185 | p.enter().append("circle")
186 | .attr("class", "trpoints")
187 | .attr("r", 2)
188 | .attr("cx", function(d) { return x(d[0]); })
189 | .attr("cy", function(d) { return y(d[1]); });
190 | p.exit().remove();
191 | },
192 | animationId: 0,
193 | componentWillReceiveProps: function(props) {
194 | // bind events
195 | if (props.state.addTrPoints) {
196 | d3.select(this.getDOMNode()).on("click", this.addTrPoint);
197 | } else {
198 | d3.select(this.getDOMNode()).on("click", null);
199 | }
200 | // redraw training points
201 | this.drawTrPoints(props.state.trPointsX, props.state.trPointsY);
202 |
203 | this.drawMeanAndVar(props);
204 |
205 | if (this.props.state.showSamples !== props.state.showSamples){
206 | this.drawPaths(props);
207 | }
208 |
209 | if (this.props.state.samplingState !== props.state.samplingState){
210 | clearInterval(this.animationId);
211 | if (props.state.samplingState === 1){
212 | this.animationId = setInterval((function() { this.updateState(); this.drawPaths(); }).bind(this), 500);
213 | } else if (props.state.samplingState === 2){
214 | this.animationId = setInterval((function() { this.contUpdateState(); this.drawPaths(); }).bind(this), 50);
215 | }
216 | }
217 | },
218 | addTrPoint: function() {
219 | var mousePos = d3.mouse(this.getDOMNode());
220 | var x = this.scales.x;
221 | var y = this.scales.y;
222 |
223 | // x is transformed to a point on a grid of 200 points between -5 and 5
224 | this.props.addTrPoint(Math.round((x.invert(mousePos[0]-50)+5)/10*199)/199*10-5, y.invert(mousePos[1]-50));
225 | },
226 | updateState: function() {
227 | var M = numeric.dim(distmatTe)[1];
228 | for (var i = 0; i < this.props.state.GPs.length; i++){
229 | var gp = this.props.state.GPs[i];
230 | gp.z = randnArray(M);
231 | }
232 | },
233 | stepState: 0,
234 | contUpdateState: function() {
235 | var M = numeric.dim(distmatTe)[1];
236 | var alfa = 1.0-this.props.state.alfa;
237 | var n_steps = this.props.state.NSteps;
238 | var t_step = this.props.state.stepSize / n_steps;
239 | this.stepState = this.stepState % n_steps;
240 |
241 | for (var i = 0; i < this.props.state.GPs.length; i++){
242 | var gp = this.props.state.GPs[i];
243 |
244 | // refresh momentum: p = alfa * p + sqrt(1 - alfa^2) * randn(size(p))
245 | if (this.stepState == (n_steps-1))
246 | gp.p = numeric.add(numeric.mul(alfa, gp.p), numeric.mul(Math.sqrt(1 - alfa*alfa), randnArray(M)));
247 |
248 | var a = gp.p.slice(0),
249 | b = gp.z.slice(0),
250 | c = numeric.mul(-1, gp.z.slice(0)),
251 | d = gp.p.slice(0);
252 |
253 | gp.z = numeric.add(numeric.mul(a, Math.sin(t_step)), numeric.mul(b, Math.cos(t_step)));
254 | gp.p = numeric.add(numeric.mul(c, Math.sin(t_step)), numeric.mul(d, Math.cos(t_step)));
255 | }
256 | this.stepState = this.stepState + 1;
257 | },
258 | drawMeanAndVar: function(props) {
259 | var gpline = this.gpline;
260 | if (props.state.showMeanAndVar){
261 | var gps = props.state.GPs;
262 | } else {
263 | var gps = [];
264 | }
265 |
266 | var paths = this.meanLines.selectAll("path").data(gps, function (d) { return d.id; })
267 | .attr("d", function (d) {
268 | var datay = d.mu;
269 | return gpline(d3.zip(tePointsX, datay));
270 | });
271 | paths.enter().append("path").attr("d", function (d) {
272 | var datay = d.mu;
273 | return gpline(d3.zip(tePointsX, datay));
274 | })
275 | .attr("class", function(d) {
276 | return "muline line line"+d.id;
277 | });
278 | paths.exit().remove();
279 |
280 | var pathsUp = this.upSd95Lines.selectAll("path").data(gps, function (d) { return d.id; })
281 | .attr("d", function (d) {
282 | var datay = numeric.add(d.mu, d.sd95);
283 | return gpline(d3.zip(tePointsX, datay));
284 | });
285 | pathsUp.enter().append("path").attr("d", function (d) {
286 | var datay = numeric.add(d.mu, d.sd95);
287 | return gpline(d3.zip(tePointsX, datay));
288 | })
289 | .attr("class", function(d) {
290 | return "sdline line line"+d.id;
291 | });
292 | pathsUp.exit().remove();
293 |
294 | var pathsDown = this.downSd95Lines.selectAll("path").data(gps, function (d) { return d.id; })
295 | .attr("d", function (d) {
296 | var datay = numeric.sub(d.mu, d.sd95);
297 | return gpline(d3.zip(tePointsX, datay));
298 | });
299 | pathsDown.enter().append("path").attr("d", function (d) {
300 | var datay = numeric.sub(d.mu, d.sd95);
301 | return gpline(d3.zip(tePointsX, datay));
302 | })
303 | .attr("class", function(d) {
304 | return "sdline line line"+d.id;
305 | });
306 | pathsDown.exit().remove();
307 | },
308 | drawPaths: function(props) {
309 | if (!props) var props = this.props;
310 | var gpline = this.gpline;
311 | if (props.state.showSamples){
312 | var gps = props.state.GPs;
313 | } else {
314 | var gps = [];
315 | }
316 | var paths = this.lines.selectAll("path").data(gps, function (d) { return d.id; })
317 | .attr("d", function (d) {
318 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu);
319 | return gpline(d3.zip(tePointsX, datay));
320 | });
321 | paths.enter().append("path").attr("d", function (d) {
322 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu);
323 | return gpline(d3.zip(tePointsX, datay));
324 | })
325 | .attr("class", function(d) {
326 | return "line line"+d.id;
327 | });
328 | paths.exit().remove();
329 | },
330 | scales: { x: null, y: null },
331 | componentDidMount: function() {
332 | var svg = d3.select(this.getDOMNode());
333 | var height = svg.attr("height"),
334 | width = svg.attr("width");
335 | if (!height) {
336 | height = 300;
337 | svg.attr("height", height);
338 | }
339 | if (!width) {
340 | width = 500;
341 | svg.attr("width", width);
342 | }
343 | var margin = 50;
344 | svg = svg.append("g")
345 | .attr("transform", "translate("+margin+","+margin+")");
346 | this.svg = svg;
347 | var fig_height = height - 2*margin,
348 | fig_width = width - 2*margin;
349 |
350 | // helper functions
351 | var x = d3.scale.linear().range([0, fig_width]).domain([-5, 5]);
352 | var y = d3.scale.linear().range([fig_height, 0]).domain([-3, 3]);
353 | this.scales.x = x;
354 | this.scales.y = y;
355 | var xAxis = d3.svg.axis()
356 | .scale(x)
357 | .orient("bottom");
358 | var yAxis = d3.svg.axis()
359 | .scale(y)
360 | .orient("left");
361 | this.gpline = d3.svg.line()
362 | .x(function(d) { return x(d[0]); })
363 | .y(function(d) { return y(d[1]); });
364 |
365 | // axes
366 | svg.append("g")
367 | .attr("class", "x axis")
368 | .attr("transform", "translate(0,"+fig_height+")")
369 | .call(xAxis);
370 |
371 | svg.append("g")
372 | .attr("class", "y axis")
373 | .call(yAxis);
374 |
375 | this.meanLines = svg.append("g");
376 | this.upSd95Lines = svg.append("g");
377 | this.downSd95Lines = svg.append("g");
378 | this.lines = svg.append("g");
379 | this.trPoints = svg.append("g");
380 | this.drawTrPoints(this.props.state.trPointsX, this.props.state.trPointsY);
381 | this.drawPaths();
382 | }
383 | });
384 |
385 |
386 | var GPList = React.createClass({displayName: 'GPList',
387 | render: function() {
388 | var delGP = this.props.delGP;
389 | var gplist = this.props.GPs.map(function (gp) {
390 | return (React.DOM.tr( {key:gp.id},
391 | React.DOM.td( {className:"tr"+gp.id}, gp.id),React.DOM.td(null, cfs[gp.cf].name),React.DOM.td(null, gp.params[0].toFixed(2)),React.DOM.td(null, gp.params[1].toFixed(2)),React.DOM.td(null, React.DOM.button( {onClick:delGP(gp.id)}, "remove"))
392 | ));
393 | });
394 | return (React.DOM.tbody(null, gplist));
395 | }
396 | });
397 |
--------------------------------------------------------------------------------
/build/slider.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 | var Slider = React.createClass({displayName: 'Slider',
3 | render: function() {
4 | // Just insert the svg-element and render rest in componentDidMount.
5 | // Marker location is updated in componentWillReceiveProps using d3.
6 | return (React.DOM.svg(null));
7 | },
8 | shouldComponentUpdate: function() { return false; }, // Never re-render.
9 | componentDidMount: function() {
10 | var val = this.props.value;
11 | var setVal = this.props.setValue;
12 | var opt = this.props.opt;
13 |
14 | // set defaults for options if not given
15 | if (!opt.width) opt.width = 200;
16 | if (!opt.height) opt.height = 20;
17 | if (!opt.step) {
18 | opt.round = function(x) { return x; }
19 | } else {
20 | opt.round = function(x) { return Math.round(x / opt.step) * opt.step; }
21 | }
22 | if (!opt.min) opt.min = opt.round(0);
23 | if (!opt.max) opt.max = opt.round(1);
24 | if (opt.min > opt.max) {
25 | var tmp = opt.min;
26 | opt.min = opt.max;
27 | opt.max = tmp;
28 | }
29 | if (val > opt.max) setVal(opt.max);
30 | if (val < opt.min) setVal(opt.min);
31 |
32 | // calculate range
33 | var markerRadius = opt.height * 0.5;
34 | var x1 = markerRadius;
35 | var x2 = opt.width - markerRadius;
36 |
37 | // d3 helpers
38 | var scale = d3.scale.linear()
39 | .domain([opt.min, opt.max])
40 | .range([x1, x2]);
41 | this.scale = scale;
42 | var setValFromMousePos = function(x) {
43 | setVal(opt.round(scale.invert(Math.max(x1, Math.min(x2, x)))));
44 | };
45 | var dragmove = function() {
46 | setValFromMousePos(d3.event.x);
47 | };
48 | var drag = d3.behavior.drag().on("drag", dragmove);
49 |
50 | // bind d3 events and insert background line and marker
51 | var svg = d3.select(this.getDOMNode());
52 | svg.attr("class", "slider")
53 | .attr("width", opt.width)
54 | .attr("height", opt.height)
55 | .on("click", function () { setValFromMousePos(d3.mouse(this)[0]); });
56 | svg.append("line")
57 | .attr("x1", x1)
58 | .attr("x2", x2)
59 | .attr("y1", "50%")
60 | .attr("y2", "50%")
61 | .attr("class", "sliderbg");
62 | this.marker = svg.append("circle")
63 | .attr("cy", "50%")
64 | .attr("r", markerRadius)
65 | .attr("class", "slidermv")
66 | .datum(val)
67 | .attr("cx", function (d) { return scale(d); })
68 | .call(drag);
69 | },
70 | componentWillReceiveProps: function(props) {
71 | // update the marker location on receiving new props
72 | var scale = this.scale;
73 | this.marker.datum(props.value)
74 | .attr("cx", function (d) { return scale(d); });
75 | }
76 | });
77 |
78 |
--------------------------------------------------------------------------------
/gp.css:
--------------------------------------------------------------------------------
1 | .slider circle.slidermv {
2 | fill: #c00;
3 | stroke: none;
4 | }
5 | .slider line.sliderbg {
6 | stroke: #000;
7 | fill: none;
8 | stroke-width: 2px;
9 | stroke-linecap: round;
10 | }
11 | .axis path,
12 | .axis line {
13 | fill: none;
14 | stroke: #000;
15 | shape-rendering: crispEdges;
16 | stroke-width: 1.5px;
17 | }
18 | .trpoints {
19 | fill: #000;
20 | }
21 | .line {
22 | fill: none;
23 | stroke-width: 1.5px;
24 | }
25 | .muline {
26 | //stroke-dasharray: 10,10;
27 | opacity: 0.5;
28 | }
29 | .sdline {
30 | //stroke-dasharray: 5,10;
31 | opacity: 0.5;
32 | }
33 | .line1 { stroke: #204a87; }
34 | .tr1 { border-left: 24px solid #204a87; }
35 | .line2 { stroke: #ce5c00; }
36 | .tr2 { border-left: 24px solid #ce5c00; }
37 | .line3 { stroke: #ad7fa8; }
38 | .tr3 { border-left: 24px solid #ad7fa8; }
39 | .line4 { stroke: #4e9a06; }
40 | .tr4 { border-left: 24px solid #4e9a06; }
41 | .line5 { stroke: #c4a000; }
42 | .tr5 { border-left: 24px solid #c4a000; }
43 | .line6 { stroke: #a40000; }
44 | .tr6 { border-left: 24px solid #a40000; }
45 | .line7 { stroke: #2e3436; }
46 | .tr7 { border-left: 24px solid #2e3436; }
47 | .line8 { stroke: #e9b96e; }
48 | .tr8 { border-left: 24px solid #e9b96e; }
49 | .line9 { stroke: #8f5902; }
50 | .tr9 { border-left: 24px solid #8f5902; }
51 | .line10 { stroke: #5c3566; }
52 | .tr10 { border-left: 24px solid #5c3566; }
53 | td {
54 | padding: 0 5px;
55 | }
56 | th {
57 | border-bottom: 2px solid #000;
58 | }
59 | button[disabled=disabled], button:disabled {
60 | text-decoration: line-through;
61 | }
62 | input, button, select {
63 | font: inherit;
64 | }
65 | .info {
66 | color: #c00;
67 | }
68 |
--------------------------------------------------------------------------------
/halfmoon9.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/to-mi/gp-demo-js/21fac4e7d24847a0d347e8f3621f674e37d715a1/halfmoon9.ico
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Gaussian process regression demo
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
Gaussian process regression demo
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/site.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Lora', serif;
3 | font-size: 14px;
4 | margin: 0;
5 | padding: 0;
6 | background: repeat url('bgrep4.png');
7 | color: #222;
8 | }
9 | h1 {
10 | font-family: 'Special Elite', cursive;
11 | font-size: 28px;
12 | color: #222;
13 | }
14 | h1#sitetitle {
15 | font-size: 10px;
16 | }
17 | a {
18 | color: #222;
19 | }
20 | a:hover {
21 | text-decoration: none;
22 | }
23 | #main {
24 | margin: auto;
25 | width: 802px;
26 | }
27 | #tmain {
28 | margin: 0;
29 | margin-top: 28px;
30 |
31 | padding-top: 14px;
32 | padding-bottom: 14px;
33 | }
34 |
--------------------------------------------------------------------------------
/src/gpapp.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 |
3 | var GPApp = React.createClass({
4 | getInitialState: function() {
5 | return { GPs: [new GP(0, [1,0.2], 1, [], [], [])],
6 | newGPParam: 1.0,
7 | newGPNoise: 0.2,
8 | newGPcf: 0,
9 | newGPavailableIDs: [10, 9, 8, 7, 6, 5, 4, 3, 2],
10 | alfa: 0.3,
11 | stepSize: 3.14,
12 | NSteps: 15,
13 | addTrPoints: false,
14 | trPointsX: [],
15 | trPointsY: [],
16 | dmTr: [],
17 | dmTeTr: [],
18 | samplingState: 0, // 0 = stopped, 1 = discrete, 2 = continuous
19 | oldSamplingState: 0,
20 | showSamples: true,
21 | showMeanAndVar: false
22 | }
23 | },
24 | setAlfa: function(newVal) { this.setState({ alfa: newVal }); },
25 | setStepSize: function(newVal) { this.setState({ stepSize: newVal }); },
26 | setNSteps: function(newVal) { this.setState({ NSteps: newVal }); },
27 | toggleAddTrPoints: function() {
28 | if (this.state.addTrPoints){
29 | // added training points
30 | var dmTr = computeDistanceMatrix(this.state.trPointsX, this.state.trPointsX);
31 | var dmTeTr = computeDistanceMatrix(tePointsX, this.state.trPointsX);
32 |
33 | var newGPs = recomputeProjections(this.state.GPs, dmTr, dmTeTr, this.state.trPointsY);
34 |
35 | this.setState({ addTrPoints: !this.state.addTrPoints, GPs: newGPs, dmTr: dmTr, dmTeTr: dmTeTr, samplingState: this.state.oldSamplingState });
36 | } else {
37 | // beginning to add training points
38 | this.setState({ addTrPoints: !this.state.addTrPoints, oldSamplingState: this.state.samplingState, samplingState: 0 });
39 | }
40 | },
41 | clearTrPoints: function() { this.setState({ trPointsX: [], trPointsY: [] }); },
42 | toggleShowMeanAndVar: function() { if (!this.state.addTrPoints) this.setState({ showMeanAndVar: !this.state.showMeanAndVar }); },
43 | toggleShowSamples: function() {
44 | if (!this.state.addTrPoints) {
45 | if (this.state.showSamples) {
46 | this.setState({ samplingState: 0, showSamples: false });
47 | } else {
48 | this.setState({ samplingState: this.state.oldSamplingState, showSamples: true });
49 | }
50 | }
51 | },
52 | setNewGPParam: function(newVal) { this.setState({ newGPParam: newVal }); },
53 | setNewGPNoise: function(newVal) { this.setState({ newGPNoise: newVal }); },
54 | setNewGPcf: function(event) { this.setState({ newGPcf: event.target.value }); },
55 | addGP: function() {
56 | if (this.state.newGPavailableIDs.length < 1) return;
57 | var id = this.state.newGPavailableIDs.pop();
58 | var newGPs = this.state.GPs.concat([new GP(this.state.newGPcf, [this.state.newGPParam, this.state.newGPNoise], id, this.state.dmTr, this.state.dmTeTr, this.state.trPointsY)]);
59 | this.setState({ GPs: newGPs, newGPavailableIDs: this.state.newGPavailableIDs });
60 | },
61 | delGP: function(id) {
62 | return (function() {
63 | var newGPs = this.state.GPs;
64 | var delIdx = newGPs.findIndex(function (g) { return g.id == id; });
65 | if (delIdx >= 0) {
66 | newGPs.splice(delIdx, 1);
67 | this.state.newGPavailableIDs.push(id);
68 | this.setState({ GPs: newGPs });
69 | }
70 | }).bind(this);
71 | },
72 | addTrPoint: function(x, y) {
73 | if (x >= -5 && x <= 5 && y >= -3 && y <= 3){
74 | var newTrPointsX = this.state.trPointsX.concat([x]);
75 | var newTrPointsY = this.state.trPointsY.concat([y]);
76 | this.setState({ trPointsX: newTrPointsX, trPointsY: newTrPointsY });
77 | }
78 | },
79 | stopSampling: function() { this.setState({ samplingState: 0, oldSamplingState: 0 }); },
80 | startDiscreteSampling: function() { this.setState({ samplingState: 1, oldSamplingState: 1 }); },
81 | startContinuousSampling: function() { this.setState({ samplingState: 2, oldSamplingState: 2 }); },
82 | render: function() {
83 | var sliderOptAlfa = { width: 200, height: 9, min: 0, max: 1 };
84 | var sliderOptStepSize = { width: 200, height: 9, min: 0, max: 2*Math.PI };
85 | var sliderOptNSteps = { width: 200, height: 9, min: 1, max: 100, step: 1 };
86 | var sliderOptGPParam = { width: 200, height: 9, min: 0.01, max: 5 };
87 | var sliderOptGPNoise = { width: 200, height: 9, min: 0, max: 2 };
88 | var delGP = this.delGP;
89 | var gpoptions = cfs.map(function (c) {
90 | return ();
91 | });
92 | return (
93 |
94 |
95 |
96 | Show mean and credible intervals
97 | Show samples
98 |
99 |
100 |
101 |
102 | {this.state.addTrPoints ? click on the figure to add an observation : ''}
103 |
104 | {this.state.addTrPoints ? : ''}
105 |
106 |
107 |
Trajectory simulation settings
108 |
109 | Momentum refreshment | {this.state.alfa.toFixed(2)} |
110 | Path length | {this.state.stepSize.toFixed(2)} |
111 | Number of steps in path | {this.state.NSteps} |
112 |
113 |
114 |
115 |
116 |
Add new process
117 |
118 | Covariance function | |
119 | Length scale | {this.state.newGPParam.toFixed(2)} |
120 | Noise | {this.state.newGPNoise.toFixed(2)} |
121 |
122 |
123 |
124 |
Process list
125 |
126 |
127 | id | covariance | length scale | noise | |
128 |
129 |
130 |
131 |
132 |
133 | )
134 | }
135 | });
136 |
137 | React.renderComponent(
138 | ,
139 | document.getElementById('gp-outer')
140 | );
141 |
--------------------------------------------------------------------------------
/src/gputils.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 | var tePointsX = numeric.linspace(-5, 5, numeric.dim(distmatTe)[0]);
3 | var randn = d3.random.normal();
4 | function randnArray(size){
5 | var zs = new Array(size);
6 | for (var i = 0; i < size; i++) {
7 | zs[i] = randn();
8 | }
9 | return zs;
10 | }
11 |
12 | // ids must be in order of the array
13 | var cfs = [
14 | {'id': 0,
15 | 'name': 'Exponentiated quadratic',
16 | 'f': function(r, params) {
17 | return numeric.exp(numeric.mul(-0.5 / (params[0] * params[0]), numeric.pow(r, 2)));
18 | }
19 | },
20 | {'id': 1,
21 | 'name': 'Exponential',
22 | 'f': function(r, params) {
23 | return numeric.exp(numeric.mul(-0.5 / params[0], r));
24 | }
25 | },
26 | {'id': 2,
27 | 'name': 'Matern 3/2',
28 | 'f': function(r, params) {
29 | var tmp = numeric.mul(Math.sqrt(3.0) / params[0], r);
30 | return numeric.mul(numeric.add(1.0, tmp), numeric.exp(numeric.neg(tmp)));
31 | }
32 | },
33 | {'id': 3,
34 | 'name': 'Matern 5/2',
35 | 'f': function(r, params) {
36 | var tmp = numeric.mul(Math.sqrt(5.0) / params[0], r);
37 | var tmp2 = numeric.div(numeric.mul(tmp, tmp), 3.0);
38 | return numeric.mul(numeric.add(numeric.add(1, tmp), tmp2), numeric.exp(numeric.neg(tmp)));
39 | }
40 | },
41 | {'id': 4,
42 | 'name': 'Rational quadratic (alpha=1)',
43 | 'f': function(r, params) {
44 | return numeric.pow(numeric.add(1.0, numeric.div(numeric.pow(r, 2), 2.0 * params[0] * params[0])), -1);
45 | }
46 | },
47 | {'id': 5,
48 | 'name': 'Piecewise polynomial (q=0)',
49 | 'f': function(r, params) {
50 | var tmp = numeric.sub(1.0, numeric.div(r, params[0]));
51 | var dims = numeric.dim(tmp);
52 | for (var i = 0; i < dims[0]; i++){
53 | for (var j = 0; j < dims[1]; j++){
54 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0;
55 | }
56 | }
57 | return tmp;
58 | }
59 | },
60 | {'id': 6,
61 | 'name': 'Piecewise polynomial (q=1)',
62 | 'f': function(r, params) {
63 | var tmp1 = numeric.div(r, params[0]);
64 | var tmp = numeric.sub(1.0, tmp1);
65 | var dims = numeric.dim(tmp);
66 | for (var i = 0; i < dims[0]; i++){
67 | for (var j = 0; j < dims[1]; j++){
68 | tmp[i][j] = tmp[i][j] > 0.0 ? tmp[i][j] : 0.0;
69 | }
70 | }
71 | return numeric.mul(numeric.pow(tmp, 3), numeric.add(numeric.mul(3.0, tmp1), 1.0));
72 | }
73 | },
74 | {'id': 7,
75 | 'name': 'Periodic (period=pi)',
76 | 'f': function(r, params) {
77 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(r), 2)));
78 | }
79 | },
80 | {'id': 8,
81 | 'name': 'Periodic (period=1)',
82 | 'f': function(r, params) {
83 | return numeric.exp(numeric.mul(-2.0/(params[0]*params[0]), numeric.pow(numeric.sin(numeric.mul(Math.PI, r)), 2)));
84 | }
85 | }
86 | ];
87 |
88 | function GP(cf, params, id, dmTr, dmTeTr, trY) {
89 | var M = numeric.dim(distmatTe)[1];
90 |
91 | this.z = randnArray(M);
92 | this.p = randnArray(M);
93 | this.cf = cf;
94 | this.params = params;
95 | this.id = id;
96 |
97 | this.Kte = cfs[this.cf].f(distmatTe, params);
98 |
99 | var tmp = computeProjection(this.Kte, this.cf, this.params, dmTr, dmTeTr, trY);
100 | this.proj = tmp.proj;
101 | this.mu = tmp.mu;
102 | this.sd95 = tmp.sd95;
103 | }
104 |
105 |
106 | function computeProjection(Kte, cf, params, dmTr, dmTeTr, trY) {
107 | var Mtr = numeric.dim(dmTr)[0];
108 | var Mte = numeric.dim(distmatTe)[0];
109 |
110 | if (Mtr > 0){
111 | var Kxx_p_noise = cfs[cf].f(dmTr, params);
112 | for (var i = 0; i < Mtr; i++){
113 | Kxx_p_noise[i][i] += params[1];
114 | }
115 |
116 | var svd1 = numeric.svd(Kxx_p_noise);
117 | for (var i = 0; i < Mtr; i++){
118 | if (svd1.S[i] > numeric.epsilon){
119 | svd1.S[i] = 1.0/svd1.S[i];
120 | } else {
121 | svd1.S[i] = 0.0;
122 | }
123 | }
124 |
125 | var tmp = numeric.dot(cfs[cf].f(dmTeTr, params), svd1.U);
126 | // there seems to be a bug in numeric.svd: svd1.U and transpose(svd1.V) are not always equal for a symmetric matrix
127 | var mu = numeric.dot(tmp, numeric.mul(svd1.S, numeric.dot(numeric.transpose(svd1.U), trY)));
128 | var cov = numeric.dot(tmp, numeric.diag(numeric.sqrt(svd1.S)));
129 | cov = numeric.dot(cov, numeric.transpose(cov));
130 | cov = numeric.sub(Kte, cov);
131 | var svd2 = numeric.svd(cov);
132 | for (var i = 0; i < Mte; i++){
133 | if (svd2.S[i] < numeric.epsilon){
134 | svd2.S[i] = 0.0;
135 | }
136 | }
137 | var proj = numeric.dot(svd2.U, numeric.diag(numeric.sqrt(svd2.S)));
138 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(numeric.dot(proj, numeric.transpose(proj)))));
139 | } else {
140 | var sd95 = numeric.mul(1.98, numeric.sqrt(numeric.getDiag(Kte)));
141 | var svd = numeric.svd(Kte);
142 | var proj = numeric.dot(svd.U, numeric.diag(numeric.sqrt(svd.S)));
143 | var mu = numeric.rep([Mte], 0);
144 | }
145 |
146 | return { proj: proj, mu: mu, sd95: sd95 };
147 | }
148 |
149 | function recomputeProjections(GPs, dmTr, dmTeTr, trY) {
150 | for (var gpi = 0; gpi < GPs.length; gpi++){
151 | var gp = GPs[gpi];
152 | var tmp = computeProjection(gp.Kte, gp.cf, gp.params, dmTr, dmTeTr, trY);
153 | gp.proj = tmp.proj;
154 | gp.mu = tmp.mu;
155 | gp.sd95 = tmp.sd95;
156 | GPs[gpi] = gp;
157 | }
158 |
159 | return GPs;
160 | }
161 |
162 | function computeDistanceMatrix(xdata1, xdata2) {
163 | var dm = numeric.rep([xdata1.length,xdata2.length], 0);
164 | for (var i = 0; i < xdata1.length; i++){
165 | for (var j = 0; j < xdata2.length; j++){
166 | var val = Math.abs(xdata2[j] - xdata1[i]);
167 | dm[i][j] = val;
168 | }
169 | }
170 | return dm;
171 | }
172 |
173 | var GPAxis = React.createClass({
174 | render: function() {
175 | return ();
176 | },
177 | shouldComponentUpdate: function() { return false; },
178 | drawTrPoints: function(pointsX, pointsY) {
179 | var x = this.scales.x;
180 | var y = this.scales.y;
181 | var p = this.trPoints.selectAll("circle.trpoints")
182 | .data(d3.zip(pointsX, pointsY))
183 | .attr("cx", function(d) { return x(d[0]); })
184 | .attr("cy", function(d) { return y(d[1]); });
185 | p.enter().append("circle")
186 | .attr("class", "trpoints")
187 | .attr("r", 2)
188 | .attr("cx", function(d) { return x(d[0]); })
189 | .attr("cy", function(d) { return y(d[1]); });
190 | p.exit().remove();
191 | },
192 | animationId: 0,
193 | componentWillReceiveProps: function(props) {
194 | // bind events
195 | if (props.state.addTrPoints) {
196 | d3.select(this.getDOMNode()).on("click", this.addTrPoint);
197 | } else {
198 | d3.select(this.getDOMNode()).on("click", null);
199 | }
200 | // redraw training points
201 | this.drawTrPoints(props.state.trPointsX, props.state.trPointsY);
202 |
203 | this.drawMeanAndVar(props);
204 |
205 | if (this.props.state.showSamples !== props.state.showSamples){
206 | this.drawPaths(props);
207 | }
208 |
209 | if (this.props.state.samplingState !== props.state.samplingState){
210 | clearInterval(this.animationId);
211 | if (props.state.samplingState === 1){
212 | this.animationId = setInterval((function() { this.updateState(); this.drawPaths(); }).bind(this), 500);
213 | } else if (props.state.samplingState === 2){
214 | this.animationId = setInterval((function() { this.contUpdateState(); this.drawPaths(); }).bind(this), 50);
215 | }
216 | }
217 | },
218 | addTrPoint: function() {
219 | var mousePos = d3.mouse(this.getDOMNode());
220 | var x = this.scales.x;
221 | var y = this.scales.y;
222 |
223 | // x is transformed to a point on a grid of 200 points between -5 and 5
224 | this.props.addTrPoint(Math.round((x.invert(mousePos[0]-50)+5)/10*199)/199*10-5, y.invert(mousePos[1]-50));
225 | },
226 | updateState: function() {
227 | var M = numeric.dim(distmatTe)[1];
228 | for (var i = 0; i < this.props.state.GPs.length; i++){
229 | var gp = this.props.state.GPs[i];
230 | gp.z = randnArray(M);
231 | }
232 | },
233 | stepState: 0,
234 | contUpdateState: function() {
235 | var M = numeric.dim(distmatTe)[1];
236 | var alfa = 1.0-this.props.state.alfa;
237 | var n_steps = this.props.state.NSteps;
238 | var t_step = this.props.state.stepSize / n_steps;
239 | this.stepState = this.stepState % n_steps;
240 |
241 | for (var i = 0; i < this.props.state.GPs.length; i++){
242 | var gp = this.props.state.GPs[i];
243 |
244 | // refresh momentum: p = alfa * p + sqrt(1 - alfa^2) * randn(size(p))
245 | if (this.stepState == (n_steps-1))
246 | gp.p = numeric.add(numeric.mul(alfa, gp.p), numeric.mul(Math.sqrt(1 - alfa*alfa), randnArray(M)));
247 |
248 | var a = gp.p.slice(0),
249 | b = gp.z.slice(0),
250 | c = numeric.mul(-1, gp.z.slice(0)),
251 | d = gp.p.slice(0);
252 |
253 | gp.z = numeric.add(numeric.mul(a, Math.sin(t_step)), numeric.mul(b, Math.cos(t_step)));
254 | gp.p = numeric.add(numeric.mul(c, Math.sin(t_step)), numeric.mul(d, Math.cos(t_step)));
255 | }
256 | this.stepState = this.stepState + 1;
257 | },
258 | drawMeanAndVar: function(props) {
259 | var gpline = this.gpline;
260 | if (props.state.showMeanAndVar){
261 | var gps = props.state.GPs;
262 | } else {
263 | var gps = [];
264 | }
265 |
266 | var paths = this.meanLines.selectAll("path").data(gps, function (d) { return d.id; })
267 | .attr("d", function (d) {
268 | var datay = d.mu;
269 | return gpline(d3.zip(tePointsX, datay));
270 | });
271 | paths.enter().append("path").attr("d", function (d) {
272 | var datay = d.mu;
273 | return gpline(d3.zip(tePointsX, datay));
274 | })
275 | .attr("class", function(d) {
276 | return "muline line line"+d.id;
277 | });
278 | paths.exit().remove();
279 |
280 | var pathsUp = this.upSd95Lines.selectAll("path").data(gps, function (d) { return d.id; })
281 | .attr("d", function (d) {
282 | var datay = numeric.add(d.mu, d.sd95);
283 | return gpline(d3.zip(tePointsX, datay));
284 | });
285 | pathsUp.enter().append("path").attr("d", function (d) {
286 | var datay = numeric.add(d.mu, d.sd95);
287 | return gpline(d3.zip(tePointsX, datay));
288 | })
289 | .attr("class", function(d) {
290 | return "sdline line line"+d.id;
291 | });
292 | pathsUp.exit().remove();
293 |
294 | var pathsDown = this.downSd95Lines.selectAll("path").data(gps, function (d) { return d.id; })
295 | .attr("d", function (d) {
296 | var datay = numeric.sub(d.mu, d.sd95);
297 | return gpline(d3.zip(tePointsX, datay));
298 | });
299 | pathsDown.enter().append("path").attr("d", function (d) {
300 | var datay = numeric.sub(d.mu, d.sd95);
301 | return gpline(d3.zip(tePointsX, datay));
302 | })
303 | .attr("class", function(d) {
304 | return "sdline line line"+d.id;
305 | });
306 | pathsDown.exit().remove();
307 | },
308 | drawPaths: function(props) {
309 | if (!props) var props = this.props;
310 | var gpline = this.gpline;
311 | if (props.state.showSamples){
312 | var gps = props.state.GPs;
313 | } else {
314 | var gps = [];
315 | }
316 | var paths = this.lines.selectAll("path").data(gps, function (d) { return d.id; })
317 | .attr("d", function (d) {
318 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu);
319 | return gpline(d3.zip(tePointsX, datay));
320 | });
321 | paths.enter().append("path").attr("d", function (d) {
322 | var datay = numeric.add(numeric.dot(d.proj, d.z), d.mu);
323 | return gpline(d3.zip(tePointsX, datay));
324 | })
325 | .attr("class", function(d) {
326 | return "line line"+d.id;
327 | });
328 | paths.exit().remove();
329 | },
330 | scales: { x: null, y: null },
331 | componentDidMount: function() {
332 | var svg = d3.select(this.getDOMNode());
333 | var height = svg.attr("height"),
334 | width = svg.attr("width");
335 | if (!height) {
336 | height = 300;
337 | svg.attr("height", height);
338 | }
339 | if (!width) {
340 | width = 500;
341 | svg.attr("width", width);
342 | }
343 | var margin = 50;
344 | svg = svg.append("g")
345 | .attr("transform", "translate("+margin+","+margin+")");
346 | this.svg = svg;
347 | var fig_height = height - 2*margin,
348 | fig_width = width - 2*margin;
349 |
350 | // helper functions
351 | var x = d3.scale.linear().range([0, fig_width]).domain([-5, 5]);
352 | var y = d3.scale.linear().range([fig_height, 0]).domain([-3, 3]);
353 | this.scales.x = x;
354 | this.scales.y = y;
355 | var xAxis = d3.svg.axis()
356 | .scale(x)
357 | .orient("bottom");
358 | var yAxis = d3.svg.axis()
359 | .scale(y)
360 | .orient("left");
361 | this.gpline = d3.svg.line()
362 | .x(function(d) { return x(d[0]); })
363 | .y(function(d) { return y(d[1]); });
364 |
365 | // axes
366 | svg.append("g")
367 | .attr("class", "x axis")
368 | .attr("transform", "translate(0,"+fig_height+")")
369 | .call(xAxis);
370 |
371 | svg.append("g")
372 | .attr("class", "y axis")
373 | .call(yAxis);
374 |
375 | this.meanLines = svg.append("g");
376 | this.upSd95Lines = svg.append("g");
377 | this.downSd95Lines = svg.append("g");
378 | this.lines = svg.append("g");
379 | this.trPoints = svg.append("g");
380 | this.drawTrPoints(this.props.state.trPointsX, this.props.state.trPointsY);
381 | this.drawPaths();
382 | }
383 | });
384 |
385 |
386 | var GPList = React.createClass({
387 | render: function() {
388 | var delGP = this.props.delGP;
389 | var gplist = this.props.GPs.map(function (gp) {
390 | return (
391 | {gp.id} | {cfs[gp.cf].name} | {gp.params[0].toFixed(2)} | {gp.params[1].toFixed(2)} | |
392 |
);
393 | });
394 | return ({gplist});
395 | }
396 | });
397 |
--------------------------------------------------------------------------------
/src/slider.js:
--------------------------------------------------------------------------------
1 | /** @jsx React.DOM */
2 | var Slider = React.createClass({
3 | render: function() {
4 | // Just insert the svg-element and render rest in componentDidMount.
5 | // Marker location is updated in componentWillReceiveProps using d3.
6 | return ();
7 | },
8 | shouldComponentUpdate: function() { return false; }, // Never re-render.
9 | componentDidMount: function() {
10 | var val = this.props.value;
11 | var setVal = this.props.setValue;
12 | var opt = this.props.opt;
13 |
14 | // set defaults for options if not given
15 | if (!opt.width) opt.width = 200;
16 | if (!opt.height) opt.height = 20;
17 | if (!opt.step) {
18 | opt.round = function(x) { return x; }
19 | } else {
20 | opt.round = function(x) { return Math.round(x / opt.step) * opt.step; }
21 | }
22 | if (!opt.min) opt.min = opt.round(0);
23 | if (!opt.max) opt.max = opt.round(1);
24 | if (opt.min > opt.max) {
25 | var tmp = opt.min;
26 | opt.min = opt.max;
27 | opt.max = tmp;
28 | }
29 | if (val > opt.max) setVal(opt.max);
30 | if (val < opt.min) setVal(opt.min);
31 |
32 | // calculate range
33 | var markerRadius = opt.height * 0.5;
34 | var x1 = markerRadius;
35 | var x2 = opt.width - markerRadius;
36 |
37 | // d3 helpers
38 | var scale = d3.scale.linear()
39 | .domain([opt.min, opt.max])
40 | .range([x1, x2]);
41 | this.scale = scale;
42 | var setValFromMousePos = function(x) {
43 | setVal(opt.round(scale.invert(Math.max(x1, Math.min(x2, x)))));
44 | };
45 | var dragmove = function() {
46 | setValFromMousePos(d3.event.x);
47 | };
48 | var drag = d3.behavior.drag().on("drag", dragmove);
49 |
50 | // bind d3 events and insert background line and marker
51 | var svg = d3.select(this.getDOMNode());
52 | svg.attr("class", "slider")
53 | .attr("width", opt.width)
54 | .attr("height", opt.height)
55 | .on("click", function () { setValFromMousePos(d3.mouse(this)[0]); });
56 | svg.append("line")
57 | .attr("x1", x1)
58 | .attr("x2", x2)
59 | .attr("y1", "50%")
60 | .attr("y2", "50%")
61 | .attr("class", "sliderbg");
62 | this.marker = svg.append("circle")
63 | .attr("cy", "50%")
64 | .attr("r", markerRadius)
65 | .attr("class", "slidermv")
66 | .datum(val)
67 | .attr("cx", function (d) { return scale(d); })
68 | .call(drag);
69 | },
70 | componentWillReceiveProps: function(props) {
71 | // update the marker location on receiving new props
72 | var scale = this.scale;
73 | this.marker.datum(props.value)
74 | .attr("cx", function (d) { return scale(d); });
75 | }
76 | });
77 |
78 |
--------------------------------------------------------------------------------