├── .gitignore ├── CONTRIBUTING.md ├── Gruntfile.js ├── LICENSE ├── README.md ├── bower.json ├── browser.js ├── lib ├── brain.js ├── cross-validate.js ├── lookup.js └── neuralnetwork.js ├── package.json ├── stream-example.js └── test ├── README.md ├── cross-validation └── ocr.js └── unit ├── bitwise.js ├── hash.js ├── json.js ├── lookup.js ├── options.js ├── stream-bitwise.js └── trainopts.js /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac. 2 | .DS_STORE 3 | 4 | # Node. 5 | node_modules 6 | npm-debug.log 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Thanks for taking the time to contribute to brain.js. Follow these guidelines to make the process smoother: 2 | 3 | 1. One feature per pull request. Each PR should have one focus, and all the code changes should be supporting that one feature or bug fix. Using a [separate branch](https://guides.github.com/introduction/flow/index.html) for each feature should help you manage developing multiple features at once. 4 | 5 | 2. Follow the style of the file when it comes to syntax like curly braces and indents. 6 | 7 | 3. Add a test for the feature or fix, if possible. See the `test` directory for existing tests and README describing how to run these tests. 8 | -------------------------------------------------------------------------------- /Gruntfile.js: -------------------------------------------------------------------------------- 1 | /* 2 | * To run this file: 3 | * `npm install --dev` 4 | * `npm install -g grunt` 5 | * 6 | * `grunt --help` 7 | */ 8 | 9 | var fs = require("fs"), 10 | browserify = require("browserify"), 11 | pkg = require("./package.json"); 12 | 13 | module.exports = function(grunt) { 14 | grunt.initConfig({ 15 | mochaTest: { 16 | test: { 17 | options: { 18 | style: 'bdd', 19 | reporter: 'spec' 20 | }, 21 | src: ['test/unit/*.js'] 22 | } 23 | }, 24 | pkg: grunt.file.readJSON('package.json'), 25 | uglify: { 26 | options: { 27 | banner: "/*\n" + grunt.file.read('LICENSE') + "*/" 28 | }, 29 | dist: { 30 | files: { 31 | '<%=pkg.name%>-<%=pkg.version%>.min.js': ['<%=pkg.name%>-<%=pkg.version%>.js'] 32 | } 33 | } 34 | } 35 | }); 36 | 37 | grunt.registerTask('build', 'build a browser file', function() { 38 | var done = this.async(); 39 | 40 | var outfile = './brain-' + pkg.version + '.js'; 41 | 42 | var bundle = browserify('./browser.js').bundle(function(err, src) { 43 | console.log("> " + outfile); 44 | 45 | // prepend license 46 | var license = fs.readFileSync("./LICENSE"); 47 | src = "/*\n" + license + "*/" + src; 48 | 49 | // write out the browser file 50 | fs.writeFileSync(outfile, src); 51 | done(); 52 | }); 53 | }); 54 | grunt.registerTask('test', 'mochaTest'); 55 | 56 | grunt.loadNpmTasks('grunt-mocha-test'); 57 | grunt.loadNpmTasks('grunt-contrib-uglify'); 58 | }; 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010 Heather Arthur 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *This project has reached the end of its development as a simple neural network library. Feel free to browse the code, but please use other JavaScript neural network libraries in development like [brain.js](https://github.com/BrainJS/brain.js) and [convnetjs](https://github.com/karpathy/convnetjs).* 2 | 3 | # brain 4 | 5 | `brain` is a JavaScript [neural network](http://neuralnetworksanddeeplearning.com/) library. Here's an example of using it to approximate the XOR function: 6 | 7 | ```javascript 8 | var net = new brain.NeuralNetwork(); 9 | 10 | net.train([{input: [0, 0], output: [0]}, 11 | {input: [0, 1], output: [1]}, 12 | {input: [1, 0], output: [1]}, 13 | {input: [1, 1], output: [0]}]); 14 | 15 | var output = net.run([1, 0]); // [0.987] 16 | ``` 17 | 18 | There's no reason to use a neural network to figure out XOR however (-: so here's a more involved, realistic example: 19 | [Demo: training a neural network to recognize color contrast](http://harthur.github.com/brain/) 20 | 21 | ## Using in node 22 | If you have [node](http://nodejs.org/) you can install with [npm](http://npmjs.org): 23 | 24 | ``` 25 | npm install brain 26 | ``` 27 | 28 | ## Using in the browser 29 | Download the latest [brain.js](https://github.com/harthur/brain/tree/gh-pages). Training is computationally expensive, so you should try to train the network offline (or on a Worker) and use the `toFunction()` or `toJSON()` options to plug the pre-trained network in to your website. 30 | 31 | ## Training 32 | Use `train()` to train the network with an array of training data. The network has to be trained with all the data in bulk in one call to `train()`. The more training patterns, the longer it will probably take to train, but the better the network will be at classifiying new patterns. 33 | 34 | #### Data format 35 | Each training pattern should have an `input` and an `output`, both of which can be either an array of numbers from `0` to `1` or a hash of numbers from `0` to `1`. For the [color constrast demo](http://harthur.github.com/brain/) it looks something like this: 36 | 37 | ```javascript 38 | var net = new brain.NeuralNetwork(); 39 | 40 | net.train([{input: { r: 0.03, g: 0.7, b: 0.5 }, output: { black: 1 }}, 41 | {input: { r: 0.16, g: 0.09, b: 0.2 }, output: { white: 1 }}, 42 | {input: { r: 0.5, g: 0.5, b: 1.0 }, output: { white: 1 }}]); 43 | 44 | var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.99, black: 0.002 } 45 | ``` 46 | 47 | #### Options 48 | `train()` takes a hash of options as its second argument: 49 | 50 | ```javascript 51 | net.train(data, { 52 | errorThresh: 0.005, // error threshold to reach 53 | iterations: 20000, // maximum training iterations 54 | log: true, // console.log() progress periodically 55 | logPeriod: 10, // number of iterations between logging 56 | learningRate: 0.3 // learning rate 57 | }) 58 | ``` 59 | 60 | The network will train until the training error has gone below the threshold (default `0.005`) or the max number of iterations (default `20000`) has been reached, whichever comes first. 61 | 62 | By default training won't let you know how its doing until the end, but set `log` to `true` to get periodic updates on the current training error of the network. The training error should decrease every time. The updates will be printed to console. If you set `log` to a function, this function will be called with the updates instead of printing to the console. 63 | 64 | The learning rate is a parameter that influences how quickly the network trains. It's a number from `0` to `1`. If the learning rate is close to `0` it will take longer to train. If the learning rate is closer to `1` it will train faster but it's in danger of training to a local minimum and performing badly on new data. The default learning rate is `0.3`. 65 | 66 | #### Output 67 | The output of `train()` is a hash of information about how the training went: 68 | 69 | ```javascript 70 | { 71 | error: 0.0039139985510105032, // training error 72 | iterations: 406 // training iterations 73 | } 74 | ``` 75 | 76 | #### Failing 77 | If the network failed to train, the error will be above the error threshold. This could happen because the training data is too noisy (most likely), the network doesn't have enough hidden layers or nodes to handle the complexity of the data, or it hasn't trained for enough iterations. 78 | 79 | If the training error is still something huge like `0.4` after 20000 iterations, it's a good sign that the network can't make sense of the data you're giving it. 80 | 81 | ## JSON 82 | Serialize or load in the state of a trained network with JSON: 83 | 84 | ```javascript 85 | var json = net.toJSON(); 86 | 87 | net.fromJSON(json); 88 | ``` 89 | 90 | You can also get a custom standalone function from a trained network that acts just like `run()`: 91 | 92 | ```javascript 93 | var run = net.toFunction(); 94 | 95 | var output = run({ r: 1, g: 0.4, b: 0 }); 96 | 97 | console.log(run.toString()); // copy and paste! no need to import brain.js 98 | ``` 99 | 100 | ## Options 101 | `NeuralNetwork()` takes a hash of options: 102 | 103 | ```javascript 104 | var net = new brain.NeuralNetwork({ 105 | hiddenLayers: [4], 106 | learningRate: 0.6 // global learning rate, useful when training using streams 107 | }); 108 | ``` 109 | 110 | #### hiddenLayers 111 | Specify the number of hidden layers in the network and the size of each layer. For example, if you want two hidden layers - the first with 3 nodes and the second with 4 nodes, you'd give: 112 | 113 | ``` 114 | hiddenLayers: [3, 4] 115 | ``` 116 | 117 | By default `brain` uses one hidden layer with size proportionate to the size of the input array. 118 | 119 | ## Streams 120 | The network now has a [WriteStream](http://nodejs.org/api/stream.html#stream_class_stream_writable). You can train the network by using `pipe()` to send the training data to the network. 121 | 122 | #### Example 123 | Refer to `stream-example.js` for an example on how to train the network with a stream. 124 | 125 | #### Initialization 126 | To train the network using a stream you must first create the stream by calling `net.createTrainStream()` which takes the following options: 127 | 128 | * `floodCallback()` - the callback function to re-populate the stream. This gets called on every training iteration. 129 | * `doneTrainingCallback(info)` - the callback function to execute when the network is done training. The `info` param will contain a hash of information about how the training went: 130 | 131 | ```javascript 132 | { 133 | error: 0.0039139985510105032, // training error 134 | iterations: 406 // training iterations 135 | } 136 | ``` 137 | 138 | #### Transform 139 | Use a [Transform](http://nodejs.org/api/stream.html#stream_class_stream_transform) to coerce the data into the correct format. You might also use a Transform stream to normalize your data on the fly. 140 | -------------------------------------------------------------------------------- /bower.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "brain", 3 | "version": "0.7.0", 4 | "homepage": "https://github.com/harthur/brain", 5 | "authors": [ 6 | "Heather Arthur " 7 | ], 8 | "description": "Neural network library", 9 | "keywords": [ 10 | "neural-networks", 11 | "machine-learning", 12 | "classifier" 13 | ], 14 | "main": "lib/brain.js", 15 | "ignore": [ 16 | "node_modules", 17 | "test" 18 | ], 19 | "dependencies": { 20 | "underscore": ">=1.5.1" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /browser.js: -------------------------------------------------------------------------------- 1 | // this file is the entrypoint for building a browser file with browserify 2 | 3 | brain = require("./lib/brain"); -------------------------------------------------------------------------------- /lib/brain.js: -------------------------------------------------------------------------------- 1 | exports.NeuralNetwork = require("./neuralnetwork").NeuralNetwork; 2 | exports.crossValidate = require("./cross-validate"); 3 | -------------------------------------------------------------------------------- /lib/cross-validate.js: -------------------------------------------------------------------------------- 1 | var _ = require("underscore")._; 2 | 3 | function testPartition(classifierConst, opts, trainOpts, trainSet, testSet) { 4 | var classifier = new classifierConst(opts); 5 | 6 | var beginTrain = Date.now(); 7 | 8 | var trainingStats = classifier.train(trainSet, trainOpts); 9 | 10 | var beginTest = Date.now(); 11 | 12 | var testStats = classifier.test(testSet); 13 | 14 | var endTest = Date.now(); 15 | 16 | var stats = _(testStats).extend({ 17 | trainTime : beginTest - beginTrain, 18 | testTime : endTest - beginTest, 19 | iterations: trainingStats.iterations, 20 | trainError: trainingStats.error, 21 | learningRate: trainOpts.learningRate, 22 | hidden: classifier.hiddenSizes, 23 | network: classifier.toJSON() 24 | }); 25 | 26 | return stats; 27 | } 28 | 29 | module.exports = function crossValidate(classifierConst, data, opts, trainOpts, k) { 30 | k = k || 4; 31 | var size = data.length / k; 32 | 33 | data = _(data).sortBy(function() { 34 | return Math.random(); 35 | }); 36 | 37 | var avgs = { 38 | error : 0, 39 | trainTime : 0, 40 | testTime : 0, 41 | iterations: 0, 42 | trainError: 0 43 | }; 44 | 45 | var stats = { 46 | truePos: 0, 47 | trueNeg: 0, 48 | falsePos: 0, 49 | falseNeg: 0, 50 | total: 0 51 | }; 52 | 53 | var misclasses = []; 54 | 55 | var results = _.range(k).map(function(i) { 56 | var dclone = _(data).clone(); 57 | var testSet = dclone.splice(i * size, size); 58 | var trainSet = dclone; 59 | 60 | var result = testPartition(classifierConst, opts, trainOpts, trainSet, testSet); 61 | 62 | _(avgs).each(function(sum, stat) { 63 | avgs[stat] = sum + result[stat]; 64 | }); 65 | 66 | _(stats).each(function(sum, stat) { 67 | stats[stat] = sum + result[stat]; 68 | }) 69 | 70 | misclasses.push(result.misclasses); 71 | 72 | return result; 73 | }); 74 | 75 | _(avgs).each(function(sum, i) { 76 | avgs[i] = sum / k; 77 | }); 78 | 79 | stats.precision = stats.truePos / (stats.truePos + stats.falsePos); 80 | stats.recall = stats.truePos / (stats.truePos + stats.falseNeg); 81 | stats.accuracy = (stats.trueNeg + stats.truePos) / stats.total; 82 | 83 | stats.testSize = size; 84 | stats.trainSize = data.length - size; 85 | 86 | return { 87 | avgs: avgs, 88 | stats: stats, 89 | sets: results, 90 | misclasses: _(misclasses).flatten() 91 | }; 92 | } 93 | -------------------------------------------------------------------------------- /lib/lookup.js: -------------------------------------------------------------------------------- 1 | var _ = require("underscore"); 2 | 3 | /* Functions for turning sparse hashes into arrays and vice versa */ 4 | 5 | function buildLookup(hashes) { 6 | // [{a: 1}, {b: 6, c: 7}] -> {a: 0, b: 1, c: 2} 7 | var hash = _(hashes).reduce(function(memo, hash) { 8 | return _(memo).extend(hash); 9 | }, {}); 10 | return lookupFromHash(hash); 11 | } 12 | 13 | function lookupFromHash(hash) { 14 | // {a: 6, b: 7} -> {a: 0, b: 1} 15 | var lookup = {}; 16 | var index = 0; 17 | for (var i in hash) { 18 | lookup[i] = index++; 19 | } 20 | return lookup; 21 | } 22 | 23 | function toArray(lookup, hash) { 24 | // {a: 0, b: 1}, {a: 6} -> [6, 0] 25 | var array = []; 26 | for (var i in lookup) { 27 | array[lookup[i]] = hash[i] || 0; 28 | } 29 | return array; 30 | } 31 | 32 | function toHash(lookup, array) { 33 | // {a: 0, b: 1}, [6, 7] -> {a: 6, b: 7} 34 | var hash = {}; 35 | for (var i in lookup) { 36 | hash[i] = array[lookup[i]]; 37 | } 38 | return hash; 39 | } 40 | 41 | function lookupFromArray(array) { 42 | var lookup = {}; 43 | // super fast loop 44 | var z = 0; 45 | var i = array.length; 46 | while (i-- > 0) { 47 | lookup[array[i]] = z++; 48 | }; 49 | return lookup; 50 | } 51 | 52 | module.exports = { 53 | buildLookup: buildLookup, 54 | lookupFromHash: lookupFromHash, 55 | toArray: toArray, 56 | toHash: toHash, 57 | lookupFromArray: lookupFromArray 58 | }; -------------------------------------------------------------------------------- /lib/neuralnetwork.js: -------------------------------------------------------------------------------- 1 | var _ = require("underscore"), 2 | lookup = require("./lookup"), 3 | Writable = require('stream').Writable, 4 | inherits = require('inherits'); 5 | 6 | var NeuralNetwork = function(options) { 7 | options = options || {}; 8 | this.learningRate = options.learningRate || 0.3; 9 | this.momentum = options.momentum || 0.1; 10 | this.hiddenSizes = options.hiddenLayers; 11 | 12 | this.binaryThresh = options.binaryThresh || 0.5; 13 | } 14 | 15 | NeuralNetwork.prototype = { 16 | initialize: function(sizes) { 17 | this.sizes = sizes; 18 | this.outputLayer = this.sizes.length - 1; 19 | 20 | this.biases = []; // weights for bias nodes 21 | this.weights = []; 22 | this.outputs = []; 23 | 24 | // state for training 25 | this.deltas = []; 26 | this.changes = []; // for momentum 27 | this.errors = []; 28 | 29 | for (var layer = 0; layer <= this.outputLayer; layer++) { 30 | var size = this.sizes[layer]; 31 | this.deltas[layer] = zeros(size); 32 | this.errors[layer] = zeros(size); 33 | this.outputs[layer] = zeros(size); 34 | 35 | if (layer > 0) { 36 | this.biases[layer] = randos(size); 37 | this.weights[layer] = new Array(size); 38 | this.changes[layer] = new Array(size); 39 | 40 | for (var node = 0; node < size; node++) { 41 | var prevSize = this.sizes[layer - 1]; 42 | this.weights[layer][node] = randos(prevSize); 43 | this.changes[layer][node] = zeros(prevSize); 44 | } 45 | } 46 | } 47 | }, 48 | 49 | run: function(input) { 50 | if (this.inputLookup) { 51 | input = lookup.toArray(this.inputLookup, input); 52 | } 53 | 54 | var output = this.runInput(input); 55 | 56 | if (this.outputLookup) { 57 | output = lookup.toHash(this.outputLookup, output); 58 | } 59 | return output; 60 | }, 61 | 62 | runInput: function(input) { 63 | this.outputs[0] = input; // set output state of input layer 64 | 65 | for (var layer = 1; layer <= this.outputLayer; layer++) { 66 | for (var node = 0; node < this.sizes[layer]; node++) { 67 | var weights = this.weights[layer][node]; 68 | 69 | var sum = this.biases[layer][node]; 70 | for (var k = 0; k < weights.length; k++) { 71 | sum += weights[k] * input[k]; 72 | } 73 | this.outputs[layer][node] = 1 / (1 + Math.exp(-sum)); 74 | } 75 | var output = input = this.outputs[layer]; 76 | } 77 | return output; 78 | }, 79 | 80 | train: function(data, options) { 81 | data = this.formatData(data); 82 | 83 | options = options || {}; 84 | var iterations = options.iterations || 20000; 85 | var errorThresh = options.errorThresh || 0.005; 86 | var log = options.log ? (_.isFunction(options.log) ? options.log : console.log) : false; 87 | var logPeriod = options.logPeriod || 10; 88 | var learningRate = options.learningRate || this.learningRate || 0.3; 89 | var callback = options.callback; 90 | var callbackPeriod = options.callbackPeriod || 10; 91 | 92 | var inputSize = data[0].input.length; 93 | var outputSize = data[0].output.length; 94 | 95 | var hiddenSizes = this.hiddenSizes; 96 | if (!hiddenSizes) { 97 | hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; 98 | } 99 | var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); 100 | this.initialize(sizes); 101 | 102 | var error = 1; 103 | for (var i = 0; i < iterations && error > errorThresh; i++) { 104 | var sum = 0; 105 | for (var j = 0; j < data.length; j++) { 106 | var err = this.trainPattern(data[j].input, data[j].output, learningRate); 107 | sum += err; 108 | } 109 | error = sum / data.length; 110 | 111 | if (log && (i % logPeriod == 0)) { 112 | log("iterations:", i, "training error:", error); 113 | } 114 | if (callback && (i % callbackPeriod == 0)) { 115 | callback({ error: error, iterations: i }); 116 | } 117 | } 118 | 119 | return { 120 | error: error, 121 | iterations: i 122 | }; 123 | }, 124 | 125 | trainPattern : function(input, target, learningRate) { 126 | learningRate = learningRate || this.learningRate; 127 | 128 | // forward propogate 129 | this.runInput(input); 130 | 131 | // back propogate 132 | this.calculateDeltas(target); 133 | this.adjustWeights(learningRate); 134 | 135 | var error = mse(this.errors[this.outputLayer]); 136 | return error; 137 | }, 138 | 139 | calculateDeltas: function(target) { 140 | for (var layer = this.outputLayer; layer >= 0; layer--) { 141 | for (var node = 0; node < this.sizes[layer]; node++) { 142 | var output = this.outputs[layer][node]; 143 | 144 | var error = 0; 145 | if (layer == this.outputLayer) { 146 | error = target[node] - output; 147 | } 148 | else { 149 | var deltas = this.deltas[layer + 1]; 150 | for (var k = 0; k < deltas.length; k++) { 151 | error += deltas[k] * this.weights[layer + 1][k][node]; 152 | } 153 | } 154 | this.errors[layer][node] = error; 155 | this.deltas[layer][node] = error * output * (1 - output); 156 | } 157 | } 158 | }, 159 | 160 | adjustWeights: function(learningRate) { 161 | for (var layer = 1; layer <= this.outputLayer; layer++) { 162 | var incoming = this.outputs[layer - 1]; 163 | 164 | for (var node = 0; node < this.sizes[layer]; node++) { 165 | var delta = this.deltas[layer][node]; 166 | 167 | for (var k = 0; k < incoming.length; k++) { 168 | var change = this.changes[layer][node][k]; 169 | 170 | change = (learningRate * delta * incoming[k]) 171 | + (this.momentum * change); 172 | 173 | this.changes[layer][node][k] = change; 174 | this.weights[layer][node][k] += change; 175 | } 176 | this.biases[layer][node] += learningRate * delta; 177 | } 178 | } 179 | }, 180 | 181 | formatData: function(data) { 182 | if (!_.isArray(data)) { // turn stream datum into array 183 | var tmp = []; 184 | tmp.push(data); 185 | data = tmp; 186 | } 187 | // turn sparse hash input into arrays with 0s as filler 188 | var datum = data[0].input; 189 | if (!_(datum).isArray() && !(datum instanceof Float64Array)) { 190 | if (!this.inputLookup) { 191 | this.inputLookup = lookup.buildLookup(_(data).pluck("input")); 192 | } 193 | data = data.map(function(datum) { 194 | var array = lookup.toArray(this.inputLookup, datum.input) 195 | return _(_(datum).clone()).extend({ input: array }); 196 | }, this); 197 | } 198 | 199 | if (!_(data[0].output).isArray()) { 200 | if (!this.outputLookup) { 201 | this.outputLookup = lookup.buildLookup(_(data).pluck("output")); 202 | } 203 | data = data.map(function(datum) { 204 | var array = lookup.toArray(this.outputLookup, datum.output); 205 | return _(_(datum).clone()).extend({ output: array }); 206 | }, this); 207 | } 208 | return data; 209 | }, 210 | 211 | test : function(data) { 212 | data = this.formatData(data); 213 | 214 | // for binary classification problems with one output node 215 | var isBinary = data[0].output.length == 1; 216 | var falsePos = 0, 217 | falseNeg = 0, 218 | truePos = 0, 219 | trueNeg = 0; 220 | 221 | // for classification problems 222 | var misclasses = []; 223 | 224 | // run each pattern through the trained network and collect 225 | // error and misclassification statistics 226 | var sum = 0; 227 | for (var i = 0; i < data.length; i++) { 228 | var output = this.runInput(data[i].input); 229 | var target = data[i].output; 230 | 231 | var actual, expected; 232 | if (isBinary) { 233 | actual = output[0] > this.binaryThresh ? 1 : 0; 234 | expected = target[0]; 235 | } 236 | else { 237 | actual = output.indexOf(_(output).max()); 238 | expected = target.indexOf(_(target).max()); 239 | } 240 | 241 | if (actual != expected) { 242 | var misclass = data[i]; 243 | _(misclass).extend({ 244 | actual: actual, 245 | expected: expected 246 | }) 247 | misclasses.push(misclass); 248 | } 249 | 250 | if (isBinary) { 251 | if (actual == 0 && expected == 0) { 252 | trueNeg++; 253 | } 254 | else if (actual == 1 && expected == 1) { 255 | truePos++; 256 | } 257 | else if (actual == 0 && expected == 1) { 258 | falseNeg++; 259 | } 260 | else if (actual == 1 && expected == 0) { 261 | falsePos++; 262 | } 263 | } 264 | 265 | var errors = output.map(function(value, i) { 266 | return target[i] - value; 267 | }); 268 | sum += mse(errors); 269 | } 270 | var error = sum / data.length; 271 | 272 | var stats = { 273 | error: error, 274 | misclasses: misclasses 275 | }; 276 | 277 | if (isBinary) { 278 | _(stats).extend({ 279 | trueNeg: trueNeg, 280 | truePos: truePos, 281 | falseNeg: falseNeg, 282 | falsePos: falsePos, 283 | total: data.length, 284 | precision: truePos / (truePos + falsePos), 285 | recall: truePos / (truePos + falseNeg), 286 | accuracy: (trueNeg + truePos) / data.length 287 | }) 288 | } 289 | return stats; 290 | }, 291 | 292 | toJSON: function() { 293 | /* make json look like: 294 | { 295 | layers: [ 296 | { x: {}, 297 | y: {}}, 298 | {'0': {bias: -0.98771313, weights: {x: 0.8374838, y: 1.245858}, 299 | '1': {bias: 3.48192004, weights: {x: 1.7825821, y: -2.67899}}}, 300 | { f: {bias: 0.27205739, weights: {'0': 1.3161821, '1': 2.00436}}} 301 | ] 302 | } 303 | */ 304 | var layers = []; 305 | for (var layer = 0; layer <= this.outputLayer; layer++) { 306 | layers[layer] = {}; 307 | 308 | var nodes; 309 | // turn any internal arrays back into hashes for readable json 310 | if (layer == 0 && this.inputLookup) { 311 | nodes = _(this.inputLookup).keys(); 312 | } 313 | else if (layer == this.outputLayer && this.outputLookup) { 314 | nodes = _(this.outputLookup).keys(); 315 | } 316 | else { 317 | nodes = _.range(0, this.sizes[layer]); 318 | } 319 | 320 | for (var j = 0; j < nodes.length; j++) { 321 | var node = nodes[j]; 322 | layers[layer][node] = {}; 323 | 324 | if (layer > 0) { 325 | layers[layer][node].bias = this.biases[layer][j]; 326 | layers[layer][node].weights = {}; 327 | for (var k in layers[layer - 1]) { 328 | var index = k; 329 | if (layer == 1 && this.inputLookup) { 330 | index = this.inputLookup[k]; 331 | } 332 | layers[layer][node].weights[k] = this.weights[layer][j][index]; 333 | } 334 | } 335 | } 336 | } 337 | return { layers: layers, outputLookup:!!this.outputLookup, inputLookup:!!this.inputLookup }; 338 | }, 339 | 340 | fromJSON: function(json) { 341 | var size = json.layers.length; 342 | this.outputLayer = size - 1; 343 | 344 | this.sizes = new Array(size); 345 | this.weights = new Array(size); 346 | this.biases = new Array(size); 347 | this.outputs = new Array(size); 348 | 349 | for (var i = 0; i <= this.outputLayer; i++) { 350 | var layer = json.layers[i]; 351 | if (i == 0 && (!layer[0] || json.inputLookup)) { 352 | this.inputLookup = lookup.lookupFromHash(layer); 353 | } 354 | else if (i == this.outputLayer && (!layer[0] || json.outputLookup)) { 355 | this.outputLookup = lookup.lookupFromHash(layer); 356 | } 357 | 358 | var nodes = _(layer).keys(); 359 | this.sizes[i] = nodes.length; 360 | this.weights[i] = []; 361 | this.biases[i] = []; 362 | this.outputs[i] = []; 363 | 364 | for (var j in nodes) { 365 | var node = nodes[j]; 366 | this.biases[i][j] = layer[node].bias; 367 | this.weights[i][j] = _(layer[node].weights).toArray(); 368 | } 369 | } 370 | return this; 371 | }, 372 | 373 | toFunction: function() { 374 | var json = this.toJSON(); 375 | // return standalone function that mimics run() 376 | return new Function("input", 377 | ' var net = ' + JSON.stringify(json) + ';\n\n\ 378 | for (var i = 1; i < net.layers.length; i++) {\n\ 379 | var layer = net.layers[i];\n\ 380 | var output = {};\n\ 381 | \n\ 382 | for (var id in layer) {\n\ 383 | var node = layer[id];\n\ 384 | var sum = node.bias;\n\ 385 | \n\ 386 | for (var iid in node.weights) {\n\ 387 | sum += node.weights[iid] * input[iid];\n\ 388 | }\n\ 389 | output[id] = (1 / (1 + Math.exp(-sum)));\n\ 390 | }\n\ 391 | input = output;\n\ 392 | }\n\ 393 | return output;'); 394 | }, 395 | 396 | // This will create a TrainStream (WriteStream) 397 | // for us to send the training data to. 398 | // param: opts - the training options 399 | createTrainStream: function(opts) { 400 | opts = opts || {}; 401 | opts.neuralNetwork = this; 402 | this.trainStream = new TrainStream(opts); 403 | return this.trainStream; 404 | } 405 | } 406 | 407 | function randomWeight() { 408 | return Math.random() * 0.4 - 0.2; 409 | } 410 | 411 | function zeros(size) { 412 | var array = new Array(size); 413 | for (var i = 0; i < size; i++) { 414 | array[i] = 0; 415 | } 416 | return array; 417 | } 418 | 419 | function randos(size) { 420 | var array = new Array(size); 421 | for (var i = 0; i < size; i++) { 422 | array[i] = randomWeight(); 423 | } 424 | return array; 425 | } 426 | 427 | function mse(errors) { 428 | // mean squared error 429 | var sum = 0; 430 | for (var i = 0; i < errors.length; i++) { 431 | sum += Math.pow(errors[i], 2); 432 | } 433 | return sum / errors.length; 434 | } 435 | 436 | exports.NeuralNetwork = NeuralNetwork; 437 | 438 | function TrainStream(opts) { 439 | Writable.call(this, { 440 | objectMode: true 441 | }); 442 | 443 | opts = opts || {}; 444 | 445 | // require the neuralNetwork 446 | if (!opts.neuralNetwork) { 447 | throw new Error('no neural network specified'); 448 | } 449 | 450 | this.neuralNetwork = opts.neuralNetwork; 451 | this.dataFormatDetermined = false; 452 | 453 | this.inputKeys = []; 454 | this.outputKeys = []; // keeps track of keys seen 455 | this.i = 0; // keep track of the for loop i variable that we got rid of 456 | this.iterations = opts.iterations || 20000; 457 | this.errorThresh = opts.errorThresh || 0.005; 458 | this.log = opts.log ? (_.isFunction(opts.log) ? opts.log : console.log) : false; 459 | this.logPeriod = opts.logPeriod || 10; 460 | this.callback = opts.callback; 461 | this.callbackPeriod = opts.callbackPeriod || 10; 462 | this.floodCallback = opts.floodCallback; 463 | this.doneTrainingCallback = opts.doneTrainingCallback; 464 | 465 | this.size = 0; 466 | this.count = 0; 467 | 468 | this.sum = 0; 469 | 470 | this.on('finish', this.finishStreamIteration); 471 | 472 | return this; 473 | } 474 | 475 | inherits(TrainStream, Writable); 476 | 477 | /* 478 | _write expects data to be in the form of a datum. 479 | ie. {input: {a: 1 b: 0}, output: {z: 0}} 480 | */ 481 | TrainStream.prototype._write = function(chunk, enc, next) { 482 | if (!chunk) { // check for the end of one interation of the stream 483 | this.emit('finish'); 484 | return next(); 485 | } 486 | 487 | if (!this.dataFormatDetermined) { 488 | this.size++; 489 | this.inputKeys = _.union(this.inputKeys, _.keys(chunk.input)); 490 | this.outputKeys = _.union(this.outputKeys, _.keys(chunk.output)); 491 | this.firstDatum = this.firstDatum || chunk; 492 | return next(); 493 | } 494 | 495 | this.count++; 496 | 497 | var data = this.neuralNetwork.formatData(chunk); 498 | this.trainDatum(data[0]); 499 | 500 | // tell the Readable Stream that we are ready for more data 501 | next(); 502 | } 503 | 504 | TrainStream.prototype.trainDatum = function(datum) { 505 | var err = this.neuralNetwork.trainPattern(datum.input, datum.output); 506 | this.sum += err; 507 | } 508 | 509 | TrainStream.prototype.finishStreamIteration = function() { 510 | if (this.dataFormatDetermined && this.size !== this.count) { 511 | console.log("This iteration's data length was different from the first."); 512 | } 513 | 514 | if (!this.dataFormatDetermined) { 515 | // create the lookup 516 | this.neuralNetwork.inputLookup = lookup.lookupFromArray(this.inputKeys); 517 | if(!_.isArray(this.firstDatum.output)){ 518 | this.neuralNetwork.outputLookup = lookup.lookupFromArray(this.outputKeys); 519 | } 520 | 521 | var data = this.neuralNetwork.formatData(this.firstDatum); 522 | var inputSize = data[0].input.length; 523 | var outputSize = data[0].output.length; 524 | 525 | var hiddenSizes = this.hiddenSizes; 526 | if (!hiddenSizes) { 527 | hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; 528 | } 529 | var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); 530 | this.dataFormatDetermined = true; 531 | this.neuralNetwork.initialize(sizes); 532 | 533 | if (typeof this.floodCallback === 'function') { 534 | this.floodCallback(); 535 | } 536 | return; 537 | } 538 | 539 | var error = this.sum / this.size; 540 | 541 | if (this.log && (this.i % this.logPeriod == 0)) { 542 | this.log("iterations:", this.i, "training error:", error); 543 | } 544 | if (this.callback && (this.i % this.callbackPeriod == 0)) { 545 | this.callback({ 546 | error: error, 547 | iterations: this.i 548 | }); 549 | } 550 | 551 | this.sum = 0; 552 | this.count = 0; 553 | // update the iterations 554 | this.i++; 555 | 556 | // do a check here to see if we need the stream again 557 | if (this.i < this.iterations && error > this.errorThresh) { 558 | if (typeof this.floodCallback === 'function') { 559 | return this.floodCallback(); 560 | } 561 | } else { 562 | // done training 563 | if (typeof this.doneTrainingCallback === 'function') { 564 | return this.doneTrainingCallback({ 565 | error: error, 566 | iterations: this.i 567 | }); 568 | } 569 | } 570 | } 571 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "brain", 3 | "description": "Neural network library", 4 | "version": "0.7.0", 5 | "author": "Heather Arthur ", 6 | "repository": { 7 | "type": "git", 8 | "url": "http://github.com/harthur/brain.git" 9 | }, 10 | "scripts": { 11 | "test-unit": "mocha test/unit", 12 | "test-cv": "mocha test/cross-validation --timeout 10000", 13 | "test": "npm run test-unit && npm run test-cv" 14 | }, 15 | "main": "./lib/brain", 16 | "dependencies": { 17 | "underscore": ">=1.5.1", 18 | "inherits": "~2.0.1" 19 | }, 20 | "devDependencies": { 21 | "mocha": ">=1.0.0", 22 | "canvas": ">=0.10.0", 23 | "grunt": "~0.4.3", 24 | "grunt-contrib-uglify": "~0.2.0", 25 | "grunt-mocha-test": "~0.11.0", 26 | "browserify": "~3.32.0" 27 | }, 28 | "keywords": ["neural network", "classifier", "machine learning"] 29 | } 30 | -------------------------------------------------------------------------------- /stream-example.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("./lib/brain"); 3 | 4 | var net = new brain.NeuralNetwork(); 5 | 6 | var xor = [ 7 | { input: [0, 0], output: [0]}, 8 | { input: [0, 1], output: [1]}, 9 | { input: [1, 0], output: [1]}, 10 | { input: [1, 1], output: [0]}]; 11 | 12 | var trainStream = net.createTrainStream({ 13 | /** 14 | * Write training data to the stream. Called on each training iteration. 15 | */ 16 | floodCallback: function() { 17 | flood(trainStream, xor); 18 | }, 19 | 20 | /** 21 | * Called when the network is done training. 22 | */ 23 | doneTrainingCallback: function(obj) { 24 | console.log("trained in " + obj.iterations + " iterations with error: " 25 | + obj.error); 26 | 27 | var result = net.run([0, 1]); 28 | 29 | console.log("0 XOR 1: ", result); // 0.987 30 | } 31 | }); 32 | 33 | // kick it off 34 | flood(trainStream, xor); 35 | 36 | 37 | function flood(stream, data) { 38 | for (var i = 0; i < data.length; i++) { 39 | stream.write(data[i]); 40 | } 41 | // let it know we've reached the end of the data 42 | stream.write(null); 43 | } 44 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | To run the tests in this directory, make sure you've installed the dev dependencies with this command from the top-level directory: 4 | 5 | ``` 6 | npm install 7 | ``` 8 | 9 | Then you can run all tests (unit and cross-validation) using `npm test`. 10 | 11 | # Unit tests 12 | Run the unit tests with: 13 | 14 | ``` 15 | grunt test 16 | ``` 17 | 18 | or 19 | 20 | `npm run test-unit` 21 | 22 | # Cross-validation tests 23 | The cross-validation tests will actually test how good the neural network is a training by getting a bunch of training data, training it with some, and using the rest as verification. 24 | 25 | Cross-validation tests will take a long time to run, and in the end will give you a printout with the average error of the test data. 26 | 27 | Run these with: 28 | 29 | ``` 30 | mocha test/cross-validation/* --timeout 10000 31 | ``` 32 | 33 | or 34 | 35 | `npm run test-cv` 36 | -------------------------------------------------------------------------------- /test/cross-validation/ocr.js: -------------------------------------------------------------------------------- 1 | var canvas = require("canvas"), 2 | _ = require("underscore"), 3 | brain = require("../../lib/brain"), 4 | crossValidate = require("../../lib/cross-validate"); 5 | 6 | var dim = 24; 7 | 8 | function getSampling(context, letter, font) { 9 | context.clearRect(0, 0, dim, dim); 10 | context.font = dim + "px " + font; 11 | context.fillText(letter, 0, dim); 12 | 13 | var data = context.getImageData(0, 0, dim, dim); 14 | var lumas = extractPoints(data); 15 | return lumas; 16 | } 17 | 18 | function extractPoints(imageData) { 19 | var points = []; 20 | for (var x = 0; x < imageData.width; x = x + 2) { 21 | for (var y = 0; y < imageData.height; y = y + 2) { 22 | var i = x * 4 + y * 4 * imageData.width; 23 | var r = imageData.data[i], 24 | g = imageData.data[i + 1], 25 | b = imageData.data[i + 2], 26 | a = imageData.data[i + 3]; 27 | 28 | var luma = a == 0 ? 1 : (r * 299/1000 + g * 587/1000 29 | + b * 114/1000 ) / 255; 30 | 31 | points.push(luma); 32 | } 33 | } 34 | return points; 35 | } 36 | 37 | describe('OCR cross-validation', function() { 38 | it('recognize characters in different fonts', function() { 39 | var canvas = require("canvas"); 40 | var canvas = new canvas(dim, dim); 41 | var context = canvas.getContext("2d"); 42 | 43 | var letters = ["A", "B", "C", "D", "E", 44 | "K", "O", "Z"]; 45 | var fonts = ["Arial", "Courier", "Georgia", "Menlo", "Optima", 46 | "Copperplate", "American Typewriter", "Comic Sans", 47 | "Baskerville", "Verdana", "Helvetica", "Didot", 48 | "Geneva", "Cracked", "Impact", "Cooper"]; 49 | 50 | var data = []; 51 | 52 | letters.forEach(function(letter) { 53 | fonts.forEach(function(font) { 54 | var input = getSampling(context, letter, font); 55 | 56 | var output = {}; 57 | output[letter] = 1; 58 | data.push({ input: input, output: output }); 59 | }); 60 | }); 61 | 62 | console.log("Cross validating"); 63 | var result = crossValidate(brain.NeuralNetwork, data, {}); 64 | 65 | console.log("\nMisclassifications:"); 66 | result.misclasses.forEach(function(misclass) { 67 | console.log("input: " + misclass.input 68 | + " actual: " + letters[misclass.actual] 69 | + " expected: " + letters[misclass.expected] + "\n") 70 | }) 71 | 72 | console.log("\nCross-validation of OCR data:\n"); 73 | console.log(result.avgs); 74 | 75 | console.log("\nMisclassification rate: " 76 | + result.misclasses.length / data.length); 77 | 78 | console.log("\nMean squared error: " 79 | + result.avgs.error); 80 | 81 | var perf = result.avgs.iterations / (result.avgs.trainTime / 1000); 82 | console.log("\nTraining iterations per second: " + perf); 83 | 84 | assert.ok(result.avgs.error < .1); 85 | }) 86 | }) 87 | -------------------------------------------------------------------------------- /test/unit/bitwise.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("../../lib/brain"); 3 | 4 | var wiggle = 0.1; 5 | 6 | function testBitwise(data, op) { 7 | var net = new brain.NeuralNetwork(); 8 | net.train(data, { errorThresh: 0.003 }); 9 | 10 | for(var i in data) { 11 | var output = net.run(data[i].input); 12 | var target = data[i].output; 13 | assert.ok(output < (target + wiggle) && output > (target - wiggle), 14 | "failed to train " + op + " - output: " + output + " target: " + target); 15 | } 16 | } 17 | 18 | describe('bitwise functions', function() { 19 | 20 | it('NOT function', function() { 21 | var not = [{input: [0], output: [1]}, 22 | {input: [1], output: [0]}]; 23 | testBitwise(not, "not"); 24 | }) 25 | 26 | it('XOR function', function() { 27 | var xor = [{input: [0, 0], output: [0]}, 28 | {input: [0, 1], output: [1]}, 29 | {input: [1, 0], output: [1]}, 30 | {input: [1, 1], output: [0]}]; 31 | testBitwise(xor, "xor"); 32 | }) 33 | 34 | it('OR function', function() { 35 | var or = [{input: [0, 0], output: [0]}, 36 | {input: [0, 1], output: [1]}, 37 | {input: [1, 0], output: [1]}, 38 | {input: [1, 1], output: [1]}]; 39 | testBitwise(or, "or"); 40 | }); 41 | 42 | it('AND function', function() { 43 | var and = [{input: [0, 0], output: [0]}, 44 | {input: [0, 1], output: [0]}, 45 | {input: [1, 0], output: [0]}, 46 | {input: [1, 1], output: [1]}]; 47 | testBitwise(and, "and"); 48 | }) 49 | }) 50 | -------------------------------------------------------------------------------- /test/unit/hash.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("../../lib/brain"); 3 | 4 | describe('hash input and output', function() { 5 | it('runs correctly with array input and output', function() { 6 | var net = new brain.NeuralNetwork(); 7 | 8 | net.train([{input: [0, 0], output: [0]}, 9 | {input: [0, 1], output: [1]}, 10 | {input: [1, 0], output: [1]}, 11 | {input: [1, 1], output: [0]}]); 12 | var output = net.run([1, 0]); 13 | 14 | assert.ok(output[0] > 0.9, "output: " + output[0]); 15 | }) 16 | 17 | it('runs correctly with hash input', function() { 18 | var net = new brain.NeuralNetwork(); 19 | 20 | var info = net.train([{input: { x: 0, y: 0 }, output: [0]}, 21 | {input: { x: 0, y: 1 }, output: [1]}, 22 | {input: { x: 1, y: 0 }, output: [1]}, 23 | {input: { x: 1, y: 1 }, output: [0]}]); 24 | var output = net.run({x: 1, y: 0}); 25 | 26 | assert.ok(output[0] > 0.9, "output: " + output[0]); 27 | }) 28 | 29 | it('runs correctly with hash output', function() { 30 | var net = new brain.NeuralNetwork(); 31 | 32 | net.train([{input: [0, 0], output: { answer: 0 }}, 33 | {input: [0, 1], output: { answer: 1 }}, 34 | {input: [1, 0], output: { answer: 1 }}, 35 | {input: [1, 1], output: { answer: 0 }}]); 36 | 37 | var output = net.run([1, 0]); 38 | 39 | assert.ok(output.answer > 0.9, "output: " + output.answer); 40 | }) 41 | 42 | it('runs correctly with hash input and output', function() { 43 | var net = new brain.NeuralNetwork(); 44 | 45 | net.train([{input: { x: 0, y: 0 }, output: { answer: 0 }}, 46 | {input: { x: 0, y: 1 }, output: { answer: 1 }}, 47 | {input: { x: 1, y: 0 }, output: { answer: 1 }}, 48 | {input: { x: 1, y: 1 }, output: { answer: 0 }}]); 49 | 50 | var output = net.run({x: 1, y: 0}); 51 | 52 | assert.ok(output.answer > 0.9, "output: " + output.answer); 53 | }) 54 | 55 | it('runs correctly with sparse hashes', function() { 56 | var net = new brain.NeuralNetwork(); 57 | 58 | net.train([{input: {}, output: {}}, 59 | {input: { y: 1 }, output: { answer: 1 }}, 60 | {input: { x: 1 }, output: { answer: 1 }}, 61 | {input: { x: 1, y: 1 }, output: {}}]); 62 | 63 | 64 | var output = net.run({x: 1}); 65 | 66 | assert.ok(output.answer > 0.9); 67 | }) 68 | 69 | it('runs correctly with unseen input', function() { 70 | var net = new brain.NeuralNetwork(); 71 | 72 | net.train([{input: {}, output: {}}, 73 | {input: { y: 1 }, output: { answer: 1 }}, 74 | {input: { x: 1 }, output: { answer: 1 }}, 75 | {input: { x: 1, y: 1 }, output: {}}]); 76 | 77 | var output = net.run({x: 1, z: 1}); 78 | assert.ok(output.answer > 0.9); 79 | }) 80 | }) 81 | -------------------------------------------------------------------------------- /test/unit/json.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("../../lib/brain"); 3 | 4 | describe('JSON', function() { 5 | var net = new brain.NeuralNetwork(); 6 | 7 | net.train([{input: {"0": Math.random(), b: Math.random()}, 8 | output: {c: Math.random(), "0": Math.random()}}, 9 | {input: {"0": Math.random(), b: Math.random()}, 10 | output: {c: Math.random(), "0": Math.random()}}]); 11 | 12 | var serialized = net.toJSON(); 13 | var net2 = new brain.NeuralNetwork().fromJSON(serialized); 14 | 15 | var input = {"0" : Math.random(), b: Math.random()}; 16 | 17 | it('toJSON()/fromJSON()', function() { 18 | var output1 = net.run(input); 19 | var output2 = net2.run(input); 20 | 21 | assert.equal(JSON.stringify(output1), JSON.stringify(output2), 22 | "loading json serialized network failed"); 23 | }) 24 | 25 | 26 | it('toFunction()', function() { 27 | var output1 = net.run(input); 28 | var output2 = net.toFunction()(input); 29 | 30 | assert.equal(JSON.stringify(output1), JSON.stringify(output2), 31 | "standalone network function failed"); 32 | }) 33 | }) 34 | -------------------------------------------------------------------------------- /test/unit/lookup.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | lookup = require("../../lib/lookup"); 3 | 4 | 5 | describe('lookup', function() { 6 | it('lookupFromHash()', function() { 7 | var lup = lookup.lookupFromHash({ a: 6, b: 7, c: 8 }); 8 | 9 | assert.deepEqual(lup, { a: 0, b: 1, c: 2 }); 10 | }) 11 | 12 | it('buildLookup()', function() { 13 | var lup = lookup.buildLookup([{ x: 0, y: 0 }, 14 | { x: 1, z: 0 }, 15 | { q: 0 }, 16 | { x: 1, y: 1 }]); 17 | 18 | assert.deepEqual(lup, { x: 0, y: 1, z: 2, q: 3 }) 19 | }) 20 | 21 | it('toArray()', function() { 22 | var lup = { a: 0, b: 1, c: 2 }; 23 | 24 | var array = lookup.toArray(lup, { b: 8, notinlookup: 9 }); 25 | 26 | assert.deepEqual(array, [0, 8, 0]) 27 | }) 28 | 29 | it('toHash()', function() { 30 | var lup = { b: 1, a: 0, c: 2 }; 31 | 32 | var hash = lookup.toHash(lup, [0, 9, 8]); 33 | 34 | assert.deepEqual(hash, {a: 0, b: 9, c: 8}) 35 | }) 36 | }) 37 | -------------------------------------------------------------------------------- /test/unit/options.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | _ = require("underscore"), 3 | brain = require("../../lib/brain"); 4 | 5 | describe('neural network options', function() { 6 | it('hiddenLayers', function() { 7 | var net = new brain.NeuralNetwork({ hiddenLayers: [8, 7] }); 8 | 9 | net.train([{input: [0, 0], output: [0]}, 10 | {input: [0, 1], output: [1]}, 11 | {input: [1, 0], output: [1]}, 12 | {input: [1, 1], output: [0]}]); 13 | 14 | var json = net.toJSON(); 15 | 16 | assert.equal(json.layers.length, 4); 17 | assert.equal(_(json.layers[1]).keys().length, 8); 18 | assert.equal(_(json.layers[2]).keys().length, 7); 19 | }) 20 | 21 | it('hiddenLayers default expand to input size', function() { 22 | var net = new brain.NeuralNetwork(); 23 | 24 | net.train([{input: [0, 0, 1, 1, 1, 1, 1, 1, 1], output: [0]}, 25 | {input: [0, 1, 1, 1, 1, 1, 1, 1, 1], output: [1]}, 26 | {input: [1, 0, 1, 1, 1, 1, 1, 1, 1], output: [1]}, 27 | {input: [1, 1, 1, 1, 1, 1, 1, 1, 1], output: [0]}]); 28 | 29 | var json = net.toJSON(); 30 | 31 | assert.equal(json.layers.length, 3); 32 | assert.equal(_(json.layers[1]).keys().length, 4, "9 input units means 4 hidden"); 33 | }) 34 | 35 | 36 | it('learningRate - higher learning rate should train faster', function() { 37 | var data = [{input: [0, 0], output: [0]}, 38 | {input: [0, 1], output: [1]}, 39 | {input: [1, 0], output: [1]}, 40 | {input: [1, 1], output: [1]}]; 41 | 42 | var net1 = new brain.NeuralNetwork(); 43 | var iters1 = net1.train(data, { learningRate: 0.5 }).iterations; 44 | 45 | var net2 = new brain.NeuralNetwork(); 46 | var iters2 = net2.train(data, { learningRate: 0.8 }).iterations; 47 | 48 | assert.ok(iters1 > (iters2 * 1.1), iters1 + " !> " + iters2 * 1.1); 49 | }) 50 | 51 | it('learningRate - backwards compatibility', function() { 52 | var data = [{input: [0, 0], output: [0]}, 53 | {input: [0, 1], output: [1]}, 54 | {input: [1, 0], output: [1]}, 55 | {input: [1, 1], output: [1]}]; 56 | 57 | var net1 = new brain.NeuralNetwork({ learningRate: 0.5 }); 58 | var iters1 = net1.train(data).iterations; 59 | 60 | var net2 = new brain.NeuralNetwork( { learningRate: 0.8 }); 61 | var iters2 = net2.train(data).iterations; 62 | 63 | assert.ok(iters1 > (iters2 * 1.1), iters1 + " !> " + iters2 * 1.1); 64 | }) 65 | 66 | it('momentum - higher momentum should train faster', function() { 67 | var data = [{input: [0, 0], output: [0]}, 68 | {input: [0, 1], output: [1]}, 69 | {input: [1, 0], output: [1]}, 70 | {input: [1, 1], output: [1]}]; 71 | 72 | var net1 = new brain.NeuralNetwork({ momentum: 0.1 }); 73 | var iters1 = net1.train(data).iterations; 74 | 75 | var net2 = new brain.NeuralNetwork({ momentum: 0.5 }); 76 | var iters2 = net2.train(data).iterations; 77 | 78 | assert.ok(iters1 > (iters2 * 1.1), iters1 + " !> " + (iters2 * 1.1)); 79 | }) 80 | 81 | describe('log', function () { 82 | var logCalled; 83 | 84 | beforeEach(function () { 85 | logCalled = false; 86 | }); 87 | 88 | function logFunction() { 89 | logCalled = true; 90 | } 91 | 92 | function trainWithLog(log) { 93 | var net = new brain.NeuralNetwork(); 94 | net.train([{input: [0], output: [0]}], 95 | { 96 | log: log, 97 | logPeriod: 1 98 | }); 99 | } 100 | 101 | it('should call console.log if log === true', function () { 102 | var originalLog = console.log; 103 | console.log = logFunction; 104 | 105 | trainWithLog(true); 106 | 107 | console.log = originalLog; 108 | assert.equal(logCalled, true); 109 | }) 110 | 111 | it('should call the given log function', function () { 112 | trainWithLog(logFunction); 113 | 114 | assert.equal(logCalled, true); 115 | }) 116 | }) 117 | }) 118 | -------------------------------------------------------------------------------- /test/unit/stream-bitwise.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("../../lib/brain"); 3 | 4 | function StreamTester(opts) { 5 | if (!(this instanceof StreamTester)) return new StreamTester(opts); 6 | 7 | var self = this; 8 | 9 | this.wiggle = opts.wiggle || 0.1; 10 | this.op = opts.op; 11 | 12 | this.testData = opts.testData; 13 | this.fakeBuffer = []; 14 | this.errorThresh = opts.errorThresh || 0.004; 15 | 16 | this.net = new brain.NeuralNetwork(); 17 | 18 | this.trainStream = this.net.createTrainStream({ 19 | floodCallback: self.flood.bind(self), 20 | doneTrainingCallback: self.doneTraining.bind(self), 21 | errorThresh: self.errorThresh // error threshold to reach 22 | }); 23 | this.flood(); 24 | } 25 | 26 | /* 27 | Every time you finish an epoch of flood, 28 | you must write null to the stream 29 | to let it know we have reached the end of the epoch 30 | */ 31 | StreamTester.prototype.flood = function() { 32 | var self = this; 33 | 34 | for (var i = self.testData.length - 1; i >= 0; i--) { 35 | self.trainStream.write(self.testData[i]); 36 | } 37 | self.trainStream.write(null); 38 | } 39 | 40 | StreamTester.prototype.doneTraining = function(info) { 41 | var self = this; 42 | 43 | for (var i in self.testData) { 44 | var output = self.net.run(self.testData[i].input)[0]; 45 | var target = self.testData[i].output; 46 | assert.ok(output < (target + self.wiggle) && output > (target - self.wiggle), 47 | "failed to train " + self.op + " - output: " + output + " target: " + target); 48 | } 49 | } 50 | 51 | 52 | function testBitwise(data, op) { 53 | var st = StreamTester({ 54 | testData: data, 55 | op: op, 56 | wiggle: 0.1, 57 | errorThresh: 0.003 58 | }); 59 | } 60 | 61 | describe('bitwise functions', function() { 62 | 63 | it('NOT function', function() { 64 | var not = [{ 65 | input: [0], 66 | output: [1] 67 | }, { 68 | input: [1], 69 | output: [0] 70 | }]; 71 | testBitwise(not, "not"); 72 | }) 73 | 74 | it('XOR function', function() { 75 | var xor = [{ 76 | input: [0, 0], 77 | output: [0] 78 | }, { 79 | input: [0, 1], 80 | output: [1] 81 | }, { 82 | input: [1, 0], 83 | output: [1] 84 | }, { 85 | input: [1, 1], 86 | output: [0] 87 | }]; 88 | testBitwise(xor, "xor"); 89 | }) 90 | 91 | it('OR function', function() { 92 | var or = [{ 93 | input: [0, 0], 94 | output: [0] 95 | }, { 96 | input: [0, 1], 97 | output: [1] 98 | }, { 99 | input: [1, 0], 100 | output: [1] 101 | }, { 102 | input: [1, 1], 103 | output: [1] 104 | }]; 105 | testBitwise(or, "or"); 106 | }); 107 | 108 | it('AND function', function() { 109 | var and = [{ 110 | input: [0, 0], 111 | output: [0] 112 | }, { 113 | input: [0, 1], 114 | output: [0] 115 | }, { 116 | input: [1, 0], 117 | output: [0] 118 | }, { 119 | input: [1, 1], 120 | output: [1] 121 | }]; 122 | testBitwise(and, "and"); 123 | }) 124 | }) 125 | -------------------------------------------------------------------------------- /test/unit/trainopts.js: -------------------------------------------------------------------------------- 1 | var assert = require("assert"), 2 | brain = require("../../lib/brain"); 3 | 4 | var data = [{input: [0, 0], output: [0]}, 5 | {input: [0, 1], output: [1]}, 6 | {input: [1, 0], output: [1]}, 7 | {input: [1, 1], output: [1]}]; 8 | 9 | describe('train() options', function() { 10 | it('train until error threshold reached', function() { 11 | var net = new brain.NeuralNetwork(); 12 | var error = net.train(data, { 13 | errorThresh: 0.2, 14 | iterations: 100000 15 | }).error; 16 | 17 | assert.ok(error < 0.2, "network did not train until error threshold was reached"); 18 | }); 19 | 20 | it('train until max iterations reached', function() { 21 | var net = new brain.NeuralNetwork(); 22 | var stats = net.train(data, { 23 | errorThresh: 0.001, 24 | iterations: 1 25 | }); 26 | 27 | assert.equal(stats.iterations, 1); 28 | }) 29 | 30 | it('training callback called with training stats', function(done) { 31 | var iters = 100; 32 | var period = 20; 33 | var target = iters / 20; 34 | 35 | var calls = 0; 36 | 37 | var net = new brain.NeuralNetwork(); 38 | net.train(data, { 39 | iterations: iters, 40 | callback: function(stats) { 41 | assert.ok(stats.iterations % period == 0); 42 | 43 | calls++; 44 | if (calls == target) { 45 | done(); 46 | } 47 | }, 48 | callbackPeriod: 20 49 | }); 50 | }); 51 | }) 52 | --------------------------------------------------------------------------------