├── .gitignore ├── README.md ├── examples ├── char-rnn-multithreaded │ ├── data │ │ ├── tiny-100000.txt │ │ └── tinyshakespeare.txt │ ├── embed.html │ ├── index.html │ ├── js │ │ └── main.js │ └── lib │ │ ├── discrete.js │ │ └── tfjs@0.11.2.js ├── char-rnn │ ├── data │ │ ├── tiny-100000.txt │ │ └── tinyshakespeare.txt │ ├── index.html │ ├── js │ │ └── main.js │ └── lib │ │ ├── discrete.js │ │ └── tfjs@0.11.2.js ├── mnist │ ├── README.md │ ├── data.js │ ├── index.html │ ├── index.js │ └── ui.js └── simple-linear-network │ ├── index.html │ ├── lib │ └── tfjs@0.11.2.js │ └── main.js ├── index.html ├── index.js ├── package-lock.json └── package.json /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow.js w/ Electron 2 | 3 | This repository contains several Tensorflow.js examples bundled as an Electron app. 4 | 5 | ## Getting Started 6 | 7 | ```bash 8 | # clone the repo 9 | git clone https://github.com/brangerbriz/tf-electron 10 | cd tf-electron 11 | 12 | # install dependencies 13 | npm install 14 | 15 | # run the electron app 16 | npm start 17 | ``` 18 | You should now see an Electron window pop up with a list of several examples. Several of these examples are taken from [tfjs-examples](https://github.com/tensorflow/tfjs-examples) and modified slightly to work with Node.js/Electron out-of-the-box. I've also annotated the source code so that it is more readable for beginners. 19 | -------------------------------------------------------------------------------- /examples/char-rnn-multithreaded/embed.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Char RNN Iframe> 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /examples/char-rnn-multithreaded/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Char RNN Text Example> 5 | 10 | 11 | 12 | 13 |

14 | This example uses Electron's webview feature to embed the code that runs Tensorflow.js operations into this 15 | "parent" page. By doing so we can remove the calls to "await tf.nextFrame()" without 16 | worrying about blocking the main UI thread. This is sort of a hacky form 17 | of multithreading. Tensorflow.js also works in Web Workers, but unfortunately 18 | it doesn't support the WebGL backend (it's cpu only) 19 | (https://github.com/tensorflow/tfjs/issues/102). Using this method you 20 | can use the WebGL backend and write your tfjs code in a "blocking" fashion 21 | without actually worrying about blocking the main UI thread. It effectively 22 | means your tfjs operations will run faster. 23 |

24 |

25 | Here we are borrowing most of the code from the char-rnn example, so check that example out to get a better idea of what's happening here. This example is really just used to demonstrate how webview (and potentially iframes in web browsers) can be used to offload compute intensive Tensorflow.js operations to a seperate thread so as to not block the main UI. On my computer this technique reduces training time for one epoch from ~285 seconds using examples/char-rnn to ~240 seconds using examples/char-rnn-mulithreaded. 26 |

27 |

28 | Open your developer console to see what's happening! 29 |

30 |
31 |

32 | Index page/main UI thread. Not getting blocked. 33 | 34 |

35 |
36 |

37 | Embed. this thing does the thinking without blocking the main UI thread. 38 | 39 |

40 |

Back

41 | 42 | 54 | 55 | -------------------------------------------------------------------------------- /examples/char-rnn-multithreaded/js/main.js: -------------------------------------------------------------------------------- 1 | const SEQLEN = 40 2 | const SKIP = 3 3 | const EPOCHS = 10 4 | const LEARNING_RATE = 0.01 5 | const BATCH_SIZE = 128 6 | const LOADMODEL = false 7 | 8 | // uncomment this if you want to use IPC to communicate with the parent page 9 | // const { ipcRenderer } = require('electron') 10 | 11 | // promisified fs.readFile() 12 | function loadFile(path) { 13 | const fs = require('fs') 14 | return new Promise((resolve, reject) => { 15 | fs.readFile(path, (err, data) => { 16 | if (err) reject(err) 17 | else resolve(data) 18 | }) 19 | }) 20 | } 21 | 22 | // define and return an LSTM RNN model architecture. RNNs are used 23 | // with sequential data. 24 | function getModel(inputShape) { 25 | const model = tf.sequential() 26 | model.add(tf.layers.lstm({units: 128, inputShape: inputShape })) 27 | model.add(tf.layers.dense({units: inputShape[1]})) 28 | model.add(tf.layers.activation({ activation: 'softmax' })) 29 | model.compile({ 30 | optimizer: tf.train.rmsprop(LEARNING_RATE), 31 | loss: 'categoricalCrossentropy', 32 | metrics: ['accuracy'] 33 | }) 34 | return model 35 | } 36 | 37 | // utility function for loading and prepairing training data. Returns an object. 38 | async function getData(path) { 39 | 40 | const buf = await loadFile(path) 41 | // convert all of the characters to lowercase to reduce our number of 42 | // output classes 43 | const text = buf.toString().toLowerCase() 44 | console.log(`corpus length: ${text.length}`) 45 | 46 | const chars = Array.from(new Set(text)) 47 | console.log(`total chars: ${chars.length}`) 48 | 49 | const charIndicies = {}; chars.forEach((c, i) => charIndicies[c] = i) 50 | 51 | // cut the text in semi-redundant sequences of maxlen characters 52 | const sentences = [] 53 | const nextChars = [] 54 | for (let i = 0; i < text.length - SEQLEN; i += SKIP) { 55 | sentences.push(text.slice(i, i + SEQLEN)) 56 | nextChars.push(text.slice(i + SEQLEN, i + SEQLEN + 1)) 57 | // console.log(text.slice(i, i + SEQLEN), '->', text.slice(i + SEQLEN, i + SEQLEN + 1)) 58 | } 59 | 60 | // each array element will hold a batch of BATCH_SIZE examples 61 | const X = [] 62 | const Y = [] 63 | const partitionSize = BATCH_SIZE * 100 64 | for (let batch = 0; batch < sentences.length; batch += partitionSize) { 65 | 66 | // we can't store all of the data in GPU memory, so instead we use 67 | // tf.TensorBuffer objects which are allocated by the CPU. Becore we 68 | // train on a batch we transform it into a Tensor using TensorBuffer.toTensor() 69 | 70 | const xBuff = tf.buffer([partitionSize, SEQLEN, chars.length]) 71 | const yBuff = tf.buffer([partitionSize, chars.length]) 72 | 73 | const sentenceBatch = sentences.slice(batch, batch + partitionSize) 74 | const nextCharsBatch = nextChars.slice(batch, batch + partitionSize) 75 | 76 | // one-hotify batches 77 | sentenceBatch.forEach((sentence, j) => { 78 | sentence.split('').forEach((char, k) => { 79 | xBuff.set(1, j, k, charIndicies[char]) 80 | }) 81 | yBuff.set(1, j, charIndicies[nextCharsBatch[j]]) 82 | }) 83 | 84 | X.push(xBuff) 85 | Y.push(yBuff) 86 | } 87 | return { 88 | X, Y, charIndicies, chars 89 | } 90 | } 91 | 92 | // sample an output from the multinomial distribution output of model.predict() 93 | function sample(preds) { 94 | let probas = Sampling.Multinomial(1, preds).draw() 95 | // lazy argmax that returns the index of the "hot" value in a one-hot vector 96 | return probas.reduce((acc, val, i) => acc + (val == 1 ? i : 0)) 97 | } 98 | 99 | async function generate(seed, numChars, charIndicies, chars, model) { 100 | console.assert(seed.length >= SEQLEN) 101 | seed = seed.toLowerCase().split('').slice(0, SEQLEN) 102 | const output = [] 103 | for (let i = 0; i < numChars; i++) { 104 | 105 | const x = tf.zeros([1, SEQLEN, chars.length]) 106 | seed.forEach((char, j) => { 107 | x.buffer().set(1, 0, j, charIndicies[char]) 108 | }) 109 | 110 | const preds = model.predict(x, { verbose: true }) 111 | const y = sample(preds.dataSync()) 112 | const char = chars[y] 113 | 114 | seed.shift(); seed.push(char) 115 | output.push(char) 116 | 117 | preds.dispose() 118 | x.dispose() 119 | } 120 | 121 | return output.join('') 122 | } 123 | 124 | async function main() { 125 | const data = await getData(`${__dirname}/data/tinyshakespeare.txt`) 126 | let model = null 127 | const seed = "This is a test sentence. It will be used to seed the model." 128 | 129 | // load model from IndexedDB if the flag says so 130 | if (LOADMODEL) { 131 | console.log('Loading model from IndexedDB') 132 | model = await tf.loadModel('indexeddb://model') 133 | } else { 134 | // otherwise train a new model and save it to IndexedDB overwriting any 135 | // existing models that may be saved there 136 | console.log('Training model...') 137 | model = getModel([SEQLEN, data.chars.length]) 138 | let history = null 139 | for (let i = 0; i < EPOCHS; i++) { 140 | const then = Date.now() 141 | for (let batch = 0; batch < data.X.length; batch++) { 142 | const batchX = data.X[batch].toTensor() 143 | const batchY = data.Y[batch].toTensor() 144 | history = await model.fit(batchX, batchY, { batchSize: BATCH_SIZE }) 145 | batchX.dispose() 146 | batchY.dispose() 147 | } 148 | console.log(`Epoch ${i + 1} loss: ${history.history.loss[0]}, accuracy: ${history.history.acc[0]}`) 149 | console.log(`Epoch lasted ${((Date.now() - then) / 1000).toFixed(0)} seconds`) 150 | await model.save('indexeddb://model') 151 | const text = await generate(seed, 100, data.charIndicies, data.chars, model) 152 | console.log(text) 153 | } 154 | } 155 | 156 | console.log(`Finished training for ${EPOCHS}`) 157 | console.log(`Generating 1000 characters of synthetic text:`) 158 | const text = await generate(seed, 1000, data.charIndicies, data.chars, model) 159 | console.log(text) 160 | 161 | } 162 | 163 | main() -------------------------------------------------------------------------------- /examples/char-rnn-multithreaded/lib/discrete.js: -------------------------------------------------------------------------------- 1 | // discrete.js 2 | // Sample from discrete distributions. 3 | 4 | var Sampling = SJS = (function(){ 5 | 6 | // Utility functions 7 | function _sum(a, b) { 8 | return a + b; 9 | }; 10 | function _fillArrayWithNumber(size, num) { 11 | // thanks be to stackOverflow... this is a beautiful one-liner 12 | return Array.apply(null, Array(size)).map(Number.prototype.valueOf, num); 13 | }; 14 | function _rangeFunc(upper) { 15 | var i = 0, out = []; 16 | while (i < upper) out.push(i++); 17 | return out; 18 | }; 19 | // Prototype function 20 | function _samplerFunction(size) { 21 | if (!Number.isInteger(size) || size < 0) { 22 | throw new Error ("Number of samples must be a non-negative integer."); 23 | } 24 | if (!this.draw) { 25 | throw new Error ("Distribution must specify a draw function."); 26 | } 27 | var result = []; 28 | while (size--) { 29 | result.push(this.draw()); 30 | } 31 | return result; 32 | }; 33 | // Prototype for discrete distributions 34 | var _samplerPrototype = { 35 | sample: _samplerFunction 36 | }; 37 | 38 | function Bernoulli(p) { 39 | 40 | var result = Object.create(_samplerPrototype); 41 | 42 | result.draw = function() { 43 | return (Math.random() < p) ? 1 : 0; 44 | }; 45 | 46 | result.toString = function() { 47 | return "Bernoulli( " + p + " )"; 48 | }; 49 | 50 | return result; 51 | } 52 | 53 | function Binomial(n, p) { 54 | 55 | var result = Object.create(_samplerPrototype), 56 | bern = Sampling.Bernoulli(p); 57 | 58 | result.draw = function() { 59 | return bern.sample(n).reduce(_sum, 0); // less space efficient than adding a bunch of draws, but cleaner :) 60 | } 61 | 62 | result.toString = function() { 63 | return "Binom( " + 64 | [n, p].join(", ") + 65 | " )"; 66 | } 67 | 68 | return result; 69 | } 70 | 71 | function Discrete(probs) { // probs should be an array of probabilities. (they get normalized automagically) // 72 | 73 | var result = Object.create(_samplerPrototype), 74 | k = probs.length; 75 | 76 | result.draw = function() { 77 | var i, p; 78 | for (i = 0; i < k; i++) { 79 | p = probs[i] / probs.slice(i).reduce(_sum, 0); // this is the (normalized) head of a slice of probs 80 | if (Bernoulli(p).draw()) return i; // using the truthiness of a Bernoulli draw 81 | } 82 | return k - 1; 83 | }; 84 | 85 | result.sampleNoReplace = function(size) { 86 | if (size>probs.length) { 87 | throw new Error("Sampling without replacement, and the sample size exceeds vector size.") 88 | } 89 | var disc, index, sum, samp = []; 90 | var currentProbs = probs; 91 | var live = _rangeFunc(probs.length); 92 | while (size--) { 93 | sum = currentProbs.reduce(_sum, 0); 94 | currentProbs = currentProbs.map(function(x) {return x/sum; }); 95 | disc = SJS.Discrete(currentProbs); 96 | index = disc.draw(); 97 | samp.push(live[index]); 98 | live.splice(index, 1); 99 | currentProbs.splice(index, 1); 100 | sum = currentProbs.reduce(_sum, 0); 101 | currentProbs = currentProbs.map(function(x) {return x/sum; }); 102 | } 103 | currentProbs = probs; 104 | live = _rangeFunc(probs.length); 105 | return samp; 106 | } 107 | 108 | result.toString = function() { 109 | return "Dicrete( [" + 110 | probs.join(", ") + 111 | "] )"; 112 | }; 113 | 114 | return result; 115 | } 116 | 117 | function Multinomial(n, probs) { 118 | 119 | var result = Object.create(_samplerPrototype), 120 | k = probs.length, 121 | disc = Discrete(probs); 122 | 123 | result.draw = function() { 124 | var draw_result = _fillArrayWithNumber(k, 0), 125 | i = n; 126 | while (i--) { 127 | draw_result[disc.draw()] += 1; 128 | } 129 | return draw_result; 130 | }; 131 | 132 | result.toString = function() { 133 | return "Multinom( " + 134 | n + 135 | ", [" + probs.join(", ") + 136 | "] )"; 137 | }; 138 | 139 | return result; 140 | } 141 | 142 | function NegBinomial(r, p) { 143 | var result = Object.create(_samplerPrototype); 144 | 145 | result.draw = function() { 146 | var draw_result = 0, failures = r; 147 | while (failures) { 148 | Bernoulli(p).draw() ? draw_result++ : failures--; 149 | } 150 | return draw_result; 151 | }; 152 | 153 | result.toString = function() { 154 | return "NegBinomial( " + r + 155 | ", " + p + " )"; 156 | }; 157 | 158 | return result; 159 | } 160 | 161 | function Poisson(lambda) { 162 | var result = Object.create(_samplerPrototype); 163 | 164 | result.draw = function() { 165 | var draw_result, L = Math.exp(- lambda), k = 0, p = 1; 166 | 167 | do { 168 | k++; 169 | p = p * Math.random() 170 | } while (p > L); 171 | return k-1; 172 | } 173 | 174 | result.toString = function() { 175 | return "Poisson( " + lambda + " )"; 176 | } 177 | 178 | return result; 179 | } 180 | 181 | return { 182 | _fillArrayWithNumber: _fillArrayWithNumber, // REMOVE EVENTUALLY - this is just so the Array.prototype mod can work 183 | _rangeFunc: _rangeFunc, 184 | Bernoulli: Bernoulli, 185 | Binomial: Binomial, 186 | Discrete: Discrete, 187 | Multinomial: Multinomial, 188 | NegBinomial: NegBinomial, 189 | Poisson: Poisson 190 | }; 191 | })(); 192 | 193 | //*** Sampling from arrays ***// 194 | // Eventually merge into SJS ??? 195 | function sample_from_array(array, numSamples, withReplacement) { 196 | var n = numSamples || 1, 197 | result = [], 198 | copy, 199 | disc, 200 | index; 201 | 202 | if (!withReplacement && numSamples > array.length) { 203 | throw new Error("Sampling without replacement, and the sample size exceeds vector size.") 204 | } 205 | 206 | if (withReplacement) { 207 | while(numSamples--) { 208 | disc = SJS.Discrete(SJS._fillArrayWithNumber(array.length, 1)); 209 | result.push(array[disc.draw()]); 210 | } 211 | } else { 212 | // instead of splicing, consider sampling from an array of possible indices? meh? 213 | copy = array.slice(0); 214 | while (numSamples--) { 215 | disc = SJS.Discrete(SJS._fillArrayWithNumber(copy.length, 1)); 216 | index = disc.draw(); 217 | result.push(copy[index]); 218 | copy.splice(index, 1); 219 | console.log("array: "+copy); 220 | } 221 | } 222 | return result; 223 | } 224 | -------------------------------------------------------------------------------- /examples/char-rnn/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Char RNN Text Example> 5 | 10 | 11 | 12 | 13 |

Tensorflow.js operations shares the main UI thread with the browser, so the GIF of Nyancat above is so that you can see if and when Tensorflow.js freezes user interaction.

14 |

15 | Char-rnn is a generative language model that creates text in the style of an author or passage given a substantial amount (1MB+) of training text. It uses a recurrent neural network (RNN) that is fed 40-character windows of sequential text data and learns to predict the next character in the sequence. This model is being trained using data from William Shakespeare's plays. For more information about Char-rnn, see this blog post: https://karpathy.github.io/2015/05/21/rnn-effectiveness/ 16 |

17 |

18 | Open your developer console to view training losses and newly generated Shakespear text each epoch. 19 |

20 | 21 |

Back

22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /examples/char-rnn/js/main.js: -------------------------------------------------------------------------------- 1 | const SEQLEN = 40 2 | const SKIP = 3 3 | const EPOCHS = 10 4 | const LEARNING_RATE = 0.01 5 | const BATCH_SIZE = 128 6 | const LOADMODEL = false 7 | 8 | // promisified fs.readFile() 9 | function loadFile(path) { 10 | const fs = require('fs') 11 | return new Promise((resolve, reject) => { 12 | fs.readFile(path, (err, data) => { 13 | if (err) reject(err) 14 | else resolve(data) 15 | }) 16 | }) 17 | } 18 | 19 | // define and return an LSTM RNN model architecture. RNNs are used 20 | // with sequential data. 21 | function getModel(inputShape) { 22 | const model = tf.sequential() 23 | model.add(tf.layers.lstm({units: 128, inputShape: inputShape })) 24 | model.add(tf.layers.dense({units: inputShape[1]})) 25 | model.add(tf.layers.activation({ activation: 'softmax' })) 26 | model.compile({ 27 | optimizer: tf.train.rmsprop(LEARNING_RATE), 28 | loss: 'categoricalCrossentropy', 29 | metrics: ['accuracy'] 30 | }) 31 | return model 32 | } 33 | 34 | // utility function for loading and prepairing training data. Returns an object. 35 | async function getData(path) { 36 | 37 | const buf = await loadFile(path) 38 | // convert all of the characters to lowercase to reduce our number of 39 | // output classes 40 | const text = buf.toString().toLowerCase() 41 | console.log(`corpus length: ${text.length}`) 42 | 43 | const chars = Array.from(new Set(text)) 44 | console.log(`total chars: ${chars.length}`) 45 | 46 | const charIndicies = {}; chars.forEach((c, i) => charIndicies[c] = i) 47 | 48 | // cut the text in semi-redundant sequences of maxlen characters 49 | const sentences = [] 50 | const nextChars = [] 51 | for (let i = 0; i < text.length - SEQLEN; i += SKIP) { 52 | sentences.push(text.slice(i, i + SEQLEN)) 53 | nextChars.push(text.slice(i + SEQLEN, i + SEQLEN + 1)) 54 | // console.log(text.slice(i, i + SEQLEN), '->', text.slice(i + SEQLEN, i + SEQLEN + 1)) 55 | } 56 | 57 | // each array element will hold a batch of BATCH_SIZE examples 58 | const X = [] 59 | const Y = [] 60 | for (let batch = 0; batch < sentences.length; batch += BATCH_SIZE) { 61 | 62 | // we can't store all of the data in GPU memory, so instead we use 63 | // tf.TensorBuffer objects which are allocated by the CPU. Becore we 64 | // train on a batch we transform it into a Tensor using TensorBuffer.toTensor() 65 | 66 | const xBuff = tf.buffer([BATCH_SIZE, SEQLEN, chars.length]) 67 | const yBuff = tf.buffer([BATCH_SIZE, chars.length]) 68 | 69 | const sentenceBatch = sentences.slice(batch, batch + BATCH_SIZE) 70 | const nextCharsBatch = nextChars.slice(batch, batch + BATCH_SIZE) 71 | 72 | // one-hotify batches 73 | sentenceBatch.forEach((sentence, j) => { 74 | sentence.split('').forEach((char, k) => { 75 | xBuff.set(1, j, k, charIndicies[char]) 76 | }) 77 | yBuff.set(1, j, charIndicies[nextCharsBatch[j]]) 78 | }) 79 | 80 | X.push(xBuff) 81 | Y.push(yBuff) 82 | } 83 | return { 84 | X, Y, charIndicies, chars 85 | } 86 | } 87 | 88 | // sample an output from the multinomial distribution output of model.predict() 89 | function sample(preds) { 90 | let probas = Sampling.Multinomial(1, preds).draw() 91 | // lazy argmax that returns the index of the "hot" value in a one-hot vector 92 | return probas.reduce((acc, val, i) => acc + (val == 1 ? i : 0)) 93 | } 94 | 95 | async function generate(seed, numChars, charIndicies, chars, model) { 96 | console.assert(seed.length >= SEQLEN) 97 | seed = seed.toLowerCase().split('').slice(0, SEQLEN) 98 | const output = [] 99 | for (let i = 0; i < numChars; i++) { 100 | 101 | const x = tf.zeros([1, SEQLEN, chars.length]) 102 | seed.forEach((char, j) => { 103 | x.buffer().set(1, 0, j, charIndicies[char]) 104 | }) 105 | 106 | const preds = model.predict(x, { verbose: true }) 107 | const y = sample(preds.dataSync()) 108 | const char = chars[y] 109 | 110 | seed.shift(); seed.push(char) 111 | output.push(char) 112 | 113 | preds.dispose() 114 | x.dispose() 115 | await tf.nextFrame() 116 | } 117 | 118 | return output.join('') 119 | } 120 | 121 | async function main() { 122 | 123 | const data = await getData(`${__dirname}/data/tinyshakespeare.txt`) 124 | let model = null 125 | 126 | const seed = "This is a test sentence. It will be used to seed the model." 127 | 128 | // load model from IndexedDB if the flag says so 129 | if (LOADMODEL) { 130 | console.log('Loading model from IndexedDB') 131 | model = await tf.loadModel('indexeddb://model') 132 | } else { 133 | // otherwise train a new model and save it to IndexedDB overwriting any 134 | // existing models that may be saved there 135 | console.log('Training model...') 136 | model = getModel([SEQLEN, data.chars.length]) 137 | let history = null 138 | for (let i = 0; i < EPOCHS; i++) { 139 | const then = Date.now() 140 | for (let batch = 0; batch < data.X.length; batch++) { 141 | const batchX = data.X[batch].toTensor() 142 | const batchY = data.Y[batch].toTensor() 143 | history = await model.fit(batchX, batchY, { batchSize: BATCH_SIZE }) 144 | batchX.dispose() 145 | batchY.dispose() 146 | await tf.nextFrame() 147 | } 148 | console.log(`Epoch ${i + 1} loss: ${history.history.loss[0]}, accuracy: ${history.history.acc[0]}`) 149 | console.log(`Epoch lasted ${((Date.now() - then) / 1000).toFixed(0)} seconds`) 150 | await model.save('indexeddb://model') 151 | const text = await generate(seed, 100, data.charIndicies, data.chars, model) 152 | console.log(text) 153 | } 154 | } 155 | 156 | console.log(`Finished training for ${EPOCHS}`) 157 | console.log(`Generating 1000 characters of synthetic text:`) 158 | const text = await generate(seed, 1000, data.charIndicies, data.chars, model) 159 | console.log(text) 160 | 161 | } 162 | 163 | main() -------------------------------------------------------------------------------- /examples/char-rnn/lib/discrete.js: -------------------------------------------------------------------------------- 1 | // discrete.js 2 | // Sample from discrete distributions. 3 | 4 | var Sampling = SJS = (function(){ 5 | 6 | // Utility functions 7 | function _sum(a, b) { 8 | return a + b; 9 | }; 10 | function _fillArrayWithNumber(size, num) { 11 | // thanks be to stackOverflow... this is a beautiful one-liner 12 | return Array.apply(null, Array(size)).map(Number.prototype.valueOf, num); 13 | }; 14 | function _rangeFunc(upper) { 15 | var i = 0, out = []; 16 | while (i < upper) out.push(i++); 17 | return out; 18 | }; 19 | // Prototype function 20 | function _samplerFunction(size) { 21 | if (!Number.isInteger(size) || size < 0) { 22 | throw new Error ("Number of samples must be a non-negative integer."); 23 | } 24 | if (!this.draw) { 25 | throw new Error ("Distribution must specify a draw function."); 26 | } 27 | var result = []; 28 | while (size--) { 29 | result.push(this.draw()); 30 | } 31 | return result; 32 | }; 33 | // Prototype for discrete distributions 34 | var _samplerPrototype = { 35 | sample: _samplerFunction 36 | }; 37 | 38 | function Bernoulli(p) { 39 | 40 | var result = Object.create(_samplerPrototype); 41 | 42 | result.draw = function() { 43 | return (Math.random() < p) ? 1 : 0; 44 | }; 45 | 46 | result.toString = function() { 47 | return "Bernoulli( " + p + " )"; 48 | }; 49 | 50 | return result; 51 | } 52 | 53 | function Binomial(n, p) { 54 | 55 | var result = Object.create(_samplerPrototype), 56 | bern = Sampling.Bernoulli(p); 57 | 58 | result.draw = function() { 59 | return bern.sample(n).reduce(_sum, 0); // less space efficient than adding a bunch of draws, but cleaner :) 60 | } 61 | 62 | result.toString = function() { 63 | return "Binom( " + 64 | [n, p].join(", ") + 65 | " )"; 66 | } 67 | 68 | return result; 69 | } 70 | 71 | function Discrete(probs) { // probs should be an array of probabilities. (they get normalized automagically) // 72 | 73 | var result = Object.create(_samplerPrototype), 74 | k = probs.length; 75 | 76 | result.draw = function() { 77 | var i, p; 78 | for (i = 0; i < k; i++) { 79 | p = probs[i] / probs.slice(i).reduce(_sum, 0); // this is the (normalized) head of a slice of probs 80 | if (Bernoulli(p).draw()) return i; // using the truthiness of a Bernoulli draw 81 | } 82 | return k - 1; 83 | }; 84 | 85 | result.sampleNoReplace = function(size) { 86 | if (size>probs.length) { 87 | throw new Error("Sampling without replacement, and the sample size exceeds vector size.") 88 | } 89 | var disc, index, sum, samp = []; 90 | var currentProbs = probs; 91 | var live = _rangeFunc(probs.length); 92 | while (size--) { 93 | sum = currentProbs.reduce(_sum, 0); 94 | currentProbs = currentProbs.map(function(x) {return x/sum; }); 95 | disc = SJS.Discrete(currentProbs); 96 | index = disc.draw(); 97 | samp.push(live[index]); 98 | live.splice(index, 1); 99 | currentProbs.splice(index, 1); 100 | sum = currentProbs.reduce(_sum, 0); 101 | currentProbs = currentProbs.map(function(x) {return x/sum; }); 102 | } 103 | currentProbs = probs; 104 | live = _rangeFunc(probs.length); 105 | return samp; 106 | } 107 | 108 | result.toString = function() { 109 | return "Dicrete( [" + 110 | probs.join(", ") + 111 | "] )"; 112 | }; 113 | 114 | return result; 115 | } 116 | 117 | function Multinomial(n, probs) { 118 | 119 | var result = Object.create(_samplerPrototype), 120 | k = probs.length, 121 | disc = Discrete(probs); 122 | 123 | result.draw = function() { 124 | var draw_result = _fillArrayWithNumber(k, 0), 125 | i = n; 126 | while (i--) { 127 | draw_result[disc.draw()] += 1; 128 | } 129 | return draw_result; 130 | }; 131 | 132 | result.toString = function() { 133 | return "Multinom( " + 134 | n + 135 | ", [" + probs.join(", ") + 136 | "] )"; 137 | }; 138 | 139 | return result; 140 | } 141 | 142 | function NegBinomial(r, p) { 143 | var result = Object.create(_samplerPrototype); 144 | 145 | result.draw = function() { 146 | var draw_result = 0, failures = r; 147 | while (failures) { 148 | Bernoulli(p).draw() ? draw_result++ : failures--; 149 | } 150 | return draw_result; 151 | }; 152 | 153 | result.toString = function() { 154 | return "NegBinomial( " + r + 155 | ", " + p + " )"; 156 | }; 157 | 158 | return result; 159 | } 160 | 161 | function Poisson(lambda) { 162 | var result = Object.create(_samplerPrototype); 163 | 164 | result.draw = function() { 165 | var draw_result, L = Math.exp(- lambda), k = 0, p = 1; 166 | 167 | do { 168 | k++; 169 | p = p * Math.random() 170 | } while (p > L); 171 | return k-1; 172 | } 173 | 174 | result.toString = function() { 175 | return "Poisson( " + lambda + " )"; 176 | } 177 | 178 | return result; 179 | } 180 | 181 | return { 182 | _fillArrayWithNumber: _fillArrayWithNumber, // REMOVE EVENTUALLY - this is just so the Array.prototype mod can work 183 | _rangeFunc: _rangeFunc, 184 | Bernoulli: Bernoulli, 185 | Binomial: Binomial, 186 | Discrete: Discrete, 187 | Multinomial: Multinomial, 188 | NegBinomial: NegBinomial, 189 | Poisson: Poisson 190 | }; 191 | })(); 192 | 193 | //*** Sampling from arrays ***// 194 | // Eventually merge into SJS ??? 195 | function sample_from_array(array, numSamples, withReplacement) { 196 | var n = numSamples || 1, 197 | result = [], 198 | copy, 199 | disc, 200 | index; 201 | 202 | if (!withReplacement && numSamples > array.length) { 203 | throw new Error("Sampling without replacement, and the sample size exceeds vector size.") 204 | } 205 | 206 | if (withReplacement) { 207 | while(numSamples--) { 208 | disc = SJS.Discrete(SJS._fillArrayWithNumber(array.length, 1)); 209 | result.push(array[disc.draw()]); 210 | } 211 | } else { 212 | // instead of splicing, consider sampling from an array of possible indices? meh? 213 | copy = array.slice(0); 214 | while (numSamples--) { 215 | disc = SJS.Discrete(SJS._fillArrayWithNumber(copy.length, 1)); 216 | index = disc.draw(); 217 | result.push(copy[index]); 218 | copy.splice(index, 1); 219 | console.log("array: "+copy); 220 | } 221 | } 222 | return result; 223 | } 224 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow.js Example: Training MNIST 2 | 3 | This example shows you how to train MNIST (using the layers API). 4 | 5 | You can check out the tutorial that accompanies this example [here](https://js.tensorflow.org/tutorials/mnist.html). 6 | 7 | This model will compute accuracy over 1000 random test set examples every 5 8 | steps, plotting loss and accuracy as the model is training. Training time can 9 | be reduced by computing accuracy over fewer examples less often. 10 | 11 | Note: currently the entire dataset of MNIST images is stored in a PNG image we have 12 | sprited, and the code in `data.js` is responsible for converting it into 13 | `Tensor`s. This will become much simpler in the near future. 14 | 15 | [See this example live!](https://storage.googleapis.com/tfjs-examples/mnist/dist/index.html) 16 | -------------------------------------------------------------------------------- /examples/mnist/data.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const tf = require('@tensorflow/tfjs'); 19 | 20 | const IMAGE_SIZE = 784; 21 | const NUM_CLASSES = 10; 22 | const NUM_DATASET_ELEMENTS = 65000; 23 | 24 | const NUM_TRAIN_ELEMENTS = 55000; 25 | const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; 26 | 27 | const MNIST_IMAGES_SPRITE_PATH = 28 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; 29 | const MNIST_LABELS_PATH = 30 | 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; 31 | 32 | /** 33 | * A class that fetches the sprited MNIST dataset and returns shuffled batches. 34 | * 35 | * NOTE: This will get much easier. For now, we do data fetching and 36 | * manipulation manually. 37 | */ 38 | class MnistData { 39 | constructor() { 40 | this.shuffledTrainIndex = 0; 41 | this.shuffledTestIndex = 0; 42 | } 43 | 44 | async load() { 45 | // Make a request for the MNIST sprited image. 46 | const img = new Image(); 47 | const canvas = document.createElement('canvas'); 48 | const ctx = canvas.getContext('2d'); 49 | const imgRequest = new Promise((resolve, reject) => { 50 | img.crossOrigin = ''; 51 | img.onload = () => { 52 | img.width = img.naturalWidth; 53 | img.height = img.naturalHeight; 54 | 55 | const datasetBytesBuffer = 56 | new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); 57 | 58 | const chunkSize = 5000; 59 | canvas.width = img.width; 60 | canvas.height = chunkSize; 61 | 62 | for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { 63 | const datasetBytesView = new Float32Array( 64 | datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, 65 | IMAGE_SIZE * chunkSize); 66 | ctx.drawImage( 67 | img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, 68 | chunkSize); 69 | 70 | const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); 71 | 72 | for (let j = 0; j < imageData.data.length / 4; j++) { 73 | // All channels hold an equal value since the image is grayscale, so 74 | // just read the red channel. 75 | datasetBytesView[j] = imageData.data[j * 4] / 255; 76 | } 77 | } 78 | this.datasetImages = new Float32Array(datasetBytesBuffer); 79 | 80 | resolve(); 81 | }; 82 | img.src = MNIST_IMAGES_SPRITE_PATH; 83 | }); 84 | 85 | const labelsRequest = fetch(MNIST_LABELS_PATH); 86 | const [imgResponse, labelsResponse] = 87 | await Promise.all([imgRequest, labelsRequest]); 88 | 89 | this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); 90 | 91 | // Create shuffled indices into the train/test set for when we select a 92 | // random dataset element for training / validation. 93 | this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); 94 | this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); 95 | 96 | // Slice the the images and labels into train and test sets. 97 | this.trainImages = 98 | this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 99 | this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); 100 | this.trainLabels = 101 | this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); 102 | this.testLabels = 103 | this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); 104 | } 105 | 106 | nextTrainBatch(batchSize) { 107 | return this.nextBatch( 108 | batchSize, [this.trainImages, this.trainLabels], () => { 109 | this.shuffledTrainIndex = 110 | (this.shuffledTrainIndex + 1) % this.trainIndices.length; 111 | return this.trainIndices[this.shuffledTrainIndex]; 112 | }); 113 | } 114 | 115 | nextTestBatch(batchSize) { 116 | return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { 117 | this.shuffledTestIndex = 118 | (this.shuffledTestIndex + 1) % this.testIndices.length; 119 | return this.testIndices[this.shuffledTestIndex]; 120 | }); 121 | } 122 | 123 | nextBatch(batchSize, data, index) { 124 | const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); 125 | const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); 126 | 127 | for (let i = 0; i < batchSize; i++) { 128 | const idx = index(); 129 | 130 | const image = 131 | data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); 132 | batchImagesArray.set(image, i * IMAGE_SIZE); 133 | 134 | const label = 135 | data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); 136 | batchLabelsArray.set(label, i * NUM_CLASSES); 137 | } 138 | 139 | const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); 140 | const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); 141 | 142 | return {xs, labels}; 143 | } 144 | } 145 | 146 | module.exports = { MnistData } 147 | -------------------------------------------------------------------------------- /examples/mnist/index.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 57 | 58 |
59 |

TensorFlow.js: Train MNIST with the Layers API

60 |

Back

61 |
Loading data...
62 |
63 | 64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | 75 |
76 |
77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /examples/mnist/index.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const tf = require('@tensorflow/tfjs'); 19 | 20 | // This is a helper class for loading and managing MNIST data specifically. 21 | // It is a useful example of how you could create your own data manager class 22 | // for arbitrary data though. It's worth a look :) 23 | const { MnistData } = require('./data'); 24 | 25 | // This is a helper class for drawing loss graphs and MNIST images to the 26 | // window. For the purposes of understanding the machine learning bits, you can 27 | // largely ignore it 28 | import * as ui from './ui'; 29 | 30 | // Create a sequential neural network model. tf.sequential provides an API for 31 | // creating "stacked" models where the output from one layer is used as the 32 | // input to the next layer. 33 | const model = tf.sequential(); 34 | 35 | // The first layer of the convolutional neural network plays a dual role: 36 | // it is both the input layer of the neural network and a layer that performs 37 | // the first convolution operation on the input. It receives the 28x28 pixels 38 | // black and white images. This input layer uses 8 filters with a kernel size 39 | // of 5 pixels each. It uses a simple RELU activation function which pretty 40 | // much just looks like this: __/ 41 | model.add(tf.layers.conv2d({ 42 | inputShape: [28, 28, 1], 43 | kernelSize: 5, 44 | filters: 8, 45 | strides: 1, 46 | activation: 'relu', 47 | kernelInitializer: 'varianceScaling' 48 | })); 49 | 50 | // After the first layer we include a MaxPooling layer. This acts as a sort of 51 | // downsampling using max values in a region instead of averaging. 52 | // https://www.quora.com/What-is-max-pooling-in-convolutional-neural-networks 53 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 54 | 55 | // Our third layer is another convolution, this time with 16 filters. 56 | model.add(tf.layers.conv2d({ 57 | kernelSize: 5, 58 | filters: 16, 59 | strides: 1, 60 | activation: 'relu', 61 | kernelInitializer: 'varianceScaling' 62 | })); 63 | 64 | // Max pooling again. 65 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); 66 | 67 | // Now we flatten the output from the 2D filters into a 1D vector to prepare 68 | // it for input into our last layer. This is common practice when feeding 69 | // higher dimensional data to a final classification output layer. 70 | model.add(tf.layers.flatten()); 71 | 72 | // Our last layer is a dense layer which has 10 output units, one for each 73 | // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). Here the classes actually 74 | // represent numbers, but it's the same idea if you had classes that represented 75 | // other entities like dogs and cats (two output classes: 0, 1). 76 | // We use the softmax function as the activation for the output layer as it 77 | // creates a probability distribution over our 10 classes so their output values 78 | // sum to 1. 79 | model.add(tf.layers.dense( 80 | {units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax'})); 81 | 82 | // Now that we've defined our model, we will define our optimizer. The optimizer 83 | // will be used to optimize our model's weight values during training so that 84 | // we can decrease our training loss and increase our classification accuracy. 85 | 86 | // The learning rate defines the magnitude by which we update our weights each 87 | // training step. The higher the value, the faster our loss values converge, 88 | // but also the more likely we are to overshoot optimal parameters 89 | // when making an update. A learning rate that is too low will take too long to 90 | // find optimal (or good enough) weight parameters while a learning rate that is 91 | // too high may overshoot optimal parameters. Learning rate is one of the most 92 | // important hyperparameters to set correctly. Finding the right value takes 93 | // practice and is often best found empirically by trying many values. 94 | const LEARNING_RATE = 0.15; 95 | 96 | // We are using Stochastic Gradient Descent (SGD) as our optimization algorithm. 97 | // This is the most famous modern optimization algorithm in deep learning and 98 | // it is largely to thank for the current machine learning renaissance. 99 | // Most other optimizers you will come across (e.g. ADAM, RMSProp, AdaGrad, 100 | // Momentum) are variants on SGD. SGD is an iterative method for minimizing an 101 | // objective function. It tries to find the minimum of our loss function with 102 | // respect to the model's weight parameters. 103 | const optimizer = tf.train.sgd(LEARNING_RATE); 104 | 105 | // We compile our model by specifying an optimizer, a loss function, and a list 106 | // of metrics that we will use for model evaluation. Here we're using a 107 | // categorical crossentropy loss, the standard choice for a multi-class 108 | // classification problem like MNIST digits. 109 | model.compile({ 110 | optimizer: optimizer, 111 | loss: 'categoricalCrossentropy', 112 | metrics: ['accuracy'], 113 | }); 114 | 115 | // Batch size is another important hyperparameter. It defines the number of 116 | // examples we group together, or batch, between updates to the model's weights 117 | // during training. A value that is too low will update weights using too few 118 | // examples and will not generalize well. Larger batch sizes require more memory 119 | // resources and aren't guaranteed to perform better. 120 | const BATCH_SIZE = 64; 121 | 122 | // The number of batches to train on before freezing the model and considering 123 | // it trained. This will result in BATCH_SIZE x TRAIN_BATCHES examples being 124 | // fed to the model during training. 125 | const TRAIN_BATCHES = 150; 126 | 127 | // Every few batches, test accuracy over many examples. Ideally, we'd compute 128 | // accuracy over the whole test set, but for performance we'll use a subset. 129 | 130 | // The number of test examples to predict each time we test. Because we don't 131 | // update model weights during testing this value doesn't affect model training. 132 | const TEST_BATCH_SIZE = 1000; 133 | // The number of training batches we will run between each test batch. 134 | const TEST_ITERATION_FREQUENCY = 5; 135 | 136 | async function train() { 137 | ui.isTraining(); 138 | 139 | // We'll keep a buffer of loss and accuracy values over time. 140 | const lossValues = []; 141 | const accuracyValues = []; 142 | 143 | // Iteratively train our model on mini-batches of data. 144 | for (let i = 0; i < TRAIN_BATCHES; i++) { 145 | 146 | const batch = data.nextTrainBatch(BATCH_SIZE); 147 | 148 | let testBatch; 149 | let validationData; 150 | // Every few batches test the accuracy of the model. 151 | if (i % TEST_ITERATION_FREQUENCY === 0) { 152 | testBatch = data.nextTestBatch(TEST_BATCH_SIZE); 153 | validationData = [ 154 | // Reshape the training data from [64, 28x28] to [64, 28, 28, 1] so 155 | // that we can feed it to our convolutional neural net. 156 | testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels 157 | ]; 158 | } 159 | 160 | // The entire dataset doesn't fit into memory so we call train repeatedly 161 | // with batches using the fit() method. 162 | const history = await model.fit( 163 | batch.xs.reshape([BATCH_SIZE, 28, 28, 1]), batch.labels, 164 | {batchSize: BATCH_SIZE, validationData, epochs: 1}); 165 | 166 | const loss = history.history.loss[0]; 167 | const accuracy = history.history.acc[0]; 168 | 169 | // Plot loss / accuracy. 170 | lossValues.push({'batch': i, 'loss': loss, 'set': 'train'}); 171 | ui.plotLosses(lossValues); 172 | 173 | if (testBatch != null) { 174 | accuracyValues.push({'batch': i, 'accuracy': accuracy, 'set': 'train'}); 175 | ui.plotAccuracies(accuracyValues); 176 | } 177 | 178 | // Call dispose on the training/test tensors to free their GPU memory. 179 | batch.xs.dispose(); 180 | batch.labels.dispose(); 181 | if (testBatch != null) { 182 | testBatch.xs.dispose(); 183 | testBatch.labels.dispose(); 184 | } 185 | 186 | // tf.nextFrame() returns a promise that resolves at the next call to 187 | // requestAnimationFrame(). By awaiting this promise we keep our model 188 | // training from blocking the main UI thread and freezing the browser. 189 | await tf.nextFrame(); 190 | } 191 | } 192 | 193 | async function showPredictions() { 194 | const testExamples = 100; 195 | const batch = data.nextTestBatch(testExamples); 196 | 197 | // Code wrapped in a tf.tidy() function callback will have their tensors freed 198 | // from GPU memory after execution without having to call dispose(). 199 | // The tf.tidy callback runs synchronously. 200 | tf.tidy(() => { 201 | const output = model.predict(batch.xs.reshape([-1, 28, 28, 1])); 202 | 203 | // tf.argMax() returns the indices of the maximum values in the tensor along 204 | // a specific axis. Categorical classification tasks like this one often 205 | // represent classes as one-hot vectors. One-hot vectors are 1D vectors with 206 | // one element for each output class. All values in the vector are 0 207 | // except for one, which has a value of 1 (e.g. [0, 0, 0, 1, 0]). The 208 | // output from model.predict() will be a probability distribution, so we use 209 | // argMax to get the index of the vector element that has the highest 210 | // probability. This is our prediction. 211 | // (e.g. argmax([0.07, 0.1, 0.03, 0.75, 0.05]) == 3) 212 | // dataSync() synchronously downloads the tf.tensor values from the GPU so 213 | // that we can use them in our normal CPU JavaScript code 214 | // (for a non-blocking version of this function, use data()). 215 | const axis = 1; 216 | const labels = Array.from(batch.labels.argMax(axis).dataSync()); 217 | const predictions = Array.from(output.argMax(axis).dataSync()); 218 | 219 | ui.showTestResults(batch, predictions, labels); 220 | }); 221 | } 222 | 223 | let data; 224 | async function load() { 225 | data = new MnistData(); 226 | await data.load(); 227 | } 228 | 229 | // This is our main function. It loads the MNIST data, trains the model, and 230 | // then shows what the model predicted on unseen test data. 231 | async function mnist() { 232 | await load(); 233 | await train(); 234 | showPredictions(); 235 | } 236 | mnist(); 237 | -------------------------------------------------------------------------------- /examples/mnist/ui.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const embed = require('vega-embed').default; 19 | // console.log(embed.default) 20 | // console.log(require('vega-embed')) 21 | 22 | const statusElement = document.getElementById('status'); 23 | const messageElement = document.getElementById('message'); 24 | const imagesElement = document.getElementById('images'); 25 | 26 | function isTraining() { 27 | statusElement.innerText = 'Training...'; 28 | } 29 | function trainingLog(message) { 30 | messageElement.innerText = `${message}\n`; 31 | console.log(message); 32 | } 33 | 34 | function showTestResults(batch, predictions, labels) { 35 | statusElement.innerText = 'Testing...'; 36 | 37 | const testExamples = batch.xs.shape[0]; 38 | let totalCorrect = 0; 39 | for (let i = 0; i < testExamples; i++) { 40 | const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]); 41 | 42 | const div = document.createElement('div'); 43 | div.className = 'pred-container'; 44 | 45 | const canvas = document.createElement('canvas'); 46 | canvas.className = 'prediction-canvas'; 47 | draw(image.flatten(), canvas); 48 | 49 | const pred = document.createElement('div'); 50 | 51 | const prediction = predictions[i]; 52 | const label = labels[i]; 53 | const correct = prediction === label; 54 | 55 | pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`; 56 | pred.innerText = `pred: ${prediction}`; 57 | 58 | div.appendChild(pred); 59 | div.appendChild(canvas); 60 | 61 | imagesElement.appendChild(div); 62 | } 63 | } 64 | 65 | const lossLabelElement = document.getElementById('loss-label'); 66 | const accuracyLabelElement = document.getElementById('accuracy-label'); 67 | function plotLosses(lossValues) { 68 | embed( 69 | '#lossCanvas', { 70 | '$schema': 'https://vega.github.io/schema/vega-lite/v2.json', 71 | 'data': {'values': lossValues}, 72 | 'mark': {'type': 'line'}, 73 | 'width': 260, 74 | 'orient': 'vertical', 75 | 'encoding': { 76 | 'x': {'field': 'batch', 'type': 'ordinal'}, 77 | 'y': {'field': 'loss', 'type': 'quantitative'}, 78 | 'color': {'field': 'set', 'type': 'nominal', 'legend': null}, 79 | } 80 | }, 81 | {width: 360}); 82 | lossLabelElement.innerText = 83 | 'last loss: ' + lossValues[lossValues.length - 1].loss.toFixed(2); 84 | } 85 | 86 | function plotAccuracies(accuracyValues) { 87 | embed( 88 | '#accuracyCanvas', { 89 | '$schema': 'https://vega.github.io/schema/vega-lite/v2.json', 90 | 'data': {'values': accuracyValues}, 91 | 'width': 260, 92 | 'mark': {'type': 'line', 'legend': null}, 93 | 'orient': 'vertical', 94 | 'encoding': { 95 | 'x': {'field': 'batch', 'type': 'ordinal'}, 96 | 'y': {'field': 'accuracy', 'type': 'quantitative'}, 97 | 'color': {'field': 'set', 'type': 'nominal', 'legend': null}, 98 | } 99 | }, 100 | {'width': 360}); 101 | accuracyLabelElement.innerText = 'last accuracy: ' + 102 | (accuracyValues[accuracyValues.length - 1].accuracy * 100).toFixed(2) + 103 | '%'; 104 | } 105 | 106 | function draw(image, canvas) { 107 | const [width, height] = [28, 28]; 108 | canvas.width = width; 109 | canvas.height = height; 110 | const ctx = canvas.getContext('2d'); 111 | const imageData = new ImageData(width, height); 112 | const data = image.dataSync(); 113 | for (let i = 0; i < height * width; ++i) { 114 | const j = i * 4; 115 | imageData.data[j + 0] = data[i] * 255; 116 | imageData.data[j + 1] = data[i] * 255; 117 | imageData.data[j + 2] = data[i] * 255; 118 | imageData.data[j + 3] = 255; 119 | } 120 | ctx.putImageData(imageData, 0, 0); 121 | } 122 | 123 | module.exports = { 124 | isTraining, 125 | trainingLog, 126 | showTestResults, 127 | plotLosses, 128 | plotAccuracies, 129 | draw 130 | } 131 | -------------------------------------------------------------------------------- /examples/simple-linear-network/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Simple Linear Prediction> 5 | 10 | 11 | 12 |

Training a 1 neuron linear network to learn a simple function: y = 2x - 1.

13 |

Training on 100 examples 100 times...

14 |

Back

15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /examples/simple-linear-network/main.js: -------------------------------------------------------------------------------- 1 | // This example is inspired by https://medium.com/tensorflow/getting-started-with-tensorflow-js-50f6783489b2 2 | // I encourage you to read that short post! 3 | 4 | async function main() { 5 | 6 | // define a single neuron linear model 7 | const model = tf.sequential() 8 | model.add(tf.layers.dense({units: 1, inputShape: [1]})) 9 | // Mean Squared Error (MSE) is the most common loss function for regression 10 | // tasks like this one. We'll use the basic Stochastic Gradient Descent 11 | // optimizer as well! 12 | model.compile({ loss: 'meanSquaredError', optimizer: 'sgd' }) 13 | 14 | // create our training set 15 | const { x, y } = getRealData(100) 16 | 17 | // create tf.tensors from the training data for use during model.fit() 18 | const xTensor = tf.tensor(x) 19 | const yTensor = tf.tensor(y) 20 | 21 | // train the model! This will train for 100 full passes, or epochs, through 22 | // our 100 sample datas. Meaning our model will see each sample 100 times. 23 | await model.fit(xTensor, yTensor, { epochs: 100 }) 24 | 25 | // don't forget to free the GPU memory now that 26 | // we're done with our training data 27 | xTensor.dispose() 28 | yTensor.dispose() 29 | 30 | // Now that we've trained our model, it's time for model inference. Let's 31 | // use our model to predict the output values of ten random numbers. 32 | for (let i = 0; i < 10; i++) { 33 | 34 | const rand = random(-10, 10) 35 | const real = realFunction(rand) 36 | // we use model.predict() followed by .dataSync() to download our predicted 37 | // data from the GPU. Our data is wrapped in a "batch dimension", so we'll 38 | // have to grab the first (and only) element to get our prediction as a float. 39 | // Notice that we pass [rand] into tf.tensor and specify an input shape as [1, 1]. 40 | // This is because model.predict(...) expects data in batches, so we are 41 | // technically passing rand into predict as a mini-batch of one example. 42 | const pred = model.predict(tf.tensor([rand], [1, 1])).dataSync()[0] 43 | 44 | // let's see how our model did by comparing the predicted value to the 45 | // real output of realFunction(). Write the results to the DOM. 46 | if (i == 0) document.getElementById('output').innerText = '' 47 | const html = `real: ${real.toFixed(2)}, pred: ${pred.toFixed(2)}, error: ${Math.abs(pred - real).toFixed(2)}
` 48 | document.getElementById('output').innerHTML += html 49 | } 50 | } 51 | 52 | // get a random float between min and max 53 | function random(min, max) { 54 | return Math.random() * (max - min) + min 55 | } 56 | 57 | // this is true function that our neural net is trying to learn 58 | // we use it to generate training data and to compare against 59 | // our model's predictions 60 | function realFunction(x) { 61 | return 2.0 * x - 1.0 62 | } 63 | 64 | // generate training data samples in the form of { x, y } using realFunction() 65 | function getRealData(numSamples) { 66 | 67 | const x = [], y = [] 68 | 69 | for (let i = 0; i < numSamples; i++) { 70 | // pick a random number between -10 and 10 71 | const rand = random(-10, 10) 72 | x.push(rand) 73 | y.push(realFunction(rand)) 74 | } 75 | 76 | return { x, y } 77 | } 78 | 79 | // run it! 80 | main() -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Tensorflow.js Examples 6 | 7 | 8 |

Tensorflow.js Examples

9 |