├── tfjs-layers ├── integration_tests │ ├── tfjs2keras │ │ ├── requirements-stable.txt │ │ ├── requirements-dev.txt │ │ ├── tsconfig.json │ │ ├── package.json │ │ └── tfjs2keras_test.py │ └── benchmarks │ │ ├── README.md │ │ ├── .eslintrc.json │ │ └── .babelrc ├── src │ ├── exports_models.ts │ ├── version.ts │ ├── engine │ │ ├── dataset_stub.ts │ │ ├── dataset_fakes_test.ts │ │ └── training_utils_test.ts │ ├── errors_test.ts │ ├── backend │ │ ├── state_test.ts │ │ ├── common.ts │ │ └── state.ts │ ├── keras_format │ │ ├── utils.ts │ │ ├── loss_config.ts │ │ ├── activation_config.ts │ │ ├── regularizer_config.ts │ │ ├── training_config.ts │ │ ├── layers │ │ │ ├── padding_serialization.ts │ │ │ ├── embeddings_serialization.ts │ │ │ ├── wrappers_serialization.ts │ │ │ ├── convolutional_depthwise_serialization.ts │ │ │ ├── normalization_serialization.ts │ │ │ ├── merge_serialization.ts │ │ │ ├── layer_serialization.ts │ │ │ ├── advanced_activation_serialization.ts │ │ │ ├── pooling_serialization.ts │ │ │ ├── core_serialization.ts │ │ │ ├── convolutional_serialization.ts │ │ │ └── recurrent_serialization.ts │ │ ├── input_config.ts │ │ ├── keras_class_names.ts │ │ ├── common.ts │ │ ├── constraint_config.ts │ │ ├── topology_config.ts │ │ ├── node_config.ts │ │ ├── model_serialization.ts │ │ ├── README.md │ │ ├── optimizer_config.ts │ │ ├── initializer_config.ts │ │ └── types.ts │ ├── utils │ │ ├── variable_utils.ts │ │ ├── variable_utils_test.ts │ │ ├── types_utils_test.ts │ │ ├── types_utils.ts │ │ ├── conv_utils.ts │ │ ├── math_utils.ts │ │ ├── math_utils_test.ts │ │ ├── test_utils.ts │ │ ├── serialization_utils.ts │ │ └── serialization_utils_test.ts │ ├── version_test.ts │ ├── layers │ │ ├── serialization_test.ts │ │ └── serialization.ts │ ├── exports_regularizers.ts │ ├── optimizers.ts │ ├── exports_constraints.ts │ ├── index.ts │ ├── types_test.ts │ ├── logs.ts │ ├── errors.ts │ ├── types.ts │ ├── optimizers_test.ts │ ├── regularizers_test.ts │ ├── common.ts │ ├── user_defined_metadata.ts │ ├── regularizers.ts │ ├── common_test.ts │ ├── constraints_test.ts │ └── metrics.ts ├── scripts │ ├── test-ci.sh │ ├── tfjs2keras-js.sh │ ├── tag-version │ ├── test_snippets.ts │ ├── build-npm.sh │ ├── publish-npm.sh │ ├── make-version │ ├── tfjs2keras-py.sh │ └── switch-tfjs-core-version.sh ├── .npmignore ├── tsconfig.json ├── demos │ └── README.md ├── cloudbuild.yml ├── tslint.json ├── DEVELOPMENT.md ├── package.json ├── tools │ └── clang_format_ts.sh ├── karma.conf.js ├── rollup.config.js └── README.md ├── .gitignore ├── README.md └── ISSUE_TEMPLATE.md /tfjs-layers/integration_tests/tfjs2keras/requirements-stable.txt: -------------------------------------------------------------------------------- 1 | keras==2.2.4 2 | tensorflow==1.13.1 3 | tensorflowjs==0.8.5 4 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/tfjs2keras/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | keras==2.2.4 2 | tensorflowjs>=1.0.0 3 | tf-nightly>=1.14.1.dev20190410 4 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/tfjs2keras/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../../tsconfig.json", 3 | "compilerOptions": { 4 | "outDir": "./dist" 5 | }, 6 | "include": ["./"] 7 | } 8 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow.js: Layers API Benchmarks 2 | 3 | The performance benchmarks for tfjs-layers have moved to 4 | [a new location](https://github.com/tensorflow/tfjs/tree/master/integration_tests/benchmarks). 5 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/benchmarks/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "parserOptions": { 3 | "ecmaVersion": 2017, 4 | "sourceType": "module", 5 | "ecmaFeatures": { 6 | "jsx": true 7 | } 8 | }, 9 | "rules": { 10 | "semi": 2 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/node_modules/ 2 | **/bundle.js 3 | **/*.js.map 4 | coverage/ 5 | *.tgz 6 | npm-debug.log 7 | yarn-error.log 8 | .DS_Store 9 | dist/ 10 | .idea/ 11 | .yalc/ 12 | yalc.lock 13 | integration_tests/tfjs2keras/test-data/ 14 | 15 | bazel-* 16 | 17 | *.pyc 18 | 19 | .cache 20 | .rpt2_cache/ 21 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/benchmarks/.babelrc: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "presets": [ 4 | [ 5 | "env", 6 | { 7 | "esmodules": false, 8 | "targets": { 9 | "browsers": [ 10 | "> 3%" 11 | ] 12 | } 13 | } 14 | ] 15 | ], 16 | "plugins": [ 17 | "transform-runtime" 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository has been archived in favor of tensorflow/tfjs. 2 | 3 | This repo will remain around for some time to keep history but all future PRs should be sent to [tensorflow/tfjs](https://github.com/tensorflow/tfjs) inside the [tfjs-layers](https://github.com/tensorflow/tfjs/tree/master/tfjs-layers) folder. 4 | 5 | All history and contributions have been preserved in the monorepo. 6 | -------------------------------------------------------------------------------- /tfjs-layers/src/exports_models.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | export {modelFromJSON} from './models'; 12 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | To get help from the community, check out our [Google group](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfjs). 2 | 3 | GitHub issues for this repository are tracked in the [tfjs union repository](https://github.com/tensorflow/tfjs/issues). 4 | 5 | Please file your issue there, following the guidance in [that issue template](https://github.com/tensorflow/tfjs/blob/master/ISSUE_TEMPLATE.md). 6 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/test-ci.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Use of this source code is governed by an MIT-style 6 | # license that can be found in the LICENSE file or at 7 | # https://opensource.org/licenses/MIT. 8 | # ============================================================================= 9 | 10 | set -e 11 | 12 | # Regular testing. 13 | yarn build 14 | yarn lint 15 | yarn run-browserstack 16 | -------------------------------------------------------------------------------- /tfjs-layers/src/version.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | // This code is auto-generated, do not modify this file! 12 | const version = '1.2.7'; 13 | export {version}; 14 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/tfjs2keras/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tfjs-layers-tfjs2keras-test", 3 | "version": "0.0.1", 4 | "description": "Testing sending models from tfjs-layers to Keras", 5 | "private": false, 6 | "license": "Apache-2.0 AND MIT", 7 | "devDependencies": { 8 | "@tensorflow/tfjs-core": "1.2.3", 9 | "@tensorflow/tfjs-layers": "1.2.2", 10 | "@tensorflow/tfjs-node": "1.1.2", 11 | "clang-format": "~1.2.2" 12 | }, 13 | "scripts": { 14 | "lint": "tslint -p . --type-check -t verbose" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /tfjs-layers/.npmignore: -------------------------------------------------------------------------------- 1 | .rpt2_cache/ 2 | .vscode/ 3 | src/**/*_test.ts 4 | demos/ 5 | python/ 6 | scripts/ 7 | tools/ 8 | coverage/ 9 | package/ 10 | **/node_modules/ 11 | karma.conf.js 12 | dist/demos/ 13 | dist/**/*_fakes.js 14 | dist/**/*_test.js 15 | dist/**/*_test.d.ts 16 | integration_tests/ 17 | *.tgz 18 | .travis.yml 19 | cloudbuild.yml 20 | .npmignore 21 | .pylintrc 22 | CONTRIBUTING.md 23 | DEVELOPMENT.md 24 | ISSUE_TEMPLATE.md 25 | pull_request_template.md 26 | tslint.json 27 | tsconfig.json 28 | rollup.config.js 29 | WORKSPACE 30 | yarn.lock 31 | yarn-error.log 32 | .yalc/ 33 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/tfjs2keras-js.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Use of this source code is governed by an MIT-style 6 | # license that can be found in the LICENSE file or at 7 | # https://opensource.org/licenses/MIT. 8 | # ============================================================================= 9 | 10 | set -e 11 | 12 | TEST_DATA="test-data/" 13 | 14 | yarn link 15 | cd integration_tests/tfjs2keras/ 16 | yarn 17 | yarn link @tensorflow/tfjs-layers 18 | rm -rf "$TEST_DATA" 19 | mkdir "$TEST_DATA" 20 | node tfjs_save.js "$TEST_DATA" 21 | cd ../.. 22 | -------------------------------------------------------------------------------- /tfjs-layers/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "moduleResolution": "node", 5 | "noImplicitAny": true, 6 | "sourceMap": true, 7 | "removeComments": false, 8 | "preserveConstEnums": true, 9 | "declaration": true, 10 | "target": "es5", 11 | "lib": ["es2015", "dom"], 12 | "outDir": "./dist", 13 | "noUnusedLocals": true, 14 | "noImplicitReturns": true, 15 | "noImplicitThis": true, 16 | "alwaysStrict": true, 17 | "noUnusedParameters": false, 18 | "pretty": true, 19 | "noFallthroughCasesInSwitch": true, 20 | "allowUnreachableCode": false 21 | }, 22 | "include": ["src/"] 23 | } 24 | -------------------------------------------------------------------------------- /tfjs-layers/src/engine/dataset_stub.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Stub interfaces and classes for testing tf.LayersModel.fitDataset(). 13 | * 14 | * TODO(cais, soergel): Remove this in favor of actual interfaces and classes 15 | * when ready. 16 | */ 17 | 18 | export abstract class LazyIterator { 19 | abstract async next(): Promise>; 20 | } 21 | 22 | export abstract class Dataset { 23 | abstract async iterator(): Promise>; 24 | size: number; 25 | } 26 | -------------------------------------------------------------------------------- /tfjs-layers/src/errors_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {AttributeError, RuntimeError, ValueError} from './errors'; 12 | 13 | describe('Error classes', () => { 14 | // tslint:disable-next-line:variable-name 15 | for (const SomeClass of [AttributeError, RuntimeError, ValueError]) { 16 | it('pass instanceof tests.', () => { 17 | const msg = 'Some message'; 18 | const e = new SomeClass(msg); 19 | expect(e.message).toEqual(msg); 20 | expect(e instanceof SomeClass).toBe(true); 21 | }); 22 | } 23 | }); 24 | -------------------------------------------------------------------------------- /tfjs-layers/src/backend/state_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {getUid} from '../backend/state'; 12 | 13 | describe('getUID ', () => { 14 | it('second UID is different.', () => { 15 | const name = 'def'; 16 | const firstUID = getUid(name); 17 | const secondUID = getUid(name); 18 | expect(secondUID).not.toEqual(firstUID); 19 | }); 20 | 21 | it('with no prefix works and returns different UIDs.', () => { 22 | const firstUID = getUid(); 23 | const secondUID = getUid(); 24 | expect(firstUID).not.toEqual(secondUID); 25 | }); 26 | }); 27 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Infers a string union type from an array of string literals, and returns 13 | * the array as an array of that type. 14 | * 15 | * For instance: 16 | * 17 | * ``` 18 | * const fruits = stringLiteralArray(['apple', 'banana', 'orange']); 19 | * type Fruit = typeof activationOptions[number]; 20 | * ``` 21 | * 22 | * now `Fruit` is the union type `'apple'|'banana'|'orange'`. 23 | * 24 | * https://stackoverflow.com/questions/52085454/typescript-define-a-union-type-from-an-array-of-strings/52085658 25 | */ 26 | export function stringLiteralArray(a: T[]) { 27 | return a; 28 | } 29 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/variable_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {LayerVariable} from '../variables'; 12 | 13 | /** 14 | * Count the elements in an Array of LayerVariables. 15 | * 16 | * @param weights: The LayerVariables of which the constituent numbers are to 17 | * be counted. 18 | * @returns A count of the elements in all the LayerVariables 19 | */ 20 | export function countParamsInWeights(weights: LayerVariable[]): number { 21 | let count = 0; 22 | for (const weight of weights) { 23 | if (weight.shape.length === 0) { 24 | count += 1; 25 | } else { 26 | count += weight.shape.reduce((a, b) => a * b); 27 | } 28 | } 29 | return count; 30 | } 31 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/loss_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {stringLiteralArray} from './utils'; 12 | 13 | /** 14 | * List of all known loss names. 15 | */ 16 | export const lossOptions = stringLiteralArray([ 17 | 'mean_squared_error', 'mean_absolute_error', 'mean_absolute_percentage_error', 18 | 'mean_squared_logarithmic_error', 'squared_hinge', 'hinge', 19 | 'categorical_hinge', 'logcosh', 'categorical_crossentropy', 20 | 'sparse_categorical_crossentropy', 'kullback_leibler_divergence', 'poisson', 21 | 'cosine_proximity' 22 | ]); 23 | 24 | /** 25 | * A type representing the strings that are valid loss names. 26 | */ 27 | export type LossIdentifier = typeof lossOptions[number]; 28 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/tag-version: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2018 Google LLC 3 | // 4 | // Use of this source code is governed by an MIT-style 5 | // license that can be found in the LICENSE file or at 6 | // https://opensource.org/licenses/MIT. 7 | // ============================================================================= 8 | 9 | 10 | // Run this script from the base directory (not the script directory): 11 | // ./scripts/make-version 12 | 13 | var fs = require('fs'); 14 | var exec = require('child_process').exec; 15 | 16 | var version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version; 17 | var tag = `v${version}`; 18 | 19 | exec(`git tag ${tag}`, err => { 20 | if (err) { 21 | throw new Error(`Could not git tag with ${tag}: ${err.message}.`); 22 | } 23 | console.log(`Successfully tagged with ${tag}.`); 24 | }); 25 | 26 | exec(`git push --tags`, err => { 27 | if (err) { 28 | throw new Error(`Could not push git tags: ${err.message}.`); 29 | } 30 | console.log(`Successfully pushed tags.`); 31 | }); 32 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/test_snippets.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import * as tfc from '@tensorflow/tfjs-core'; 18 | import {parseAndEvaluateSnippets} from '@tensorflow/tfjs-core/dist/scripts/test_snippets/util'; 19 | 20 | import * as tfl from '../src/index'; 21 | 22 | const tf = { 23 | ...tfl, 24 | ...tfc 25 | }; 26 | parseAndEvaluateSnippets(tf); 27 | -------------------------------------------------------------------------------- /tfjs-layers/src/backend/common.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {backend} from '@tensorflow/tfjs-core'; 12 | import {DataFormat} from '../keras_format/common'; 13 | 14 | let _epsilon: number; 15 | 16 | /** 17 | * Returns the value of the fuzz factor used in numeric expressions. 18 | */ 19 | export function epsilon() { 20 | if (_epsilon == null) { 21 | _epsilon = backend().epsilon(); 22 | } 23 | return _epsilon; 24 | } 25 | 26 | /** 27 | * Sets the value of the fuzz factor used in numeric expressions. 28 | * @param e New value of epsilon. 29 | */ 30 | export function setEpsilon(e: number) { 31 | _epsilon = e; 32 | } 33 | 34 | /** 35 | * Returns the default image data format convention. 36 | */ 37 | export function imageDataFormat(): DataFormat { 38 | return 'channelsLast'; 39 | } 40 | -------------------------------------------------------------------------------- /tfjs-layers/demos/README.md: -------------------------------------------------------------------------------- 1 | # tfjs-layers benchmarks 2 | 3 | To run the benchmark script, first set up your environment. 4 | 5 | (You may wish to set up Python the requirements in a virtual environment using 6 | [pipenv](https://github.com/pypa/pipenv) or [virtualenv](https://virtualenv.pypa.io)) 7 | 8 | ``` 9 | pip install tensorflowjs 10 | ``` 11 | 12 | Once the development environment is prepared, execute the build script from the root of tfjs-layers. 13 | 14 | ``` 15 | ./scripts/build-benchmarks-demo.sh 16 | ``` 17 | 18 | The script will construct a number of Keras models in Python and benchmark their training using the TensorFlow backend. When it is complete, it will bring up a 19 | local HTTP server. Navigate to the local URL spcecified in stdout to bring up 20 | the benchmarks page UI. There will be a button to begin the JS side of the 21 | benchmarks. Clicking the button will run through and time the same models, now 22 | running in the browser. 23 | 24 | Once complete, the models' `fit()` and `predict()` costs are listed in a table. 25 | 26 | Prese Ctl-C to end the http-server process. 27 | -------------------------------------------------------------------------------- /tfjs-layers/src/version_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | // tslint:disable-next-line:no-require-imports 12 | const packageJSON = require('../package.json'); 13 | import {version_layers} from './index'; 14 | 15 | describe('tfjs-core version consistency', () => { 16 | it('dev-peer match', () => { 17 | const tfjsCoreDevDepVersion = 18 | packageJSON.devDependencies['@tensorflow/tfjs-core']; 19 | const tfjsCorePeerDepVersion = 20 | packageJSON.peerDependencies['@tensorflow/tfjs-core']; 21 | expect(tfjsCoreDevDepVersion).toEqual(tfjsCorePeerDepVersion); 22 | }); 23 | 24 | it('version.ts matches package version', () => { 25 | // tslint:disable-next-line:no-require-imports 26 | const expected = require('../package.json').version; 27 | expect(version_layers).toBe(expected); 28 | }); 29 | }); 30 | -------------------------------------------------------------------------------- /tfjs-layers/src/backend/state.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Utilities related to persistent state in the backend. 13 | */ 14 | 15 | /** 16 | * An ID to track `tf.SymbolicTensor`s and derived classes. 17 | * Required in different places in engine/topology.ts to identify unique 18 | * tensors. 19 | */ 20 | let _nextUniqueTensorId = 0; 21 | 22 | export function getNextUniqueTensorId(): number { 23 | return _nextUniqueTensorId++; 24 | } 25 | 26 | const _uidPrefixes: {[prefix: string]: number} = {}; 27 | 28 | /** 29 | * Provides a unique UID given a string prefix. 30 | * 31 | * @param prefix 32 | */ 33 | export function getUid(prefix = ''): string { 34 | if (!(prefix in _uidPrefixes)) { 35 | _uidPrefixes[prefix] = 0; 36 | } 37 | _uidPrefixes[prefix] += 1; 38 | return prefix + _uidPrefixes[prefix].toString(); 39 | } 40 | -------------------------------------------------------------------------------- /tfjs-layers/src/layers/serialization_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {serialization} from '@tensorflow/tfjs-core'; 12 | 13 | import {Initializer, Ones, Zeros} from '../initializers'; 14 | import {deserialize} from './serialization'; 15 | 16 | describe('Deserialization', () => { 17 | it('Zeros Initialzer', () => { 18 | const config: serialization.ConfigDict = {}; 19 | config['className'] = 'Zeros'; 20 | config.config = {}; 21 | const initializer: Zeros = deserialize(config) as Initializer; 22 | expect(initializer instanceof (Zeros)).toEqual(true); 23 | }); 24 | it('Ones Initialzer', () => { 25 | const config: serialization.ConfigDict = {}; 26 | config['className'] = 'Ones'; 27 | config.config = {}; 28 | const initializer: Ones = deserialize(config) as Initializer; 29 | expect(initializer instanceof (Ones)).toEqual(true); 30 | }); 31 | }); 32 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/build-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | # Exit immediately if a command exits with a non-zero status. 18 | set -e 19 | 20 | rimraf dist/ 21 | yarn 22 | yarn build 23 | rollup -c 24 | 25 | # Use minified files for miniprogram 26 | mkdir dist/miniprogram 27 | cp dist/tf-layers.min.js dist/miniprogram/index.js 28 | cp dist/tf-layers.min.js.map dist/miniprogram/index.js.map 29 | 30 | echo "Stored standalone library at dist/tf-layers(.min).js" 31 | npm pack 32 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/activation_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {stringLiteralArray} from './utils'; 12 | 13 | /** 14 | * List of all known activation names. 15 | */ 16 | export const activationOptions = stringLiteralArray([ 17 | 'elu', 'hard_sigmoid', 'linear', 'relu', 'relu6', 'selu', 'sigmoid', 18 | 'softmax', 'softplus', 'softsign', 'tanh' 19 | ]); 20 | 21 | /** 22 | * A type representing the strings that are valid loss names. 23 | */ 24 | export type ActivationSerialization = typeof activationOptions[number]; 25 | 26 | // Sad that we have to do all this just for hard_sigmoid vs. hardSigmoid. 27 | // TODO(soergel): Move the CamelCase versions back out of keras_format 28 | // e.g. to src/common.ts. Maybe even duplicate *all* of these to be pedantic? 29 | /** @docinline */ 30 | export type ActivationIdentifier = 'elu'|'hardSigmoid'|'linear'|'relu'|'relu6'| 31 | 'selu'|'sigmoid'|'softmax'|'softplus'|'softsign'|'tanh'; 32 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/regularizer_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BaseSerialization} from './types'; 12 | 13 | export type L1L2Config = { 14 | l1?: number; 15 | l2?: number; 16 | }; 17 | 18 | export type L1L2Serialization = BaseSerialization<'L1L2', L1L2Config>; 19 | 20 | // Update regularizerClassNames below in concert with this. 21 | export type RegularizerSerialization = L1L2Serialization; 22 | 23 | export type RegularizerClassName = RegularizerSerialization['class_name']; 24 | 25 | // We can't easily extract a string[] from the string union type, but we can 26 | // recapitulate the list, enforcing at compile time that the values are valid 27 | // and that we have the right number of them. 28 | 29 | /** 30 | * A string array of valid Regularizer class names. 31 | * 32 | * This is guaranteed to match the `RegularizerClassName` union type. 33 | */ 34 | export const regularizerClassNames: RegularizerClassName[] = ['L1L2']; 35 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/publish-npm.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Use of this source code is governed by an MIT-style 4 | # license that can be found in the LICENSE file or at 5 | # https://opensource.org/licenses/MIT. 6 | # ============================================================================= 7 | 8 | # Before you run this script, do: 9 | # 1) Update the version in package.json 10 | # 2) Run ./scripts/make-version from the base dir of the project. 11 | # 3) Run `yarn` to update `yarn.lock`, in case you updated dependencies 12 | # 4) Commit to the master branch. 13 | 14 | # Then: 15 | # 5) Checkout the master branch of this repo. 16 | # 6) Run this script as `./scripts/publish-npm.sh` from the project base dir. 17 | 18 | set -e 19 | 20 | BRANCH=`git rev-parse --abbrev-ref HEAD` 21 | ORIGIN=`git config --get remote.origin.url` 22 | 23 | if [ "$BRANCH" != "master" ]; then 24 | echo "Error: Switch to the master branch before publishing." 25 | exit 26 | fi 27 | 28 | if ! [[ "$ORIGIN" =~ tensorflow/tfjs-layers ]]; then 29 | echo "Error: Switch to the main repo (tensorflow/tfjs-layers)." 30 | exit 31 | fi 32 | 33 | yarn build-npm 34 | ./scripts/make-version # This is for safety in case you forgot to do 2). 35 | ./scripts/tag-version 36 | npm publish 37 | echo 'Yay! Published a new package to npm.' 38 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/make-version: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2018 Google LLC 3 | // 4 | // Use of this source code is governed by an MIT-style 5 | // license that can be found in the LICENSE file or at 6 | // https://opensource.org/licenses/MIT. 7 | // ============================================================================= 8 | 9 | 10 | // Run this script from the base directory (not the script directory): 11 | // ./scripts/make-version 12 | 13 | const fs = require('fs'); 14 | 15 | const packageJSON = JSON.parse(fs.readFileSync('package.json', 'utf8')); 16 | const version = packageJSON.version; 17 | 18 | const tag = `v${version}`; 19 | 20 | const versionCode = 21 | `/** 22 | * @license 23 | * Copyright 2019 Google LLC 24 | * 25 | * Use of this source code is governed by an MIT-style 26 | * license that can be found in the LICENSE file or at 27 | * https://opensource.org/licenses/MIT. 28 | * ============================================================================= 29 | */ 30 | 31 | // This code is auto-generated, do not modify this file! 32 | const version = '${version}'; 33 | export {version}; 34 | `; 35 | 36 | fs.writeFile('src/version.ts', versionCode, err => { 37 | if (err) { 38 | throw new Error(`Could not save version file ${version}: ${err}`); 39 | } 40 | console.log(`Version file for version ${version} saved sucessfully.`); 41 | }); 42 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/training_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | import {SampleWeightMode} from './common'; 11 | import {LossIdentifier} from './loss_config'; 12 | import {OptimizerSerialization} from './optimizer_config'; 13 | import {PyJsonDict} from './types'; 14 | 15 | // TODO(soergel): flesh out known metrics options 16 | export type MetricsIdentifier = string; 17 | 18 | /** 19 | * a type for valid values of the `loss_weights` field. 20 | */ 21 | export type LossWeights = number[]|{[key: string]: number}; 22 | 23 | /** 24 | * Configuration of the Keras trainer. This includes the configuration to the 25 | * optimizer, the loss, any metrics to be calculated, etc. 26 | */ 27 | export interface TrainingConfig extends PyJsonDict { 28 | // tslint:disable-next-line:no-any 29 | optimizer_config: OptimizerSerialization; 30 | loss: LossIdentifier|LossIdentifier[]|{[key: string]: LossIdentifier}; 31 | metrics?: MetricsIdentifier[]|{[key: string]: MetricsIdentifier}; 32 | weighted_metrics?: MetricsIdentifier[]; 33 | sample_weight_mode?: SampleWeightMode; 34 | loss_weights?: LossWeights; 35 | } 36 | -------------------------------------------------------------------------------- /tfjs-layers/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: 'node:10' 3 | entrypoint: 'yarn' 4 | id: 'yarn' 5 | args: ['prep'] 6 | - name: 'node:10' 7 | entrypoint: 'yarn' 8 | id: 'test-browser' 9 | args: ['test-ci'] 10 | waitFor: ['yarn'] 11 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1', 'NIGHTLY=$_NIGHTLY'] 12 | secretEnv: ['BROWSERSTACK_KEY'] 13 | # - name: 'node:10' # TODO(cais): Reinstate the tests after new tfjs-node release and/or smarting linking. 14 | # entrypoint: 'yarn' 15 | # id: 'tfjs2keras-js' 16 | # args: ['tfjs2keras-js'] 17 | # waitFor: ['yarn'] 18 | # - name: 'python:2' 19 | # entrypoint: 'bash' 20 | # id: 'tfjs2keras-py' 21 | # args: ['-c', './scripts/tfjs2keras-py.sh --stable && ./scripts/tfjs2keras-py.sh --stable --tfkeras && ./scripts/tfjs2keras-py.sh --dev --tfkeras'] 22 | # waitFor: ['tfjs2keras-js'] 23 | - name: 'node:10' 24 | entrypoint: 'yarn' 25 | id: 'test-snippets' 26 | args: ['test-snippets'] 27 | waitFor: ['yarn'] 28 | secrets: 29 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc 30 | secretEnv: 31 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA= 32 | timeout: 1800s 33 | logsBucket: 'gs://tfjs-build-logs' 34 | substitutions: 35 | _NIGHTLY: '' 36 | options: 37 | logStreamingOption: 'STREAM_ON' 38 | substitution_option: 'ALLOW_LOOSE' 39 | -------------------------------------------------------------------------------- /tfjs-layers/src/layers/serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /* Original Source layers/__init__.py */ 12 | import {serialization} from '@tensorflow/tfjs-core'; 13 | 14 | import {deserializeKerasObject} from '../utils/generic_utils'; 15 | 16 | /** 17 | * Instantiate a layer from a config dictionary. 18 | * @param config dict of the form {class_name: str, config: dict} 19 | * @param customObjects dict mapping class names (or function names) 20 | * of custom (non-Keras) objects to class/functions 21 | * @param fastWeightInit Optional flag to use fast weight initialization 22 | * during deserialization. This is applicable to cases in which 23 | * the initialization will be immediately overwritten by loaded weight 24 | * values. Default: `false`. 25 | * @returns Layer instance (may be LayersModel, Sequential, Layer...) 26 | */ 27 | export function deserialize( 28 | config: serialization.ConfigDict, 29 | customObjects = {} as serialization.ConfigDict, 30 | fastWeightInit = false): serialization.Serializable { 31 | return deserializeKerasObject( 32 | config, serialization.SerializationMap.getMap().classNameMap, 33 | customObjects, 'layer', fastWeightInit); 34 | } 35 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/padding_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {DataFormatSerialization} from '../common'; 12 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 13 | 14 | export interface ZeroPadding2DLayerConfig extends LayerConfig { 15 | padding?: number|[number, number]|[[number, number], [number, number]]; 16 | data_format?: DataFormatSerialization; 17 | } 18 | 19 | // Update paddingLayerClassNames below in concert with this. 20 | export type ZeroPadding2DLayerSerialization = 21 | BaseLayerSerialization<'ZeroPadding2D', ZeroPadding2DLayerConfig>; 22 | 23 | export type PaddingLayerSerialization = ZeroPadding2DLayerSerialization; 24 | 25 | export type PaddingLayerClassName = PaddingLayerSerialization['class_name']; 26 | 27 | // We can't easily extract a string[] from the string union type, but we can 28 | // recapitulate the list, enforcing at compile time that the values are valid. 29 | 30 | /** 31 | * A string array of valid PaddingLayer class names. 32 | * 33 | * This is guaranteed to match the `PaddingLayerClassName` union type. 34 | */ 35 | export const paddingLayerClassNames: PaddingLayerClassName[] = [ 36 | 'ZeroPadding2D', 37 | ]; 38 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/input_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {DataType} from '@tensorflow/tfjs-core'; 12 | import {Shape} from './common'; 13 | import {BaseLayerSerialization} from './topology_config'; 14 | 15 | 16 | export type InputLayerConfig = { 17 | name?: string; 18 | input_shape?: Shape; 19 | batch_size?: number; 20 | batch_input_shape?: Shape; 21 | dtype?: DataType; 22 | sparse?: boolean; 23 | }; 24 | 25 | // This really should be BaseSerialization because an input layer has no 26 | // inbound_nodes. But, that makes type safety more difficult. 27 | 28 | // Update inputLayerClassNames below in concert with this. 29 | export type InputLayerSerialization = 30 | BaseLayerSerialization<'InputLayer', InputLayerConfig>; 31 | 32 | export type InputLayerClassName = InputLayerSerialization['class_name']; 33 | 34 | // We can't easily extract a string[] from the string union type, but we can 35 | // recapitulate the list, enforcing at compile time that the values are valid. 36 | 37 | /** 38 | * A string array of valid InputLayer class names. 39 | * 40 | * This is guaranteed to match the `InputLayerClassName` union type. 41 | */ 42 | export const inputLayerClassNames: InputLayerClassName[] = [ 43 | 'InputLayer', 44 | ]; 45 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/keras_class_names.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {constraintClassNames, ConstraintSerialization} from './constraint_config'; 12 | import {initializerClassNames, InitializerSerialization} from './initializer_config'; 13 | import {layerClassNames, LayerSerialization} from './layers/layer_serialization'; 14 | import {optimizerClassNames, OptimizerSerialization} from './optimizer_config'; 15 | import {regularizerClassNames, RegularizerSerialization} from './regularizer_config'; 16 | 17 | /** 18 | * A type representing all possible Serializations of Keras objects, including 19 | * Layers, Constraints, Optimizers, etc. 20 | */ 21 | export type KerasSerialization = LayerSerialization|ConstraintSerialization| 22 | InitializerSerialization|RegularizerSerialization|OptimizerSerialization; 23 | 24 | /** 25 | * A type representing all valid values of `class_name` in a Keras JSON file 26 | * (regardless of context, which will naturally further restrict the valid 27 | * values). 28 | */ 29 | export type KerasClassName = KerasSerialization['class_name']; 30 | 31 | export const kerasClassNames: KerasClassName[] = [ 32 | ...layerClassNames, ...constraintClassNames, ...initializerClassNames, 33 | ...regularizerClassNames, ...optimizerClassNames 34 | ]; 35 | -------------------------------------------------------------------------------- /tfjs-layers/src/exports_regularizers.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | import * as regularizers from './regularizers'; 11 | // tslint:disable-next-line:max-line-length 12 | import {L1Args, L1L2, L1L2Args, L2Args, Regularizer} from './regularizers'; 13 | 14 | /** 15 | * Regularizer for L1 and L2 regularization. 16 | * 17 | * Adds a term to the loss to penalize large weights: 18 | * loss += sum(l1 * abs(x)) + sum(l2 * x^2) 19 | */ 20 | /** @doc {heading: 'Regularizers', namespace: 'regularizers'} */ 21 | export function l1l2(config?: L1L2Args): Regularizer { 22 | return new L1L2(config); 23 | } 24 | 25 | /** 26 | * Regularizer for L1 regularization. 27 | * 28 | * Adds a term to the loss to penalize large weights: 29 | * loss += sum(l1 * abs(x)) 30 | * @param args l1 config. 31 | */ 32 | /** @doc {heading: 'Regularizers', namespace: 'regularizers'} */ 33 | export function l1(config?: L1Args): Regularizer { 34 | return regularizers.l1(config); 35 | } 36 | 37 | /** 38 | * Regularizer for L2 regularization. 39 | * 40 | * Adds a term to the loss to penalize large weights: 41 | * loss += sum(l2 * x^2) 42 | * @param args l2 config. 43 | */ 44 | /** @doc {heading: 'Regularizers', namespace: 'regularizers'} */ 45 | export function l2(config?: L2Args): Regularizer { 46 | return regularizers.l2(config); 47 | } 48 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/variable_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {scalar, zeros} from '@tensorflow/tfjs-core'; 12 | 13 | import {LayerVariable} from '../variables'; 14 | 15 | import * as variable_utils from './variable_utils'; 16 | 17 | describe('countParamsInWeights', () => { 18 | it('Zero weights', () => { 19 | expect(variable_utils.countParamsInWeights([])).toEqual(0); 20 | }); 21 | 22 | it('One float32 weight', () => { 23 | const weight1 = new LayerVariable(zeros([2, 3])); 24 | expect(variable_utils.countParamsInWeights([weight1])).toEqual(6); 25 | }); 26 | 27 | it('One float32 scalar weight', () => { 28 | const weight1 = new LayerVariable(scalar(42)); 29 | expect(variable_utils.countParamsInWeights([weight1])).toEqual(1); 30 | }); 31 | 32 | it('One int32 weight', () => { 33 | const weight1 = new LayerVariable(zeros([1, 3, 4], 'int32'), 'int32'); 34 | expect(variable_utils.countParamsInWeights([weight1])).toEqual(12); 35 | }); 36 | 37 | it('Two weights, mixed types and shapes', () => { 38 | const weight1 = new LayerVariable(scalar(42)); 39 | const weight2 = new LayerVariable(zeros([2, 3])); 40 | const weight3 = new LayerVariable(zeros([1, 3, 4], 'int32'), 'int32'); 41 | expect(variable_utils.countParamsInWeights([ 42 | weight1, weight2, weight3 43 | ])).toEqual(19); 44 | }); 45 | }); 46 | -------------------------------------------------------------------------------- /tfjs-layers/src/optimizers.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Optimizers. 13 | */ 14 | 15 | import {Optimizer, train} from '@tensorflow/tfjs-core'; 16 | 17 | import {epsilon} from './backend/common'; 18 | 19 | import {ValueError} from './errors'; 20 | 21 | // Add (de)serialize() 22 | 23 | // Porting note: This diverges from the PyKeras implementation and may need to 24 | // change based on (de)serialization requirements. 25 | export function getOptimizer(identifier: string): Optimizer { 26 | const optimizerMap: {[optimizerName: string]: () => Optimizer} = { 27 | 'Adagrad': () => train.adagrad(0.01), 28 | 'Adadelta': () => train.adadelta(1, 0.95, epsilon()), 29 | 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()), 30 | 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0), 31 | 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()), 32 | 'SGD': () => train.sgd(0.01) 33 | }; 34 | optimizerMap['adagrad'] = optimizerMap['Adagrad']; 35 | optimizerMap['adadelta'] = optimizerMap['Adadelta']; 36 | optimizerMap['adam'] = optimizerMap['Adam']; 37 | optimizerMap['adamax'] = optimizerMap['Adamax']; 38 | optimizerMap['rmsprop'] = optimizerMap['RMSProp']; 39 | optimizerMap['sgd'] = optimizerMap['SGD']; 40 | 41 | if (identifier in optimizerMap) { 42 | return optimizerMap[identifier](); 43 | } 44 | throw new ValueError(`Unknown Optimizer ${identifier}`); 45 | } 46 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/common.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | // TODO(huan): add layer-specific input shape types (see: https://github.com/tensorflow/tfjs-layers/pull/492) 12 | /** @docalias (null | number)[] */ 13 | export type Shape = Array; 14 | 15 | // The tfjs-core version of DataType must stay synced with this. 16 | export type DataType = 'float32'|'int32'|'bool'|'complex64'|'string'; 17 | 18 | // TODO(soergel): Move the CamelCase versions back out of keras_format 19 | // e.g. to src/common.ts. Maybe even duplicate *all* of these to be pedantic? 20 | /** @docinline */ 21 | export type DataFormat = 'channelsFirst'|'channelsLast'; 22 | export const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast']; 23 | 24 | // These constants have a snake vs. camel distinction. 25 | export type DataFormatSerialization = 'channels_first'|'channels_last'; 26 | 27 | /** @docinline */ 28 | export type PaddingMode = 'valid'|'same'|'causal'; 29 | export const VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal']; 30 | 31 | /** @docinline */ 32 | export type PoolMode = 'max'|'avg'; 33 | export const VALID_POOL_MODE_VALUES = ['max', 'avg']; 34 | 35 | /** @docinline */ 36 | export type BidirectionalMergeMode = 'sum'|'mul'|'concat'|'ave'; 37 | export const VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave']; 38 | 39 | /** @docinline */ 40 | export type SampleWeightMode = 'temporal'; 41 | export const VALID_SAMPLE_WEIGHT_MODES = ['temporal']; 42 | -------------------------------------------------------------------------------- /tfjs-layers/src/exports_constraints.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | // tslint:disable-next-line:max-line-length 11 | import {Constraint, MaxNorm, MaxNormArgs, MinMaxNorm, MinMaxNormArgs, NonNeg, UnitNorm, UnitNormArgs} from './constraints'; 12 | 13 | 14 | /** 15 | * MaxNorm weight constraint. 16 | * 17 | * Constrains the weights incident to each hidden unit 18 | * to have a norm less than or equal to a desired value. 19 | * 20 | * References 21 | * - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting 22 | * Srivastava, Hinton, et al. 23 | * 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) 24 | */ 25 | /** @doc {heading: 'Constraints',namespace: 'constraints'} */ 26 | export function maxNorm(args: MaxNormArgs): Constraint { 27 | return new MaxNorm(args); 28 | } 29 | 30 | /** 31 | * Constrains the weights incident to each hidden unit to have unit norm. 32 | */ 33 | /** @doc {heading: 'Constraints', namespace: 'constraints'} */ 34 | export function unitNorm(args: UnitNormArgs): Constraint { 35 | return new UnitNorm(args); 36 | } 37 | 38 | /** 39 | * Constains the weight to be non-negative. 40 | */ 41 | /** @doc {heading: 'Constraints', namespace: 'constraints'} */ 42 | export function nonNeg(): Constraint { 43 | return new NonNeg(); 44 | } 45 | 46 | /** @doc {heading: 'Constraints', namespace: 'constraints'} */ 47 | export function minMaxNorm(config: MinMaxNormArgs): Constraint { 48 | return new MinMaxNorm(config); 49 | } 50 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/embeddings_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ConstraintSerialization} from '../constraint_config'; 12 | import {InitializerSerialization} from '../initializer_config'; 13 | import {RegularizerSerialization} from '../regularizer_config'; 14 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 15 | 16 | export interface EmbeddingLayerConfig extends LayerConfig { 17 | input_dim: number; 18 | output_dim: number; 19 | embeddings_initializer?: InitializerSerialization; 20 | embeddings_regularizer?: RegularizerSerialization; 21 | activity_regularizer?: RegularizerSerialization; 22 | embeddings_constraint?: ConstraintSerialization; 23 | mask_zero?: boolean; 24 | input_length?: number|number[]; 25 | } 26 | 27 | // Update embeddingLayerClassNames below in concert with this. 28 | export type EmbeddingLayerSerialization = 29 | BaseLayerSerialization<'Embedding', EmbeddingLayerConfig>; 30 | 31 | export type EmbeddingLayerClassName = EmbeddingLayerSerialization['class_name']; 32 | 33 | // We can't easily extract a string[] from the string union type, but we can 34 | // recapitulate the list, enforcing at compile time that the values are valid. 35 | 36 | /** 37 | * A string array of valid EmbeddingLayer class names. 38 | * 39 | * This is guaranteed to match the `EmbeddingLayerClassName` union type. 40 | */ 41 | export const embeddingLayerClassNames: EmbeddingLayerClassName[] = [ 42 | 'Embedding', 43 | ]; 44 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/tfjs2keras-py.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Google LLC 4 | # 5 | # Use of this source code is governed by an MIT-style 6 | # license that can be found in the LICENSE file or at 7 | # https://opensource.org/licenses/MIT. 8 | # ============================================================================= 9 | 10 | set -e 11 | 12 | cd integration_tests/tfjs2keras 13 | DEV_VERSION="" 14 | TFJS2KERAS_TEST_USING_TF_KERAS=0 15 | while [[ ! -z "$1" ]]; do 16 | if [[ "$1" == "--stable" ]]; then 17 | DEV_VERSION="stable" 18 | elif [[ "$1" == "--dev" ]]; then 19 | DEV_VERSION="dev" 20 | elif [[ "$1" == "--tfkeras" ]]; then 21 | TFJS2KERAS_TEST_USING_TF_KERAS=1 22 | else 23 | echo "ERROR: Unrecognized command-line flag $1" 24 | exit 1 25 | fi 26 | shift 27 | done 28 | 29 | echo "DEV_VERSION: ${DEV_VERSION}" 30 | echo "TFJS2KERAS_TEST_USING_TF_KERAS: ${TFJS2KERAS_TEST_USING_TF_KERAS}" 31 | 32 | if [[ -z "${DEV_VERSION}" ]]; then 33 | echo "Must specify one of --stable and --dev." 34 | exit 1 35 | fi 36 | 37 | if [[ "${DEV_VERSION}" == "dev" && 38 | "${TFJS2KERAS_TEST_USING_TF_KERAS}" == "0" ]]; then 39 | echo "--dev && keras-team/keras is not a valid combination." 40 | echo "Use --dev and --tfkeras together." 41 | exit 1 42 | fi 43 | 44 | VENV_DIR="$(mktemp -d)_venv" 45 | echo "Creating virtualenv at ${VENV_DIR} ..." 46 | virtualenv "${VENV_DIR}" 47 | source "${VENV_DIR}/bin/activate" 48 | 49 | if [[ "${DEV_VERSION}" == "stable" ]]; then 50 | pip install -r requirements-stable.txt 51 | else 52 | pip install -r requirements-dev.txt 53 | fi 54 | 55 | export TFJS2KERAS_TEST_USING_TF_KERAS="${TFJS2KERAS_TEST_USING_TF_KERAS}" 56 | 57 | python tfjs2keras_test.py 58 | 59 | # Clean up virtualenv directory. 60 | rm -rf "${VENV_DIR}" 61 | cd ../.. 62 | -------------------------------------------------------------------------------- /tfjs-layers/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | // This file lists all exports of TensorFlow.js Layers 12 | 13 | import * as constraints from './exports_constraints'; 14 | import * as initializers from './exports_initializers'; 15 | import * as layers from './exports_layers'; 16 | import * as metrics from './exports_metrics'; 17 | import * as models from './exports_models'; 18 | import * as regularizers from './exports_regularizers'; 19 | 20 | export {CallbackList, CustomCallback, CustomCallbackArgs, History} from './base_callbacks'; 21 | export {Callback, callbacks, EarlyStopping, EarlyStoppingCallbackArgs} from './callbacks'; 22 | export {InputSpec, SymbolicTensor} from './engine/topology'; 23 | export {LayersModel, ModelCompileArgs, ModelEvaluateArgs} from './engine/training'; 24 | export {ClassWeight, ClassWeightMap} from './engine/training_utils'; 25 | export {ModelFitDatasetArgs} from './engine/training_dataset'; 26 | export {ModelFitArgs} from './engine/training_tensors'; 27 | export {input, loadLayersModel, model, registerCallbackConstructor, sequential} from './exports'; 28 | export {Shape} from './keras_format/common'; 29 | export {GRUCellLayerArgs, GRULayerArgs, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNLayerArgs, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs} from './layers/recurrent'; 30 | export {Logs} from './logs'; 31 | export {ModelAndWeightsConfig, Sequential, SequentialArgs} from './models'; 32 | export {LayerVariable} from './variables'; 33 | export {version as version_layers} from './version'; 34 | export {constraints, initializers, layers, metrics, models, regularizers}; 35 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/constraint_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BaseSerialization} from './types'; 12 | 13 | export type MaxNormConfig = { 14 | max_value?: number; 15 | axis?: number; 16 | }; 17 | 18 | export type MaxNormSerialization = BaseSerialization<'MaxNorm', MaxNormConfig>; 19 | 20 | export type UnitNormConfig = { 21 | axis?: number; 22 | }; 23 | 24 | export type UnitNormSerialization = 25 | BaseSerialization<'UnitNorm', UnitNormConfig>; 26 | 27 | export type NonNegSerialization = BaseSerialization<'NonNeg', null>; 28 | 29 | export type MinMaxNormConfig = { 30 | min_value?: number; 31 | max_value?: number; 32 | axis?: number; 33 | rate?: number; 34 | }; 35 | 36 | export type MinMaxNormSerialization = 37 | BaseSerialization<'MinMaxNorm', MinMaxNormConfig>; 38 | 39 | // Update constraintClassNames below in concert with this. 40 | export type ConstraintSerialization = MaxNormSerialization|NonNegSerialization| 41 | UnitNormSerialization|MinMaxNormSerialization; 42 | 43 | export type ConstraintClassName = ConstraintSerialization['class_name']; 44 | 45 | // We can't easily extract a string[] from the string union type, but we can 46 | // recapitulate the list, enforcing at compile time that the values are valid 47 | // and that we have the right number of them. 48 | 49 | /** 50 | * A string array of valid Constraint class names. 51 | * 52 | * This is guaranteed to match the `ConstraintClassName` union type. 53 | */ 54 | export const constraintClassNames: ConstraintClassName[] = 55 | ['MaxNorm', 'UnitNorm', 'NonNeg', 'MinMaxNorm']; 56 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/wrappers_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BidirectionalMergeMode} from '../common'; 12 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 13 | import {LayerSerialization} from './layer_serialization'; 14 | import {RecurrentLayerSerialization} from './recurrent_serialization'; 15 | 16 | 17 | export type TimeDistributedLayerSerialization = 18 | BaseLayerSerialization<'TimeDistributed', TimeDistributedLayerConfig>; 19 | 20 | export interface TimeDistributedLayerConfig extends LayerConfig { 21 | layer: LayerSerialization; 22 | } 23 | 24 | export type BidirectionalLayerSerialization = 25 | BaseLayerSerialization<'Bidirectional', BidirectionalLayerConfig>; 26 | 27 | export interface BidirectionalLayerConfig extends LayerConfig { 28 | layer: RecurrentLayerSerialization; 29 | merge_mode?: BidirectionalMergeMode; 30 | } 31 | 32 | // Update wrapperLayerClassNames below in concert with this. 33 | export type WrapperLayerSerialization = 34 | TimeDistributedLayerSerialization|BidirectionalLayerSerialization; 35 | 36 | export type WrapperLayerClassName = WrapperLayerSerialization['class_name']; 37 | 38 | // We can't easily extract a string[] from the string union type, but we can 39 | // recapitulate the list, enforcing at compile time that the values are valid. 40 | 41 | /** 42 | * A string array of valid WrapperLayer class names. 43 | * 44 | * This is guaranteed to match the `WrapperLayerClassName` union type. 45 | */ 46 | export const wrapperLayerClassNames: WrapperLayerClassName[] = [ 47 | 'Bidirectional', 48 | 'TimeDistributed', 49 | ]; 50 | -------------------------------------------------------------------------------- /tfjs-layers/src/types_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Unit tests for -specific types. 13 | */ 14 | 15 | 16 | import {SymbolicTensor} from './engine/topology'; 17 | 18 | 19 | /** 20 | * Unit tests for SymbolicTensor. 21 | */ 22 | describe('SymbolicTensor Test', () => { 23 | it('Correct dtype and shape properties', () => { 24 | const st1 = new SymbolicTensor('float32', [4, 6], null, [], {}); 25 | expect(st1.dtype).toEqual('float32'); 26 | expect(st1.shape).toEqual([4, 6]); 27 | expect(st1.rank).toEqual(2); 28 | }); 29 | it('Correct when operating on scalars', () => { 30 | const scalar = new SymbolicTensor('float32', [], null, [], {}); 31 | expect(scalar.dtype).toEqual('float32'); 32 | expect(scalar.shape).toEqual([]); 33 | expect(scalar.rank).toEqual(0); 34 | }); 35 | 36 | it('Correct names and ids', () => { 37 | const st1 = new SymbolicTensor( 38 | 'float32', [2, 2], null, [], {}, 'TestSymbolicTensor'); 39 | const st2 = new SymbolicTensor( 40 | 'float32', [2, 2], null, [], {}, 'TestSymbolicTensor'); 41 | expect(st1.name.indexOf('TestSymbolicTensor')).toEqual(0); 42 | expect(st2.name.indexOf('TestSymbolicTensor')).toEqual(0); 43 | // Explicit names of symbolic tensors should be unique. 44 | expect(st1 === st2).toBe(false); 45 | 46 | expect(st1.id).toBeGreaterThanOrEqual(0); 47 | expect(st2.id).toBeGreaterThanOrEqual(0); 48 | expect(st1.id === st2.id).toBe(false); 49 | }); 50 | 51 | it('Invalid tensor name leads to error', () => { 52 | expect(() => new SymbolicTensor('float32', [2, 2], null, [], {}, '!')) 53 | .toThrowError(); 54 | }); 55 | }); 56 | -------------------------------------------------------------------------------- /tfjs-layers/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["tslint-no-circular-imports"], 3 | "rules": { 4 | "array-type": [true, "array-simple"], 5 | "arrow-return-shorthand": true, 6 | "ban": [true, 7 | ["fit"], 8 | ["fdescribe"], 9 | ["xit"], 10 | ["xdescribe"], 11 | ["fitAsync"], 12 | ["xitAsync"], 13 | ["fitFakeAsync"], 14 | ["xitFakeAsync"] 15 | ], 16 | "ban-types": [true, 17 | ["Object", "Use {} instead."], 18 | ["String", "Use 'string' instead."], 19 | ["Number", "Use 'number' instead."], 20 | ["Boolean", "Use 'boolean' instead."] 21 | ], 22 | "class-name": true, 23 | "interface-name": [true, "never-prefix"], 24 | "jsdoc-format": true, 25 | "forin": false, 26 | "label-position": true, 27 | "max-line-length": { 28 | "options": {"limit": 80, "ignore-pattern": "^import |^export |https?://"} 29 | }, 30 | "new-parens": true, 31 | "no-angle-bracket-type-assertion": true, 32 | "no-any": true, 33 | "no-construct": true, 34 | "no-debugger": true, 35 | "no-default-export": true, 36 | "no-inferrable-types": true, 37 | "no-namespace": [true, "allow-declarations"], 38 | "no-reference": true, 39 | "no-require-imports": true, 40 | "no-string-throw": true, 41 | "no-unused-expression": true, 42 | "no-var-keyword": true, 43 | "object-literal-shorthand": true, 44 | "only-arrow-functions": [true, "allow-declarations", "allow-named-functions"], 45 | "prefer-const": true, 46 | "quotemark": [true, "single"], 47 | "radix": true, 48 | "semicolon": [true, "always", "ignore-bound-class-methods"], 49 | "switch-default": true, 50 | "triple-equals": [true, "allow-null-check"], 51 | "use-isnan": true, 52 | "variable-name": [ 53 | true, 54 | "check-format", 55 | "ban-keywords", 56 | "allow-leading-underscore", 57 | "allow-trailing-underscore" 58 | ] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /tfjs-layers/DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | ## Development process 2 | 3 | As a preparatory step, run `yarn` which installs all dev dependencies. 4 | 5 | Before submitting a PR with a change, make sure the following 6 | commands succeed: 7 | * `yarn build` which compiles the project to ES5 Javascript. 8 | * `yarn format` to format your code. 9 | * `yarn lint` to check for linter errors. 10 | * `yarn test` to run unit tests in Chrome and Firefox. Make sure all unit tests pass. 11 | 12 | When you send a PR, the above commands will also run on [Cloud Build](https://pantheon.corp.google.com/cloud-build/builds?organizationId=433637338589&project=learnjs-174218) 13 | and show up as Github checks. If you see Cloud Build failing, click on the `Details` 14 | link next to the check to open the log. 15 | 16 | ## Changing @tensorflow/tfjs-layers and testing @tensorflow/tfjs 17 | 18 | Often we want to make a change in `tfjs-layers/core` and create a new 19 | `tfjs` package that reflects that change. There is a 3-step initial process to 20 | set this up. The instructions below are for `tfjs-layers`, but they should work 21 | for developing `tfjs-core` if you replace `tfjs-layers` with `tfjs-core`. 22 | 23 | 1. In the `tfjs-layers` repo, run `yarn publish-local`. This builds the 24 | project and publishes a new package in a local registry. 25 | 26 | 2. In the `tfjs` repo, run `yarn link-local @tensorflow/tfjs-layers`. This makes 27 | `tfjs` depend on the locally published `tfjs-layers` package. 28 | 29 | 3. In the `tfjs` repo, run `yarn build-npm` to build a new npm package. 30 | 31 | Every time you make a change in `tfjs-layers`, re-run: 32 | - `yarn publish-local` in the `tfjs-layers` repo 33 | - `yarn build-npm` in the `tfjs` repo to make a new package. 34 | 35 | ## Running integration tests 36 | 37 | ### tfjs2keras 38 | 39 | This is an integration test that checks the models exported by tfjs-layers 40 | can be loaded correctly by Keras in Python. To run this test, do: 41 | 42 | ```sh 43 | yarn tfjs2keras 44 | ``` 45 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/convolutional_depthwise_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ConstraintSerialization} from '../constraint_config'; 12 | import {InitializerSerialization} from '../initializer_config'; 13 | import {RegularizerSerialization} from '../regularizer_config'; 14 | import {BaseLayerSerialization} from '../topology_config'; 15 | import {BaseConvLayerConfig} from './convolutional_serialization'; 16 | 17 | 18 | export interface DepthwiseConv2DLayerConfig extends BaseConvLayerConfig { 19 | kernel_size: number|[number, number]; 20 | depth_multiplier?: number; 21 | depthwise_initializer?: InitializerSerialization; 22 | depthwise_constraint?: ConstraintSerialization; 23 | depthwise_regularizer?: RegularizerSerialization; 24 | } 25 | 26 | // Update depthwiseConv2DLayerClassNames below in concert with this. 27 | export type DepthwiseConv2DLayerSerialization = 28 | BaseLayerSerialization<'DepthwiseConv2D', DepthwiseConv2DLayerConfig>; 29 | 30 | export type ConvolutionalDepthwiseLayerSerialization = 31 | DepthwiseConv2DLayerSerialization; 32 | 33 | export type ConvolutionalDepthwiseLayerClassName = 34 | ConvolutionalDepthwiseLayerSerialization['class_name']; 35 | 36 | // We can't easily extract a string[] from the string union type, but we can 37 | // recapitulate the list, enforcing at compile time that the values are valid. 38 | 39 | /** 40 | * A string array of valid ConvolutionalDepthwiseLayer class names. 41 | * 42 | * This is guaranteed to match the `ConvolutionalDepthwiseLayerClassName` union 43 | * type. 44 | */ 45 | export const convolutionalDepthwiseLayerClassNames: 46 | ConvolutionalDepthwiseLayerClassName[] = [ 47 | 'DepthwiseConv2D', 48 | ]; 49 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/topology_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {DataType} from '@tensorflow/tfjs-core'; 12 | 13 | import {Shape} from './common'; 14 | import {NodeConfig} from './node_config'; 15 | import {BaseSerialization, PyJson, PyJsonDict} from './types'; 16 | 17 | /** Constructor arguments for Layer. */ 18 | export interface LayerConfig extends PyJsonDict { 19 | input_shape?: Shape; 20 | batch_input_shape?: Shape; 21 | batch_size?: number; 22 | dtype?: DataType; 23 | name?: string; 24 | trainable?: boolean; 25 | input_dtype?: DataType; 26 | } 27 | 28 | /** 29 | * Converts a subtype of `LayerConfig` to a variant with restricted keys. 30 | * 31 | * This is a bit tricky because `keyof` obtains only local fields, not inherited 32 | * fields. Thus, this type combines the keys from the `LayerConfig` supertype 33 | * with those of the specific subtype. 34 | * 35 | * See ./types.ts for an explanation of the PyJson type. 36 | */ 37 | export type JsonLayer = C&LayerConfig& 38 | PyJson|Extract>; 39 | 40 | /** 41 | * A Keras JSON entry representing a layer. 42 | * 43 | * The Keras JSON convention is to provide the `class_name` (i.e., the layer 44 | * type) at the top level, and then to place the layer-specific configuration in 45 | * a `config` subtree. These layer-specific configurations are provided by 46 | * subtypes of `LayerConfig`. Thus, this `*Serialization` has a type parameter 47 | * giving the specific type of the wrapped `LayerConfig`. 48 | */ 49 | export interface BaseLayerSerialization 50 | extends BaseSerialization> { 51 | name: string; 52 | inbound_nodes?: NodeConfig[]; 53 | } 54 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/normalization_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ConstraintSerialization} from '../constraint_config'; 12 | import {InitializerSerialization} from '../initializer_config'; 13 | import {RegularizerSerialization} from '../regularizer_config'; 14 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 15 | 16 | export interface BatchNormalizationLayerConfig extends LayerConfig { 17 | axis?: number; 18 | momentum?: number; 19 | epsilon?: number; 20 | center?: boolean; 21 | scale?: boolean; 22 | beta_initializer?: InitializerSerialization; 23 | gamma_initializer?: InitializerSerialization; 24 | moving_mean_initializer?: InitializerSerialization; 25 | moving_variance_initializer?: InitializerSerialization; 26 | beta_constraint?: ConstraintSerialization; 27 | gamma_constraint?: ConstraintSerialization; 28 | beta_regularizer?: RegularizerSerialization; 29 | gamma_regularizer?: RegularizerSerialization; 30 | } 31 | 32 | // Update batchNormalizationLayerClassNames below in concert with this. 33 | export type BatchNormalizationLayerSerialization = 34 | BaseLayerSerialization<'BatchNormalization', BatchNormalizationLayerConfig>; 35 | 36 | export type NormalizationLayerSerialization = 37 | BatchNormalizationLayerSerialization; 38 | 39 | export type NormalizationLayerClassName = 40 | NormalizationLayerSerialization['class_name']; 41 | 42 | // We can't easily extract a string[] from the string union type, but we can 43 | // recapitulate the list, enforcing at compile time that the values are valid. 44 | 45 | /** 46 | * A string array of valid NormalizationLayer class names. 47 | * 48 | * This is guaranteed to match the `NormalizationLayerClassName` union 49 | * type. 50 | */ 51 | export const normalizationLayerClassNames: NormalizationLayerClassName[] = [ 52 | 'BatchNormalization', 53 | ]; 54 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/node_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {PyJsonDict} from './types'; 12 | 13 | /** 14 | * The unique string name of a Layer. 15 | */ 16 | export type LayerName = string; 17 | 18 | /** 19 | * The index of a Node, identifying a specific invocation of a given Layer. 20 | */ 21 | export type NodeIndex = number; 22 | 23 | /** 24 | * The index of a Tensor output by a given Node of a given Layer. 25 | */ 26 | export type TensorIndex = number; 27 | 28 | /** 29 | * Arguments to the apply(...) method that produced a specific Node. 30 | */ 31 | // tslint:disable-next-line:no-empty-interface 32 | export interface NodeArgs extends PyJsonDict {} 33 | 34 | /** 35 | * A reference to a specific Tensor, given by its Layer name, Node index, and 36 | * output index, including the apply() arguments associated with the Node. 37 | * 38 | * This is used in `NodeConfig` to specify the inputs to each Node. 39 | */ 40 | export type TensorKeyWithArgsArray = 41 | [LayerName, NodeIndex, TensorIndex, NodeArgs]; 42 | 43 | // TODO(soergel): verify behavior of Python Keras; maybe PR to standardize it. 44 | /** 45 | * A reference to a specific Tensor, given by its Layer name, Node index, and 46 | * output index. 47 | * 48 | * This does not include the apply() arguments associated with the Node. It is 49 | * used in the LayersModel config to specify the inputLayers and outputLayers. 50 | * It seems to be an idiosyncrasy of Python Keras that the node arguments are 51 | * not included here. 52 | */ 53 | export type TensorKeyArray = [LayerName, NodeIndex, TensorIndex]; 54 | 55 | /** 56 | * A Keras JSON entry representing a Node, i.e. a specific instance of a Layer. 57 | * 58 | * By Keras JSON convention, a Node is specified as an array of Tensor keys 59 | * (i.e., references to Tensors output by other Layers) providing the inputs to 60 | * this Layer in order. 61 | */ 62 | export type NodeConfig = TensorKeyWithArgsArray[]; 63 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/model_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {LayerSerialization} from './layers/layer_serialization'; 12 | import {TensorKeyArray} from './node_config'; 13 | import {TrainingConfig} from './training_config'; 14 | import {BaseSerialization} from './types'; 15 | 16 | export type ModelConfig = { 17 | name: string, 18 | layers: LayerSerialization[], 19 | input_layers: TensorKeyArray[], 20 | output_layers: TensorKeyArray[], 21 | }; 22 | 23 | /** 24 | * A standard Keras JSON 'Model' configuration. 25 | */ 26 | export interface ModelSerialization extends 27 | BaseSerialization<'Model', ModelConfig> { 28 | backend?: string; 29 | keras_version?: string; 30 | } 31 | 32 | export type SequentialConfig = { 33 | layers: LayerSerialization[] 34 | }; 35 | 36 | /** 37 | * A standard Keras JSON 'Sequential' configuration. 38 | */ 39 | export interface SequentialSerialization extends 40 | BaseSerialization<'Sequential', SequentialConfig> { 41 | backend?: string; 42 | keras_version?: string; 43 | } 44 | 45 | /** 46 | * A legacy Keras JSON 'Sequential' configuration. 47 | * 48 | * It was a bug that Keras Sequential models were recorded with 49 | * model_config.config as an array of layers, instead of a dict containing a 50 | * 'layers' entry. While the bug has been fixed, we still need to be able to 51 | * read this legacy format. 52 | */ 53 | export type LegacySequentialSerialization = { 54 | // Note this cannot extend `BaseSerialization` because of the bug. 55 | class_name: 'Sequential'; 56 | 57 | config: LayerSerialization[]; 58 | backend?: string; 59 | keras_version?: string; 60 | }; 61 | 62 | /** 63 | * Contains the description of a KerasModel, as well as the configuration 64 | * necessary to train that model. 65 | */ 66 | export type KerasFileSerialization = { 67 | // aka ModelTopology? 68 | model_config: ModelSerialization|SequentialSerialization| 69 | LegacySequentialSerialization; 70 | training_config: TrainingConfig; 71 | }; 72 | -------------------------------------------------------------------------------- /tfjs-layers/src/logs.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {dispose, Scalar} from '@tensorflow/tfjs-core'; 12 | 13 | /** 14 | * Logs in which values can be either numbers or Tensors (Scalars). 15 | * 16 | * Used internally. 17 | */ 18 | export type UnresolvedLogs = { 19 | [key: string]: number|Scalar; 20 | }; 21 | 22 | /** 23 | * Turn any Scalar values in a Logs object into actual number values. 24 | * 25 | * @param logs The `Logs` object to be resolved in place. 26 | */ 27 | export async function resolveScalarsInLogs(logs: UnresolvedLogs) { 28 | if (logs == null) { 29 | return; 30 | } 31 | const promises: Array> = []; 32 | const keys: string[] = []; 33 | const scalarsToDispose: Scalar[] = []; 34 | for (const key in logs) { 35 | const value = logs[key]; 36 | if (typeof value !== 'number') { 37 | const valueScalar = value as Scalar; 38 | promises.push(valueScalar.data()); 39 | keys.push(key); 40 | scalarsToDispose.push(valueScalar); 41 | } 42 | } 43 | if (promises.length > 0) { 44 | const values = await Promise.all(promises); 45 | for (let i = 0; i < values.length; ++i) { 46 | logs[keys[i]] = values[i][0]; 47 | } 48 | // Dispose the original scalar tensors. 49 | dispose(scalarsToDispose); 50 | } 51 | } 52 | 53 | /** 54 | * Dispose all Tensors in an UnresolvedLogs object. 55 | * 56 | * @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in 57 | * places where the values can be `tf.Tensor` or `number`. 58 | */ 59 | export function disposeTensorsInLogs(logs: UnresolvedLogs) { 60 | if (logs == null) { 61 | return; 62 | } 63 | for (const key in logs) { 64 | const value = logs[key]; 65 | if (typeof value !== 'number') { 66 | value.dispose(); 67 | } 68 | } 69 | } 70 | 71 | /** 72 | * Logs in which values can only be numbers. 73 | * 74 | * Used when calling client-provided custom callbacks. 75 | */ 76 | export type Logs = { 77 | [key: string]: number; 78 | }; 79 | -------------------------------------------------------------------------------- /tfjs-layers/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@tensorflow/tfjs-layers", 3 | "version": "1.2.7", 4 | "description": "TensorFlow layers API in JavaScript", 5 | "license": "Apache-2.0 AND MIT", 6 | "private": false, 7 | "main": "dist/index.js", 8 | "types": "dist/index.d.ts", 9 | "jsnext:main": "dist/tf-layers.esm.js", 10 | "module": "dist/tf-layers.esm.js", 11 | "jsdelivr": "dist/tf-layers.min.js", 12 | "unpkg": "dist/tf-layers.min.js", 13 | "miniprogram": "dist/miniprogram", 14 | "devDependencies": { 15 | "@tensorflow/tfjs-core": "1.2.8", 16 | "@types/jasmine": "~2.5.53", 17 | "clang-format": "~1.2.2", 18 | "http-server": "~0.10.0", 19 | "jasmine-core": "~3.1.0", 20 | "karma": "~4.2.0", 21 | "karma-browserstack-launcher": "~1.4.0", 22 | "karma-chrome-launcher": "~2.2.0", 23 | "karma-firefox-launcher": "~1.1.0", 24 | "karma-jasmine": "~1.1.1", 25 | "karma-typescript": "~4.0.0", 26 | "rimraf": "~2.6.2", 27 | "rollup": "^0.58.2", 28 | "rollup-plugin-commonjs": "9.1.3", 29 | "rollup-plugin-node-resolve": "3.3.0", 30 | "rollup-plugin-typescript2": "0.13.0", 31 | "rollup-plugin-uglify": "~3.0.0", 32 | "ts-node": "7.0.0", 33 | "tslint": "~5.11.0", 34 | "tslint-no-circular-imports": "^0.5.0", 35 | "typescript": "3.3.3333", 36 | "yalc": "~1.0.0-pre.21" 37 | }, 38 | "scripts": { 39 | "prep": "yarn install && yarn build", 40 | "build": "tsc", 41 | "build-npm": "./scripts/build-npm.sh", 42 | "format": "./tools/clang_format_ts.sh", 43 | "publish-npm": "./scripts/publish-npm.sh", 44 | "link-local": "yalc link", 45 | "publish-local": "yarn build-npm && yalc push", 46 | "test": "karma start", 47 | "tfjs2keras": "yarn tfjs2keras-js && yarn tfjs2keras-py --stable && yarn tfjs2keras-py --stable --tfkeras && yarn tfjs2keras-py --dev --tfkeras", 48 | "tfjs2keras-js": "./scripts/tfjs2keras-js.sh", 49 | "tfjs2keras-py": "./scripts/tfjs2keras-py.sh", 50 | "test-ci": "./scripts/test-ci.sh", 51 | "test-snippets": "ts-node ./scripts/test_snippets.ts", 52 | "run-browserstack": "karma start --browsers='bs_firefox_mac,bs_chrome_mac' --singleRun --reporters='dots,karma-typescript'", 53 | "lint": "tslint -p . -t verbose" 54 | }, 55 | "peerDependencies": { 56 | "@tensorflow/tfjs-core": "1.2.8" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/merge_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 12 | 13 | 14 | 15 | export type AddLayerSerialization = BaseLayerSerialization<'Add', LayerConfig>; 16 | 17 | export type MultiplyLayerSerialization = 18 | BaseLayerSerialization<'Multiply', LayerConfig>; 19 | 20 | export type AverageLayerSerialization = 21 | BaseLayerSerialization<'Average', LayerConfig>; 22 | 23 | export type MaximumLayerSerialization = 24 | BaseLayerSerialization<'Maximum', LayerConfig>; 25 | 26 | export type MinimumLayerSerialization = 27 | BaseLayerSerialization<'Minimum', LayerConfig>; 28 | 29 | export interface ConcatenateLayerConfig extends LayerConfig { 30 | axis?: number; 31 | } 32 | 33 | export type ConcatenateLayerSerialization = 34 | BaseLayerSerialization<'Concatenate', ConcatenateLayerConfig>; 35 | 36 | export interface DotLayerConfig extends LayerConfig { 37 | axes: number|[number, number]; 38 | normalize?: boolean; 39 | } 40 | 41 | export type DotLayerSerialization = 42 | BaseLayerSerialization<'Dot', DotLayerConfig>; 43 | 44 | // Update mergeLayerClassNames below in concert with this. 45 | export type MergeLayerSerialization = 46 | AddLayerSerialization|MultiplyLayerSerialization|AverageLayerSerialization| 47 | MaximumLayerSerialization|MinimumLayerSerialization| 48 | ConcatenateLayerSerialization|DotLayerSerialization; 49 | 50 | export type MergeLayerClassName = MergeLayerSerialization['class_name']; 51 | 52 | // We can't easily extract a string[] from the string union type, but we can 53 | // recapitulate the list, enforcing at compile time that the values are valid. 54 | 55 | /** 56 | * A string array of valid MergeLayer class names. 57 | * 58 | * This is guaranteed to match the `MergeLayerClassName` union type. 59 | */ 60 | export const mergeLayerClassNames: MergeLayerClassName[] = [ 61 | 'Add', 62 | 'Average', 63 | 'Concatenate', 64 | 'Dot', 65 | 'Maximum', 66 | 'Minimum', 67 | 'Multiply', 68 | ]; 69 | -------------------------------------------------------------------------------- /tfjs-layers/tools/clang_format_ts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2018 Google LLC 4 | # 5 | # Use of this source code is governed by an MIT-style 6 | # license that can be found in the LICENSE file or at 7 | # https://opensource.org/licenses/MIT. 8 | # ============================================================================== 9 | 10 | 11 | # This script applies google-style clang format on all TypeScript (.ts) files 12 | # within a certain scope. 13 | # 14 | # Usage examples: 15 | # 1. Format all .ts files touched by this change (unstaged or staged in git). 16 | # clang_format_ts.sh 17 | # 18 | # 2. Format all .ts files under the source tree. 19 | # clang_format_ts.sh -a 20 | # 21 | # 3. Format specific files. 22 | # clang_format_ts.sh src/types.ts 23 | 24 | set -e 25 | 26 | FILE_SCOPE="" 27 | 28 | if [[ "$#" -gt 0 ]]; then 29 | while true; do 30 | if [[ -z "$1" ]]; then 31 | break 32 | fi 33 | if [[ "$1" == "-a" ]]; then 34 | if [[ -z "${FILE_SCOPE}" ]]; then 35 | FILE_SCOPE="__all__" 36 | else 37 | echo "ERROR: -a flag should not be used with file names" 38 | exit 1 39 | fi 40 | else 41 | if [[ "${FILE_SCOPE}" != "__all__" ]]; then 42 | FILE_SCOPE="${FILE_SCOPE} $1" 43 | else 44 | echo "ERROR: -a flag should not be used with file names" 45 | exit 1 46 | fi 47 | fi 48 | shift 49 | done 50 | else 51 | FILE_SCOPE="__touched__" 52 | fi 53 | 54 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 55 | CLANG_FORMAT_PREFIX="${SCRIPT_DIR}/../node_modules/.bin/clang-format -i --verbose --style=google" 56 | if [[ "${FILE_SCOPE}" == "__touched__" ]]; then 57 | TOUCHED_TS_FILES="$(git status --porcelain | grep '.*\.ts$' | sed s/^...//)" 58 | 59 | if [[ -z "${TOUCHED_TS_FILES}" ]]; then 60 | exit 0 61 | else 62 | pushd "${SCRIPT_DIR}/.." > /dev/null 63 | for TS_FILE in ${TOUCHED_TS_FILES}; do 64 | if [[ -f ${TS_FILE} ]]; then 65 | ${CLANG_FORMAT_PREFIX} "${TS_FILE}" 66 | fi 67 | done 68 | popd > /dev/null 69 | fi 70 | elif [[ "${FILE_SCOPE}" == "__all__" ]]; then 71 | ALL_TS_FILES="$(find "${SCRIPT_DIR}/../src" "${SCRIPT_DIR}/../demos" -name '*.ts')" 72 | for TS_FILE in ${ALL_TS_FILES}; do 73 | ${CLANG_FORMAT_PREFIX} "${TS_FILE}" 74 | done 75 | 76 | else 77 | ${CLANG_FORMAT_PREFIX} ${FILE_SCOPE} 78 | fi 79 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/README.md: -------------------------------------------------------------------------------- 1 | TypeScript Interfaces describing the Keras JSON format 2 | ------------------------------------------------------ 3 | 4 | This directory contains a description of the current Keras JSON serialization 5 | format, in the form of TypeScript interfaces. The intent is that any valid 6 | Keras JSON file can be parsed in a type-safe manner using these types. 7 | 8 | The Keras JSON format originated in the Python Keras implementation. The basic 9 | design is that the format mirrors the Python API. Each class instance in a 10 | Python model is serialized as a JSON object containing the class name and its 11 | serialized constructor arguments. 12 | 13 | Here, we provide a type called `*Serialization` to describe the on-disk JSON 14 | representation for each class. It always provides a `class_name` and a `config` 15 | representing the constructor arguments required to reconstruct a given instance. 16 | 17 | The constructor arguments may be primitives, arrays of primitives, or plain 18 | key-value dictionaries, in which case the JSON serialization is straightforward. 19 | 20 | If a constructor argument is another object, then it is represented by a nested 21 | `*Serialization`. This structure is illustrated below: 22 | 23 | FooSerialization { 24 | class_name: 'Foo'; 25 | config: { 26 | bar: string; 27 | baz: number[]; 28 | qux: QuxSerialization; 29 | } 30 | } 31 | 32 | Deserializing such a nested object configuration requires recursively 33 | deserializing any object arguments, and finally calling the top-level 34 | constructor using the reconstructed object arguments. 35 | 36 | In general this means that deserialization is purely tree-like, so instances 37 | cannot be reused. (The deserialization code for Models is an exception to this 38 | principle, because it allows Layers to refer to each other in order to describe 39 | a DAG). 40 | 41 | As a consequence of this design, our deserialization code requires an `*Args` 42 | type mirroring each of the `*Serialization` types here. `*Args` types represent 43 | the actual arguments passed to a constructor, after any nested objects have been 44 | deserialized. For instance, the above `FooSerialization` will be resolved into 45 | `FooArgs` like this: 46 | 47 | FooArgs { 48 | bar: string; 49 | baz: number[]; 50 | qux: Qux; 51 | } 52 | 53 | which can then be passed to the `Foo` constructor. 54 | -------------------------------------------------------------------------------- /tfjs-layers/src/errors.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Explicit error types. 13 | * 14 | * See the following link for more information about why the code includes 15 | * calls to setPrototypeOf: 16 | * 17 | * https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work 18 | */ 19 | // tslint:enable 20 | 21 | /** 22 | * Equivalent of Python's AttributeError. 23 | */ 24 | export class AttributeError extends Error { 25 | constructor(message?: string) { 26 | super(message); 27 | // Set the prototype explicitly. 28 | Object.setPrototypeOf(this, AttributeError.prototype); 29 | } 30 | } 31 | 32 | /** 33 | * Equivalent of Python's RuntimeError. 34 | */ 35 | export class RuntimeError extends Error { 36 | constructor(message?: string) { 37 | super(message); 38 | // Set the prototype explicitly. 39 | Object.setPrototypeOf(this, RuntimeError.prototype); 40 | } 41 | } 42 | 43 | /** 44 | * Equivalent of Python's ValueError. 45 | */ 46 | export class ValueError extends Error { 47 | constructor(message?: string) { 48 | super(message); 49 | // Set the prototype explicitly. 50 | Object.setPrototypeOf(this, ValueError.prototype); 51 | } 52 | } 53 | 54 | /** 55 | * Equivalent of Python's NotImplementedError. 56 | */ 57 | export class NotImplementedError extends Error { 58 | constructor(message?: string) { 59 | super(message); 60 | // Set the prototype explicitly. 61 | Object.setPrototypeOf(this, NotImplementedError.prototype); 62 | } 63 | } 64 | 65 | /** 66 | * Equivalent of Python's AssertionError. 67 | */ 68 | export class AssertionError extends Error { 69 | constructor(message?: string) { 70 | super(message); 71 | // Set the prototype explicitly. 72 | Object.setPrototypeOf(this, AssertionError.prototype); 73 | } 74 | } 75 | 76 | /** 77 | * Equivalent of Python's IndexError. 78 | */ 79 | export class IndexError extends Error { 80 | constructor(message?: string) { 81 | super(message); 82 | // Set the prototype explicitly. 83 | Object.setPrototypeOf(this, IndexError.prototype); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/layer_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {inputLayerClassNames, InputLayerSerialization} from '../input_config'; 12 | import {advancedActivationLayerClassNames, AdvancedActivationLayerSerialization} from './advanced_activation_serialization'; 13 | import {convolutionalDepthwiseLayerClassNames, ConvolutionalDepthwiseLayerSerialization} from './convolutional_depthwise_serialization'; 14 | import {convolutionalLayerClassNames, ConvolutionalLayerSerialization} from './convolutional_serialization'; 15 | import {coreLayerClassNames, CoreLayerSerialization} from './core_serialization'; 16 | import {embeddingLayerClassNames, EmbeddingLayerSerialization} from './embeddings_serialization'; 17 | import {mergeLayerClassNames, MergeLayerSerialization} from './merge_serialization'; 18 | import {normalizationLayerClassNames, NormalizationLayerSerialization} from './normalization_serialization'; 19 | import {paddingLayerClassNames, PaddingLayerSerialization} from './padding_serialization'; 20 | import {poolingLayerClassNames, PoolingLayerSerialization} from './pooling_serialization'; 21 | import {recurrentLayerClassNames, RecurrentLayerSerialization} from './recurrent_serialization'; 22 | 23 | 24 | export type LayerSerialization = AdvancedActivationLayerSerialization| 25 | ConvolutionalDepthwiseLayerSerialization|ConvolutionalLayerSerialization| 26 | CoreLayerSerialization|EmbeddingLayerSerialization|MergeLayerSerialization| 27 | NormalizationLayerSerialization|PaddingLayerSerialization| 28 | PoolingLayerSerialization|RecurrentLayerSerialization| 29 | InputLayerSerialization; 30 | 31 | export type LayerClassName = LayerSerialization['class_name']; 32 | 33 | /** 34 | * A string array of valid Layer class names. 35 | * 36 | * This is guaranteed to match the `LayerClassName` union type. 37 | */ 38 | export const layerClassNames: LayerClassName[] = [ 39 | ...advancedActivationLayerClassNames, 40 | ...convolutionalDepthwiseLayerClassNames, ...convolutionalLayerClassNames, 41 | ...coreLayerClassNames, ...embeddingLayerClassNames, ...mergeLayerClassNames, 42 | ...normalizationLayerClassNames, ...paddingLayerClassNames, 43 | ...poolingLayerClassNames, ...recurrentLayerClassNames, 44 | ...inputLayerClassNames 45 | ]; 46 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/types_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import * as types_utils from './types_utils'; 12 | 13 | describe('isArrayOfShapes', () => { 14 | it('returns false for a single non-empty shape', () => { 15 | expect(types_utils.isArrayOfShapes([1, 2, 3])).toEqual(false); 16 | }); 17 | it('returns false for a single empty shape', () => { 18 | expect(types_utils.isArrayOfShapes([])).toEqual(false); 19 | }); 20 | it('returns true for an array of shapes', () => { 21 | expect(types_utils.isArrayOfShapes([[1], [2, 3]])).toEqual(true); 22 | }); 23 | it('returns true for an array of shapes that includes empty shapes', () => { 24 | expect(types_utils.isArrayOfShapes([[], [2, 3]])).toEqual(true); 25 | expect(types_utils.isArrayOfShapes([[]])).toEqual(true); 26 | expect(types_utils.isArrayOfShapes([[], []])).toEqual(true); 27 | }); 28 | }); 29 | 30 | describe('normalizeShapeList', () => { 31 | it('returns an empty list if an empty list is passed in.', () => { 32 | expect(types_utils.normalizeShapeList([])).toEqual([]); 33 | }); 34 | 35 | it('returns a list of shapes if a single shape is passed in.', () => { 36 | expect(types_utils.normalizeShapeList([1])).toEqual([[1]]); 37 | }); 38 | 39 | it('returns a list of shapes if an empty shape is passed in.', () => { 40 | expect(types_utils.normalizeShapeList([[]])).toEqual([[]]); 41 | }); 42 | 43 | it('returns a list of shapes if a list of shapes is passed in.', () => { 44 | expect(types_utils.normalizeShapeList([[1]])).toEqual([[1]]); 45 | }); 46 | }); 47 | 48 | describe('getExactlyOneShape', () => { 49 | it('single instance', () => { 50 | expect(types_utils.getExactlyOneShape([1, 2, 3])).toEqual([1, 2, 3]); 51 | expect(types_utils.getExactlyOneShape([null, 8])).toEqual([null, 8]); 52 | expect(types_utils.getExactlyOneShape([])).toEqual([]); 53 | }); 54 | it('Array of length 1', () => { 55 | expect(types_utils.getExactlyOneShape([[1, 2]])).toEqual([1, 2]); 56 | expect(types_utils.getExactlyOneShape([[]])).toEqual([]); 57 | }); 58 | it('Array of length 2: ValueError', () => { 59 | expect(() => types_utils.getExactlyOneShape([ 60 | [1], [2] 61 | ])).toThrowError(/Expected exactly 1 Shape; got 2/); 62 | }); 63 | }); 64 | -------------------------------------------------------------------------------- /tfjs-layers/src/types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** Defines allowable data types for tensors. */ 12 | 13 | import {NamedTensorMap, Scalar, Tensor} from '@tensorflow/tfjs-core'; 14 | 15 | import {Shape} from './keras_format/common'; 16 | 17 | export type HasShape = { 18 | shape: Shape; 19 | }; 20 | 21 | /** 22 | * Type for loss a metric function. 23 | * 24 | * Takes a true value and a predicted value, and returns a loss or metric value. 25 | */ 26 | export type LossOrMetricFn = (yTrue: Tensor, yPred: Tensor) => Tensor; 27 | 28 | /** 29 | * Type for a regularizer function. 30 | */ 31 | export type RegularizerFn = () => Scalar; 32 | 33 | /* 34 | * The type for an RNN step function. 35 | * The arguments are: 36 | * - inputs: Input data, with shape `[sapmles, ...]`. 37 | * The return values are: 38 | * - outputs: tensor with shape `[samples, outputDim]` (no time dimension). 39 | * - newStates: Array of tensors. The `Array` has the same length as `states` 40 | * in the input arguments. Each `tf.Tensor` has the same shape as the 41 | * corresponding element in `states`. 42 | */ 43 | export type RnnStepFunction = 44 | (inputs: Tensor, states: Tensor[]) => [Tensor, Tensor[]]; 45 | 46 | /** 47 | * A single Tensor or a non-nested collection of Tensors. 48 | * 49 | * An object of this type can always be reduced to `Tensor[]`. A single 50 | * 'Tensor' becomes `[Tensor]`. A `Tensor[]` is unchanged. A `NamedTensorMap` 51 | * can be converted with the help of a list of names, providing the order in 52 | * which the Tensors should appear in the resulting array. 53 | */ 54 | export type TensorOrArrayOrMap = Tensor|Tensor[]|NamedTensorMap; 55 | 56 | /** 57 | * Type representing a loosely-typed bundle of keyword arguments. 58 | * 59 | * This is a looser type than PyJsonDict/serialization.ConfigDict as it 60 | * can contain arbitrary objects as its values. It is most appropriate 61 | * for functions that pass through keyword arguments to other functions 62 | * without knowledge of the structure. If the function can place type 63 | * restrictions on the keyword arguments, it should via the Config 64 | * interface convention used throughout. 65 | */ 66 | export type Kwargs = { 67 | // tslint:disable-next-line:no-any 68 | [key: string]: any 69 | }; 70 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/types_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /* Original source: utils/generic_utils.py */ 12 | 13 | import {Tensor} from '@tensorflow/tfjs-core'; 14 | import {ValueError} from '../errors'; 15 | import {Shape} from '../keras_format/common'; 16 | // tslint:enable 17 | 18 | 19 | /** 20 | * Determine whether the input is an Array of Shapes. 21 | */ 22 | export function isArrayOfShapes(x: Shape|Shape[]): boolean { 23 | return Array.isArray(x) && Array.isArray(x[0]); 24 | } 25 | 26 | /** 27 | * Special case of normalizing shapes to lists. 28 | * 29 | * @param x A shape or list of shapes to normalize into a list of Shapes. 30 | * @return A list of Shapes. 31 | */ 32 | export function normalizeShapeList(x: Shape|Shape[]): Shape[] { 33 | if (x.length === 0) { 34 | return []; 35 | } 36 | if (!Array.isArray(x[0])) { 37 | return [x] as Shape[]; 38 | } 39 | return x as Shape[]; 40 | } 41 | 42 | /** 43 | * Helper function to obtain exactly one Tensor. 44 | * @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s. 45 | * @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one. 46 | * @throws ValueError: If `xs` is an `Array` and its length is not 1. 47 | */ 48 | export function getExactlyOneTensor(xs: Tensor|Tensor[]): Tensor { 49 | let x: Tensor; 50 | if (Array.isArray(xs)) { 51 | if (xs.length !== 1) { 52 | throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`); 53 | } 54 | x = xs[0]; 55 | } else { 56 | x = xs as Tensor; 57 | } 58 | return x; 59 | } 60 | 61 | /** 62 | * Helper function to obtain exactly on instance of Shape. 63 | * 64 | * @param shapes Input single `Shape` or Array of `Shape`s. 65 | * @returns If input is a single `Shape`, return it unchanged. If the input is 66 | * an `Array` containing exactly one instance of `Shape`, return the instance. 67 | * Otherwise, throw a `ValueError`. 68 | * @throws ValueError: If input is an `Array` of `Shape`s, and its length is not 69 | * 1. 70 | */ 71 | export function getExactlyOneShape(shapes: Shape|Shape[]): Shape { 72 | if (Array.isArray(shapes) && Array.isArray(shapes[0])) { 73 | if (shapes.length === 1) { 74 | shapes = shapes as Shape[]; 75 | return shapes[0]; 76 | } else { 77 | throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`); 78 | } 79 | } else { 80 | return shapes as Shape; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /tfjs-layers/karma.conf.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const karmaTypescriptConfig = { 19 | tsconfig: 'tsconfig.json', 20 | // Disable coverage reports and instrumentation by default for tests 21 | coverageOptions: {instrumentation: false}, 22 | reports: {} 23 | }; 24 | 25 | // Enable coverage reports and instrumentation under KARMA_COVERAGE=1 env 26 | const coverageEnabled = !!process.env.KARMA_COVERAGE; 27 | if (coverageEnabled) { 28 | karmaTypescriptConfig.coverageOptions.instrumentation = true; 29 | karmaTypescriptConfig.coverageOptions.exclude = /_test\.ts$/; 30 | karmaTypescriptConfig.reports = {html: 'coverage', 'text-summary': ''}; 31 | } 32 | 33 | module.exports = function(config) { 34 | config.set({ 35 | frameworks: ['jasmine', 'karma-typescript'], 36 | files: [{pattern: 'src/**/*.ts'}], 37 | preprocessors: { 38 | '**/*.ts': ['karma-typescript'], 39 | }, 40 | karmaTypescriptConfig, 41 | reporters: ['progress', 'karma-typescript'], 42 | browsers: ['Chrome'], 43 | browserStack: { 44 | username: process.env.BROWSERSTACK_USERNAME, 45 | accessKey: process.env.BROWSERSTACK_KEY 46 | }, 47 | reportSlowerThan: 500, 48 | browserNoActivityTimeout: 30000, 49 | customLaunchers: { 50 | bs_chrome_mac: { 51 | base: 'BrowserStack', 52 | browser: 'chrome', 53 | browser_version: 'latest', 54 | os: 'OS X', 55 | os_version: 'High Sierra' 56 | }, 57 | bs_firefox_mac: { 58 | base: 'BrowserStack', 59 | browser: 'firefox', 60 | // TODO(cais): Change to latest after browser stack infrastructure 61 | // stabilizes. https://github.com/tensorflow/tfjs/issues/1620 62 | browser_version: '66.0', 63 | os: 'OS X', 64 | os_version: 'High Sierra' 65 | }, 66 | chrome_with_swift_shader: { 67 | base: 'Chrome', 68 | flags: ['--blacklist-accelerated-compositing', '--blacklist-webgl'] 69 | } 70 | }, 71 | client: { 72 | jasmine: { 73 | random: false 74 | }, 75 | args: ['--grep', config.grep || ''] 76 | } 77 | }); 78 | }; 79 | -------------------------------------------------------------------------------- /tfjs-layers/src/optimizers_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Unit tests for optimizers.ts. 13 | */ 14 | 15 | import {AdagradOptimizer, AdadeltaOptimizer, AdamOptimizer, AdamaxOptimizer, RMSPropOptimizer, SGDOptimizer} from '@tensorflow/tfjs-core'; 16 | 17 | import {getOptimizer} from './optimizers'; 18 | import {describeMathCPU} from './utils/test_utils'; 19 | 20 | 21 | describeMathCPU('getOptimizer', () => { 22 | // TODO(nsthorat): Assert defaults by getting config from the optimizer. 23 | 24 | it(`can instantiate SGD`, () => { 25 | const optimizer = getOptimizer('SGD'); 26 | expect(optimizer instanceof SGDOptimizer).toBe(true); 27 | }); 28 | it(`can instantiate sgd`, () => { 29 | const optimizer = getOptimizer('sgd'); 30 | expect(optimizer instanceof SGDOptimizer).toBe(true); 31 | }); 32 | it(`can instantiate Adam`, () => { 33 | const optimizer = getOptimizer('Adam'); 34 | expect(optimizer instanceof AdamOptimizer).toBe(true); 35 | }); 36 | it(`can instantiate adam`, () => { 37 | const optimizer = getOptimizer('adam'); 38 | expect(optimizer instanceof AdamOptimizer).toBe(true); 39 | }); 40 | it(`can instantiate RMSProp`, () => { 41 | const optimizer = getOptimizer('RMSProp'); 42 | expect(optimizer instanceof RMSPropOptimizer).toBe(true); 43 | }); 44 | it(`can instantiate rmsprop`, () => { 45 | const optimizer = getOptimizer('rmsprop'); 46 | expect(optimizer instanceof RMSPropOptimizer).toBe(true); 47 | }); 48 | it(`can instantiate Adagrad`, () => { 49 | const optimizer = getOptimizer('Adagrad'); 50 | expect(optimizer instanceof AdagradOptimizer).toBe(true); 51 | }); 52 | it(`can instantiate adagrad`, () => { 53 | const optimizer = getOptimizer('adagrad'); 54 | expect(optimizer instanceof AdagradOptimizer).toBe(true); 55 | }); 56 | it(`can instantiate Adadelta`, () => { 57 | const optimizer = getOptimizer('Adadelta'); 58 | expect(optimizer instanceof AdadeltaOptimizer).toBe(true); 59 | }); 60 | it(`can instantiate adadelta`, () => { 61 | const optimizer = getOptimizer('adadelta'); 62 | expect(optimizer instanceof AdadeltaOptimizer).toBe(true); 63 | }); 64 | it(`can instantiate Adamax`, () => { 65 | const optimizer = getOptimizer('Adamax'); 66 | expect(optimizer instanceof AdamaxOptimizer).toBe(true); 67 | }); 68 | it(`can instantiate adamax`, () => { 69 | const optimizer = getOptimizer('adamax'); 70 | expect(optimizer instanceof AdamaxOptimizer).toBe(true); 71 | }); 72 | it('throws for non-existent optimizer', () => { 73 | expect(() => getOptimizer('not an optimizer')) 74 | .toThrowError(/Unknown Optimizer/); 75 | }); 76 | 77 | }); 78 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/conv_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ValueError} from '../errors'; 12 | import {PaddingMode} from '../keras_format/common'; 13 | 14 | import {pyListRepeat} from './generic_utils'; 15 | import {isInteger, max} from './math_utils'; 16 | 17 | /** 18 | * Transforms a single number of array of numbers into an array of numbers. 19 | * @param value 20 | * @param n: The size of the tuple to be returned. 21 | * @param name: Name of the parameter, used for generating error messages. 22 | * @returns An array of numbers. 23 | */ 24 | export function normalizeArray( 25 | value: number|number[], n: number, name: string): number[] { 26 | if (typeof value === 'number') { 27 | return pyListRepeat(value, n); 28 | } else { 29 | if (value.length !== n) { 30 | throw new ValueError( 31 | `The ${name} argument must be an integer or tuple of ${n} integers.` + 32 | ` Received: ${value.length} elements.`); 33 | } 34 | for (let i = 0; i < n; ++i) { 35 | const singleValue = value[i]; 36 | if (!isInteger(singleValue)) { 37 | throw new ValueError( 38 | `The ${name} argument must be an integer or tuple of ${n}` + 39 | ` integers. Received: ${JSON.stringify(value)} including a` + 40 | ` non-integer number ${singleValue}`); 41 | } 42 | } 43 | return value; 44 | } 45 | } 46 | 47 | /** 48 | * Determines output length of a convolution given input length. 49 | * @param inputLength 50 | * @param filterSize 51 | * @param padding 52 | * @param stride 53 | * @param dilation: dilation rate. 54 | */ 55 | export function convOutputLength( 56 | inputLength: number, filterSize: number, padding: PaddingMode, 57 | stride: number, dilation = 1): number { 58 | if (inputLength == null) { 59 | return inputLength; 60 | } 61 | const dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1); 62 | let outputLength: number; 63 | if (padding === 'same') { 64 | outputLength = inputLength; 65 | } else { // VALID 66 | outputLength = inputLength - dilatedFilterSize + 1; 67 | } 68 | return Math.floor((outputLength + stride - 1) / stride); 69 | } 70 | 71 | export function deconvLength( 72 | dimSize: number, strideSize: number, kernelSize: number, 73 | padding: PaddingMode): number { 74 | if (dimSize == null) { 75 | return null; 76 | } 77 | 78 | if (padding === 'valid') { 79 | dimSize = dimSize * strideSize + max([kernelSize - strideSize, 0]); 80 | } else if (padding === 'same') { 81 | dimSize = dimSize * strideSize; 82 | } else { 83 | throw new ValueError(`Unsupport padding mode: ${padding}.`); 84 | } 85 | return dimSize; 86 | } 87 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/advanced_activation_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ConstraintSerialization} from '../constraint_config'; 12 | import {InitializerSerialization} from '../initializer_config'; 13 | import {RegularizerSerialization} from '../regularizer_config'; 14 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 15 | 16 | export interface ReLULayerConfig extends LayerConfig { 17 | max_value?: number; 18 | } 19 | 20 | export type ReLULayerSerialization = 21 | BaseLayerSerialization<'ReLU', ReLULayerConfig>; 22 | 23 | export interface LeakyReLULayerConfig extends LayerConfig { 24 | alpha?: number; 25 | } 26 | 27 | export type LeakyReLULayerSerialization = 28 | BaseLayerSerialization<'LeakyReLU', LeakyReLULayerConfig>; 29 | 30 | export interface PReLULayerConfig extends LayerConfig { 31 | alpha_initializer?: InitializerSerialization; 32 | alpha_regularizer?: RegularizerSerialization; 33 | alpha_constraint?: ConstraintSerialization; 34 | shared_axes?: number|number[]; 35 | } 36 | 37 | export type PReLULayerSerialization = 38 | BaseLayerSerialization<'PReLU', PReLULayerConfig>; 39 | 40 | export interface ELULayerConfig extends LayerConfig { 41 | alpha?: number; 42 | } 43 | 44 | export type ELULayerSerialization = 45 | BaseLayerSerialization<'ELU', ELULayerConfig>; 46 | 47 | export interface ThresholdedReLULayerConfig extends LayerConfig { 48 | theta?: number; 49 | } 50 | 51 | export type ThresholdedReLULayerSerialization = 52 | BaseLayerSerialization<'ThresholdedReLU', ThresholdedReLULayerConfig>; 53 | 54 | export interface SoftmaxLayerConfig extends LayerConfig { 55 | axis?: number; 56 | } 57 | 58 | export type SoftmaxLayerSerialization = 59 | BaseLayerSerialization<'Softmax', SoftmaxLayerConfig>; 60 | 61 | // Update advancedActivationLayerClassNames below in concert with this. 62 | export type AdvancedActivationLayerSerialization = ReLULayerSerialization| 63 | LeakyReLULayerSerialization|PReLULayerSerialization|ELULayerSerialization| 64 | ThresholdedReLULayerSerialization|SoftmaxLayerSerialization; 65 | 66 | export type AdvancedActivationLayerClassName = 67 | AdvancedActivationLayerSerialization['class_name']; 68 | 69 | // We can't easily extract a string[] from the string union type, but we can 70 | // recapitulate the list, enforcing at compile time that the values are valid. 71 | 72 | /** 73 | * A string array of valid AdvancedActivationLayer class names. 74 | * 75 | * This is guaranteed to match the `AdvancedActivationLayerClassName` union 76 | * type. 77 | */ 78 | export const advancedActivationLayerClassNames: 79 | AdvancedActivationLayerClassName[] = [ 80 | 'ReLU', 81 | 'LeakyReLU', 82 | 'PReLU', 83 | 'ELU', 84 | 'ThresholdedReLU', 85 | 'Softmax', 86 | ]; 87 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/pooling_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {DataFormatSerialization, PaddingMode} from '../common'; 12 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 13 | 14 | 15 | export interface Pooling1DLayerConfig extends LayerConfig { 16 | pool_size?: [number]; 17 | strides?: [number]; 18 | padding?: PaddingMode; 19 | } 20 | 21 | export type MaxPooling1DLayerSerialization = 22 | BaseLayerSerialization<'MaxPooling1D', Pooling1DLayerConfig>; 23 | 24 | export type AveragePooling1DLayerSerialization = 25 | BaseLayerSerialization<'AveragePooling1D', Pooling1DLayerConfig>; 26 | 27 | export interface Pooling2DLayerConfig extends LayerConfig { 28 | pool_size?: number|[number, number]; 29 | strides?: number|[number, number]; 30 | padding?: PaddingMode; 31 | data_format?: DataFormatSerialization; 32 | } 33 | 34 | export type MaxPooling2DLayerSerialization = 35 | BaseLayerSerialization<'MaxPooling2D', Pooling2DLayerConfig>; 36 | 37 | export type AveragePooling2DLayerSerialization = 38 | BaseLayerSerialization<'AveragePooling2D', Pooling2DLayerConfig>; 39 | 40 | export type GlobalAveragePooling1DLayerSerialization = 41 | BaseLayerSerialization<'GlobalAveragePooling1D', LayerConfig>; 42 | 43 | export type GlobalMaxPooling1DLayerSerialization = 44 | BaseLayerSerialization<'GlobalMaxPooling1D', LayerConfig>; 45 | 46 | export interface GlobalPooling2DLayerConfig extends LayerConfig { 47 | data_format?: DataFormatSerialization; 48 | } 49 | 50 | export type GlobalAveragePooling2DLayerSerialization = BaseLayerSerialization< 51 | 'GlobalAveragePooling2D', GlobalPooling2DLayerConfig>; 52 | 53 | export type GlobalMaxPooling2DLayerSerialization = 54 | BaseLayerSerialization<'GlobalMaxPooling2D', GlobalPooling2DLayerConfig>; 55 | 56 | // Update poolingLayerClassNames below in concert with this. 57 | export type PoolingLayerSerialization = MaxPooling1DLayerSerialization| 58 | AveragePooling1DLayerSerialization|MaxPooling2DLayerSerialization| 59 | AveragePooling2DLayerSerialization|GlobalAveragePooling1DLayerSerialization| 60 | GlobalMaxPooling1DLayerSerialization| 61 | GlobalAveragePooling2DLayerSerialization| 62 | GlobalMaxPooling2DLayerSerialization; 63 | 64 | export type PoolingLayerClassName = PoolingLayerSerialization['class_name']; 65 | 66 | // We can't easily extract a string[] from the string union type, but we can 67 | // recapitulate the list, enforcing at compile time that the values are valid. 68 | 69 | /** 70 | * A string array of valid PoolingLayer class names. 71 | * 72 | * This is guaranteed to match the `PoolingLayerClassName` union type. 73 | */ 74 | export const poolingLayerClassNames: PoolingLayerClassName[] = [ 75 | 'AveragePooling1D', 76 | 'AveragePooling2D', 77 | 'GlobalAveragePooling1D', 78 | 'GlobalAveragePooling2D', 79 | 'GlobalMaxPooling1D', 80 | 'GlobalMaxPooling2D', 81 | 'MaxPooling1D', 82 | 'MaxPooling2D', 83 | ]; 84 | -------------------------------------------------------------------------------- /tfjs-layers/src/regularizers_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /* Unit tests for constraints */ 12 | 13 | import {scalar, serialization, Tensor, tensor1d} from '@tensorflow/tfjs-core'; 14 | 15 | import * as tfl from './index'; 16 | import {deserializeRegularizer, getRegularizer, serializeRegularizer} from './regularizers'; 17 | import {describeMathCPU, expectTensorsClose} from './utils/test_utils'; 18 | 19 | 20 | describeMathCPU('Built-in Regularizers', () => { 21 | it('l1_l2', () => { 22 | const x = tensor1d([1, -2, 3, -4]); 23 | const regularizer = tfl.regularizers.l1l2(); 24 | const score = regularizer.apply(x); 25 | expectTensorsClose( 26 | score, scalar(0.01 * (1 + 2 + 3 + 4) + 0.01 * (1 + 4 + 9 + 16))); 27 | }); 28 | it('l1', () => { 29 | const x = tensor1d([1, -2, 3, -4]); 30 | const regularizer = tfl.regularizers.l1(); 31 | const score = regularizer.apply(x); 32 | expectTensorsClose(score, scalar(0.01 * (1 + 2 + 3 + 4))); 33 | }); 34 | it('l2', () => { 35 | const x = tensor1d([1, -2, 3, -4]); 36 | const regularizer = tfl.regularizers.l2(); 37 | const score = regularizer.apply(x); 38 | expectTensorsClose(score, scalar(0.01 * (1 + 4 + 9 + 16))); 39 | }); 40 | it('l1_l2 non default', () => { 41 | const x = tensor1d([1, -2, 3, -4]); 42 | const regularizer = tfl.regularizers.l1l2({l1: 1, l2: 2}); 43 | const score = regularizer.apply(x); 44 | expectTensorsClose( 45 | score, scalar(1 * (1 + 2 + 3 + 4) + 2 * (1 + 4 + 9 + 16))); 46 | }); 47 | }); 48 | 49 | describeMathCPU('regularizers.get', () => { 50 | let x: Tensor; 51 | beforeEach(() => { 52 | x = tensor1d([1, -2, 3, -4]); 53 | }); 54 | 55 | it('by string - lower camel', () => { 56 | const regularizer = getRegularizer('l1l2'); 57 | expectTensorsClose(regularizer.apply(x), tfl.regularizers.l1l2().apply(x)); 58 | }); 59 | it('by string - upper camel', () => { 60 | const regularizer = getRegularizer('L1L2'); 61 | expectTensorsClose(regularizer.apply(x), tfl.regularizers.l1l2().apply(x)); 62 | }); 63 | 64 | it('by existing object', () => { 65 | const origReg = tfl.regularizers.l1l2({l1: 1, l2: 2}); 66 | const regularizer = getRegularizer(origReg); 67 | expect(regularizer).toEqual(origReg); 68 | }); 69 | it('by config dict', () => { 70 | const origReg = tfl.regularizers.l1l2({l1: 1, l2: 2}); 71 | const regularizer = getRegularizer( 72 | serializeRegularizer(origReg) as serialization.ConfigDict); 73 | expectTensorsClose(regularizer.apply(x), origReg.apply(x)); 74 | }); 75 | }); 76 | 77 | describeMathCPU('Regularizer Serialization', () => { 78 | it('Built-ins', () => { 79 | const regularizer = tfl.regularizers.l1l2({l1: 1, l2: 2}); 80 | const config = 81 | serializeRegularizer(regularizer) as serialization.ConfigDict; 82 | const reconstituted = deserializeRegularizer(config); 83 | const roundTripConfig = 84 | serializeRegularizer(reconstituted) as serialization.ConfigDict; 85 | expect(roundTripConfig.className).toEqual('L1L2'); 86 | const nestedConfig = roundTripConfig.config as serialization.ConfigDict; 87 | expect(nestedConfig.l1).toEqual(1); 88 | expect(nestedConfig.l2).toEqual(2); 89 | }); 90 | }); 91 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/optimizer_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BaseSerialization} from './types'; 12 | 13 | // Because of the limitations in the current Keras spec, there is no clear 14 | // definition of what may or may not be the configuration of an optimizer. 15 | // 16 | // For now we'll represent the ones available in TF.js--but it will take more 17 | // thought to get this right in a cross-platform way. 18 | // 19 | // See internal issue: b/121033602 20 | 21 | // TODO(soergel): This is a stopgap that needs further thought. 22 | // Does it belong here? 23 | // Does it belong in tfjs-core? 24 | // See also the dormant https://github.com/tensorflow/tfjs-core/pull/1404 25 | 26 | export type AdadeltaOptimizerConfig = { 27 | learning_rate: number; rho: number; epsilon: number; 28 | }; 29 | 30 | export type AdadeltaSerialization = 31 | BaseSerialization<'Adadelta', AdadeltaOptimizerConfig>; 32 | 33 | export type AdagradOptimizerConfig = { 34 | learning_rate: number; 35 | initial_accumulator_value?: number; 36 | }; 37 | 38 | export type AdagradSerialization = 39 | BaseSerialization<'Adagrad', AdagradOptimizerConfig>; 40 | 41 | export type AdamOptimizerConfig = { 42 | learning_rate: number; beta1: number; beta2: number; 43 | epsilon?: number; 44 | }; 45 | 46 | export type AdamSerialization = BaseSerialization<'Adam', AdamOptimizerConfig>; 47 | 48 | export type AdamaxOptimizerConfig = { 49 | learning_rate: number; beta1: number; beta2: number; 50 | epsilon?: number; 51 | decay?: number; 52 | }; 53 | 54 | export type AdamaxSerialization = 55 | BaseSerialization<'Adamax', AdamaxOptimizerConfig>; 56 | 57 | export type MomentumOptimizerConfig = { 58 | // extends SGDOptimizerConfig { 59 | learning_rate: number; momentum: number; 60 | use_nesterov?: boolean; 61 | }; 62 | 63 | export type MomentumSerialization = 64 | BaseSerialization<'Momentum', MomentumOptimizerConfig>; 65 | 66 | export type RMSPropOptimizerConfig = { 67 | learning_rate: number; 68 | decay?: number; 69 | momentum?: number; 70 | epsilon?: number; 71 | centered?: boolean; 72 | }; 73 | 74 | export type RMSPropSerialization = 75 | BaseSerialization<'RMSProp', RMSPropOptimizerConfig>; 76 | 77 | export type SGDOptimizerConfig = { 78 | learning_rate: number; 79 | }; 80 | 81 | export type SGDSerialization = BaseSerialization<'SGD', SGDOptimizerConfig>; 82 | 83 | // Update optimizerClassNames below in concert with this. 84 | export type OptimizerSerialization = AdadeltaSerialization|AdagradSerialization| 85 | AdamSerialization|AdamaxSerialization|MomentumSerialization| 86 | RMSPropSerialization|SGDSerialization; 87 | 88 | export type OptimizerClassName = OptimizerSerialization['class_name']; 89 | 90 | // We can't easily extract a string[] from the string union type, but we can 91 | // recapitulate the list, enforcing at compile time that the values are valid. 92 | 93 | /** 94 | * A string array of valid Optimizer class names. 95 | * 96 | * This is guaranteed to match the `OptimizerClassName` union type. 97 | */ 98 | export const optimizerClassNames: OptimizerClassName[] = 99 | ['Adadelta', 'Adagrad', 'Adam', 'Adamax', 'Momentum', 'RMSProp', 'SGD']; 100 | -------------------------------------------------------------------------------- /tfjs-layers/rollup.config.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import commonjs from 'rollup-plugin-commonjs'; 19 | import node from 'rollup-plugin-node-resolve'; 20 | import typescript from 'rollup-plugin-typescript2'; 21 | import uglify from 'rollup-plugin-uglify'; 22 | 23 | const PREAMBLE = `/** 24 | * @license 25 | * Copyright ${(new Date).getFullYear()} Google LLC. All Rights Reserved. 26 | * Licensed under the Apache License, Version 2.0 (the "License"); 27 | * you may not use this file except in compliance with the License. 28 | * You may obtain a copy of the License at 29 | * 30 | * http://www.apache.org/licenses/LICENSE-2.0 31 | * 32 | * Unless required by applicable law or agreed to in writing, software 33 | * distributed under the License is distributed on an "AS IS" BASIS, 34 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 35 | * See the License for the specific language governing permissions and 36 | * limitations under the License. 37 | * ============================================================================= 38 | */`; 39 | 40 | function minify() { 41 | return uglify( 42 | {output: {preamble: PREAMBLE}}, 43 | ); 44 | } 45 | 46 | function config({plugins = [], output = {}}) { 47 | return { 48 | input: 'src/index.ts', 49 | plugins: [ 50 | typescript({ 51 | tsconfigOverride: { 52 | compilerOptions: { 53 | module: 'ES2015', 54 | } 55 | }, 56 | }), 57 | node(), 58 | // Polyfill require() from dependencies. 59 | commonjs({ 60 | ignore: ['crypto'], 61 | include: 'node_modules/**', 62 | namedExports: { 63 | './node_modules/seedrandom/index.js': ['alea'], 64 | }, 65 | }), 66 | ...plugins 67 | ], 68 | output: { 69 | banner: PREAMBLE, 70 | sourcemap: true, 71 | globals: {'@tensorflow/tfjs-core': 'tf'}, 72 | ...output, 73 | }, 74 | external: ['crypto', '@tensorflow/tfjs-core'], 75 | onwarn: warning => { 76 | let {code} = warning; 77 | if (code === 'CIRCULAR_DEPENDENCY' || code === 'CIRCULAR') { 78 | return; 79 | } 80 | console.warn('WARNING: ', warning.toString()); 81 | } 82 | }; 83 | } 84 | 85 | export default [ 86 | config({ 87 | output: { 88 | format: 'umd', 89 | name: 'tf', 90 | extend: true, 91 | file: 'dist/tf-layers.js', 92 | } 93 | }), 94 | config({ 95 | plugins: [minify()], 96 | output: { 97 | format: 'umd', 98 | name: 'tf', 99 | extend: true, 100 | file: 'dist/tf-layers.min.js', 101 | } 102 | }), 103 | config({ 104 | plugins: [minify()], 105 | output: { 106 | format: 'es', 107 | file: 'dist/tf-layers.esm.js', 108 | } 109 | }) 110 | ]; 111 | -------------------------------------------------------------------------------- /tfjs-layers/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow.js Layers: High-Level Machine Learning Model API 2 | 3 | A part of the TensorFlow.js ecosystem, TensorFlow.js Layers is a high-level 4 | API built on [TensorFlow.js Core](https://github.com/tensorflow/tfjs-core), 5 | enabling users to build, train and execute deep learning models in the browser. 6 | TensorFlow.js Layers is modeled after 7 | [Keras](https://keras.io/) and 8 | [tf.keras](https://www.tensorflow.org/api_docs/python/tf/keras) and can 9 | load models saved from those libraries. 10 | 11 | ## Importing 12 | 13 | There are three ways to import TensorFlow.js Layers 14 | 15 | 1. You can access TensorFlow.js Layers through the union package 16 | between the TensorFlow.js Core and Layers: 17 | [@tensorflow/tfjs](https://www.npmjs.com/package/@tensorflow/tfjs) 18 | 2. You can get [TensorFlow.js](https://github.com/tensorflow/tfjs) Layers as a module: 19 | [@tensorflow/tfjs-layers](https://www.npmjs.com/package/@tensorflow/tfjs-layers). 20 | Note that `tfjs-layers` has peer dependency on tfjs-core, so if you import 21 | `@tensorflow/tfjs-layers`, you also need to import 22 | `@tensorflow/tfjs-core`. 23 | 3. As a standalone through [unpkg](https://unpkg.com/). 24 | 25 | Option 1 is the most convenient, but leads to a larger bundle size (we will be 26 | adding more packages to it in the future). Use option 2 if you care about bundle 27 | size. 28 | 29 | ## Getting started 30 | 31 | ### Building, training and executing a model 32 | 33 | The following example shows how to build a toy model with only one `dense` layer 34 | to perform linear regression. 35 | 36 | ```js 37 | import * as tf from '@tensorflow/tfjs'; 38 | 39 | // A sequential model is a container which you can add layers to. 40 | const model = tf.sequential(); 41 | 42 | // Add a dense layer with 1 output unit. 43 | model.add(tf.layers.dense({units: 1, inputShape: [1]})); 44 | 45 | // Specify the loss type and optimizer for training. 46 | model.compile({loss: 'meanSquaredError', optimizer: 'SGD'}); 47 | 48 | // Generate some synthetic data for training. 49 | const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]); 50 | const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]); 51 | 52 | // Train the model. 53 | await model.fit(xs, ys, {epochs: 500}); 54 | 55 | // Ater the training, perform inference. 56 | const output = model.predict(tf.tensor2d([[5]], [1, 1])); 57 | output.print(); 58 | ``` 59 | 60 | ### Loading a pretrained Keras model 61 | 62 | You can also load a model previously trained and saved from elsewhere (e.g., 63 | from Python Keras) and use it for inference or transfer learning in the browser. 64 | 65 | For example, in Python, save your Keras model using 66 | [tensorflowjs](https://pypi.org/project/tensorflowjs/), 67 | which can be installed using `pip install tensorflowjs`. 68 | 69 | 70 | ```python 71 | import tensorflowjs as tfjs 72 | 73 | # ... Create and train your Keras model. 74 | 75 | # Save your Keras model in TensorFlow.js format. 76 | tfjs.converters.save_keras_model(model, '/path/to/tfjs_artifacts/') 77 | 78 | # Then use your favorite web server to serve the directory at a URL, say 79 | # http://foo.bar/tfjs_artifacts/model.json 80 | ``` 81 | 82 | To load the model with TensorFlow.js Layers: 83 | 84 | ```js 85 | import * as tf from '@tensorflow/tfjs'; 86 | 87 | const model = await tf.loadLayersModel('http://foo.bar/tfjs_artifacts/model.json'); 88 | // Now the model is ready for inference, evaluation or re-training. 89 | ``` 90 | 91 | ## For more information 92 | 93 | - [TensorFlow.js API documentation](https://js.tensorflow.org/api/latest/) 94 | - [TensorFlow.js Tutorials](https://js.tensorflow.org/tutorials/) 95 | -------------------------------------------------------------------------------- /tfjs-layers/scripts/switch-tfjs-core-version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2018 Google LLC 4 | # 5 | # Use of this source code is governed by an MIT-style 6 | # license that can be found in the LICENSE file or at 7 | # https://opensource.org/licenses/MIT. 8 | # ============================================================================= 9 | 10 | # Switch between different versions of tfjs-core dependency. 11 | # 12 | # Usage examples: 13 | # 14 | # 1. To depend on the HEAD of the public GitHub repo of tfjs-core 15 | # ./scripts/switch-tfjs-core-version.sh --github 16 | # 17 | # 2. To depend on a given branch or tags of the public GitHub repo of 18 | # tfjs-core: 19 | # ./scripts/switch-tfjs-core-version.sh --github --branch tags/v0.5.0 20 | # 21 | # 3. To depend on tfjs-core built from a local repo, with any local 22 | # edits incorporated: 23 | # ./scripts/switch-tfjs-core-version.sh --local_path "${HOME}/my-dljs" 24 | 25 | set -e 26 | 27 | ORIGIN_DIR="$(pwd)" 28 | 29 | GITHUB=0 30 | GIT_BRANCH="" 31 | LOCAL_PATH="" 32 | 33 | DEFAULT_TMP_REPO_DIR="/tmp/dljs-github-clean" 34 | 35 | while [[ ! -z "$1" ]]; do 36 | if [[ "$1" == "--github" ]]; then 37 | GITHUB=1 38 | shift 1 39 | elif [[ "$1" == "--branch" ]]; then 40 | GIT_BRANCH="$2" 41 | shift 2 42 | elif [[ "$1" == "--local_path" ]]; then 43 | LOCAL_PATH="$2" 44 | if [[ -z "${LOCAL_PATH}" ]]; then 45 | echo "ERROR: Unspecified local path" 46 | exit 1 47 | fi 48 | shift 2 49 | else 50 | echo "ERROR: Unrecognized argument: $1" 51 | exit 1 52 | fi 53 | done 54 | 55 | # Do sanity checks on flags. 56 | if [[ "${GITHUB}" == 1 ]] && [[ ! -z "${LOCAL_PATH}" ]]; then 57 | echo "ERROR: --github and --local_path are mutually exclusive." 58 | exit 1 59 | fi 60 | 61 | if [[ ! -z "${GIT_BRANCH}" ]] && [[ "${GITHUB}" == 0 ]]; then 62 | echo "ERROR: --branch flag can only be used with the --github flag." 63 | exit 1 64 | fi 65 | 66 | # Check yarn is on path. 67 | if [[ -z "$(which yarn)" ]]; then 68 | echo "ERROR: switch-tfjs-core-version.sh relies on yarn." \ 69 | "But yarn is not found on path." \ 70 | "See https://yarnpkg.com/lang/en/docs/install/" 71 | exit 1 72 | fi 73 | 74 | if [[ ${GITHUB} == 1 ]]; then 75 | REPO_DIR="${DEFAULT_TMP_REPO_DIR}" 76 | 77 | if [[ ! -d "${REPO_DIR}/tfjs-core" ]]; then 78 | echo "Cloning tfjs-core git repo to: ${REPO_DIR}" 79 | echo 80 | mkdir -p "${REPO_DIR}" 81 | cd "${REPO_DIR}" 82 | git clone https://github.com/tensorflow/tfjs-core.git 83 | fi 84 | cd "${REPO_DIR}/tfjs-core" 85 | 86 | if [[ ! -z "${GIT_BRANCH}" ]]; then 87 | git checkout "${GIT_BRANCH}" 88 | fi 89 | git pull 90 | elif [[ ! -z "${LOCAL_PATH}" ]]; then 91 | cd "${LOCAL_PATH}" 92 | else 93 | echo "Must specify either --github or --local_path " 94 | exit 1 95 | fi 96 | 97 | # Call yarn link / build in the tfjs-core source folder. 98 | # In case another tfjs-core repo has been registered. 99 | yarn unlink || echo "No tfjs-core is registered with yarn link." 100 | yarn link 101 | yarn 102 | yarn build 103 | 104 | # cd back to where we started and call yarn link tfjs-core. 105 | cd "${ORIGIN_DIR}" 106 | rm -rf node_modules/@tensorflow/tfjs-core 107 | yarn link 108 | 109 | # Call yarn link tfjs-core for the demos/ directory. 110 | cd "${ORIGIN_DIR}/demos" 111 | rm -rf node_modules/@tensorflow/tfjs-core 112 | yarn link @tensorflow/tfjs-core 113 | 114 | echo "Linking to custom tfjs-core source is done." 115 | echo 116 | echo "To switch back to the default tfjs-core version, do:" 117 | echo " yarn unlink @tensorflow/tfjs-core && yarn" 118 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/core_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ActivationSerialization} from '../activation_config'; 12 | import {Shape} from '../common'; 13 | import {ConstraintSerialization} from '../constraint_config'; 14 | import {InitializerSerialization} from '../initializer_config'; 15 | import {RegularizerSerialization} from '../regularizer_config'; 16 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 17 | 18 | export interface DropoutLayerConfig extends LayerConfig { 19 | rate: number; 20 | noise_shape?: number[]; 21 | seed?: number; 22 | } 23 | 24 | export type DropoutLayerSerialization = 25 | BaseLayerSerialization<'Dropout', DropoutLayerConfig>; 26 | 27 | export interface DenseLayerConfig extends LayerConfig { 28 | units: number; 29 | activation?: ActivationSerialization; 30 | use_bias?: boolean; 31 | input_dim?: number; 32 | kernel_initializer?: InitializerSerialization; 33 | bias_initializer?: InitializerSerialization; 34 | kernel_constraint?: ConstraintSerialization; 35 | bias_constraint?: ConstraintSerialization; 36 | kernel_regularizer?: RegularizerSerialization; 37 | bias_regularizer?: RegularizerSerialization; 38 | activity_regularizer?: RegularizerSerialization; 39 | } 40 | 41 | export type DenseLayerSerialization = 42 | BaseLayerSerialization<'Dense', DenseLayerConfig>; 43 | 44 | export type FlattenLayerSerialization = 45 | BaseLayerSerialization<'Flatten', LayerConfig>; 46 | 47 | export interface ActivationLayerConfig extends LayerConfig { 48 | activation: ActivationSerialization; 49 | } 50 | 51 | export type ActivationLayerSerialization = 52 | BaseLayerSerialization<'Activation', ActivationLayerConfig>; 53 | 54 | export interface RepeatVectorLayerConfig extends LayerConfig { 55 | n: number; 56 | } 57 | 58 | export type RepeatVectorLayerSerialization = 59 | BaseLayerSerialization<'RepeatVector', RepeatVectorLayerConfig>; 60 | 61 | export interface ReshapeLayerConfig extends LayerConfig { 62 | target_shape: Shape; 63 | } 64 | 65 | export type ReshapeLayerSerialization = 66 | BaseLayerSerialization<'Reshape', ReshapeLayerConfig>; 67 | 68 | export interface PermuteLayerConfig extends LayerConfig { 69 | dims: number[]; 70 | } 71 | 72 | export type PermuteLayerSerialization = 73 | BaseLayerSerialization<'Permute', PermuteLayerConfig>; 74 | 75 | export interface MaskingLayerConfig extends LayerConfig { 76 | maskValue: number; 77 | } 78 | 79 | export type MaskingLayerSerialization = 80 | BaseLayerSerialization<'Masking', MaskingLayerConfig>; 81 | 82 | // Update coreLayerClassNames below in concert with this. 83 | export type CoreLayerSerialization = 84 | DropoutLayerSerialization|DenseLayerSerialization|FlattenLayerSerialization| 85 | ActivationLayerSerialization|RepeatVectorLayerSerialization| 86 | ReshapeLayerSerialization|PermuteLayerSerialization| 87 | MaskingLayerSerialization; 88 | 89 | export type CoreLayerClassName = CoreLayerSerialization['class_name']; 90 | 91 | // We can't easily extract a string[] from the string union type, but we can 92 | // recapitulate the list, enforcing at compile time that the values are valid. 93 | 94 | /** 95 | * A string array of valid CoreLayer class names. 96 | * 97 | * This is guaranteed to match the `CoreLayerClassName` union type. 98 | */ 99 | export const coreLayerClassNames: CoreLayerClassName[] = [ 100 | 'Activation', 101 | 'Dense', 102 | 'Dropout', 103 | 'Flatten', 104 | 'Permute', 105 | 'RepeatVector', 106 | 'Reshape', 107 | ]; 108 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/initializer_config.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {BaseSerialization} from './types'; 12 | 13 | // TODO(soergel): Move the CamelCase versions back out of keras_format 14 | // e.g. to src/common.ts. Maybe even duplicate *all* of these to be pedantic? 15 | /** @docinline */ 16 | export type FanMode = 'fanIn'|'fanOut'|'fanAvg'; 17 | export const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg']; 18 | 19 | // These constants have a snake vs. camel distinction. 20 | export type FanModeSerialization = 'fan_in'|'fan_out'|'fan_avg'; 21 | 22 | /** @docinline */ 23 | export type Distribution = 'normal'|'uniform'|'truncatedNormal'; 24 | export const VALID_DISTRIBUTION_VALUES = 25 | ['normal', 'uniform', 'truncatedNormal']; 26 | // These constants have a snake vs. camel distinction. 27 | export type DistributionSerialization = 'normal'|'uniform'|'truncated_normal'; 28 | 29 | export type ZerosSerialization = BaseSerialization<'Zeros', {}>; 30 | 31 | export type OnesSerialization = BaseSerialization<'Ones', {}>; 32 | 33 | export type ConstantConfig = { 34 | value: number; 35 | }; 36 | 37 | export type ConstantSerialization = 38 | BaseSerialization<'Constant', ConstantConfig>; 39 | 40 | export type RandomNormalConfig = { 41 | mean?: number; 42 | stddev?: number; 43 | seed?: number; 44 | }; 45 | 46 | export type RandomNormalSerialization = 47 | BaseSerialization<'RandomNormal', RandomNormalConfig>; 48 | 49 | export type RandomUniformConfig = { 50 | minval?: number; 51 | maxval?: number; 52 | seed?: number; 53 | }; 54 | 55 | export type RandomUniformSerialization = 56 | BaseSerialization<'RandomUniform', RandomUniformConfig>; 57 | 58 | export type TruncatedNormalConfig = { 59 | mean?: number; 60 | stddev?: number; 61 | seed?: number; 62 | }; 63 | 64 | export type TruncatedNormalSerialization = 65 | BaseSerialization<'TruncatedNormal', TruncatedNormalConfig>; 66 | 67 | export type VarianceScalingConfig = { 68 | scale?: number; 69 | 70 | mode?: FanModeSerialization; 71 | distribution?: DistributionSerialization; 72 | seed?: number; 73 | }; 74 | 75 | export type VarianceScalingSerialization = 76 | BaseSerialization<'VarianceScaling', VarianceScalingConfig>; 77 | 78 | export type OrthogonalConfig = { 79 | seed?: number; 80 | gain?: number; 81 | }; 82 | 83 | export type OrthogonalSerialization = 84 | BaseSerialization<'Orthogonal', OrthogonalConfig>; 85 | 86 | export type IdentityConfig = { 87 | gain?: number; 88 | }; 89 | 90 | export type IdentitySerialization = 91 | BaseSerialization<'Identity', IdentityConfig>; 92 | 93 | // Update initializerClassNames below in concert with this. 94 | export type InitializerSerialization = ZerosSerialization|OnesSerialization| 95 | ConstantSerialization|RandomUniformSerialization|RandomNormalSerialization| 96 | TruncatedNormalSerialization|IdentitySerialization| 97 | VarianceScalingSerialization|OrthogonalSerialization; 98 | 99 | export type InitializerClassName = InitializerSerialization['class_name']; 100 | 101 | // We can't easily extract a string[] from the string union type, but we can 102 | // recapitulate the list, enforcing at compile time that the values are valid 103 | // and that we have the right number of them. 104 | 105 | /** 106 | * A string array of valid Initializer class names. 107 | * 108 | * This is guaranteed to match the `InitializerClassName` union type. 109 | */ 110 | export const initializerClassNames: InitializerClassName[] = [ 111 | 'Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform', 112 | 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity' 113 | ]; 114 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/convolutional_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {DataFormatSerialization, PaddingMode} from '../common'; 12 | import {ConstraintSerialization} from '../constraint_config'; 13 | import {InitializerSerialization} from '../initializer_config'; 14 | import {RegularizerSerialization} from '../regularizer_config'; 15 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 16 | 17 | export interface BaseConvLayerConfig extends LayerConfig { 18 | kernel_size: number|number[]; 19 | strides?: number|number[]; 20 | padding?: PaddingMode; 21 | data_format?: DataFormatSerialization; 22 | dilation_rate?: number|[number]|[number, number]; 23 | activation?: string; 24 | use_bias?: boolean; 25 | kernel_initializer?: InitializerSerialization; 26 | bias_initializer?: InitializerSerialization; 27 | kernel_constraint?: ConstraintSerialization; 28 | bias_constraint?: ConstraintSerialization; 29 | kernel_regularizer?: RegularizerSerialization; 30 | bias_regularizer?: RegularizerSerialization; 31 | activity_regularizer?: RegularizerSerialization; 32 | } 33 | 34 | export interface ConvLayerConfig extends BaseConvLayerConfig { 35 | filters: number; 36 | } 37 | 38 | export type Conv1DLayerSerialization = 39 | BaseLayerSerialization<'Conv1D', ConvLayerConfig>; 40 | 41 | export type Conv2DLayerSerialization = 42 | BaseLayerSerialization<'Conv2D', ConvLayerConfig>; 43 | 44 | export type Conv2DTransposeLayerSerialization = 45 | BaseLayerSerialization<'Conv2DTranspose', ConvLayerConfig>; 46 | 47 | export interface SeparableConvLayerConfig extends ConvLayerConfig { 48 | depth_multiplier?: number; 49 | depthwise_initializer?: InitializerSerialization; 50 | pointwise_initializer?: InitializerSerialization; 51 | depthwise_regularizer?: RegularizerSerialization; 52 | pointwise_regularizer?: RegularizerSerialization; 53 | depthwise_constraint?: ConstraintSerialization; 54 | pointwise_constraint?: ConstraintSerialization; 55 | } 56 | 57 | export type SeparableConv2DLayerSerialization = 58 | BaseLayerSerialization<'SeparableConv2D', ConvLayerConfig>; 59 | 60 | 61 | export interface Cropping2DLayerConfig extends LayerConfig { 62 | cropping: number|[number, number]|[[number, number], [number, number]]; 63 | data_format?: DataFormatSerialization; 64 | } 65 | 66 | export type Cropping2DLayerSerialization = 67 | BaseLayerSerialization<'Cropping2D', Cropping2DLayerConfig>; 68 | 69 | export interface UpSampling2DLayerConfig extends LayerConfig { 70 | size?: number[]; 71 | data_format?: DataFormatSerialization; 72 | } 73 | 74 | export type UpSampling2DLayerSerialization = 75 | BaseLayerSerialization<'UpSampling2D', UpSampling2DLayerConfig>; 76 | 77 | // Update convolutionalLayerClassNames below in concert with this. 78 | export type ConvolutionalLayerSerialization = 79 | Conv1DLayerSerialization|Conv2DLayerSerialization| 80 | Conv2DTransposeLayerSerialization|SeparableConv2DLayerSerialization| 81 | Cropping2DLayerSerialization|UpSampling2DLayerSerialization; 82 | 83 | export type ConvolutionalLayerClassName = 84 | ConvolutionalLayerSerialization['class_name']; 85 | 86 | // We can't easily extract a string[] from the string union type, but we can 87 | // recapitulate the list, enforcing at compile time that the values are valid. 88 | 89 | /** 90 | * A string array of valid ConvolutionalLayer class names. 91 | * 92 | * This is guaranteed to match the `ConvolutionalLayerClassName` union type. 93 | */ 94 | export const convolutionalLayerClassNames: ConvolutionalLayerClassName[] = [ 95 | 'Conv1D', 96 | 'Conv2D', 97 | 'Conv2DTranspose', 98 | 'Cropping2D', 99 | 'SeparableConv2D', 100 | 'UpSampling2D', 101 | ]; 102 | -------------------------------------------------------------------------------- /tfjs-layers/integration_tests/tfjs2keras/tfjs2keras_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Use of this source code is governed by an MIT-style 4 | # license that can be found in the LICENSE file or at 5 | # https://opensource.org/licenses/MIT. 6 | # ============================================================================= 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import json 13 | import os 14 | import shutil 15 | import tempfile 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | import tensorflowjs as tfjs 20 | 21 | if os.environ['TFJS2KERAS_TEST_USING_TF_KERAS'] == '1': 22 | print('Using tensorflow.keras.') 23 | from tensorflow import keras 24 | else: 25 | print('Using keras-team/keras.') 26 | import keras 27 | 28 | 29 | class Tfjs2KerasExportTest(tf.test.TestCase): 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 34 | cls._tmp_dir = os.path.join(curr_dir, 'test-data') 35 | 36 | def _loadAndTestModel(self, model_path): 37 | """Load a Keras Model from artifacts generated by tensorflow.js. 38 | 39 | This method tests: 40 | - Python Keras loading of the topology JSON file saved from TensorFlow.js. 41 | - Python Keras loading of the model's weight values. 42 | - The equality of the model.predict() output between Python Keras and 43 | TensorFlow.js (up to a certain numeric tolerance.) 44 | 45 | Args: 46 | model_path: Path to the model JSON file. 47 | """ 48 | xs_shape_path = os.path.join( 49 | self._tmp_dir, model_path + '.xs-shapes.json') 50 | xs_data_path = os.path.join( 51 | self._tmp_dir, model_path + '.xs-data.json') 52 | with open(xs_shape_path, 'rt') as f: 53 | xs_shapes = json.load(f) 54 | with open(xs_data_path, 'rt') as f: 55 | xs_values = json.load(f) 56 | xs = [np.array(value, dtype=np.float32).reshape(shape) 57 | for value, shape in zip(xs_values, xs_shapes)] 58 | if len(xs) == 1: 59 | xs = xs[0] 60 | 61 | ys_shape_path = os.path.join( 62 | self._tmp_dir, model_path + '.ys-shapes.json') 63 | ys_data_path = os.path.join( 64 | self._tmp_dir, model_path + '.ys-data.json') 65 | with open(ys_shape_path, 'rt') as f: 66 | ys_shapes = json.load(f) 67 | with open(ys_data_path, 'rt') as f: 68 | ys_values = json.load(f) 69 | ys = [np.array(value, dtype=np.float32).reshape(shape) 70 | for value, shape in zip(ys_values, ys_shapes)] 71 | if len(ys) == 1: 72 | ys = ys[0] 73 | 74 | session = tf.Session() if hasattr(tf, 'Session') else tf.compat.v1.Session() 75 | with tf.Graph().as_default(), session: 76 | model_json_path = os.path.join(self._tmp_dir, model_path, 'model.json') 77 | print('Loading model from path %s' % model_json_path) 78 | model = tfjs.converters.load_keras_model(model_json_path) 79 | ys_new = model.predict(xs) 80 | if isinstance(ys, list): 81 | self.assertEqual(len(ys), len(ys_new)) 82 | for i, y in enumerate(ys): 83 | self.assertAllClose(y, ys_new[i]) 84 | else: 85 | self.assertAllClose(ys, ys_new) 86 | 87 | def testMLP(self): 88 | self._loadAndTestModel('mlp') 89 | 90 | def testCNN(self): 91 | self._loadAndTestModel('cnn') 92 | 93 | def testDepthwiseCNN(self): 94 | self._loadAndTestModel('depthwise_cnn') 95 | 96 | def testSimpleRNN(self): 97 | self._loadAndTestModel('simple_rnn') 98 | 99 | def testGRU(self): 100 | self._loadAndTestModel('gru') 101 | 102 | def testBidirectionalLSTM(self): 103 | self._loadAndTestModel('bidirectional_lstm') 104 | 105 | def testTimeDistributedLSTM(self): 106 | self._loadAndTestModel('time_distributed_lstm') 107 | 108 | def testOneDimensional(self): 109 | self._loadAndTestModel('one_dimensional') 110 | 111 | def testFunctionalMerge(self): 112 | self._loadAndTestModel('functional_merge.json') 113 | 114 | 115 | if __name__ == '__main__': 116 | tf.test.main() 117 | -------------------------------------------------------------------------------- /tfjs-layers/src/common.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Common functions for TensorFlow.js Layers. 13 | */ 14 | import {VALID_DATA_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES} from './keras_format/common'; 15 | import {checkStringTypeUnionValue} from './utils/generic_utils'; 16 | 17 | // A map from the requested scoped name of a Tensor to the number of Tensors 18 | // wanting that name so far. This allows enforcing name uniqueness by appending 19 | // an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc. 20 | const nameMap: Map = new Map(); 21 | 22 | export function checkDataFormat(value?: string): void { 23 | checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value); 24 | } 25 | 26 | export function checkPaddingMode(value?: string): void { 27 | checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value); 28 | } 29 | 30 | export function checkPoolMode(value?: string): void { 31 | checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value); 32 | } 33 | 34 | const _nameScopeStack: string[] = []; 35 | const _nameScopeDivider = '/'; 36 | 37 | /** 38 | * Enter namescope, which can be nested. 39 | */ 40 | export function nameScope(name: string, fn: () => T): T { 41 | _nameScopeStack.push(name); 42 | try { 43 | const val: T = fn(); 44 | _nameScopeStack.pop(); 45 | return val; 46 | } catch (e) { 47 | _nameScopeStack.pop(); 48 | throw e; 49 | } 50 | } 51 | 52 | /** 53 | * Get the current namescope as a flat, concatenated string. 54 | */ 55 | function currentNameScopePrefix(): string { 56 | if (_nameScopeStack.length === 0) { 57 | return ''; 58 | } else { 59 | return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider; 60 | } 61 | } 62 | 63 | /** 64 | * Get the name a Tensor (or Variable) would have if not uniqueified. 65 | * @param tensorName 66 | * @return Scoped name string. 67 | */ 68 | export function getScopedTensorName(tensorName: string): string { 69 | if (!isValidTensorName(tensorName)) { 70 | throw new Error('Not a valid tensor name: \'' + tensorName + '\''); 71 | } 72 | return currentNameScopePrefix() + tensorName; 73 | } 74 | 75 | /** 76 | * Get unique names for Tensors and Variables. 77 | * @param scopedName The fully-qualified name of the Tensor, i.e. as produced by 78 | * `getScopedTensorName()`. 79 | * @return A unique version of the given fully scoped name. 80 | * If this is the first time that the scoped name is seen in this session, 81 | * then the given `scopedName` is returned unaltered. If the same name is 82 | * seen again (producing a collision), an incrementing suffix is added to the 83 | * end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc. 84 | */ 85 | export function getUniqueTensorName(scopedName: string): string { 86 | if (!isValidTensorName(scopedName)) { 87 | throw new Error('Not a valid tensor name: \'' + scopedName + '\''); 88 | } 89 | if (!nameMap.has(scopedName)) { 90 | nameMap.set(scopedName, 0); 91 | } 92 | const index = nameMap.get(scopedName); 93 | nameMap.set(scopedName, nameMap.get(scopedName) + 1); 94 | 95 | if (index > 0) { 96 | const result = scopedName + '_' + index; 97 | // Mark the composed name as used in case someone wants 98 | // to call getUniqueTensorName("name_1"). 99 | nameMap.set(result, 1); 100 | return result; 101 | } else { 102 | return scopedName; 103 | } 104 | } 105 | 106 | const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/); 107 | 108 | /** 109 | * Determine whether a string is a valid tensor name. 110 | * @param name 111 | * @returns A Boolean indicating whether `name` is a valid tensor name. 112 | */ 113 | export function isValidTensorName(name: string): boolean { 114 | return !!name.match(tensorNameRegex); 115 | } 116 | -------------------------------------------------------------------------------- /tfjs-layers/src/user_defined_metadata.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** Utility functions related to user-defined metadata. */ 12 | 13 | // Maximum recommended serialized size for user-defined metadata. 14 | // Beyond this limit, a warning message will be printed during model loading and 15 | // saving. 16 | export const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024; 17 | 18 | /** 19 | * Check validity of user-defined metadata. 20 | * 21 | * @param userDefinedMetadata 22 | * @param modelName Name of the model that the user-defined metadata belongs to. 23 | * Used during construction of error messages. 24 | * @param checkSize Whether to check the size of the metadata is under 25 | * recommended limit. Default: `false`. If `true`, will try stringify the 26 | * JSON object and print a console warning if the serialzied size is above the 27 | * limit. 28 | * @throws Error if `userDefinedMetadata` is not a plain JSON object. 29 | */ 30 | export function checkUserDefinedMetadata( 31 | userDefinedMetadata: {}, modelName: string, checkSize = false): void { 32 | if (userDefinedMetadata == null || 33 | typeof userDefinedMetadata !== 'object' || 34 | Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || 35 | !plainObjectCheck(userDefinedMetadata)) { 36 | throw new Error( 37 | 'User-defined metadata is expected to be a JSON object, but is not.'); 38 | } 39 | 40 | if (checkSize) { 41 | const out = JSON.stringify(userDefinedMetadata); 42 | if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) { 43 | console.warn( 44 | `User-defined metadata of model "${modelName}" is too large in ` + 45 | `size (length=${out.length} when serialized). It is not ` + 46 | `recommended to store such large objects in user-defined metadata. ` + 47 | `Please make sure its serialized length is <= ` + 48 | `${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`); 49 | } 50 | } 51 | } 52 | 53 | /** 54 | * Check if an input is plain JSON object or any valid subfield of it. 55 | * 56 | * @param x The input to be checked. 57 | * @param assertObject Whether to assert `x` is a JSON object, i.e., reject 58 | * cases of arrays and primitives. 59 | * @return Returns `true` if and only if `x` is a plain JSON object, 60 | * a JSON-valid primitive including string, number, boolean and null, 61 | * or an array of the said types. 62 | */ 63 | // tslint:disable-next-line:no-any 64 | export function plainObjectCheck(x: any): boolean { 65 | if (x === null) { 66 | // Note: typeof `null` is 'object', and `null` is valid in JSON. 67 | return true; 68 | } else if (typeof x === 'object') { 69 | if (Object.getPrototypeOf(x) === Object.prototype) { 70 | // `x` is a JavaScript object and its prototype is Object. 71 | const keys = Object.keys(x); 72 | for (const key of keys) { 73 | if (typeof key !== 'string') { 74 | // JSON keys must be strings. 75 | return false; 76 | } 77 | if (!plainObjectCheck(x[key])) { // Recursive call. 78 | return false; 79 | } 80 | } 81 | return true; 82 | } else { 83 | // `x` is a JavaScript object but its prototype is not Object. 84 | if (Array.isArray(x)) { 85 | // `x` is a JavaScript array. 86 | for (const item of x) { 87 | if (!plainObjectCheck(item)) { // Recursive call. 88 | return false; 89 | } 90 | } 91 | return true; 92 | } else { 93 | // `x` is a JavaScript object and its prototype is not Object, 94 | // and it's not an Array. I.e., it's a complex object such as 95 | // `Error` and `Date`. 96 | return false; 97 | } 98 | } 99 | } else { 100 | // `x` is not a JavaScript object or `null`. 101 | const xType = typeof x; 102 | return xType === 'string' || xType === 'number' || xType === 'boolean'; 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/math_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Math utility functions. 13 | * 14 | * This file contains some frequently used math function that operates on 15 | * number[] or Float32Array and return a number. Many of these functions are 16 | * not-so-thick wrappers around TF.js Core functions. But they offer the 17 | * convenience of 18 | * 1) not having to convert the inputs into Tensors, 19 | * 2) not having to convert the returned Tensors to numbers. 20 | */ 21 | 22 | import * as tfc from '@tensorflow/tfjs-core'; 23 | import {scalar, Tensor1D, tensor1d} from '@tensorflow/tfjs-core'; 24 | import {ValueError} from '../errors'; 25 | 26 | export type ArrayTypes = Uint8Array | Int32Array | Float32Array; 27 | 28 | /** 29 | * Determine if a number is an integer. 30 | */ 31 | export function isInteger(x: number): boolean { 32 | return x === parseInt(x.toString(), 10); 33 | } 34 | 35 | /** 36 | * Calculate the product of an array of numbers. 37 | * @param array The array to calculate the product over. 38 | * @param begin Beginning index, inclusive. 39 | * @param end Ending index, exclusive. 40 | * @return The product. 41 | */ 42 | export function arrayProd( 43 | array: number[] | ArrayTypes, begin?: number, end?: number): number { 44 | if (begin == null) { 45 | begin = 0; 46 | } 47 | if (end == null) { 48 | end = array.length; 49 | } 50 | 51 | let prod = 1; 52 | for (let i = begin; i < end; ++i) { 53 | prod *= array[i]; 54 | } 55 | return prod; 56 | } 57 | 58 | /** 59 | * A helper function transforms the two input types to an instance of Tensor1D, 60 | * so the return value can be fed directly into various TF.js Core functions. 61 | * @param array 62 | */ 63 | function toArray1D(array: number[] | Float32Array): Tensor1D { 64 | array = Array.isArray(array) ? new Float32Array(array) : array; 65 | return tensor1d(array); 66 | } 67 | 68 | /** 69 | * Compute minimum value. 70 | * @param array 71 | * @return minimum value. 72 | */ 73 | export function min(array: number[] | Float32Array): number { 74 | return tfc.min(toArray1D(array)).dataSync()[0]; 75 | } 76 | 77 | /** 78 | * Compute maximum value. 79 | * @param array 80 | * @return maximum value 81 | */ 82 | export function max(array: number[] | Float32Array): number { 83 | return tfc.max(toArray1D(array)).dataSync()[0]; 84 | } 85 | 86 | /** 87 | * Compute sum of array. 88 | * @param array 89 | * @return The sum. 90 | */ 91 | export function sum(array: number[] | Float32Array): number { 92 | return tfc.sum(toArray1D(array)).dataSync()[0]; 93 | } 94 | 95 | /** 96 | * Compute mean of array. 97 | * @param array 98 | * @return The mean. 99 | */ 100 | export function mean(array: number[] | Float32Array): number { 101 | return sum(array) / array.length; 102 | } 103 | 104 | /** 105 | * Compute variance of array. 106 | * @param array 107 | * @return The variance. 108 | */ 109 | export function variance(array: number[] | Float32Array): number { 110 | const demeaned = tfc.sub(toArray1D(array), scalar(mean(array))); 111 | const sumSquare = tfc.sum(tfc.mulStrict(demeaned, demeaned)).dataSync()[0]; 112 | return sumSquare / array.length; 113 | } 114 | 115 | /** 116 | * Compute median of array. 117 | * @param array 118 | * @return The median value. 119 | */ 120 | export function median(array: number[] | Float32Array): number { 121 | const arraySorted = array.slice().sort((a, b) => a - b); 122 | const lowIdx = Math.floor((arraySorted.length - 1) / 2); 123 | const highIdx = Math.ceil((arraySorted.length - 1) / 2); 124 | if (lowIdx === highIdx) { 125 | return arraySorted[lowIdx]; 126 | } 127 | return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2; 128 | } 129 | 130 | /** 131 | * Generate an array of integers in [begin, end). 132 | * @param begin Beginning integer, inclusive. 133 | * @param end Ending integer, exclusive. 134 | * @returns Range array. 135 | * @throws ValueError, iff `end` < `begin`. 136 | */ 137 | export function range(begin: number, end: number): number[] { 138 | if (end < begin) { 139 | throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`); 140 | } 141 | const out: number[] = []; 142 | for (let i = begin; i < end; ++i) { 143 | out.push(i); 144 | } 145 | return out; 146 | } 147 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/math_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Unit tests for math_utils. 13 | */ 14 | 15 | import * as tfc from '@tensorflow/tfjs-core'; 16 | 17 | import * as math_utils from './math_utils'; 18 | import {describeMathCPU} from './test_utils'; 19 | 20 | describe('isInteger', () => { 21 | it('True cases', () => { 22 | expect(math_utils.isInteger(-103)).toBe(true); 23 | expect(math_utils.isInteger(0)).toBe(true); 24 | expect(math_utils.isInteger(1337)).toBe(true); 25 | }); 26 | 27 | it('False cases', () => { 28 | expect(math_utils.isInteger(-1.03)).toBe(false); 29 | expect(math_utils.isInteger(0.008)).toBe(false); 30 | expect(math_utils.isInteger(133.7)).toBe(false); 31 | }); 32 | }); 33 | 34 | describe('arrayProd', () => { 35 | it('Full length', () => { 36 | expect(math_utils.arrayProd([2, 3, 4])).toEqual(24); 37 | expect(math_utils.arrayProd(new Float32Array([2, 3, 4]))).toEqual(24); 38 | }); 39 | 40 | it('Partial from beginning', () => { 41 | expect(math_utils.arrayProd([2, 3, 4], null, 2)).toEqual(6); 42 | expect(math_utils.arrayProd([2, 3, 4], 0, 2)).toEqual(6); 43 | }); 44 | 45 | it('Partial to end', () => { 46 | expect(math_utils.arrayProd([2, 3, 4], 1)).toEqual(12); 47 | expect(math_utils.arrayProd([2, 3, 4], 1, 3)).toEqual(12); 48 | }); 49 | 50 | it('Partial no beginninng no end', () => { 51 | expect(math_utils.arrayProd([2, 3, 4, 5], 1, 3)).toEqual(12); 52 | }); 53 | 54 | it('Empty array', () => { 55 | expect(math_utils.arrayProd([])).toEqual(1); 56 | }); 57 | }); 58 | 59 | describeMathCPU('min', () => { 60 | it('Number array', () => { 61 | expect(math_utils.min([-100, -200, 150])).toEqual(-200); 62 | }); 63 | 64 | it('Float32Array', () => { 65 | expect(math_utils.min(new Float32Array([-100, -200, 150]))).toEqual(-200); 66 | }); 67 | }); 68 | 69 | describeMathCPU('max', () => { 70 | it('Number array', () => { 71 | expect(math_utils.max([-100, -200, 150])).toEqual(150); 72 | }); 73 | 74 | it('Float32Array', () => { 75 | expect(math_utils.max(new Float32Array([-100, -200, 150]))).toEqual(150); 76 | }); 77 | }); 78 | 79 | describeMathCPU('sum', () => { 80 | it('Number array', () => { 81 | expect(math_utils.sum([-100, -200, 150])).toEqual(-150); 82 | }); 83 | 84 | it('Float32Array', () => { 85 | expect(math_utils.sum(new Float32Array([-100, -200, 150]))).toEqual(-150); 86 | }); 87 | }); 88 | 89 | describeMathCPU('mean', () => { 90 | it('Number array', () => { 91 | expect(math_utils.mean([-100, -200, 150])).toEqual(-50); 92 | }); 93 | 94 | it('Float32Array', () => { 95 | expect(math_utils.mean(new Float32Array([-100, -200, 150]))).toEqual(-50); 96 | }); 97 | }); 98 | 99 | 100 | describeMathCPU('variance', () => { 101 | it('Number array', () => { 102 | expect(math_utils.variance([-100, -200, 150, 50])).toEqual(18125); 103 | }); 104 | 105 | it('Float32Array', () => { 106 | expect(math_utils.variance(new Float32Array([-100, -200, 150, 50]))) 107 | .toEqual(18125); 108 | }); 109 | }); 110 | 111 | describeMathCPU('median', () => { 112 | it('Number array', () => { 113 | expect(math_utils.median([-100, -200, 150, 50])).toEqual(-25); 114 | }); 115 | 116 | it('Float32Array', () => { 117 | expect(math_utils.median(new Float32Array([-100, -200, 150, 50]))) 118 | .toEqual(-25); 119 | }); 120 | 121 | it('does not mutate input array', () => { 122 | const numbers = [-100, -200, 150, 50]; 123 | math_utils.median(numbers); 124 | tfc.test_util.expectArraysClose(numbers, [-100, -200, 150, 50]); 125 | }); 126 | }); 127 | 128 | describe('range', () => { 129 | it('end > begin', () => { 130 | expect(math_utils.range(0, 1)).toEqual([0]); 131 | expect(math_utils.range(0, 5)).toEqual([0, 1, 2, 3, 4]); 132 | expect(math_utils.range(-10, -5)).toEqual([-10, -9, -8, -7, -6]); 133 | expect(math_utils.range(-3, 3)).toEqual([-3, -2, -1, 0, 1, 2]); 134 | }); 135 | it('end === begin', () => { 136 | expect(math_utils.range(0, 0)).toEqual([]); 137 | expect(math_utils.range(-2, -2)).toEqual([]); 138 | }); 139 | it('end < begin throws error', () => { 140 | expect(() => math_utils.range(0, -2)).toThrowError(/.*-2.*0.*forbidden/); 141 | }); 142 | }); 143 | -------------------------------------------------------------------------------- /tfjs-layers/src/regularizers.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /* original source: keras/regularizers.py */ 12 | 13 | import * as tfc from '@tensorflow/tfjs-core'; 14 | import {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensorflow/tfjs-core'; 15 | import * as K from './backend/tfjs_backend'; 16 | import {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils'; 17 | 18 | /** 19 | * Regularizer base class. 20 | */ 21 | export abstract class Regularizer extends serialization.Serializable { 22 | abstract apply(x: Tensor): Scalar; 23 | } 24 | 25 | export interface L1L2Args { 26 | /** L1 regularization rate. Defaults to 0.01. */ 27 | l1?: number; 28 | /** L2 regularization rate. Defaults to 0.01. */ 29 | l2?: number; 30 | } 31 | 32 | export interface L1Args { 33 | /** L1 regularization rate. Defaults to 0.01. */ 34 | l1: number; 35 | } 36 | 37 | export interface L2Args { 38 | /** L2 regularization rate. Defaults to 0.01. */ 39 | l2: number; 40 | } 41 | 42 | export class L1L2 extends Regularizer { 43 | /** @nocollapse */ 44 | static className = 'L1L2'; 45 | 46 | private readonly l1: number; 47 | private readonly l2: number; 48 | private readonly hasL1: boolean; 49 | private readonly hasL2: boolean; 50 | constructor(args?: L1L2Args) { 51 | super(); 52 | 53 | this.l1 = args == null || args.l1 == null ? 0.01 : args.l1; 54 | this.l2 = args == null || args.l2 == null ? 0.01 : args.l2; 55 | this.hasL1 = this.l1 !== 0; 56 | this.hasL2 = this.l2 !== 0; 57 | } 58 | 59 | /** 60 | * Porting note: Renamed from __call__. 61 | * @param x Variable of which to calculate the regularization score. 62 | */ 63 | apply(x: Tensor): Scalar { 64 | return tidy(() => { 65 | let regularization: Tensor = zeros([1]); 66 | if (this.hasL1) { 67 | regularization = add(regularization, sum(tfc.mul(this.l1, abs(x)))); 68 | } 69 | if (this.hasL2) { 70 | regularization = 71 | add(regularization, sum(tfc.mul(this.l2, K.square(x)))); 72 | } 73 | return regularization.asScalar(); 74 | }); 75 | } 76 | 77 | getConfig(): serialization.ConfigDict { 78 | return {'l1': this.l1, 'l2': this.l2}; 79 | } 80 | 81 | /** @nocollapse */ 82 | static fromConfig( 83 | cls: serialization.SerializableConstructor, 84 | config: serialization.ConfigDict): T { 85 | return new cls({l1: config['l1'] as number, l2: config['l2'] as number}); 86 | } 87 | } 88 | serialization.registerClass(L1L2); 89 | 90 | export function l1(args?: L1Args) { 91 | return new L1L2({l1: args != null ? args.l1 : null, l2: 0}); 92 | } 93 | 94 | export function l2(args: L2Args) { 95 | return new L1L2({l2: args != null ? args.l2 : null, l1: 0}); 96 | } 97 | 98 | /** @docinline */ 99 | export type RegularizerIdentifier = 'l1l2'|string; 100 | 101 | // Maps the JavaScript-like identifier keys to the corresponding keras symbols. 102 | export const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP: 103 | {[identifier in RegularizerIdentifier]: string} = { 104 | 'l1l2': 'L1L2' 105 | }; 106 | 107 | export function serializeRegularizer(constraint: Regularizer): 108 | serialization.ConfigDictValue { 109 | return serializeKerasObject(constraint); 110 | } 111 | 112 | export function deserializeRegularizer( 113 | config: serialization.ConfigDict, 114 | customObjects: serialization.ConfigDict = {}): Regularizer { 115 | return deserializeKerasObject( 116 | config, serialization.SerializationMap.getMap().classNameMap, 117 | customObjects, 'regularizer'); 118 | } 119 | 120 | export function getRegularizer(identifier: RegularizerIdentifier| 121 | serialization.ConfigDict| 122 | Regularizer): Regularizer { 123 | if (identifier == null) { 124 | return null; 125 | } 126 | if (typeof identifier === 'string') { 127 | const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? 128 | REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : 129 | identifier; 130 | const config = {className, config: {}}; 131 | return deserializeRegularizer(config); 132 | } else if (identifier instanceof Regularizer) { 133 | return identifier; 134 | } else { 135 | return deserializeRegularizer(identifier); 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/test_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Testing utilities. 13 | */ 14 | 15 | import {memory, Tensor, test_util, util} from '@tensorflow/tfjs-core'; 16 | import {ALL_ENVS, describeWithFlags, registerTestEnv} from '@tensorflow/tfjs-core/dist/jasmine_util'; 17 | import {ValueError} from '../errors'; 18 | 19 | // Register backends. 20 | registerTestEnv({name: 'cpu', backendName: 'cpu'}); 21 | registerTestEnv({ 22 | name: 'webgl2', 23 | backendName: 'webgl', 24 | flags: { 25 | 'WEBGL_VERSION': 2, 26 | 'WEBGL_CPU_FORWARD': false, 27 | 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 28 | } 29 | }); 30 | 31 | /** 32 | * Expect values are close between an Tensor or number array. 33 | * @param actual 34 | * @param expected 35 | */ 36 | export function expectTensorsClose( 37 | actual: Tensor|number[], expected: Tensor|number[], epsilon?: number) { 38 | if (actual == null) { 39 | throw new ValueError( 40 | 'First argument to expectTensorsClose() is not defined.'); 41 | } 42 | if (expected == null) { 43 | throw new ValueError( 44 | 'Second argument to expectTensorsClose() is not defined.'); 45 | } 46 | if (actual instanceof Tensor && expected instanceof Tensor) { 47 | if (actual.dtype !== expected.dtype) { 48 | throw new Error( 49 | `Data types do not match. Actual: '${actual.dtype}'. ` + 50 | `Expected: '${expected.dtype}'`); 51 | } 52 | if (!util.arraysEqual(actual.shape, expected.shape)) { 53 | throw new Error( 54 | `Shapes do not match. Actual: [${actual.shape}]. ` + 55 | `Expected: [${expected.shape}].`); 56 | } 57 | } 58 | const actualData = actual instanceof Tensor ? actual.dataSync() : actual; 59 | const expectedData = 60 | expected instanceof Tensor ? expected.dataSync() : expected; 61 | test_util.expectArraysClose(actualData, expectedData, epsilon); 62 | } 63 | 64 | /** 65 | * Expect values in array are within a specified range, boundaries inclusive. 66 | * @param actual 67 | * @param expected 68 | */ 69 | export function expectTensorsValuesInRange( 70 | actual: Tensor, low: number, high: number) { 71 | if (actual == null) { 72 | throw new ValueError( 73 | 'First argument to expectTensorsClose() is not defined.'); 74 | } 75 | test_util.expectValuesInRange(actual.dataSync(), low, high); 76 | } 77 | 78 | 79 | /** 80 | * Describe tests to be run on CPU and GPU. 81 | * @param testName 82 | * @param tests 83 | */ 84 | export function describeMathCPUAndGPU(testName: string, tests: () => void) { 85 | describeWithFlags(testName, ALL_ENVS, () => { 86 | tests(); 87 | }); 88 | } 89 | 90 | /** 91 | * Describe tests to be run on CPU only. 92 | * @param testName 93 | * @param tests 94 | */ 95 | export function describeMathCPU(testName: string, tests: () => void) { 96 | describeWithFlags( 97 | testName, {predicate: testEnv => testEnv.backendName === 'cpu'}, () => { 98 | tests(); 99 | }); 100 | } 101 | 102 | /** 103 | * Describe tests to be run on GPU only. 104 | * @param testName 105 | * @param tests 106 | */ 107 | export function describeMathGPU(testName: string, tests: () => void) { 108 | describeWithFlags( 109 | testName, {predicate: testEnv => testEnv.backendName === 'webgl'}, () => { 110 | tests(); 111 | }); 112 | } 113 | 114 | /** 115 | * Check that a function only generates the expected number of new Tensors. 116 | * 117 | * The test function is called twice, once to prime any regular constants and 118 | * once to ensure that additional copies aren't created/tensors aren't leaked. 119 | * 120 | * @param testFunc A fully curried (zero arg) version of the function to test. 121 | * @param numNewTensors The expected number of new Tensors that should exist. 122 | */ 123 | export function expectNoLeakedTensors( 124 | // tslint:disable-next-line:no-any 125 | testFunc: () => any, numNewTensors: number) { 126 | testFunc(); 127 | const numTensorsBefore = memory().numTensors; 128 | testFunc(); 129 | const numTensorsAfter = memory().numTensors; 130 | const actualNewTensors = numTensorsAfter - numTensorsBefore; 131 | if (actualNewTensors !== numNewTensors) { 132 | throw new ValueError( 133 | `Created an unexpected number of new ` + 134 | `Tensors. Expected: ${numNewTensors}, created : ${ 135 | actualNewTensors}. ` + 136 | `Please investigate the discrepency and/or use tidy.`); 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /tfjs-layers/src/engine/dataset_fakes_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {NamedTensorMap, Tensor} from '@tensorflow/tfjs-core'; 12 | 13 | import {describeMathCPUAndGPU} from '../utils/test_utils'; 14 | 15 | import {FakeNumericDataset} from './dataset_fakes'; 16 | 17 | describeMathCPUAndGPU('FakeNumericDataset', () => { 18 | it('1D features, 1D targets', async () => { 19 | const dataset = new FakeNumericDataset( 20 | {xShape: [3], yShape: [1], batchSize: 8, numBatches: 5}); 21 | for (let k = 0; k < 2; ++k) { 22 | // Run twice to make sure that calling iteartor() multiple times works. 23 | const iterator = await dataset.iterator(); 24 | for (let i = 0; i < 5; ++i) { 25 | const result = await iterator.next(); 26 | expect((result.value.xs as Tensor).shape).toEqual([8, 3]); 27 | expect((result.value.ys as Tensor).shape).toEqual([8, 1]); 28 | expect(result.done).toEqual(false); 29 | } 30 | for (let i = 0; i < 3; ++i) { 31 | const result = await iterator.next(); 32 | expect(result.value).toBeNull(); 33 | expect(result.done).toEqual(true); 34 | } 35 | } 36 | }); 37 | 38 | it('2D features, 1D targets', async () => { 39 | const dataset = new FakeNumericDataset( 40 | {xShape: [3, 4], yShape: [2], batchSize: 8, numBatches: 5}); 41 | for (let k = 0; k < 2; ++k) { 42 | // Run twice to make sure that calling iteartor() multiple times works. 43 | const iterator = await dataset.iterator(); 44 | for (let i = 0; i < 5; ++i) { 45 | const result = await iterator.next(); 46 | expect((result.value.xs as Tensor).shape).toEqual([8, 3, 4]); 47 | expect((result.value.ys as Tensor).shape).toEqual([8, 2]); 48 | expect(result.done).toEqual(false); 49 | } 50 | for (let i = 0; i < 3; ++i) { 51 | const result = await iterator.next(); 52 | expect(result.value).toBeNull(); 53 | expect(result.done).toEqual(true); 54 | } 55 | } 56 | }); 57 | 58 | it('Multiple 2D features, 1D targets', async () => { 59 | const dataset = new FakeNumericDataset({ 60 | xShape: {'input1': [3, 4], 'input2': [2, 3]}, 61 | yShape: [2], 62 | batchSize: 8, 63 | numBatches: 5 64 | }); 65 | for (let k = 0; k < 2; ++k) { 66 | // Run twice to make sure that calling iteartor() multiple times works. 67 | const iterator = await dataset.iterator(); 68 | for (let i = 0; i < 5; ++i) { 69 | const result = await iterator.next(); 70 | const xs = result.value.xs as NamedTensorMap; 71 | expect(xs['input1'].shape).toEqual([8, 3, 4]); 72 | expect(xs['input2'].shape).toEqual([8, 2, 3]); 73 | expect((result.value.ys as Tensor).shape).toEqual([8, 2]); 74 | expect(result.done).toEqual(false); 75 | } 76 | for (let i = 0; i < 3; ++i) { 77 | const result = await iterator.next(); 78 | expect(result.value).toBeNull(); 79 | expect(result.done).toEqual(true); 80 | } 81 | } 82 | }); 83 | 84 | it('Invalid batchSize leads to Error', () => { 85 | expect( 86 | () => new FakeNumericDataset( 87 | {xShape: [3], yShape: [1], batchSize: -8, numBatches: 5})) 88 | .toThrow(); 89 | expect( 90 | () => new FakeNumericDataset( 91 | {xShape: [3], yShape: [1], batchSize: 8.5, numBatches: 5})) 92 | .toThrow(); 93 | expect( 94 | () => new FakeNumericDataset( 95 | {xShape: [3], yShape: [1], batchSize: 0, numBatches: 5})) 96 | .toThrow(); 97 | expect( 98 | () => new FakeNumericDataset( 99 | // tslint:disable-next-line:no-any 100 | {xShape: [3], yShape: [1], batchSize: 'foo' as any, numBatches: 5})) 101 | .toThrow(); 102 | }); 103 | 104 | it('Invalid numBatches leads to Error', () => { 105 | expect( 106 | () => new FakeNumericDataset( 107 | {xShape: [3], yShape: [1], batchSize: 8, numBatches: -5})) 108 | .toThrow(); 109 | expect( 110 | () => new FakeNumericDataset( 111 | {xShape: [3], yShape: [1], batchSize: 8, numBatches: 5.5})) 112 | .toThrow(); 113 | expect( 114 | () => new FakeNumericDataset( 115 | {xShape: [3], yShape: [1], batchSize: 8, numBatches: 0})) 116 | .toThrow(); 117 | expect( 118 | () => new FakeNumericDataset( 119 | // tslint:disable-next-line:no-any 120 | {xShape: [3], yShape: [1], batchSize: 8, numBatches: 'foo' as any})) 121 | .toThrow(); 122 | }); 123 | }); 124 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/layers/recurrent_serialization.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {ActivationSerialization} from '../activation_config'; 12 | import {ConstraintSerialization} from '../constraint_config'; 13 | import {InitializerSerialization} from '../initializer_config'; 14 | import {RegularizerSerialization} from '../regularizer_config'; 15 | import {BaseLayerSerialization, LayerConfig} from '../topology_config'; 16 | import {BaseSerialization} from '../types'; 17 | 18 | export interface BaseRNNLayerConfig extends LayerConfig { 19 | cell?: RNNCellSerialization|RNNCellSerialization[]; 20 | return_sequences?: boolean; 21 | return_state?: boolean; 22 | go_backwards?: boolean; 23 | stateful?: boolean; 24 | unroll?: boolean; 25 | input_dim?: number; 26 | input_length?: number; 27 | } 28 | 29 | export interface SimpleRNNCellConfig extends LayerConfig { 30 | units: number; 31 | activation?: ActivationSerialization; 32 | use_bias?: boolean; 33 | kernel_initializer?: InitializerSerialization; 34 | recurrent_initializer?: InitializerSerialization; 35 | bias_initializer?: InitializerSerialization; 36 | kernel_regularizer?: RegularizerSerialization; 37 | recurrent_regularizer?: RegularizerSerialization; 38 | bias_regularizer?: RegularizerSerialization; 39 | kernel_constraint?: ConstraintSerialization; 40 | recurrent_constraint?: ConstraintSerialization; 41 | bias_constraint?: ConstraintSerialization; 42 | dropout?: number; 43 | recurrent_dropout?: number; 44 | } 45 | 46 | export type SimpleRNNCellSerialization = 47 | BaseSerialization<'SimpleRNNCell', SimpleRNNCellConfig>; 48 | 49 | export interface SimpleRNNLayerConfig extends BaseRNNLayerConfig { 50 | units: number; 51 | activation?: ActivationSerialization; 52 | use_bias?: boolean; 53 | kernel_initializer?: InitializerSerialization; 54 | recurrent_initializer?: InitializerSerialization; 55 | bias_initializer?: InitializerSerialization; 56 | kernel_regularizer?: RegularizerSerialization; 57 | recurrent_regularizer?: RegularizerSerialization; 58 | bias_regularizer?: RegularizerSerialization; 59 | kernel_constraint?: ConstraintSerialization; 60 | recurrent_constraint?: ConstraintSerialization; 61 | bias_constraint?: ConstraintSerialization; 62 | dropout?: number; 63 | recurrent_dropout?: number; 64 | } 65 | 66 | export type SimpleRNNLayerSerialization = 67 | BaseLayerSerialization<'SimpleRNN', SimpleRNNLayerConfig>; 68 | 69 | export interface GRUCellConfig extends SimpleRNNCellConfig { 70 | recurrent_activation?: string; 71 | implementation?: number; 72 | } 73 | 74 | export type GRUCellSerialization = BaseSerialization<'GRUCell', GRUCellConfig>; 75 | 76 | export interface GRULayerConfig extends SimpleRNNLayerConfig { 77 | recurrent_activation?: ActivationSerialization; 78 | implementation?: number; 79 | } 80 | 81 | export type GRULayerSerialization = 82 | BaseLayerSerialization<'GRU', GRULayerConfig>; 83 | 84 | export interface LSTMCellConfig extends SimpleRNNCellConfig { 85 | recurrent_activation?: ActivationSerialization; 86 | unit_forget_bias?: boolean; 87 | implementation?: number; 88 | } 89 | 90 | export type LSTMCellSerialization = 91 | BaseSerialization<'LSTMCell', LSTMCellConfig>; 92 | 93 | export interface LSTMLayerConfig extends SimpleRNNLayerConfig { 94 | recurrent_activation?: ActivationSerialization; 95 | unit_forget_bias?: boolean; 96 | implementation?: number; 97 | } 98 | export type LSTMLayerSerialization = 99 | BaseLayerSerialization<'LSTM', LSTMLayerConfig>; 100 | 101 | export interface StackedRNNCellsConfig extends LayerConfig { 102 | // TODO(soergel): consider whether we can avoid improperly mixing 103 | // Simple / LSTM / GRU cells here and in the above Layer serializations. 104 | cells: RNNCellSerialization[]; 105 | } 106 | 107 | export type StackedRNNCellsSerialization = 108 | BaseSerialization<'StackedRNNCells', StackedRNNCellsConfig>; 109 | 110 | export type RNNCellSerialization = SimpleRNNCellSerialization| 111 | GRUCellSerialization|LSTMCellSerialization|StackedRNNCellsSerialization; 112 | 113 | // Update recurrentLayerClassNames below in concert with this. 114 | export type RecurrentLayerSerialization = 115 | SimpleRNNLayerSerialization|LSTMLayerSerialization|GRULayerSerialization; 116 | 117 | export type RecurrentLayerClassName = RecurrentLayerSerialization['class_name']; 118 | 119 | // We can't easily extract a string[] from the string union type, but we can 120 | // recapitulate the list, enforcing at compile time that the values are valid. 121 | 122 | /** 123 | * A string array of valid RecurrentLayer class names. 124 | * 125 | * This is guaranteed to match the `RecurrentLayerClassName` union type. 126 | */ 127 | export const recurrentLayerClassNames: RecurrentLayerClassName[] = [ 128 | 'GRU', 129 | 'LSTM', 130 | 'SimpleRNN', 131 | ]; 132 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/serialization_utils.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | // Porting note: This file doesn't exist in PyKeras. 12 | // Its purpose here is to centralize the boundary layer between 13 | // tfjs-layers's internal Config TS-Centric format and PyKeras's 14 | // serialized Python Config format. 15 | 16 | import {serialization} from '@tensorflow/tfjs-core'; 17 | 18 | import {PyJsonValue} from '../keras_format/types'; 19 | import * as generic_utils from '../utils/generic_utils'; 20 | // tslint:enable 21 | 22 | /** 23 | * Test whether a value in an array is the name of a LayersModel or Layer. 24 | * @param key The key name that the value is found under. Note that the key 25 | * may not be at the level immediately above the value, if the value is in a 26 | * nested array. 27 | * @param index Index of the value in the Array that it is found in. 28 | * @param value The value object. 29 | * @returns A boolean indicating whether value is a name. 30 | */ 31 | function isArrayItemInputOrOutputName( 32 | key: string, index: number, value: T): boolean { 33 | return (key === 'inboundNodes' || key === 'outputLayers' || 34 | key === 'inputLayers') && 35 | index === 0 && typeof value === 'string'; 36 | } 37 | 38 | /** 39 | * Convert a Pythonic config object to TypeScript config object. 40 | * @param pythonicConfig The config object to convert. 41 | * @param key Optional key name of the object being converted. 42 | * @returns Result of the conversion. 43 | */ 44 | export function convertPythonicToTs( 45 | pythonicConfig: PyJsonValue, key?: string): serialization.ConfigDictValue { 46 | if (pythonicConfig === null) { 47 | return null; 48 | } else if (typeof pythonicConfig === 'string') { 49 | return generic_utils.toCamelCase(pythonicConfig); 50 | } else if ( 51 | (typeof pythonicConfig === 'number') || 52 | (typeof pythonicConfig === 'boolean')) { 53 | return pythonicConfig; 54 | } else if (pythonicConfig instanceof Array) { 55 | const tsArray = []; 56 | const arrayLength = pythonicConfig.length; 57 | for (let i = 0; i < arrayLength; ++i) { 58 | const item = pythonicConfig[i]; 59 | if (isArrayItemInputOrOutputName(key, i, item)) { 60 | tsArray.push(item); 61 | } else { 62 | tsArray.push(convertPythonicToTs(item, key)); 63 | } 64 | } 65 | return tsArray; 66 | } else { 67 | const tsDict: serialization.ConfigDict = {}; 68 | for (const pythonicKey of Object.keys(pythonicConfig)) { 69 | const pythonicValue = pythonicConfig[pythonicKey]; 70 | if (pythonicKey === 'name' && typeof pythonicValue === 'string') { 71 | // Special case the 'name' key with a string value. Name values, such as 72 | // the names of LayersModel and Layer instances, should not undergo the 73 | // camel-case conversion. 74 | tsDict[pythonicKey] = pythonicValue; 75 | } else { 76 | const tsKey = generic_utils.toCamelCase(pythonicKey); 77 | tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey); 78 | } 79 | } 80 | return tsDict; 81 | } 82 | } 83 | 84 | /** 85 | * Convert a TypeScript config object to Python config object. 86 | * @param tsConfig The config object to convert. 87 | * @param key Optional key name of the object being converted. 88 | * @returns Result of the conversion. 89 | */ 90 | export function convertTsToPythonic( 91 | tsConfig: serialization.ConfigDictValue, key?: string): PyJsonValue { 92 | if (tsConfig === null || tsConfig === undefined) { 93 | return null; 94 | } else if (typeof tsConfig === 'string') { 95 | return generic_utils.toSnakeCase(tsConfig); 96 | } else if ( 97 | (typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) { 98 | return tsConfig; 99 | } else if (tsConfig instanceof Array) { 100 | const pyArray = []; 101 | const arrayLength = tsConfig.length; 102 | for (let i = 0; i < arrayLength; ++i) { 103 | const item = tsConfig[i]; 104 | if (isArrayItemInputOrOutputName(key, i, item)) { 105 | pyArray.push(item); 106 | } else { 107 | pyArray.push(convertTsToPythonic(item, key)); 108 | } 109 | } 110 | return pyArray; 111 | } else { 112 | const pyDict: serialization.ConfigDict = {}; 113 | for (const tsKey of Object.keys(tsConfig)) { 114 | const tsValue = tsConfig[tsKey]; 115 | const pyKey = generic_utils.toSnakeCase(tsKey); 116 | if ((tsKey === 'name' || tsKey === 'className') && 117 | typeof tsValue === 'string') { 118 | // Special case the 'name' key with a string value. Name values, such as 119 | // the names of LayersModel and Layer instances, should not undergo the 120 | // snake-case conversion. 121 | pyDict[pyKey] = tsValue; 122 | } else { 123 | pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey); 124 | } 125 | } 126 | return pyDict; 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /tfjs-layers/src/keras_format/types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * A value within the JSON-serialized form of a serializable object. 13 | * 14 | * The keys of any nested dicts should be in snake_case (i.e., using Python 15 | * naming conventions) for compatibility with Python Keras. 16 | * 17 | * @see PyJsonDict 18 | */ 19 | export type PyJsonValue = boolean|number|string|null|PyJsonArray|PyJsonDict; 20 | 21 | /** 22 | * A key-value dict within the JSON-serialized form of a serializable object. 23 | * 24 | * Serialization/deserialization uses stringified-JSON as the storage 25 | * representation. Typically this should be used for materialized JSON 26 | * stored on disk or sent/received over the wire. 27 | * 28 | * The keys of this dict and of any nested dicts should be in snake_case (i.e., 29 | * using Python naming conventions) for compatibility with Python Keras. 30 | * 31 | * Internally this is normally converted to a ConfigDict that has CamelCase keys 32 | * (using TypeScript naming conventions) and support for Enums. 33 | */ 34 | export interface PyJsonDict { 35 | [key: string]: PyJsonValue; 36 | } 37 | 38 | /** 39 | * A key-value dict like @see PyJsonDict, but with restricted keys. 40 | * 41 | * This makes it possible to create subtypes that have only the specified 42 | * fields, while requiring that the values are JSON-compatible. 43 | * 44 | * That is in contrast to extending `PyJsonDict`, or using an intersection type 45 | * `Foo & PyJsonDict`. In both of those cases, the fields of Foo are actually 46 | * allowed to be of types that are incompatible with `PyJsonValue`. Worse, the 47 | * index signature of `PyJsonValue` means that *any* key is accepted: eg. 48 | * `const foo: Foo = ...; foo.bogus = 12; const x = foo.bogus` works for both 49 | * reading and assignment, even if `bogus` is not a field of the type `Foo`, 50 | * because the index signature inherited from `PyJsonDict` accepts all strings. 51 | * 52 | * Here, we *both* restrict the keys to known values, *and* guarantee that the 53 | * values associated with those keys are compatible with `PyJsonValue`. 54 | * 55 | * This guarantee is easiest to apply via an additional incantation: 56 | * 57 | * ``` 58 | * export interface Foo extends PyJson { 59 | * a: SomeType; 60 | * b: SomeOtherType; 61 | * } 62 | * ``` 63 | * 64 | * Now instances of `Foo` have *only* the fields `a` and `b`, and furthermore, 65 | * if either the type `SomeType` or `SomeOtherType` is incompatible with 66 | * `PyJsonValue`, the compiler produces a typing error. 67 | */ 68 | export type PyJson = { 69 | [x in Keys]?: PyJsonValue; 70 | }; 71 | 72 | /** 73 | * An array of values within the JSON-serialized form of a serializable object. 74 | * 75 | * The keys of any nested dicts should be in snake_case (i.e., using Python 76 | * naming conventions) for compatibility with Python Keras. 77 | * 78 | * @see PyJsonDict 79 | */ 80 | export interface PyJsonArray extends Array {} 81 | 82 | /** 83 | * A Keras JSON entry representing a Keras object such as a Layer. 84 | * 85 | * The Keras JSON convention is to provide the `class_name` (e.g., the layer 86 | * type) at the top level, and then to place the class-specific configuration in 87 | * a `config` subtree. These class-specific configurations are provided by 88 | * subtypes of `PyJsonDict`. Thus, this `*Serialization` has a type parameter 89 | * giving the specific type of the wrapped `PyJsonDict`. 90 | */ 91 | export interface BaseSerialization< 92 | N extends string, T extends PyJson>> extends 93 | PyJsonDict { 94 | // The above type voodoo does this: 95 | // * `keyof T` obtains the known keys of the specific config type. 96 | // `keyof` returns `string | number | symbol`; see 97 | // (https://www.typescriptlang.org/docs/handbook/release-notes/typescript-2-9.html) 98 | // * `Extract` selects the string values. This amounts to 99 | // assuming that we are dealing with a type with string keys, as opposed to 100 | // an array. In our usage, this assumption always holds. 101 | // * `PyJson>` is a type whose keys are constrained 102 | // to the provided ones, and whose values must be JSON-compatible. 103 | // * `T extends PyJson> means that we can provide any 104 | // config type with known keys, provided that the associated values are 105 | // JSON-compatible. 106 | // 107 | // The upshot is that we can extend `BaseSerialization` with whatever config 108 | // type `T` that we like-- remaining confident that the result can be 109 | // trivially rendered as JSON, because the compiler will produce a typing 110 | // error if that guarantee does not hold. 111 | // 112 | // To test this, try adding a field with a non-JSON-like value (e.g., Tensor) 113 | // to any subclass of `BaseSerialization`. A compilation error will result. 114 | class_name: N; 115 | config: T; 116 | } 117 | -------------------------------------------------------------------------------- /tfjs-layers/src/common_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Unit tests for common.ts. 13 | */ 14 | 15 | import {checkDataFormat, checkPaddingMode, checkPoolMode, getUniqueTensorName, isValidTensorName} from './common'; 16 | import {VALID_DATA_FORMAT_VALUES, VALID_PADDING_MODE_VALUES, VALID_POOL_MODE_VALUES} from './keras_format/common'; 17 | 18 | 19 | describe('checkDataFormat', () => { 20 | it('Valid values', () => { 21 | const extendedValues = VALID_DATA_FORMAT_VALUES.concat([undefined, null]); 22 | for (const validValue of extendedValues) { 23 | // Using implicit "expect().toNotThrow()" for valid values 24 | checkDataFormat(validValue); 25 | } 26 | }); 27 | it('Invalid values', () => { 28 | // Test invalid values are rejected, and reported in the error. 29 | expect(() => checkDataFormat('foo')).toThrowError(/foo/); 30 | try { 31 | checkDataFormat('bad'); 32 | } catch (e) { 33 | expect(e).toMatch('DataFormat'); 34 | // Test that the error message contains the list of valid values. 35 | for (const validValue of VALID_DATA_FORMAT_VALUES) { 36 | expect(e).toMatch(validValue); 37 | } 38 | } 39 | }); 40 | }); 41 | 42 | describe('checkPaddingMode', () => { 43 | it('Valid values', () => { 44 | const extendedValues = VALID_PADDING_MODE_VALUES.concat([undefined, null]); 45 | for (const validValue of extendedValues) { 46 | // Using implicit "expect().toNotThrow()" for valid values 47 | checkPaddingMode(validValue); 48 | } 49 | }); 50 | it('Invalid values', () => { 51 | // Test invalid values are rejected, and reported in the error. 52 | expect(() => checkPaddingMode('foo')).toThrowError(/foo/); 53 | try { 54 | checkPaddingMode('bad'); 55 | } catch (e) { 56 | expect(e).toMatch('PaddingMode'); 57 | // Test that the error message contains the list of valid values. 58 | for (const validValue of VALID_PADDING_MODE_VALUES) { 59 | expect(e).toMatch(validValue); 60 | } 61 | } 62 | }); 63 | }); 64 | 65 | describe('checkPoolMode', () => { 66 | it('Valid values', () => { 67 | const extendedValues = VALID_POOL_MODE_VALUES.concat([undefined, null]); 68 | for (const validValue of extendedValues) { 69 | // Using implicit "expect().toNotThrow()" for valid values 70 | checkPoolMode(validValue); 71 | } 72 | }); 73 | it('Invalid values', () => { 74 | // Test invalid values are rejected, and reported in the error. 75 | expect(() => checkPoolMode('foo')).toThrowError(/foo/); 76 | try { 77 | checkPoolMode('bad'); 78 | } catch (e) { 79 | expect(e).toMatch('PoolMode'); 80 | // Test that the error message contains the list of valid values. 81 | for (const validValue of VALID_POOL_MODE_VALUES) { 82 | expect(e).toMatch(validValue); 83 | } 84 | } 85 | }); 86 | }); 87 | 88 | 89 | describe('isValidTensorName', () => { 90 | it('Valid tensor names', () => { 91 | expect(isValidTensorName('a')).toEqual(true); 92 | expect(isValidTensorName('A')).toEqual(true); 93 | expect(isValidTensorName('foo1')).toEqual(true); 94 | expect(isValidTensorName('Foo2')).toEqual(true); 95 | expect(isValidTensorName('n_1')).toEqual(true); 96 | expect(isValidTensorName('n.1')).toEqual(true); 97 | expect(isValidTensorName('n_1_2')).toEqual(true); 98 | expect(isValidTensorName('n.1.2')).toEqual(true); 99 | expect(isValidTensorName('a/B/c')).toEqual(true); 100 | expect(isValidTensorName('z_1/z_2/z.3')).toEqual(true); 101 | expect(isValidTensorName('z-1/z-2/z.3')).toEqual(true); 102 | expect(isValidTensorName('1Qux')).toEqual(true); 103 | expect(isValidTensorName('5-conv/kernel')).toEqual(true); 104 | }); 105 | 106 | it('Invalid tensor names: empty', () => { 107 | expect(isValidTensorName('')).toEqual(false); 108 | }); 109 | 110 | it('Invalid tensor names: whitespaces', () => { 111 | expect(isValidTensorName('a b')).toEqual(false); 112 | expect(isValidTensorName('ab ')).toEqual(false); 113 | }); 114 | 115 | it('Invalid tensor names: forbidden characters', () => { 116 | expect(isValidTensorName('-foo1')).toEqual(false); 117 | expect(isValidTensorName('-foo2-')).toEqual(false); 118 | expect(isValidTensorName('bar3!4')).toEqual(false); 119 | }); 120 | 121 | it('Invalid tensor names: invalid first characters', () => { 122 | expect(isValidTensorName('/foo/bar')).toEqual(false); 123 | expect(isValidTensorName('.baz')).toEqual(false); 124 | expect(isValidTensorName('_baz')).toEqual(false); 125 | }); 126 | 127 | it('Invalid tensor names: non-ASCII', () => { 128 | expect(isValidTensorName('フ')).toEqual(false); 129 | expect(isValidTensorName('ξ')).toEqual(false); 130 | }); 131 | }); 132 | 133 | describe('getUniqueTensorName', () => { 134 | it('Adds unique suffixes to tensor names', () => { 135 | expect(getUniqueTensorName('xx')).toEqual('xx'); 136 | expect(getUniqueTensorName('xx')).toEqual('xx_1'); 137 | expect(getUniqueTensorName('xx')).toEqual('xx_2'); 138 | expect(getUniqueTensorName('xx')).toEqual('xx_3'); 139 | }); 140 | 141 | it('Correctly handles preexisting unique suffixes on tensor names', () => { 142 | expect(getUniqueTensorName('yy')).toEqual('yy'); 143 | expect(getUniqueTensorName('yy')).toEqual('yy_1'); 144 | expect(getUniqueTensorName('yy_1')).toEqual('yy_1_1'); 145 | expect(getUniqueTensorName('yy')).toEqual('yy_2'); 146 | expect(getUniqueTensorName('yy_1')).toEqual('yy_1_2'); 147 | expect(getUniqueTensorName('yy_2')).toEqual('yy_2_1'); 148 | expect(getUniqueTensorName('yy')).toEqual('yy_3'); 149 | expect(getUniqueTensorName('yy_1_1')).toEqual('yy_1_1_1'); 150 | }); 151 | }); 152 | -------------------------------------------------------------------------------- /tfjs-layers/src/constraints_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /* Unit tests for constraints */ 12 | 13 | import {serialization, Tensor1D, tensor1d} from '@tensorflow/tfjs-core'; 14 | 15 | import {ConstraintIdentifier, deserializeConstraint, getConstraint, serializeConstraint} from './constraints'; 16 | import * as tfl from './index'; 17 | import {describeMathCPU, expectNoLeakedTensors, expectTensorsClose} from './utils/test_utils'; 18 | 19 | 20 | describeMathCPU('Built-in Constraints', () => { 21 | let initVals: Tensor1D; 22 | beforeEach(() => { 23 | initVals = tensor1d(new Float32Array([-1, 2, 0, 4, -5, 6])); 24 | }); 25 | 26 | it('NonNeg', () => { 27 | const constraint = getConstraint('NonNeg'); 28 | const postConstraint = constraint.apply(initVals); 29 | expectTensorsClose( 30 | postConstraint, tensor1d(new Float32Array([0, 2, 0, 4, 0, 6]))); 31 | expectNoLeakedTensors(() => constraint.apply(initVals), 1); 32 | }); 33 | 34 | it('MaxNorm', () => { 35 | const constraint = getConstraint('MaxNorm'); 36 | const postConstraint = constraint.apply(initVals); 37 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 38 | -0.2208630521, 0.4417261043, 0, 0.8834522086, 39 | -1.104315261, 1.325178313 40 | ]))); 41 | expectNoLeakedTensors(() => constraint.apply(initVals), 1); 42 | }); 43 | it('UnitNorm', () => { 44 | const constraint = getConstraint('UnitNorm'); 45 | const postConstraint = constraint.apply(initVals); 46 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 47 | -0.2208630521 / 2, 0.4417261043 / 2, 0, 48 | 0.8834522086 / 2, -1.104315261 / 2, 1.325178313 / 2 49 | ]))); 50 | expectNoLeakedTensors(() => constraint.apply(initVals), 1); 51 | }); 52 | it('MinMaxNorm', () => { 53 | const constraint = getConstraint('MinMaxNorm'); 54 | const postConstraint = constraint.apply(initVals); 55 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 56 | -0.2208630521 / 2, 0.4417261043 / 2, 0, 57 | 0.8834522086 / 2, -1.104315261 / 2, 1.325178313 / 2 58 | ]))); 59 | expectNoLeakedTensors(() => constraint.apply(initVals), 1); 60 | }); 61 | 62 | // Lower camel case. 63 | it('nonNeg', () => { 64 | const constraint = getConstraint('nonNeg'); 65 | const postConstraint = constraint.apply(initVals); 66 | expectTensorsClose( 67 | postConstraint, tensor1d(new Float32Array([0, 2, 0, 4, 0, 6]))); 68 | }); 69 | 70 | it('maxNorm', () => { 71 | const constraint = getConstraint('maxNorm'); 72 | const postConstraint = constraint.apply(initVals); 73 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 74 | -0.2208630521, 0.4417261043, 0, 0.8834522086, 75 | -1.104315261, 1.325178313 76 | ]))); 77 | }); 78 | it('unitNorm', () => { 79 | const constraint = getConstraint('unitNorm'); 80 | const postConstraint = constraint.apply(initVals); 81 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 82 | -0.2208630521 / 2, 0.4417261043 / 2, 0, 83 | 0.8834522086 / 2, -1.104315261 / 2, 1.325178313 / 2 84 | ]))); 85 | }); 86 | it('minMaxNorm', () => { 87 | const constraint = getConstraint('minMaxNorm'); 88 | const postConstraint = constraint.apply(initVals); 89 | expectTensorsClose(postConstraint, tensor1d(new Float32Array([ 90 | -0.2208630521 / 2, 0.4417261043 / 2, 0, 91 | 0.8834522086 / 2, -1.104315261 / 2, 1.325178313 / 2 92 | ]))); 93 | }); 94 | }); 95 | 96 | describeMathCPU('constraints.get', () => { 97 | it('by string', () => { 98 | const constraint = getConstraint('maxNorm'); 99 | const config = serializeConstraint(constraint) as serialization.ConfigDict; 100 | const nestedConfig = config.config as serialization.ConfigDict; 101 | expect(nestedConfig.maxValue).toEqual(2); 102 | expect(nestedConfig.axis).toEqual(0); 103 | }); 104 | 105 | it('by string, upper case', () => { 106 | const constraint = getConstraint('maxNorm'); 107 | const config = serializeConstraint(constraint) as serialization.ConfigDict; 108 | const nestedConfig = config.config as serialization.ConfigDict; 109 | expect(nestedConfig.maxValue).toEqual(2); 110 | expect(nestedConfig.axis).toEqual(0); 111 | }); 112 | 113 | it('by existing object', () => { 114 | const origConstraint = tfl.constraints.nonNeg(); 115 | expect(getConstraint(origConstraint)).toEqual(origConstraint); 116 | }); 117 | it('by config dict', () => { 118 | const origConstraint = tfl.constraints.minMaxNorm( 119 | {minValue: 0, maxValue: 2, rate: 3, axis: 4}); 120 | const constraint = getConstraint( 121 | serializeConstraint(origConstraint) as serialization.ConfigDict); 122 | expect(serializeConstraint(constraint)) 123 | .toEqual(serializeConstraint(origConstraint)); 124 | }); 125 | }); 126 | 127 | describe('Constraints Serialization', () => { 128 | it('Built-ins', () => { 129 | // Test both types of captialization. 130 | const constraints: ConstraintIdentifier[] = [ 131 | 'maxNorm', 'nonNeg', 'unitNorm', 'minMaxNorm', 'MaxNorm', 'NonNeg', 132 | 'UnitNorm', 'MinMaxNorm' 133 | ]; 134 | for (const name of constraints) { 135 | const constraint = getConstraint(name); 136 | const config = 137 | serializeConstraint(constraint) as serialization.ConfigDict; 138 | const reconstituted = deserializeConstraint(config); 139 | expect(reconstituted).toEqual(constraint); 140 | } 141 | }); 142 | }); 143 | -------------------------------------------------------------------------------- /tfjs-layers/src/utils/serialization_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Unit Tests for serialization_utils 13 | * Porting Note: serialization_utils is a tfjs-layers only file, not found 14 | * int the original PyKeras. 15 | */ 16 | import {serialization} from '@tensorflow/tfjs-core'; 17 | 18 | import {PyJsonValue} from '../keras_format/types'; 19 | 20 | import {convertPythonicToTs, convertTsToPythonic} from './serialization_utils'; 21 | 22 | describe('convertPythonToTs', () => { 23 | it('primitives', () => { 24 | expect(convertPythonicToTs(null)).toEqual(null); 25 | expect(convertPythonicToTs(true)).toEqual(true); 26 | expect(convertPythonicToTs(false)).toEqual(false); 27 | expect(convertPythonicToTs(4)).toEqual(4); 28 | }); 29 | it('strings w/o name tags', () => { 30 | expect(convertPythonicToTs('abc')).toEqual('abc'); 31 | expect(convertPythonicToTs('ABC')).toEqual('ABC'); 32 | expect(convertPythonicToTs('one_two')).toEqual('oneTwo'); 33 | expect(convertPythonicToTs('OneTwo')).toEqual('OneTwo'); 34 | }); 35 | it('simple arrays', () => { 36 | expect(convertPythonicToTs([])).toEqual([]); 37 | expect(convertPythonicToTs([null])).toEqual([null]); 38 | expect(convertPythonicToTs(['one_two'])).toEqual(['oneTwo']); 39 | expect(convertPythonicToTs([null, true, false, 4, 'abc'])).toEqual([ 40 | null, true, false, 4, 'abc' 41 | ]); 42 | expect(convertPythonicToTs([[[]]])).toEqual([[[]]]); 43 | }); 44 | // We have to special case strings that are layer_names to not be converted 45 | it('layer tuple (array) with key', () => { 46 | for (const key of ['inboundNodes', 'inputLayers', 'outputLayers']) { 47 | expect(convertPythonicToTs(['layer_name', 'meta_data', 0], key)).toEqual([ 48 | 'layer_name', 'metaData', 0 49 | ]); 50 | } 51 | }); 52 | it('dictionary', () => { 53 | expect(convertPythonicToTs({})).toEqual({}); 54 | expect(convertPythonicToTs({key: null})).toEqual({key: null}); 55 | expect(convertPythonicToTs({key_one: 4})).toEqual({keyOne: 4}); 56 | expect(convertPythonicToTs({key_two: 'abc_def'})).toEqual({ 57 | keyTwo: 'abcDef' 58 | }); 59 | expect(convertPythonicToTs({key_one: true, key_two: false})) 60 | .toEqual({keyOne: true, keyTwo: false}); 61 | // values keyed by 'name' are special and don't get converted. 62 | expect(convertPythonicToTs({name: 'layer_name'})).toEqual({ 63 | name: 'layer_name' 64 | }); 65 | }); 66 | it('dictionary keys are passed down the stack', () => { 67 | const dict: PyJsonValue = {inbound_nodes: ['DoNotChange_Me', 0, null]}; 68 | expect(convertPythonicToTs(dict)).toEqual({ 69 | inboundNodes: ['DoNotChange_Me', 0, null] 70 | }); 71 | }); 72 | // We promote certan fields to enums 73 | it('enum promotion', () => { 74 | expect(convertPythonicToTs({mode: 'fan_out'})).toEqual({mode: 'fanOut'}); 75 | expect(convertPythonicToTs({distribution: 'normal'})).toEqual({ 76 | distribution: 'normal' 77 | }); 78 | expect(convertPythonicToTs({data_format: 'channels_last'})).toEqual({ 79 | dataFormat: 'channelsLast' 80 | }); 81 | expect(convertPythonicToTs({padding: 'valid'})).toEqual({padding: 'valid'}); 82 | }); 83 | }); 84 | 85 | 86 | describe('convertTsToPythonic', () => { 87 | it('primitives', () => { 88 | expect(convertTsToPythonic(null)).toEqual(null); 89 | expect(convertTsToPythonic(true)).toEqual(true); 90 | expect(convertTsToPythonic(false)).toEqual(false); 91 | expect(convertTsToPythonic(4)).toEqual(4); 92 | }); 93 | it('strings w/o name tags', () => { 94 | expect(convertTsToPythonic('abc')).toEqual('abc'); 95 | expect(convertTsToPythonic('ABC')).toEqual('abc'); 96 | expect(convertTsToPythonic('oneTwo')).toEqual('one_two'); 97 | expect(convertTsToPythonic('OneTwo')).toEqual('one_two'); 98 | }); 99 | it('simple arrays', () => { 100 | expect(convertTsToPythonic([])).toEqual([]); 101 | expect(convertTsToPythonic([null])).toEqual([null]); 102 | expect(convertTsToPythonic(['oneTwo'])).toEqual(['one_two']); 103 | expect(convertTsToPythonic([null, true, false, 4, 'abc'])).toEqual([ 104 | null, true, false, 4, 'abc' 105 | ]); 106 | expect(convertTsToPythonic([[[]]])).toEqual([[[]]]); 107 | }); 108 | // We have to special case strings that are layer_names to not be converted 109 | it('layer tuple (array) with key', () => { 110 | for (const key of ['inboundNodes', 'inputLayers', 'outputLayers']) { 111 | expect(convertTsToPythonic(['layerName', 'metaData', 0], key)).toEqual([ 112 | 'layerName', 'meta_data', 0 113 | ]); 114 | } 115 | }); 116 | it('dictionary', () => { 117 | expect(convertTsToPythonic({})).toEqual({}); 118 | expect(convertTsToPythonic({key: null})).toEqual({key: null}); 119 | expect(convertTsToPythonic({keyOne: 4})).toEqual({key_one: 4}); 120 | expect(convertTsToPythonic({keyTwo: 'abcDef'})).toEqual({ 121 | key_two: 'abc_def' 122 | }); 123 | expect(convertTsToPythonic({keyOne: true, keyTwo: false})) 124 | .toEqual({key_one: true, key_two: false}); 125 | // values keyed by 'name' are special and don't get converted. 126 | expect(convertTsToPythonic({name: 'layerName'})).toEqual({ 127 | name: 'layerName' 128 | }); 129 | }); 130 | it('dictionary keys are passed down the stack', () => { 131 | const dict: serialization.ConfigDictValue = { 132 | inboundNodes: ['DoNotChange_Me', 0, null] 133 | }; 134 | expect(convertTsToPythonic(dict)).toEqual({ 135 | inbound_nodes: ['DoNotChange_Me', 0, null] 136 | }); 137 | }); 138 | // We need to stringify our enums 139 | it('enum promotion', () => { 140 | expect(convertTsToPythonic({mode: 'fanOut'})).toEqual({mode: 'fan_out'}); 141 | expect(convertTsToPythonic({distribution: 'normal'})).toEqual({ 142 | distribution: 'normal' 143 | }); 144 | expect(convertTsToPythonic({dataFormat: 'channelsLast'})).toEqual({ 145 | data_format: 'channels_last' 146 | }); 147 | expect(convertTsToPythonic({padding: 'valid'})).toEqual({padding: 'valid'}); 148 | }); 149 | }); 150 | -------------------------------------------------------------------------------- /tfjs-layers/src/metrics.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | /** 12 | * Built-in metrics. 13 | */ 14 | 15 | import * as tfc from '@tensorflow/tfjs-core'; 16 | import {Tensor, tidy} from '@tensorflow/tfjs-core'; 17 | 18 | import * as K from './backend/tfjs_backend'; 19 | import {NotImplementedError, ValueError} from './errors'; 20 | import {categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses'; 21 | import {binaryCrossentropy as lossBinaryCrossentropy} from './losses'; 22 | import {lossesMap} from './losses'; 23 | import {LossOrMetricFn} from './types'; 24 | import * as util from './utils/generic_utils'; 25 | 26 | export function binaryAccuracy(yTrue: Tensor, yPred: Tensor): Tensor { 27 | return tidy(() => { 28 | const threshold = tfc.mul(.5, tfc.onesLike(yPred)); 29 | const yPredThresholded = K.cast(tfc.greater(yPred, threshold), yTrue.dtype); 30 | return tfc.mean(tfc.equal(yTrue, yPredThresholded), -1); 31 | }); 32 | } 33 | 34 | export function categoricalAccuracy(yTrue: Tensor, yPred: Tensor): Tensor { 35 | return tidy( 36 | () => K.cast( 37 | tfc.equal(tfc.argMax(yTrue, -1), tfc.argMax(yPred, -1)), 'float32')); 38 | } 39 | 40 | function truePositives(yTrue: Tensor, yPred: Tensor): Tensor { 41 | return tidy(() => { 42 | return tfc.logicalAnd(yTrue.equal(1), yPred.equal(1)).sum().cast('float32'); 43 | }); 44 | } 45 | 46 | function falseNegatives(yTrue: Tensor, yPred: Tensor): Tensor { 47 | return tidy(() => { 48 | return tfc.logicalAnd(yTrue.equal(1), yPred.equal(0)).sum().cast('float32'); 49 | }); 50 | } 51 | 52 | function falsePositives(yTrue: Tensor, yPred: Tensor): Tensor { 53 | return tidy(() => { 54 | return tfc.logicalAnd(yTrue.equal(0), yPred.equal(1)).sum().cast('float32'); 55 | }); 56 | } 57 | 58 | export function precision(yTrue: Tensor, yPred: Tensor): Tensor { 59 | return tidy(() => { 60 | const tp = truePositives(yTrue, yPred); 61 | const fp = falsePositives(yTrue, yPred); 62 | 63 | const denominator = tp.add(fp); 64 | 65 | return tfc.where(tfc.greater(denominator, 0), tp.div(denominator), 0) 66 | .cast('float32'); 67 | }); 68 | } 69 | 70 | export function recall(yTrue: Tensor, yPred: Tensor): Tensor { 71 | return tidy(() => { 72 | const tp = truePositives(yTrue, yPred); 73 | const fn = falseNegatives(yTrue, yPred); 74 | 75 | const denominator = tp.add(fn); 76 | 77 | return tfc.where(tfc.greater(denominator, 0), tp.div(denominator), 0) 78 | .cast('float32'); 79 | }); 80 | } 81 | 82 | export function binaryCrossentropy(yTrue: Tensor, yPred: Tensor): Tensor { 83 | return lossBinaryCrossentropy(yTrue, yPred); 84 | } 85 | 86 | export function sparseCategoricalAccuracy( 87 | yTrue: Tensor, yPred: Tensor): Tensor { 88 | if (yTrue.rank === yPred.rank) { 89 | yTrue = yTrue.squeeze([yTrue.rank - 1]); 90 | } 91 | yPred = yPred.argMax(-1); 92 | if (yPred.dtype !== yTrue.dtype) { 93 | yPred = yPred.asType(yTrue.dtype); 94 | } 95 | return tfc.equal(yTrue, yPred).asType('float32'); 96 | } 97 | 98 | export function topKCategoricalAccuracy(yTrue: Tensor, yPred: Tensor): Tensor { 99 | throw new NotImplementedError(); 100 | } 101 | 102 | export function sparseTopKCategoricalAccuracy( 103 | yTrue: Tensor, yPred: Tensor): Tensor { 104 | throw new NotImplementedError(); 105 | } 106 | 107 | // Aliases. 108 | export const mse = meanSquaredError; 109 | export const MSE = meanSquaredError; 110 | export const mae = meanAbsoluteError; 111 | export const MAE = meanAbsoluteError; 112 | export const mape = meanAbsolutePercentageError; 113 | export const MAPE = meanAbsolutePercentageError; 114 | export const categoricalCrossentropy = categoricalCrossentropyLoss; 115 | export const cosine = cosineProximity; 116 | export const sparseCategoricalCrossentropy = sparseCategoricalCrossentropyLoss; 117 | 118 | // TODO(cais, nielsene): Add serialize(). 119 | 120 | export const metricsMap: {[functionName: string]: LossOrMetricFn} = { 121 | binaryAccuracy, 122 | categoricalAccuracy, 123 | precision, 124 | categoricalCrossentropy, 125 | sparseCategoricalCrossentropy, 126 | mse, 127 | MSE, 128 | mae, 129 | MAE, 130 | mape, 131 | MAPE, 132 | cosine 133 | }; 134 | 135 | export function get(identifier: string|LossOrMetricFn): LossOrMetricFn { 136 | if (typeof identifier === 'string' && identifier in metricsMap) { 137 | return metricsMap[identifier]; 138 | } else if (typeof identifier !== 'string' && identifier != null) { 139 | return identifier; 140 | } else { 141 | throw new ValueError(`Unknown metric ${identifier}`); 142 | } 143 | } 144 | 145 | /** 146 | * Get the shortcut function name. 147 | * 148 | * If the fn name is a string, 149 | * directly return the string name. 150 | * If the function is included in metricsMap or lossesMap, 151 | * return key of the map. 152 | * - If the function relative to multiple keys, 153 | * return the first found key as the function name. 154 | * - If the function exists in both lossesMap and metricsMap, 155 | * search lossesMap first. 156 | * If the function is not included in metricsMap or lossesMap, 157 | * return the function name. 158 | * 159 | * @param fn loss function, metric function, or short cut name. 160 | * @returns Loss or Metric name in string. 161 | */ 162 | export function getLossOrMetricName(fn: string|LossOrMetricFn): string { 163 | util.assert(fn !== null, `Unknown LossOrMetricFn ${fn}`); 164 | if (typeof fn === 'string') { 165 | return fn; 166 | } else { 167 | let fnName; 168 | for (const key of Object.keys(lossesMap)) { 169 | if (lossesMap[key] === fn) { 170 | fnName = key; 171 | break; 172 | } 173 | } 174 | if (fnName !== undefined) { 175 | return fnName; 176 | } 177 | for (const key of Object.keys(metricsMap)) { 178 | if (metricsMap[key] === fn) { 179 | fnName = key; 180 | break; 181 | } 182 | } 183 | if (fnName !== undefined) { 184 | return fnName; 185 | } 186 | return (fn as Function).name; 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /tfjs-layers/src/engine/training_utils_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC 4 | * 5 | * Use of this source code is governed by an MIT-style 6 | * license that can be found in the LICENSE file or at 7 | * https://opensource.org/licenses/MIT. 8 | * ============================================================================= 9 | */ 10 | 11 | import {memory, tensor1d, tensor2d} from '@tensorflow/tfjs-core'; 12 | 13 | import {describeMathCPU, expectTensorsClose} from '../utils/test_utils'; 14 | 15 | import {ClassWeight, ClassWeightMap, standardizeClassWeights, standardizeWeights} from './training_utils'; 16 | 17 | describeMathCPU('standardizeWeights', () => { 18 | it('classWeights with 1D class-index target', async () => { 19 | const y = tensor1d([0, 1, 2, 1, 0]); 20 | const classWeight: ClassWeight = {0: 10, 1: 1, 2: 0.1}; 21 | const numTensors0 = memory().numTensors; 22 | const classSampleWeight = await standardizeWeights(y, null, classWeight); 23 | // Assert no memory leak. The extra tensor is `classSampleWeight` itself. 24 | expect(memory().numTensors).toEqual(numTensors0 + 1); 25 | expectTensorsClose(classSampleWeight, tensor1d([10, 1, 0.1, 1, 10])); 26 | expect(y.isDisposed).toEqual(false); 27 | }); 28 | 29 | it('classWeights with 2D class-index target', async () => { 30 | const y = tensor2d([[3], [2], [0]]); 31 | const classWeight: ClassWeight = {0: 10, 1: 1, 2: 0.1, 3: 0.01}; 32 | const numTensors0 = memory().numTensors; 33 | const classSampleWeight = await standardizeWeights(y, null, classWeight); 34 | // Assert no memory leak. The extra tensor is `classSampleWeight` itself. 35 | expect(memory().numTensors).toEqual(numTensors0 + 1); 36 | expectTensorsClose(classSampleWeight, tensor1d([0.01, 0.1, 10])); 37 | expect(y.isDisposed).toEqual(false); 38 | }); 39 | 40 | it('classWeights with 2D one-hot target', async () => { 41 | const y = tensor2d([[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]]); 42 | const classWeight: ClassWeight = {0: 10, 1: 1, 2: 0.1, 3: 0.01}; 43 | const numTensors0 = memory().numTensors; 44 | const classSampleWeight = await standardizeWeights(y, null, classWeight); 45 | // Assert no memory leak. The extra tensor is `classSampleWeight` itself. 46 | expect(memory().numTensors).toEqual(numTensors0 + 1); 47 | expectTensorsClose(classSampleWeight, tensor1d([0.01, 0.1, 10])); 48 | expect(y.isDisposed).toEqual(false); 49 | }); 50 | 51 | it('classWeights with 1D class-index target: Missing class', async () => { 52 | const y = tensor1d([0, 1, 2, 3, 2, 1, 0]); 53 | const classWeight: ClassWeight = {0: 10, 1: 1, 2: 0.1}; 54 | 55 | let caughtError: Error; 56 | try { 57 | await standardizeWeights(y, null, classWeight); 58 | } catch (error) { 59 | caughtError = error; 60 | } 61 | expect(caughtError.message) 62 | .toMatch(/classWeight must contain all classes.* class 3 .*/); 63 | }); 64 | 65 | it('classWeights with 2D class-index target: Missing class', async () => { 66 | const y = tensor2d([[3], [2], [0], [4]]); 67 | const classWeight: ClassWeight = {0: 10, 1: 1, 2: 0.1, 3: 0.01}; 68 | 69 | let caughtError: Error; 70 | try { 71 | await standardizeWeights(y, null, classWeight); 72 | } catch (error) { 73 | caughtError = error; 74 | } 75 | expect(caughtError.message) 76 | .toMatch(/classWeight must contain all classes.* class 4 .*/); 77 | }); 78 | 79 | 80 | it('classWeights with 2D one-hot target: missing weight', async () => { 81 | const y = tensor2d([[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]]); 82 | const classWeight: ClassWeight = {0: 10, 1: 1, 3: 0.01}; 83 | 84 | let caughtError: Error; 85 | try { 86 | await standardizeWeights(y, null, classWeight); 87 | } catch (error) { 88 | caughtError = error; 89 | } 90 | expect(caughtError.message) 91 | .toMatch(/classWeight must contain all classes.* class 2 .*/); 92 | }); 93 | }); 94 | 95 | describe('standardizeClassWeights', () => { 96 | it('One output, ClassWeight singleton', () => { 97 | const outputNames = ['output1']; 98 | const classWeight: ClassWeight = {0: 1, 1: 2}; 99 | const output = standardizeClassWeights(classWeight, outputNames); 100 | expect(output).toEqual([{0: 1, 1: 2}]); 101 | }); 102 | 103 | it('One output, ClassWeight array', () => { 104 | const outputNames = ['output1']; 105 | const classWeight: ClassWeight[] = [{0: 1, 1: 2}]; 106 | const output = standardizeClassWeights(classWeight, outputNames); 107 | expect(output).toEqual([{0: 1, 1: 2}]); 108 | }); 109 | 110 | it('One output, ClassWeight dict', () => { 111 | const outputNames = ['output1']; 112 | const classWeight: ClassWeightMap = {'output1': {0: 1, 1: 2}}; 113 | const output = standardizeClassWeights(classWeight, outputNames); 114 | expect(output).toEqual([{0: 1, 1: 2}]); 115 | }); 116 | 117 | it('Two outputs, ClassWeight array', () => { 118 | const outputNames = ['output1', 'output2']; 119 | const classWeight: ClassWeight[] = [{0: 1, 1: 2}, {0: 10, 1: 20}]; 120 | const output = standardizeClassWeights(classWeight, outputNames); 121 | expect(output).toEqual([{0: 1, 1: 2}, {0: 10, 1: 20}]); 122 | }); 123 | 124 | it('Two outputs, ClassWeight dict', () => { 125 | const outputNames = ['output1', 'output2']; 126 | const classWeight: 127 | ClassWeightMap = {'output2': {0: 10, 1: 20}, 'output1': {0: 1, 1: 2}}; 128 | const output = standardizeClassWeights(classWeight, outputNames); 129 | expect(output).toEqual([{0: 1, 1: 2}, {0: 10, 1: 20}]); 130 | }); 131 | 132 | it('Two outputs, ClassWeight singleton leads to Error', () => { 133 | const outputNames = ['output1', 'output2']; 134 | const classWeight: ClassWeight = {0: 10, 1: 20}; 135 | expect(() => standardizeClassWeights(classWeight, outputNames)) 136 | .toThrowError(/.*has multiple \(2\) outputs.*/); 137 | }); 138 | 139 | it('Three outputs, ClassWeight array missing element', () => { 140 | const outputNames = ['output1', 'output2', 'output3']; 141 | const classWeight: ClassWeight[] = [{0: 1, 1: 2}, {0: 10, 1: 20}]; 142 | expect(() => standardizeClassWeights(classWeight, outputNames)) 143 | .toThrowError( 144 | /.*classWeight is an array of 2 element.* model has 3 outputs/); 145 | }); 146 | 147 | it('Three outputs, ClassWeight dict missing element is okay', () => { 148 | const outputNames = ['output1', 'output2', 'output3']; 149 | const classWeight: 150 | ClassWeightMap = {'output1': {0: 1, 1: 2}, 'output3': {0: 10, 1: 20}}; 151 | const output = standardizeClassWeights(classWeight, outputNames); 152 | expect(output).toEqual([{0: 1, 1: 2}, null, {0: 10, 1: 20}]); 153 | }); 154 | }); 155 | --------------------------------------------------------------------------------