├── .babelrc ├── .bmp.yml ├── .editorconfig ├── .gitignore ├── .releaseignore ├── LICENSE ├── README.md ├── circle.yml ├── data ├── .gitkeep ├── test_input.json.gz ├── test_output.json.gz ├── training_input.json.gz ├── training_output.json.gz ├── validation_input.json.gz └── validation_output.json.gz ├── esdoc.json ├── examples └── mnist.js ├── gulpfile.js ├── package.json ├── spec └── lib.spec.js └── src ├── data_loader.js ├── index.js ├── layers ├── fully_connected_layer.js ├── index.js ├── relu_layer.js ├── sigmoid_layer.js └── softmax_layer.js ├── lib.js ├── mnist_loader.js └── network.js /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": ["es2015", "stage-0"] 3 | } 4 | 5 | -------------------------------------------------------------------------------- /.bmp.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 1.0.0 3 | files: 4 | package.json: '"version": "%.%.%",' 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*.js] 4 | trim_trailing_whitespace = true 5 | insert_final_newline = true 6 | indent_size = 2 7 | 8 | [package.json] 9 | indent_size = 2 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Source for distribution 2 | dist 3 | 4 | # ESDoc 5 | doc 6 | 7 | # Logs 8 | logs 9 | *.log 10 | npm-debug.log* 11 | 12 | # Temporay data 13 | /tmp/* 14 | 15 | # Runtime data 16 | pids 17 | *.pid 18 | *.seed 19 | 20 | # Directory for instrumented libs generated by jscoverage/JSCover 21 | lib-cov 22 | 23 | # Coverage directory used by tools like istanbul 24 | coverage 25 | 26 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 27 | .grunt 28 | 29 | # node-waf configuration 30 | .lock-wscript 31 | 32 | # Compiled binary addons (http://nodejs.org/api/addons.html) 33 | build/Release 34 | 35 | # Dependency directory 36 | node_modules 37 | 38 | # Optional npm cache directory 39 | .npm 40 | 41 | # Optional REPL history 42 | .node_repl_history 43 | -------------------------------------------------------------------------------- /.releaseignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | .editorconfig 3 | spec 4 | .releaseignore 5 | .gitignore 6 | .bmp.yml 7 | circle.yml 8 | npm-debug.log 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 yujiosaka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # js-mind 2 | Deep Learning Library Written in ES2015. 3 | 4 | [API Documentation](https://shinout.github.io/js-mind/api) 5 | 6 | ## Features 7 | 8 | * Activation functions 9 | * Sigmoid function 10 | * Softmax function 11 | * ReLU function 12 | * Cost Function 13 | * Cross Entropy Cost Function 14 | * Regularization 15 | * L2 Regularization 16 | * Dropout 17 | 18 | ## How To Use 19 | 20 | ```javascript 21 | var Promise = require('bluebird'); 22 | 23 | var _ = require('lodash'); 24 | 25 | var jsmind = require('../dist'); 26 | 27 | var net = new jsmind.Network([ 28 | new jsmind.layers.ReLULayer(784, 100, {pDropout: 0.5}), 29 | new jsmind.layers.ReLULayer(100, 100, {pDropout: 0.5}), 30 | new jsmind.layers.SoftmaxLayer(100, 10, {pDropout: 0.5}) 31 | ]); 32 | 33 | Promise.all([ 34 | jsmind.MnistLoader.loadTrainingDataWrapper(), 35 | jsmind.MnistLoader.loadValidationDataWrapper() 36 | ]).spread(function(trainingData, validationData) { 37 | net.SGD( 38 | trainingData, 39 | 60, // epochs 40 | 10, // miniBatchSize 41 | 0.03 // eta 42 | , { 43 | validationData: validationData, 44 | lmbda: 0.1 45 | }); 46 | }).then(function() { 47 | return jsmind.MnistLoader.loadTestDataWrapper(); 48 | }).then(function(testData) { 49 | var testInput, prediction, accuracy; 50 | testInput = _.unzip(testData)[0]; 51 | accuracy = net.accuracy(testData); 52 | prediction = net.predict(testInput); 53 | console.log('Test accuracy ' + accuracy); 54 | console.log('Test prediction ' + prediction.toString()); 55 | }); 56 | ``` 57 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | general: 2 | branches: 3 | ignore: 4 | - gh-pages 5 | - '/release.*/' 6 | machine: 7 | environment: 8 | PATH: '$PATH:$HOME/$CIRCLE_PROJECT_REPONAME/node_modules/node-circleci-autorelease/bin' 9 | VERSION_PREFIX: v 10 | CREATE_GH_PAGES: 1 11 | GH_PAGES_DIR: doc 12 | pre: 13 | - "git config --global user.name 'CircleCI'" 14 | - "git config --global user.email 'circleci@cureapp.jp'" 15 | node: 16 | version: 4.4.2 17 | dependencies: 18 | post: 19 | - npm run post-dependencies 20 | deployment: 21 | create_release_branch: 22 | branch: 23 | - master 24 | commands: 25 | - 'cc-prepare-for-release && npm run pre-release && cc-release || cc-not-released' 26 | - cc-gh-pages 27 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/.gitkeep -------------------------------------------------------------------------------- /data/test_input.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/test_input.json.gz -------------------------------------------------------------------------------- /data/test_output.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/test_output.json.gz -------------------------------------------------------------------------------- /data/training_input.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/training_input.json.gz -------------------------------------------------------------------------------- /data/training_output.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/training_output.json.gz -------------------------------------------------------------------------------- /data/validation_input.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/validation_input.json.gz -------------------------------------------------------------------------------- /data/validation_output.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yujiosaka/js-mind/0d33426da9d8525227ce7601f9c306e81cb86120/data/validation_output.json.gz -------------------------------------------------------------------------------- /esdoc.json: -------------------------------------------------------------------------------- 1 | { 2 | "source": "./src", 3 | "destination": "./doc/api", 4 | "plugins": [ 5 | {"name": "esdoc-es7-plugin"} 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /examples/mnist.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | var Promise = require('bluebird'); 4 | 5 | var _ = require('lodash'); 6 | 7 | var jsmind = require('../dist'); 8 | 9 | var net = new jsmind.Network([ 10 | new jsmind.layers.ReLULayer(784, 100, {pDropout: 0.5}), 11 | new jsmind.layers.ReLULayer(100, 100, {pDropout: 0.5}), 12 | new jsmind.layers.SoftmaxLayer(100, 10, {pDropout: 0.5}) 13 | ]); 14 | 15 | Promise.all([ 16 | jsmind.MnistLoader.loadTrainingDataWrapper(), 17 | jsmind.MnistLoader.loadValidationDataWrapper() 18 | ]).spread(function(trainingData, validationData) { 19 | net.SGD( 20 | trainingData, 21 | 60, // epochs 22 | 10, // miniBatchSize 23 | 0.03 // eta 24 | , { 25 | validationData: validationData, 26 | lmbda: 0.1 27 | }); 28 | }).then(function() { 29 | return jsmind.MnistLoader.loadTestDataWrapper(); 30 | }).then(function(testData) { 31 | var testInput, prediction, accuracy; 32 | testInput = _.unzip(testData)[0]; 33 | accuracy = net.accuracy(testData); 34 | prediction = net.predict(testInput); 35 | console.log('Test accuracy ' + accuracy); 36 | console.log('Test prediction ' + prediction.toString()); 37 | }); 38 | -------------------------------------------------------------------------------- /gulpfile.js: -------------------------------------------------------------------------------- 1 | const gulp = require('gulp') 2 | const gutil = require('gulp-util') 3 | const babel = require('gulp-babel') 4 | const plumber = require('gulp-plumber') 5 | const dirname = require('path').dirname 6 | 7 | const srcdir = __dirname + '/src' 8 | const destdir = __dirname + '/dist' 9 | 10 | /** 11 | * 1. transpile all babel files 12 | * 2. watch src dir 13 | * 3. transpile on changed 14 | */ 15 | gulp.task('watch', x => { 16 | 17 | gulp.start('babel:all') 18 | 19 | gulp.watch('src/**/*.js', (info) => { 20 | 21 | const src = info.path 22 | const relpath = src.slice(srcdir.length + 1) 23 | const dest = destdir + '/' + dirname(relpath) 24 | 25 | gutil.log(`[${info.type}]: ${relpath}`) 26 | compileBabel(src, dest) 27 | .on('end', x => { 28 | gutil.log(`compilation finished: ${dest}`) 29 | }) 30 | }) 31 | }) 32 | 33 | /** 34 | * transpile all babel files 35 | */ 36 | gulp.task('babel:all', x => { 37 | compileBabel('src/**/*.js', 'dist') 38 | .on('end', x => { 39 | gutil.log('compilation finished: all js files.') 40 | }) 41 | }) 42 | 43 | 44 | /** 45 | * transipile babel src to dest 46 | */ 47 | function compileBabel(src, dest) { 48 | return gulp.src(src) 49 | .pipe(plumber()) 50 | .pipe(babel()) 51 | .pipe(gulp.dest(dest)) 52 | } 53 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "js-mind", 3 | "version": "1.0.0", 4 | "description": "Deep Learning Library Written in ES6+.", 5 | "main": "dist/index.js", 6 | "directories": { 7 | "test": "spec" 8 | }, 9 | "engines": { 10 | "node": ">=0.12" 11 | }, 12 | "scripts": { 13 | "test": "mocha --compilers js:espower-babel/guess spec/**/*.spec.js", 14 | "build": "babel src -d dist", 15 | "bmp": "cc-bmp", 16 | "bmp-p": "cc-bmp -p", 17 | "bmp-m": "cc-bmp -m", 18 | "bmp-j": "cc-bmp -j", 19 | "circle": "cc-generate-yml", 20 | "post-dependencies": "echo post-dependencies", 21 | "pre-release": "gulp babel:all", 22 | "post-release": "echo post-release", 23 | "gh-pages": "esdoc -c esdoc.json" 24 | }, 25 | "repository": { 26 | "type": "git", 27 | "url": "git+https://github.com/yujiosaka/js-mind.git" 28 | }, 29 | "keywords": [ 30 | "es6", 31 | "neural", 32 | "network", 33 | "convolutional", 34 | "deep", 35 | "machine", 36 | "learning" 37 | ], 38 | "author": "Yuji Isobe", 39 | "contributors": [ 40 | { 41 | "name": "Shin Suzuki", 42 | "email": "shinout310@gmail.com" 43 | } 44 | ], 45 | "license": "MIT", 46 | "bugs": { 47 | "url": "https://github.com/yujiosaka/js-mind/issues" 48 | }, 49 | "homepage": "https://github.com/yujiosaka/js-mind#readme", 50 | "dependencies": { 51 | "bluebird": "^3.4.1", 52 | "linear-algebra": "git://github.com/yujiosaka/linear-algebra#for_neural_networks_and_deep_learning", 53 | "lodash": "^4.3.0" 54 | }, 55 | "devDependencies": { 56 | "babel-cli": "^6.6.5", 57 | "babel-preset-es2015": "^6.5.0", 58 | "babel-preset-stage-0": "^6.5.0", 59 | "esdoc": "^0.4.4", 60 | "esdoc-es7-plugin": "0.0.3", 61 | "espower-babel": "^4.0.1", 62 | "gulp": "^3.9.1", 63 | "gulp-babel": "^6.1.2", 64 | "gulp-plumber": "^1.1.0", 65 | "gulp-util": "^3.0.7", 66 | "mocha": "^3.0.1", 67 | "node-circleci-autorelease": "^2.2.0", 68 | "power-assert": "^1.2.0" 69 | }, 70 | "node-circleci-autorelease": { 71 | "machine": { 72 | "node": { 73 | "version": "4.4.2" 74 | } 75 | }, 76 | "config": { 77 | "git-user-name": "CircleCI", 78 | "git-user-email": "circleci@cureapp.jp", 79 | "version-prefix": "v", 80 | "create-branch": false, 81 | "create-gh-pages": true, 82 | "gh-pages-dir": "doc" 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /spec/lib.spec.js: -------------------------------------------------------------------------------- 1 | 2 | import assert from 'power-assert' 3 | 4 | const linearAlgebra = require('linear-algebra')(); 5 | const { Matrix } = linearAlgebra; 6 | 7 | const ___ = (...args) => new Matrix(args) 8 | 9 | import { 10 | sigmoidPrime, 11 | vectorizedResult, 12 | dropoutLayer, 13 | norm, 14 | randn 15 | } from '../src/lib' 16 | 17 | 18 | describe('sigmoidPrime', ()=> { 19 | 20 | it('returns 0.25 when zero is given', ()=> { 21 | 22 | const zero = Matrix.zeros(1, 1) 23 | 24 | assert(zero.sigmoid().data[0][0] === 0.5) 25 | 26 | assert(sigmoidPrime(zero).data[0][0] === 0.5 * (-0.5 + 1)) 27 | }) 28 | 29 | }) 30 | 31 | describe('vectorizedResult', ()=> { 32 | 33 | it('returns a 10-dimensional unit vector', ()=> { 34 | 35 | const result = vectorizedResult(3); 36 | 37 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9].forEach(i => { 38 | const elem = result.data[i][0] 39 | const expected = (i === 3) ? 1 : 0 40 | assert(elem === expected) 41 | }) 42 | }) 43 | }) 44 | 45 | describe('dropoutLayer', ()=> { 46 | 47 | const mat = ___( 48 | [1,4,9], 49 | [3,8,2], 50 | [6,5,7], 51 | ) 52 | 53 | it('always drops out values from a layer when pDropout is 1', ()=> { 54 | 55 | const pDropout = 1 // probability to dropout 56 | const result = dropoutLayer(mat, pDropout) 57 | 58 | assert.deepEqual(result, Matrix.zeros(3, 3)) 59 | }) 60 | 61 | it('never drops out values from a layer when pDropout is 0', ()=> { 62 | 63 | const pDropout = 0 64 | const result = dropoutLayer(mat, pDropout) 65 | 66 | assert.deepEqual(result, mat) 67 | }) 68 | }) 69 | 70 | 71 | describe('norm', ()=> { 72 | 73 | it('always returns the mean when sigma is 0', ()=> { 74 | 75 | assert(norm(123, 0) === 123) 76 | }) 77 | 78 | it('returns different values every time', ()=> { 79 | assert(norm(123, 1) !== norm(123, 1)) 80 | }) 81 | 82 | 83 | }) 84 | 85 | 86 | describe('randn', ()=> { 87 | 88 | it('returns matrix, all of whose elements follow the standard normal distribution', ()=> { 89 | 90 | const result = randn(3, 3) 91 | 92 | const Z = 4.0 // p < 1.0e-4 93 | 94 | result.map(elem => { 95 | assert(Math.abs(elem) < Z) 96 | }) 97 | 98 | }) 99 | }) 100 | -------------------------------------------------------------------------------- /src/data_loader.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import Promise from 'bluebird'; 4 | import fs from 'fs'; 5 | import zlib from 'zlib'; 6 | import path from 'path'; 7 | 8 | Promise.promisifyAll(fs); 9 | Promise.promisifyAll(zlib); 10 | 11 | class DataLoader { 12 | static loadTrainingData() { 13 | return Promise.all([ 14 | DataLoader._loadDate('training_input.json.gz'), 15 | DataLoader._loadDate('training_output.json.gz') 16 | ]); 17 | } 18 | 19 | static loadValidationData() { 20 | return Promise.all([ 21 | DataLoader._loadDate('validation_input.json.gz'), 22 | DataLoader._loadDate('validation_output.json.gz') 23 | ]); 24 | } 25 | 26 | static loadTestData() { 27 | return Promise.all([ 28 | DataLoader._loadDate('test_input.json.gz'), 29 | DataLoader._loadDate('test_output.json.gz') 30 | ]); 31 | } 32 | 33 | static _loadDate(filename) { 34 | return fs.readFileAsync( 35 | path.join(__dirname, `../data/${filename}`) 36 | ).then(content => { 37 | return zlib.gunzipAsync(content); 38 | }).then(binary => { 39 | return JSON.parse(binary.toString()); 40 | }); 41 | } 42 | } 43 | 44 | module.exports = DataLoader; 45 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import layers from './layers'; 4 | import Network from './network'; 5 | import MnistLoader from './mnist_loader'; 6 | import DataLoader from './data_loader'; 7 | 8 | export { layers, Network, MnistLoader, DataLoader} 9 | -------------------------------------------------------------------------------- /src/layers/fully_connected_layer.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import linearAlgebra from 'linear-algebra'; 4 | 5 | import { randn, dropoutLayer } from '../lib'; 6 | import layers from './'; 7 | 8 | const { Matrix } = linearAlgebra(); 9 | 10 | const MATRIX_OBJ = [ 11 | 'w', 12 | 'b', 13 | 'input', 14 | 'output', 15 | 'inputDropout', 16 | 'outputDropout' 17 | ].reduce((obj, key) => { 18 | obj[key] = true; 19 | return obj; 20 | }, {}); 21 | 22 | export default class FullyConnectedLayer { 23 | constructor(nIn, nOut, opts = {}) { 24 | this.pDropout = opts.pDropout || (opts.pDropout = 0); 25 | this.activationFn = opts.activationFn; 26 | this.w = randn(nOut, nIn).mulEach(1 / Math.sqrt(nIn)); 27 | this.b = randn(nOut, 1); 28 | } 29 | 30 | setInput(input, inputDropout, miniBatchSize) { 31 | const axis = 0; 32 | const bMask = new Matrix(this.b.ravel().map(v => { 33 | let results = []; 34 | for (let i = 0; i < miniBatchSize; i++) { 35 | results.push(v); 36 | } 37 | return results; 38 | })); 39 | this.input = input; 40 | this.output = this.w.dot(input).mulEach(1 - this.pDropout).plus(bMask)[this.activationFn](axis); 41 | this.yOut = this.output.getArgMax(); 42 | this.inputDropout = dropoutLayer(inputDropout, this.pDropout); 43 | this.outputDropout = this.w.dot(this.inputDropout).plus(bMask)[this.activationFn](axis); 44 | } 45 | 46 | accuracy(y) { 47 | return this.yOut === y; 48 | } 49 | 50 | update(delta) { 51 | this.nb = new Matrix(delta.getSum(1)).trans(); 52 | this.nw = delta.dot(this.inputDropout.trans()); 53 | } 54 | 55 | dump() { 56 | let properties = Object.keys(this).reduce((obj, key) => { 57 | let val = this[key]; 58 | obj[key] = (MATRIX_OBJ[key]) ? val.toArray() : val; 59 | return obj; 60 | }, {}); 61 | return { 62 | className: this.constructor.name, 63 | properties: properties 64 | }; 65 | } 66 | 67 | static load(className, properties) { 68 | let layer = new layers[className](0, 0); 69 | Object.keys(properties).forEach(key => { 70 | let val = properties[key]; 71 | layer[key] = (MATRIX_OBJ[key]) ? new Matrix(val) : val; 72 | }); 73 | return layer; 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/layers/index.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import FullyConnectedLayer from './fully_connected_layer'; 4 | import SigmoidLayer from './sigmoid_layer'; 5 | import SoftmaxLayer from './softmax_layer'; 6 | import ReLULayer from './relu_layer'; 7 | 8 | export default { FullyConnectedLayer, SigmoidLayer, SoftmaxLayer, ReLULayer } 9 | -------------------------------------------------------------------------------- /src/layers/relu_layer.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import FullyConnectedLayer from './fully_connected_layer'; 4 | 5 | export default class ReLULayer extends FullyConnectedLayer { 6 | constructor(nIn, nOut, opts = {}) { 7 | opts.activationFn = 'relu'; 8 | super(nIn, nOut, opts); 9 | } 10 | 11 | costDelta(y) { 12 | return this.outputDropout.eleMap(v => { 13 | return (v > 0) ? 1 : 0; 14 | }); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/layers/sigmoid_layer.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import FullyConnectedLayer from './fully_connected_layer'; 4 | 5 | export default class SigmoidLayer extends FullyConnectedLayer { 6 | constructor(nIn, nOut, opts = {}) { 7 | opts.activationFn = 'sigmoid'; 8 | super(nIn, nOut, opts); 9 | } 10 | 11 | costDelta(y) { 12 | return this.outputDropout.minus(y); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/layers/softmax_layer.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import FullyConnectedLayer from './fully_connected_layer'; 4 | 5 | export default class SoftmaxLayer extends FullyConnectedLayer { 6 | constructor(nIn, nOut, opts = {}) { 7 | opts.activationFn = 'softmax'; 8 | super(nIn, nOut, opts); 9 | } 10 | 11 | costDelta(y) { 12 | return this.outputDropout.minus(y); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/lib.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import linearAlgebra from 'linear-algebra'; 4 | const { Matrix } = linearAlgebra(); 5 | 6 | /** 7 | * Randomly drop values from a layer. 8 | * @param {Matrix} layer 9 | * @param {number} pDropout probability to drop out value for each element 10 | * @return {Matrix} converted matrix with the same nRow, nCol 11 | */ 12 | export function dropoutLayer(layer, pDropout) { 13 | return layer.eleMap(function(elem) { 14 | return (Math.random() < pDropout ? 0 : elem); 15 | }); 16 | }; 17 | 18 | /** 19 | * Return a sample from a normal distribution. 20 | * @param {number} mu mean 21 | * @param {number} sigma sd (must be greater than 0) 22 | * @return {number} sample 23 | */ 24 | export function norm(mu, sigma) { 25 | const a = 1 - Math.random(); 26 | const b = 1 - Math.random(); 27 | const c = Math.sqrt(-2 * Math.log(a)); 28 | if (0.5 - Math.random() > 0) { 29 | return c * Math.sin(Math.PI * 2 * b) * sigma + mu; 30 | } else { 31 | return c * Math.cos(Math.PI * 2 * b) * sigma + mu; 32 | } 33 | } 34 | 35 | /** 36 | * Return a matrix, all of whose element are sampled from the standard normal distribution. 37 | * see http://d.hatena.ne.jp/iroiro123/20111210/1323515616 38 | * 39 | * @param {number} rows the number of rows 40 | * @param {number} cols the number of cols 41 | * @return {Matrix} random matrix 42 | */ 43 | export function randn(rows, cols) { 44 | const result = new Array(rows); 45 | for (let i = 0; i < rows; i++) { 46 | result[i] = [] 47 | for (let j = 0; j < cols; j++) { 48 | result[i][j] = norm(0, 1) 49 | } 50 | } 51 | return new Matrix(result); 52 | } 53 | -------------------------------------------------------------------------------- /src/mnist_loader.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import _ from 'lodash'; 4 | import linearAlgebra from 'linear-algebra'; 5 | 6 | import DataLoader from './data_loader'; 7 | 8 | const { Matrix } = linearAlgebra(); 9 | 10 | class MnistLoader { 11 | static loadTrainingDataWrapper() { 12 | return DataLoader.loadTrainingData().then(trD => { 13 | const trainingInputs = trD[0].map(x => { return Matrix.reshape(x, 784, 1); }); 14 | const trainingResults = trD[1].map(y => { return MnistLoader._vectorizedResult(y); }); 15 | const trainingData = _.zip(trainingInputs, trainingResults); 16 | return trainingData; 17 | }); 18 | } 19 | 20 | static loadValidationDataWrapper() { 21 | return DataLoader.loadValidationData().then(vaD => { 22 | const validationInputs = vaD[0].map(x => { return Matrix.reshape(x, 784, 1); }); 23 | const validationData = _.zip(validationInputs, vaD[1]); 24 | return validationData; 25 | }); 26 | } 27 | 28 | static loadTestDataWrapper() { 29 | return DataLoader.loadTestData().then(teD => { 30 | const testInputs = teD[0].map(x => { return Matrix.reshape(x, 784, 1); }); 31 | const testData = _.zip(testInputs, teD[1]); 32 | return testData; 33 | }); 34 | } 35 | 36 | static _vectorizedResult(j) { 37 | const e = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9].map(i => [0]); 38 | e[j] = [1]; 39 | return new Matrix(e); 40 | } 41 | } 42 | 43 | module.exports = MnistLoader; 44 | -------------------------------------------------------------------------------- /src/network.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | 3 | import _ from 'lodash'; 4 | import fs from 'fs'; 5 | import Promise from 'bluebird'; 6 | import linearAlgebra from 'linear-algebra'; 7 | 8 | import layers from './layers'; 9 | 10 | const { Matrix } = linearAlgebra(); 11 | 12 | const ENCODING = 'utf8'; 13 | 14 | Promise.promisifyAll(fs); 15 | 16 | class Network { 17 | constructor(layers) { 18 | this.layers = layers; 19 | } 20 | 21 | SGD(trainingData, epochs, miniBatchSize, eta, opts = {}) { 22 | opts.lmbda || (opts.lmbda = 0); 23 | let bestValidationAccuracy = 0; 24 | for (let j = 0; j < epochs; j++) { 25 | trainingData = _.shuffle(trainingData); 26 | let miniBatches = this.createMiniBatches(trainingData, miniBatchSize); 27 | for (let i = 0; i < miniBatches.length; i++) { 28 | let miniBatch = miniBatches[i]; 29 | let iteration = trainingData.length / miniBatchSize * j + i; 30 | if (iteration % 1000 === 0) { 31 | console.log(`Training mini-batch number ${iteration}`); 32 | } 33 | this.updateMiniBatch(miniBatch, eta, opts.lmbda, trainingData.length); 34 | } 35 | if (opts.validationData) { 36 | let validationAccuracy = this.accuracy(opts.validationData); 37 | console.log(`Epoch ${j}: validation accuracy ${validationAccuracy}`); 38 | if (validationAccuracy >= bestValidationAccuracy) { 39 | console.log('This is the best validation accuracy to date.'); 40 | bestValidationAccuracy = validationAccuracy; 41 | } 42 | } 43 | } 44 | console.log('Finished training network.'); 45 | if (opts.validationData) { 46 | console.log(`Best validation accuracy ${bestValidationAccuracy}`); 47 | } 48 | } 49 | 50 | createMiniBatches(trainingData, miniBatchSize) { 51 | let results = []; 52 | for (let k = 0; k < trainingData.length; k += miniBatchSize) { 53 | results.push(trainingData.slice(k, k + miniBatchSize)); 54 | } 55 | return results; 56 | } 57 | 58 | updateMiniBatch(miniBatch, eta, lmbda, n) { 59 | let x = new Matrix(miniBatch.map(([_x, _y]) => { return _x.ravel();})).trans(); 60 | let y = new Matrix(miniBatch.map(([_x, _y]) => { return _y.ravel();})).trans(); 61 | this.train(x, miniBatch.length); 62 | this.backprop(y); 63 | for (let i = 0; i < this.layers.length; i++) { 64 | let layer = this.layers[i]; 65 | // l2 regularization 66 | layer.w = layer.w.mulEach(1 - eta * (lmbda / n)).minus((layer.nw.mulEach(eta / miniBatch.length))); 67 | layer.b = layer.b.minus(layer.nb.mulEach(eta / miniBatch.length)); 68 | } 69 | } 70 | 71 | train(x, miniBatchSize) { 72 | let initLayer = this.layers[0]; 73 | initLayer.setInput(x, x, miniBatchSize); 74 | for (let j = 1; j < this.layers.length; j++) { 75 | let prevLayer = this.layers[j - 1]; 76 | let layer = this.layers[j]; 77 | layer.setInput(prevLayer.output, prevLayer.outputDropout, miniBatchSize); 78 | } 79 | } 80 | 81 | backprop(y) { 82 | let lastLayer = this.layers[this.layers.length - 1]; 83 | let delta = lastLayer.costDelta(y); 84 | lastLayer.update(delta); 85 | for (let l = 2; l <= this.layers.length; l++) { 86 | let followinglayer = this.layers[this.layers.length - l + 1]; 87 | let layer = this.layers[this.layers.length - l]; 88 | delta = followinglayer.w.trans().dot(delta).mul(layer.costDelta(y)); 89 | layer.update(delta); 90 | } 91 | } 92 | 93 | accuracy(data) { 94 | return _.mean(data.map(([x, y]) => { return this.feedforward(x).accuracy(y); })); 95 | } 96 | 97 | feedforward(a) { 98 | this.train(a, 1); 99 | return this.layers[this.layers.length - 1]; 100 | } 101 | 102 | predict(inputs) { 103 | return inputs.map(x => { return this.feedforward(x).yOut; }); 104 | } 105 | 106 | save(file) { 107 | let json = JSON.stringify(this.layers.map(layer => { 108 | return layer.dump(); 109 | })); 110 | return fs.writeFileAsync(file, json, ENCODING); 111 | } 112 | 113 | static load(file) { 114 | return fs.readFileAsync(file, ENCODING).then(json => { 115 | return new Network(JSON.parse(json).map(data => { 116 | return layers[data.className].load(data.className, data.properties); 117 | })); 118 | }); 119 | } 120 | } 121 | 122 | module.exports = Network; 123 | --------------------------------------------------------------------------------