├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── model.test.ts ├── package.json ├── src ├── data.ts ├── main.ts └── model.ts ├── tests └── model.test.ts ├── tsconfig.json └── tslint.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | 8 | # Runtime data 9 | pids 10 | *.pid 11 | *.seed 12 | *.pid.lock 13 | 14 | # Directory for instrumented libs generated by jscoverage/JSCover 15 | lib-cov 16 | 17 | # Coverage directory used by tools like istanbul 18 | coverage 19 | 20 | # nyc test coverage 21 | .nyc_output 22 | 23 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 24 | .grunt 25 | 26 | # Bower dependency directory (https://bower.io/) 27 | bower_components 28 | 29 | # node-waf configuration 30 | .lock-wscript 31 | 32 | # Compiled binary addons (https://nodejs.org/api/addons.html) 33 | build/Release 34 | 35 | # Dependency directories 36 | node_modules/ 37 | jspm_packages/ 38 | 39 | # TypeScript v1 declaration files 40 | typings/ 41 | 42 | # Optional npm cache directory 43 | .npm 44 | 45 | # Optional eslint cache 46 | .eslintcache 47 | 48 | # Optional REPL history 49 | .node_repl_history 50 | 51 | # Output of 'npm pack' 52 | *.tgz 53 | 54 | # Yarn Integrity file 55 | .yarn-integrity 56 | 57 | # dotenv environment variables file 58 | .env 59 | 60 | # next.js build output 61 | .next 62 | t10k-images-idx3-ubyte 63 | t10k-labels-idx1-ubyte 64 | train-images-idx3-ubyte 65 | train-labels-idx1-ubyte 66 | package-lock.json 67 | 68 | # should be build 69 | dist 70 | 71 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "type": "node", 6 | "request": "launch", 7 | "name": "Mocha All", 8 | "program": "${workspaceFolder}/node_modules/mocha/bin/_mocha", 9 | "args": [ 10 | "--no-timeouts", 11 | "--colors", 12 | "${workspaceFolder}/tests/**/*.test.ts", 13 | "--require", 14 | "ts-node/register" 15 | ], 16 | "console": "integratedTerminal", 17 | "internalConsoleOptions": "neverOpen" 18 | }, 19 | { 20 | "type": "node", 21 | "request": "launch", 22 | "name": "Mocha Current File", 23 | "program": "${workspaceFolder}/node_modules/mocha/bin/_mocha", 24 | "args": [ 25 | "--no-timeouts", 26 | "--colors", 27 | "${file}", 28 | "--require", 29 | "ts-node/register" 30 | ], 31 | "console": "integratedTerminal", 32 | "sourceMaps": true, 33 | "internalConsoleOptions": "neverOpen" 34 | } 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Starter Project for TensorFlow.js in TypeScript # 2 | 3 | [TensorFlow.js](https://js.tensorflow.org/) is an incredibly easy and powerful way to work with deep learning both in the browser and on the server. 4 | 5 | TensorFlow.js would be even easier to use if it was written to run in TypeScript. 6 | 7 | It is not hard to get this working, but there are still a little trail and error. This project is a starter project for running TensorFlow.js in TypeScript. 8 | 9 | The example task is the *hello world* example of deep learning. Recognize the handwritten digits in the MNIST dataset by training a convolutional neural network. 10 | 11 | This project is based on the [MNIST TensorFlow.js example](https://github.com/tensorflow/tfjs-examples/tree/master/mnist-node). 12 | 13 | 14 | 15 | # To Run # 16 | 17 | ``` bash 18 | 19 | git clone https://github.com/sami-badawi/tensorflow-typescript-starter.git 20 | 21 | cd tensorflow-typescript-starter 22 | 23 | npm i 24 | 25 | npm run build 26 | 27 | npm run test 28 | 29 | npm run start 30 | 31 | ``` 32 | 33 | 34 | ## To Run with Parameters ## 35 | 36 | ``` base 37 | 38 | node dist/main.js --epochs 1 39 | 40 | ``` 41 | 42 | If you want to run with only 1 epoch. 43 | 44 | 45 | 46 | ## Output ## 47 | 48 | ``` 49 | node dist/main.js --epochs 1 50 | 2019-02-16 13:49:35.108255: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA 51 | _________________________________________________________________ 52 | Layer (type) Output shape Param # 53 | ================================================================= 54 | conv2d_Conv2D1 (Conv2D) [null,26,26,32] 320 55 | _________________________________________________________________ 56 | conv2d_Conv2D2 (Conv2D) [null,24,24,32] 9248 57 | _________________________________________________________________ 58 | max_pooling2d_MaxPooling2D1 [null,12,12,32] 0 59 | _________________________________________________________________ 60 | conv2d_Conv2D3 (Conv2D) [null,10,10,64] 18496 61 | _________________________________________________________________ 62 | conv2d_Conv2D4 (Conv2D) [null,8,8,64] 36928 63 | _________________________________________________________________ 64 | max_pooling2d_MaxPooling2D2 [null,4,4,64] 0 65 | _________________________________________________________________ 66 | flatten_Flatten1 (Flatten) [null,1024] 0 67 | _________________________________________________________________ 68 | dropout_Dropout1 (Dropout) [null,1024] 0 69 | _________________________________________________________________ 70 | dense_Dense1 (Dense) [null,512] 524800 71 | _________________________________________________________________ 72 | dropout_Dropout2 (Dropout) [null,512] 0 73 | _________________________________________________________________ 74 | dense_Dense2 (Dense) [null,10] 5130 75 | ================================================================= 76 | Total params: 594922 77 | Trainable params: 594922 78 | Non-trainable params: 0 79 | _________________________________________________________________ 80 | Epoch 1 / 1 81 | eta=0.0 =====================================================================================================================================> 82 | 477998ms 9373us/step - acc=0.921 loss=0.250 val_acc=0.983 val_loss=0.0553 83 | ``` 84 | 85 | ### Accuracy ### 86 | 87 | The model gets to 98.3% accuracy on test data in 1 epoch and to 99.3% accuracy in 20 epochs. 88 | 89 | 90 | # Unit Test # 91 | 92 | There is only one unit tests. It checks that TensorFlow has been loaded. 93 | 94 | The purpose is to set up the needed dependencies to work with unit tests. 95 | It is using Mocha and Chai libraries. 96 | -------------------------------------------------------------------------------- /model.test.ts: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs-node"; 2 | import { expect } from "chai"; 3 | 4 | describe("tensorflow.js", () => { 5 | it("Check that it loaded", () => { 6 | expect(tf.version.tfjs.length).greaterThan(3); 7 | }); 8 | }); 9 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tensorflow-typescript-starter", 3 | "version": "0.1.0", 4 | "description": "", 5 | "main": "dist/main.js", 6 | "license": "Apache-2.0", 7 | "private": true, 8 | "engines": { 9 | "node": ">=8.11.0" 10 | }, 11 | "dependencies": { 12 | "@tensorflow/tfjs-node": "^0.3.0", 13 | "@types/argparse": "^1.0.35", 14 | "argparse": "^1.0.10", 15 | "fs": "0.0.1-security" 16 | }, 17 | "devDependencies": { 18 | "clang-format": "~1.2.2", 19 | "@types/chai": "^4.1.7", 20 | "@types/mocha": "^5.2.6", 21 | "chai": "^4.2.0", 22 | "mocha": "^5.2.0", 23 | "nyc": "^13.3.0", 24 | "ts-node": "^8.0.2", 25 | "tslint": "^5.12.1", 26 | "typescript": "^3.3.3" 27 | }, 28 | "scripts": { 29 | "prebuild": "tslint -c tslint.json -p tsconfig.json --fix", 30 | "build": "tsc", 31 | "start": "node .", 32 | "test": "mocha -r ts-node/register tests/**/*.test.ts", 33 | "testWithCoverage": "nyc -r lcov -e .ts -x \"tests/*.test.ts\" mocha -r ts-node/register tests/**/*.test.ts && nyc report" 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/data.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from "@tensorflow/tfjs"; 19 | import * as assert from "assert"; 20 | import * as fs from "fs"; 21 | import * as https from "https"; 22 | import * as util from "util"; 23 | import * as zlib from "zlib"; 24 | 25 | export const readFile = util.promisify(fs.readFile); 26 | 27 | // MNIST data constants: 28 | const BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; 29 | const TRAIN_IMAGES_FILE = "train-images-idx3-ubyte"; 30 | const TRAIN_LABELS_FILE = "train-labels-idx1-ubyte"; 31 | const TEST_IMAGES_FILE = "t10k-images-idx3-ubyte"; 32 | const TEST_LABELS_FILE = "t10k-labels-idx1-ubyte"; 33 | const IMAGE_HEADER_MAGIC_NUM = 2051; 34 | const IMAGE_HEADER_BYTES = 16; 35 | const IMAGE_HEIGHT = 28; 36 | const IMAGE_WIDTH = 28; 37 | const IMAGE_FLAT_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH; 38 | const LABEL_HEADER_MAGIC_NUM = 2049; 39 | const LABEL_HEADER_BYTES = 8; 40 | const LABEL_RECORD_BYTE = 1; 41 | const LABEL_FLAT_SIZE = 10; 42 | 43 | // Downloads a test file only once and returns the buffer for the file. 44 | export async function fetchOnceAndSaveToDiskWithBuffer(filename: string): Promise { 45 | return new Promise((resolve) => { 46 | const url = `${BASE_URL}${filename}.gz`; 47 | if (fs.existsSync(filename)) { 48 | resolve(readFile(filename)); 49 | return; 50 | } 51 | const file = fs.createWriteStream(filename); 52 | console.log(` * Downloading from: ${url}`); 53 | https.get(url, (response) => { 54 | const unzip = zlib.createGunzip(); 55 | response.pipe(unzip).pipe(file); 56 | unzip.on("end", () => { 57 | resolve(readFile(filename)); 58 | }); 59 | }); 60 | }); 61 | } 62 | 63 | export function loadHeaderValues(buffer, headerLength: number) { 64 | const headerValues = []; 65 | for (let i = 0; i < headerLength / 4; i++) { 66 | // Header data is stored in-order (aka big-endian) 67 | headerValues[i] = buffer.readUInt32BE(i * 4); 68 | } 69 | return headerValues; 70 | } 71 | 72 | export async function loadImages(filename: string) { 73 | const buffer: Buffer = await fetchOnceAndSaveToDiskWithBuffer(filename); 74 | 75 | const headerBytes = IMAGE_HEADER_BYTES; 76 | const recordBytes = IMAGE_HEIGHT * IMAGE_WIDTH; 77 | 78 | const headerValues = loadHeaderValues(buffer, headerBytes); 79 | assert.equal(headerValues[0], IMAGE_HEADER_MAGIC_NUM); 80 | assert.equal(headerValues[2], IMAGE_HEIGHT); 81 | assert.equal(headerValues[3], IMAGE_WIDTH); 82 | 83 | const images = []; 84 | let index = headerBytes; 85 | while (index < buffer.byteLength) { 86 | const array = new Float32Array(recordBytes); 87 | for (let i = 0; i < recordBytes; i++) { 88 | // Normalize the pixel values into the 0-1 interval, from 89 | // the original 0-255 interval. 90 | array[i] = buffer.readUInt8(index++) / 255; 91 | } 92 | images.push(array); 93 | } 94 | 95 | assert.equal(images.length, headerValues[1]); 96 | return images; 97 | } 98 | 99 | export async function loadLabels(filename: string) { 100 | const buffer = await fetchOnceAndSaveToDiskWithBuffer(filename); 101 | 102 | const headerBytes = LABEL_HEADER_BYTES; 103 | const recordBytes = LABEL_RECORD_BYTE; 104 | 105 | const headerValues = loadHeaderValues(buffer, headerBytes); 106 | assert.equal(headerValues[0], LABEL_HEADER_MAGIC_NUM); 107 | 108 | const labels = []; 109 | let index = headerBytes; 110 | while (index < buffer.byteLength) { 111 | const array = new Int32Array(recordBytes); 112 | for (let i = 0; i < recordBytes; i++) { 113 | array[i] = buffer.readUInt8(index++); 114 | } 115 | labels.push(array); 116 | } 117 | 118 | assert.equal(labels.length, headerValues[1]); 119 | return labels; 120 | } 121 | 122 | /** Helper class to handle loading training and test data. */ 123 | export class MnistDataset { 124 | public dataset: any[]; 125 | public trainSize: number; 126 | public testSize: number; 127 | public trainBatchIndex: number; 128 | public testBatchIndex: number; 129 | 130 | constructor() { 131 | this.dataset = null; 132 | this.trainSize = 0; 133 | this.testSize = 0; 134 | this.trainBatchIndex = 0; 135 | this.testBatchIndex = 0; 136 | } 137 | 138 | /** Loads training and test data. */ 139 | public async loadData() { 140 | this.dataset = await Promise.all([ 141 | loadImages(TRAIN_IMAGES_FILE), loadLabels(TRAIN_LABELS_FILE), 142 | loadImages(TEST_IMAGES_FILE), loadLabels(TEST_LABELS_FILE) 143 | ]); 144 | this.trainSize = this.dataset[0].length; 145 | this.testSize = this.dataset[2].length; 146 | } 147 | 148 | public getTrainData() { 149 | return this.getData_(true); 150 | } 151 | 152 | public getTestData() { 153 | return this.getData_(false); 154 | } 155 | 156 | public getData_(isTrainingData: boolean) { 157 | let imagesIndex; 158 | let labelsIndex; 159 | if (isTrainingData) { 160 | imagesIndex = 0; 161 | labelsIndex = 1; 162 | } else { 163 | imagesIndex = 2; 164 | labelsIndex = 3; 165 | } 166 | const size = this.dataset[imagesIndex].length; 167 | tf.util.assert( 168 | this.dataset[labelsIndex].length === size, 169 | `Mismatch in the number of images (${size}) and ` + 170 | `the number of labels (${this.dataset[labelsIndex].length})`); 171 | 172 | // Only create one big array to hold batch of images. 173 | const imagesShape: [number, number, number, number] = [size, IMAGE_HEIGHT, IMAGE_WIDTH, 1]; 174 | const images: Float32Array = new Float32Array(tf.util.sizeFromShape(imagesShape)); 175 | const labels: Int32Array = new Int32Array(tf.util.sizeFromShape([size, 1])); 176 | 177 | let imageOffset = 0; 178 | let labelOffset = 0; 179 | for (let i = 0; i < size; ++i) { 180 | images.set(this.dataset[imagesIndex][i], imageOffset); 181 | labels.set(this.dataset[labelsIndex][i], labelOffset); 182 | imageOffset += IMAGE_FLAT_SIZE; 183 | labelOffset += 1; 184 | } 185 | 186 | return { 187 | images: tf.tensor4d(images, imagesShape), 188 | labels: tf.oneHot(tf.tensor1d(labels, "int32"), LABEL_FLAT_SIZE).toFloat() 189 | }; 190 | } 191 | } 192 | 193 | export default new MnistDataset(); 194 | -------------------------------------------------------------------------------- /src/main.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distrinted under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as argparse from "argparse"; 19 | 20 | import data from "./data"; 21 | import model from "./model"; 22 | 23 | async function run(epochs: number, batchSize: number, modelSavePath: string) { 24 | await data.loadData(); 25 | 26 | const {images: trainImages, labels: trainLabels} = data.getTrainData(); 27 | model.summary(); 28 | 29 | const validationSplit = 0.15; 30 | const numTrainExamplesPerEpoch = 31 | trainImages.shape[0] * (1 - validationSplit); 32 | const numTrainBatchesPerEpoch = 33 | Math.ceil(numTrainExamplesPerEpoch / batchSize); 34 | await model.fit(trainImages, trainLabels, { 35 | batchSize, 36 | epochs, 37 | validationSplit 38 | }); 39 | 40 | const {images: testImages, labels: testLabels} = data.getTestData(); 41 | const evalOutput = model.evaluate(testImages, testLabels); 42 | 43 | console.log( 44 | `\nEvaluation result:\n` + 45 | ` Loss = ${evalOutput[0].dataSync()[0].toFixed(3)}; ` + 46 | `Accuracy = ${evalOutput[1].dataSync()[0].toFixed(3)}`); 47 | 48 | if (modelSavePath != null) { 49 | await model.save(`file://${modelSavePath}`); 50 | console.log(`Saved model to path: ${modelSavePath}`); 51 | } 52 | } 53 | 54 | const parser = new argparse.ArgumentParser({ 55 | addHelp: true, 56 | description: "TensorFlow.js-Node MNIST Example." 57 | }); 58 | parser.addArgument("--epochs", { 59 | defaultValue: 20, 60 | help: "Number of epochs to train the model for.", 61 | type: "int" 62 | }); 63 | parser.addArgument("--batch_size", { 64 | defaultValue: 128, 65 | help: "Batch size to be used during model training.", 66 | type: "int" 67 | }); 68 | parser.addArgument("--model_save_path", { 69 | help: "Path to which the model will be saved after training.", 70 | type: "string" 71 | }); 72 | const args = parser.parseArgs(); 73 | 74 | run(args.epochs, args.batch_size, args.model_save_path); 75 | -------------------------------------------------------------------------------- /src/model.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from "@tensorflow/tfjs-node"; 19 | 20 | export const model = tf.sequential(); 21 | model.add(tf.layers.conv2d({ 22 | activation: "relu", 23 | filters: 32, 24 | inputShape: [28, 28, 1], 25 | kernelSize: 3, 26 | })); 27 | model.add(tf.layers.conv2d({ 28 | activation: "relu", 29 | filters: 32, 30 | kernelSize: 3, 31 | })); 32 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); 33 | model.add(tf.layers.conv2d({ 34 | activation: "relu", 35 | filters: 64, 36 | kernelSize: 3, 37 | })); 38 | model.add(tf.layers.conv2d({ 39 | activation: "relu", 40 | filters: 64, 41 | kernelSize: 3, 42 | })); 43 | model.add(tf.layers.maxPooling2d({poolSize: [2, 2]})); 44 | model.add(tf.layers.flatten()); 45 | model.add(tf.layers.dropout({rate: 0.25})); 46 | model.add(tf.layers.dense({units: 512, activation: "relu"})); 47 | model.add(tf.layers.dropout({rate: 0.5})); 48 | model.add(tf.layers.dense({units: 10, activation: "softmax"})); 49 | 50 | export const optimizer = "rmsprop"; 51 | model.compile({ 52 | loss: "categoricalCrossentropy", 53 | metrics: ["accuracy"], 54 | optimizer, 55 | }); 56 | 57 | export default model; 58 | -------------------------------------------------------------------------------- /tests/model.test.ts: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs-node"; 2 | import { expect } from "chai"; 3 | 4 | describe("tensorflow.js", () => { 5 | it("Check that it loaded", () => { 6 | expect(tf.version.tfjs.length).greaterThan(3); 7 | }); 8 | }); 9 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | /* Basic Options */ 4 | "target": "ES2018", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017','ES2018' or 'ESNEXT'. */ 5 | "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', or 'ESNext'. */ 6 | /* Worked: none, commonjs, umd. Not working es2015 */ 7 | // "lib": [], /* Specify library files to be included in the compilation. */ 8 | // "allowJs": true, /* Allow javascript files to be compiled. */ 9 | // "checkJs": true, /* Report errors in .js files. */ 10 | // "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */ 11 | // "declaration": true, /* Generates corresponding '.d.ts' file. */ 12 | // "declarationMap": true, /* Generates a sourcemap for each corresponding '.d.ts' file. */ 13 | // "sourceMap": true, /* Generates corresponding '.map' file. */ 14 | // "outFile": "./", /* Concatenate and emit output to single file. */ 15 | "outDir": "./dist", /* Redirect output structure to the directory. */ 16 | // "rootDir": "./", /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */ 17 | // "composite": true, /* Enable project compilation */ 18 | // "removeComments": true, /* Do not emit comments to output. */ 19 | // "noEmit": true, /* Do not emit outputs. */ 20 | // "importHelpers": true, /* Import emit helpers from 'tslib'. */ 21 | // "downlevelIteration": true, /* Provide full support for iterables in 'for-of', spread, and destructuring when targeting 'ES5' or 'ES3'. */ 22 | // "isolatedModules": true, /* Transpile each file as a separate module (similar to 'ts.transpileModule'). */ 23 | 24 | /* Strict Type-Checking Options */ 25 | "strict": false, /* Enable all strict type-checking options. */ 26 | // "noImplicitAny": true, /* Raise error on expressions and declarations with an implied 'any' type. */ 27 | // "strictNullChecks": true, /* Enable strict null checks. */ 28 | // "strictFunctionTypes": true, /* Enable strict checking of function types. */ 29 | // "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */ 30 | // "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */ 31 | // "noImplicitThis": true, /* Raise error on 'this' expressions with an implied 'any' type. */ 32 | // "alwaysStrict": true, /* Parse in strict mode and emit "use strict" for each source file. */ 33 | 34 | /* Additional Checks */ 35 | // "noUnusedLocals": true, /* Report errors on unused locals. */ 36 | // "noUnusedParameters": true, /* Report errors on unused parameters. */ 37 | // "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */ 38 | // "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */ 39 | 40 | /* Module Resolution Options */ 41 | "moduleResolution": "node", /* Specify module resolution strategy: 'node' (Node.js) or 'classic' (TypeScript pre-1.6). */ 42 | // "baseUrl": "./", /* Base directory to resolve non-absolute module names. */ 43 | // "paths": {}, /* A series of entries which re-map imports to lookup locations relative to the 'baseUrl'. */ 44 | // "rootDirs": [], /* List of root folders whose combined content represents the structure of the project at runtime. */ 45 | // "typeRoots": [], /* List of folders to include type definitions from. */ 46 | // "types": [], /* Type declaration files to be included in compilation. */ 47 | // "allowSyntheticDefaultImports": true, /* Allow default imports from modules with no default export. This does not affect code emit, just typechecking. */ 48 | "esModuleInterop": true /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */ 49 | // "preserveSymlinks": true, /* Do not resolve the real path of symlinks. */ 50 | 51 | /* Source Map Options */ 52 | // "sourceRoot": "", /* Specify the location where debugger should locate TypeScript files instead of source locations. */ 53 | // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ 54 | // "inlineSourceMap": true, /* Emit a single file with source maps instead of having a separate file. */ 55 | // "inlineSources": true, /* Emit the source alongside the sourcemaps within a single file; requires '--inlineSourceMap' or '--sourceMap' to be set. */ 56 | 57 | /* Experimental Options */ 58 | // "experimentalDecorators": true, /* Enables experimental support for ES7 decorators. */ 59 | // "emitDecoratorMetadata": true, /* Enables experimental support for emitting type metadata for decorators. */ 60 | }, 61 | "include": [ 62 | "src/*.ts", 63 | "src/**/*.ts" 64 | ], 65 | "exclude": [ 66 | "node_modules", 67 | "**/*.spec.ts" 68 | ] 69 | } 70 | -------------------------------------------------------------------------------- /tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "defaultSeverity": "error", 3 | "extends": [ 4 | "tslint:recommended" 5 | ], 6 | "jsRules": {}, 7 | "rules": { 8 | "indent": [ true, "spaces" ], 9 | "trailing-comma": [ false ], 10 | "no-console": false 11 | }, 12 | "rulesDirectory": [] 13 | } 14 | --------------------------------------------------------------------------------