├── .gitignore
├── CONTRIBUTING.md
├── DEVELOPMENT.md
├── ISSUE_TEMPLATE.md
├── LICENSE
├── README.md
├── WORKSPACE
├── cloudbuild.yml
├── package.json
├── pull_request_template.md
├── scripts
├── diff.js
├── run-build.sh
└── test-util.js
├── tfjs-backend-nodegl
├── .vscode
│ └── settings.json
├── README.md
├── demo
│ ├── README.md
│ ├── dog.jpg
│ ├── package.json
│ ├── run_mobilenet_inference.js
│ └── yarn.lock
├── package.json
├── src
│ ├── index.ts
│ ├── run_tests.ts
│ └── version.ts
├── tasks.json
├── tsconfig.json
├── tslint.json
└── yarn.lock
├── tfjs-backend-wasm
├── .npmrc
├── .vscode
│ ├── c_cpp_properties.json
│ └── settings.json
├── README.md
├── karma.conf.js
├── package.json
├── rollup.config.js
├── scripts
│ ├── build-npm.sh
│ └── build-wasm.sh
├── src
│ ├── backend_wasm.ts
│ ├── index.ts
│ ├── index_test.ts
│ ├── kernels.cc
│ ├── kernels.h
│ ├── lib.cc
│ ├── setup_test.ts
│ └── util.h
├── tsconfig.json
├── tslint.json
├── wasm-out
│ └── tfjs-backend-wasm.d.ts
└── yarn.lock
├── tfjs-core
├── .bazelignore
├── .bazelrc
├── .npmignore
├── .npmrc
├── .vscode
│ ├── c_cpp_properties.json
│ ├── launch.json
│ ├── settings.json
│ └── tasks.json
├── BUILD.bazel
├── README.md
├── benchmarks
│ └── index.html
├── cloudbuild.yml
├── karma.conf.js
├── package.json
├── rollup.config.js
├── scripts
│ ├── build-npm.sh
│ ├── cloud_funcs
│ │ ├── README.md
│ │ ├── send_email
│ │ │ ├── .gcloudignore
│ │ │ ├── config.json
│ │ │ ├── index.js
│ │ │ ├── package-lock.json
│ │ │ └── package.json
│ │ └── trigger_nightly
│ │ │ ├── .gcloudignore
│ │ │ ├── index.js
│ │ │ ├── package-lock.json
│ │ │ └── package.json
│ ├── enumerate-tests.js
│ ├── make-version
│ ├── publish-npm.sh
│ ├── tag-version.js
│ ├── test-bundle-size.js
│ ├── test-ci.sh
│ ├── test-integration.js
│ ├── test-integration.sh
│ └── test_snippets
│ │ ├── test_snippets.ts
│ │ ├── tsconfig.json
│ │ └── util.ts
├── src
│ ├── BUILD.bazel
│ ├── backends
│ │ ├── backend.ts
│ │ ├── backend_test.ts
│ │ ├── backend_util.ts
│ │ ├── complex_util.ts
│ │ ├── complex_util_test.ts
│ │ ├── cpu
│ │ │ ├── backend_cpu.ts
│ │ │ ├── backend_cpu_test.ts
│ │ │ └── backend_cpu_test_registry.ts
│ │ ├── non_max_suppression_impl.ts
│ │ ├── packing_util.ts
│ │ ├── split_shared.ts
│ │ ├── tile_impl.ts
│ │ ├── topk_impl.ts
│ │ ├── webgl
│ │ │ ├── addn_gpu.ts
│ │ │ ├── addn_packed_gpu.ts
│ │ │ ├── argminmax_gpu.ts
│ │ │ ├── argminmax_packed_gpu.ts
│ │ │ ├── avg_pool_backprop_gpu.ts
│ │ │ ├── backend_webgl.ts
│ │ │ ├── backend_webgl_test.ts
│ │ │ ├── backend_webgl_test_registry.ts
│ │ │ ├── batchnorm_gpu.ts
│ │ │ ├── batchnorm_packed_gpu.ts
│ │ │ ├── binaryop_complex_gpu.ts
│ │ │ ├── binaryop_gpu.ts
│ │ │ ├── binaryop_packed_gpu.ts
│ │ │ ├── canvas_util.ts
│ │ │ ├── canvas_util_test.ts
│ │ │ ├── clip_gpu.ts
│ │ │ ├── clip_packed_gpu.ts
│ │ │ ├── complex_abs_gpu.ts
│ │ │ ├── concat_gpu.ts
│ │ │ ├── concat_packed_gpu.ts
│ │ │ ├── conv_backprop_gpu.ts
│ │ │ ├── conv_backprop_gpu_depthwise.ts
│ │ │ ├── conv_gpu.ts
│ │ │ ├── conv_gpu_depthwise.ts
│ │ │ ├── conv_packed_gpu_depthwise.ts
│ │ │ ├── crop_and_resize_gpu.ts
│ │ │ ├── cumsum_gpu.ts
│ │ │ ├── decode_matrix_gpu.ts
│ │ │ ├── decode_matrix_packed_gpu.ts
│ │ │ ├── depth_to_space_gpu.ts
│ │ │ ├── diag_gpu.ts
│ │ │ ├── encode_float_gpu.ts
│ │ │ ├── encode_float_packed_gpu.ts
│ │ │ ├── encode_matrix_gpu.ts
│ │ │ ├── encode_matrix_packed_gpu.ts
│ │ │ ├── fft_gpu.ts
│ │ │ ├── fill_gpu.ts
│ │ │ ├── flags_webgl.ts
│ │ │ ├── flags_webgl_test.ts
│ │ │ ├── from_pixels_gpu.ts
│ │ │ ├── from_pixels_packed_gpu.ts
│ │ │ ├── gather_gpu.ts
│ │ │ ├── gather_nd_gpu.ts
│ │ │ ├── glsl_version.ts
│ │ │ ├── gpgpu_context.ts
│ │ │ ├── gpgpu_context_test.ts
│ │ │ ├── gpgpu_math.ts
│ │ │ ├── gpgpu_util.ts
│ │ │ ├── gpgpu_util_test.ts
│ │ │ ├── im2col_packed_gpu.ts
│ │ │ ├── lrn_gpu.ts
│ │ │ ├── lrn_grad_gpu.ts
│ │ │ ├── lrn_packed_gpu.ts
│ │ │ ├── max_pool_backprop_gpu.ts
│ │ │ ├── mulmat_packed_gpu.ts
│ │ │ ├── multinomial_gpu.ts
│ │ │ ├── onehot_gpu.ts
│ │ │ ├── pack_gpu.ts
│ │ │ ├── pad_gpu.ts
│ │ │ ├── pad_packed_gpu.ts
│ │ │ ├── pool_gpu.ts
│ │ │ ├── reduce_gpu.ts
│ │ │ ├── reshape_packed_gpu.ts
│ │ │ ├── reshape_packed_test.ts
│ │ │ ├── resize_bilinear_backprop_gpu.ts
│ │ │ ├── resize_bilinear_gpu.ts
│ │ │ ├── resize_bilinear_packed_gpu.ts
│ │ │ ├── resize_nearest_neighbor_backprop_gpu.ts
│ │ │ ├── resize_nearest_neighbor_gpu.ts
│ │ │ ├── reverse_gpu.ts
│ │ │ ├── reverse_packed_gpu.ts
│ │ │ ├── scatter_gpu.ts
│ │ │ ├── segment_gpu.ts
│ │ │ ├── select_gpu.ts
│ │ │ ├── shader_compiler.ts
│ │ │ ├── shader_compiler_util.ts
│ │ │ ├── shader_compiler_util_test.ts
│ │ │ ├── slice_gpu.ts
│ │ │ ├── slice_packed_gpu.ts
│ │ │ ├── strided_slice_gpu.ts
│ │ │ ├── tex_util.ts
│ │ │ ├── tex_util_test.ts
│ │ │ ├── texture_manager.ts
│ │ │ ├── tile_gpu.ts
│ │ │ ├── transpose_gpu.ts
│ │ │ ├── transpose_packed_gpu.ts
│ │ │ ├── unaryop_gpu.ts
│ │ │ ├── unaryop_packed_gpu.ts
│ │ │ ├── unpack_gpu.ts
│ │ │ ├── webgl_batchnorm_test.ts
│ │ │ ├── webgl_custom_op_test.ts
│ │ │ ├── webgl_ops_test.ts
│ │ │ ├── webgl_types.ts
│ │ │ ├── webgl_util.ts
│ │ │ └── webgl_util_test.ts
│ │ └── where_impl.ts
│ ├── browser_util.ts
│ ├── browser_util_test.ts
│ ├── buffer_test.ts
│ ├── debug_mode_test.ts
│ ├── device_util.ts
│ ├── engine.ts
│ ├── engine_test.ts
│ ├── environment.ts
│ ├── environment_test.ts
│ ├── flags.ts
│ ├── flags_test.ts
│ ├── globals.ts
│ ├── globals_test.ts
│ ├── gradients.ts
│ ├── gradients_test.ts
│ ├── index.ts
│ ├── io
│ │ ├── browser_files.ts
│ │ ├── browser_files_test.ts
│ │ ├── http.ts
│ │ ├── http_test.ts
│ │ ├── indexed_db.ts
│ │ ├── indexed_db_test.ts
│ │ ├── io.ts
│ │ ├── io_utils.ts
│ │ ├── io_utils_test.ts
│ │ ├── local_storage.ts
│ │ ├── local_storage_test.ts
│ │ ├── model_management.ts
│ │ ├── model_management_test.ts
│ │ ├── passthrough.ts
│ │ ├── passthrough_test.ts
│ │ ├── progress.ts
│ │ ├── progress_test.ts
│ │ ├── router_registry.ts
│ │ ├── router_registry_test.ts
│ │ ├── types.ts
│ │ ├── weights_loader.ts
│ │ └── weights_loader_test.ts
│ ├── jasmine_util.ts
│ ├── jasmine_util_test.ts
│ ├── log.ts
│ ├── math.ts
│ ├── model_types.ts
│ ├── ops
│ │ ├── arithmetic_test.ts
│ │ ├── array_ops.ts
│ │ ├── array_ops_test.ts
│ │ ├── array_ops_util.ts
│ │ ├── axis_util.ts
│ │ ├── axis_util_test.ts
│ │ ├── batchnorm.ts
│ │ ├── batchnorm_test.ts
│ │ ├── binary_ops.ts
│ │ ├── binary_ops_test.ts
│ │ ├── boolean_mask.ts
│ │ ├── boolean_mask_test.ts
│ │ ├── broadcast_util.ts
│ │ ├── broadcast_util_test.ts
│ │ ├── browser.ts
│ │ ├── clone_test.ts
│ │ ├── compare.ts
│ │ ├── compare_ops_test.ts
│ │ ├── complex_ops.ts
│ │ ├── complex_ops_test.ts
│ │ ├── concat_split.ts
│ │ ├── concat_test.ts
│ │ ├── concat_util.ts
│ │ ├── concat_util_test.ts
│ │ ├── confusion_matrix.ts
│ │ ├── confusion_matrix_test.ts
│ │ ├── conv.ts
│ │ ├── conv1d_test.ts
│ │ ├── conv2d_depthwise_test.ts
│ │ ├── conv2d_separable_test.ts
│ │ ├── conv2d_test.ts
│ │ ├── conv2d_transpose_test.ts
│ │ ├── conv3d_test.ts
│ │ ├── conv3d_transpose_test.ts
│ │ ├── conv_util.ts
│ │ ├── conv_util_test.ts
│ │ ├── diag.ts
│ │ ├── diag_test.ts
│ │ ├── dropout.ts
│ │ ├── dropout_test.ts
│ │ ├── dropout_util.ts
│ │ ├── dropout_util_test.ts
│ │ ├── erf_util.ts
│ │ ├── fused_ops.ts
│ │ ├── fused_test.ts
│ │ ├── fused_util.ts
│ │ ├── gather_nd.ts
│ │ ├── gather_nd_test.ts
│ │ ├── gather_nd_util.ts
│ │ ├── image_ops.ts
│ │ ├── image_ops_test.ts
│ │ ├── in_top_k.ts
│ │ ├── in_top_k_test.ts
│ │ ├── linalg_ops.ts
│ │ ├── linalg_ops_test.ts
│ │ ├── logical_ops.ts
│ │ ├── logical_ops_test.ts
│ │ ├── loss_ops.ts
│ │ ├── loss_ops_test.ts
│ │ ├── lrn.ts
│ │ ├── lrn_test.ts
│ │ ├── lstm.ts
│ │ ├── lstm_test.ts
│ │ ├── matmul.ts
│ │ ├── matmul_test.ts
│ │ ├── moving_average.ts
│ │ ├── moving_average_test.ts
│ │ ├── multinomial_test.ts
│ │ ├── norm.ts
│ │ ├── operation.ts
│ │ ├── operation_test.ts
│ │ ├── ops.ts
│ │ ├── pad_test.ts
│ │ ├── pool.ts
│ │ ├── pool_test.ts
│ │ ├── rand.ts
│ │ ├── rand_test.ts
│ │ ├── rand_util.ts
│ │ ├── reduce_util.ts
│ │ ├── reduction_ops.ts
│ │ ├── reduction_ops_test.ts
│ │ ├── relu_ops.ts
│ │ ├── resize_bilinear_test.ts
│ │ ├── resize_nearest_neighbor_test.ts
│ │ ├── reverse.ts
│ │ ├── reverse_test.ts
│ │ ├── scatter_nd.ts
│ │ ├── scatter_nd_test.ts
│ │ ├── scatter_nd_util.ts
│ │ ├── segment_ops.ts
│ │ ├── segment_ops_test.ts
│ │ ├── segment_util.ts
│ │ ├── selu_util.ts
│ │ ├── signal_ops.ts
│ │ ├── signal_ops_test.ts
│ │ ├── slice.ts
│ │ ├── slice_test.ts
│ │ ├── slice_util.ts
│ │ ├── slice_util_test.ts
│ │ ├── softmax.ts
│ │ ├── softmax_test.ts
│ │ ├── sparse_to_dense.ts
│ │ ├── sparse_to_dense_test.ts
│ │ ├── sparse_to_dense_util.ts
│ │ ├── spectral_ops.ts
│ │ ├── spectral_ops_test.ts
│ │ ├── strided_slice.ts
│ │ ├── strided_slice_test.ts
│ │ ├── tensor_ops.ts
│ │ ├── topk.ts
│ │ ├── topk_test.ts
│ │ ├── transpose.ts
│ │ ├── transpose_test.ts
│ │ ├── unary_ops.ts
│ │ └── unary_ops_test.ts
│ ├── optimizers
│ │ ├── adadelta_optimizer.ts
│ │ ├── adadelta_optimizer_test.ts
│ │ ├── adagrad_optimizer.ts
│ │ ├── adagrad_optimizer_test.ts
│ │ ├── adam_optimizer.ts
│ │ ├── adam_optimizer_test.ts
│ │ ├── adamax_optimizer.ts
│ │ ├── adamax_optimizer_test.ts
│ │ ├── momentum_optimizer.ts
│ │ ├── momentum_optimizer_test.ts
│ │ ├── optimizer.ts
│ │ ├── optimizer_constructors.ts
│ │ ├── optimizer_test.ts
│ │ ├── rmsprop_optimizer.ts
│ │ ├── rmsprop_optimizer_test.ts
│ │ ├── sgd_optimizer.ts
│ │ └── sgd_optimizer_test.ts
│ ├── platforms
│ │ ├── platform.ts
│ │ ├── platform_browser.ts
│ │ ├── platform_browser_test.ts
│ │ ├── platform_node.ts
│ │ └── platform_node_test.ts
│ ├── profiler.ts
│ ├── profiler_test.ts
│ ├── serialization.ts
│ ├── serialization_test.ts
│ ├── setup_test.ts
│ ├── tape.ts
│ ├── tape_test.ts
│ ├── tensor.ts
│ ├── tensor_format.ts
│ ├── tensor_test.ts
│ ├── tensor_types.ts
│ ├── tensor_util.ts
│ ├── tensor_util_env.ts
│ ├── tensor_util_test.ts
│ ├── test_async_backends.ts
│ ├── test_node.ts
│ ├── test_util.ts
│ ├── test_util_test.ts
│ ├── tests.ts
│ ├── train.ts
│ ├── types.ts
│ ├── types_test.ts
│ ├── util.ts
│ ├── util_test.ts
│ ├── variable_test.ts
│ ├── version.ts
│ ├── version_test.ts
│ ├── webgl.ts
│ ├── worker_node_test.ts
│ └── worker_test.ts
├── tsconfig.json
├── tslint.json
└── yarn.lock
├── tfjs-react-native
├── .npmignore
├── .vscode
│ └── settings.json
├── README.md
├── cloudbuild.yml
├── karma.conf.js
├── package.json
├── rollup.config.js
├── scripts
│ └── test-ci.sh
├── src
│ ├── async_storage_io.ts
│ ├── async_storage_io_test.ts
│ ├── bundle_resource_io.ts
│ ├── bundle_resource_io_test.ts
│ ├── index.ts
│ ├── platform_react_native.ts
│ ├── platform_react_native_test.ts
│ └── test_utils
│ │ ├── async_storage_mock.ts
│ │ ├── gl_view_mock.ts
│ │ └── react_native_mock.ts
├── tsconfig.json
├── tslint.json
└── yarn.lock
├── tfjs-webgpu
├── .npmignore
├── .vscode
│ └── settings.json
├── README.md
├── cloudbuild.yml
├── karma.conf.js
├── package.json
├── rollup.config.js
├── scripts
│ └── test-ci.sh
├── src
│ ├── backend_webgpu.ts
│ ├── backend_webgpu_test.ts
│ ├── benchmark_ops_test.ts
│ ├── buffer_manager.ts
│ ├── flags_webgpu.ts
│ ├── index.ts
│ ├── kernels
│ │ ├── argminmax_webgpu.ts
│ │ ├── binary_op_webgpu.ts
│ │ ├── concat_webgpu.ts
│ │ ├── conv2d_mm_webgpu.ts
│ │ ├── conv2d_naive_webgpu.ts
│ │ ├── from_pixels_webgpu.ts
│ │ ├── matmul_packed_webgpu.ts
│ │ ├── matmul_webgpu.ts
│ │ ├── maxpool_webgpu.ts
│ │ ├── pad_webgpu.ts
│ │ ├── resize_bilinear_webgpu.ts
│ │ ├── transpose_webgpu.ts
│ │ ├── unary_op_webgpu.ts
│ │ └── webgpu_program.ts
│ ├── matmul_test.ts
│ ├── setup_test.ts
│ ├── shader_preprocessor.ts
│ ├── shader_util.ts
│ ├── shader_util_test.ts
│ ├── test_util.ts
│ ├── webgpu_util.ts
│ └── webgpu_util_test.ts
├── tsconfig.json
├── tslint.json
└── yarn.lock
├── tfjs.code-workspace
├── tsconfig.json
├── tslint.json
└── yarn.lock
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules/
2 | coverage/
3 | npm-debug.log
4 | yarn-error.log
5 | local.log
6 | .DS_Store
7 | dist/
8 | bazel-out/
9 | .idea/
10 | *.tgz
11 | **/*.pyc
12 | .yalc/
13 | yalc.lock
14 | .rpt2_cache/
15 | package/
16 | */diff
17 |
18 | tfjs-backend-wasm/dist
19 | tfjs-backend-wasm/wasm-out
20 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Adding functionality
19 |
20 | One way to ensure that your PR will be accepted is to add functionality that
21 | has been requested in Github issues. If there is something you think is
22 | important and we're missing it but does not show up in Github issues, it would
23 | be good to file an issue there first so we can have the discussion before
24 | sending us a PR.
25 |
26 | In general, we're trying to add functionality when driven by use-cases instead of
27 | adding functionality for the sake of parity with TensorFlow python. This will
28 | help us keep the bundle size smaller and have less to maintain especially as we
29 | add new backends.
30 |
31 | ### Adding an op
32 |
33 | When adding ops to the library and deciding whether to write a kernel
34 | implementation in [backend.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/backends/backend.ts),
35 | be sure to check out the TensorFlow ops list [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/ops.pbtxt).
36 | This list shows the kernels available for the TensorFlow C API. To ensure that
37 | we can bind to this with node.js, we should ensure that our backend.ts
38 | interface matches ops in the TensorFlow C API.
39 |
40 | ## Code reviews
41 |
42 | All submissions, including submissions by project members, require review. We
43 | use GitHub pull requests for this purpose. Consult
44 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
45 | information on using pull requests.
46 |
47 | We require unit tests for most code, instructions for running our unit test
48 | suites are in the documentation.
49 |
--------------------------------------------------------------------------------
/DEVELOPMENT.md:
--------------------------------------------------------------------------------
1 | ## Development
2 |
3 | To build **TensorFlow.js Core API** from source, we need to clone the project and prepare
4 | the dev environment:
5 |
6 | ```bash
7 | $ git clone https://github.com/tensorflow/tfjs-core.git
8 | $ cd tfjs-core
9 | $ yarn # Installs dependencies.
10 | ```
11 |
12 | #### Yarn
13 | We use yarn, and if you are adding or removing dependencies you should use yarn
14 | to keep the `yarn.lock` file up to date.
15 |
16 | #### Code editor
17 | We recommend using [Visual Studio Code](https://code.visualstudio.com/) for
18 | development. Make sure to install
19 | [TSLint VSCode extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.vscode-typescript-tslint-plugin)
20 | and the npm [clang-format](https://github.com/angular/clang-format) `1.2.2` or later
21 | with the
22 | [Clang-Format VSCode extension](https://marketplace.visualstudio.com/items?itemName=xaver.clang-format)
23 | for auto-formatting.
24 |
25 | #### Testing
26 | Before submitting a pull request, make sure the code passes all the tests and is clean of lint errors:
27 |
28 | ```bash
29 | $ yarn test
30 | $ yarn lint
31 | ```
32 |
33 | To run a subset of tests and/or on a specific browser:
34 |
35 | ```bash
36 | $ yarn test --browsers=Chrome --grep='multinomial'
37 |
38 | > ...
39 | > Chrome 62.0.3202 (Mac OS X 10.12.6): Executed 28 of 1891 (skipped 1863) SUCCESS (6.914 secs / 0.634 secs)
40 | ```
41 |
42 | To run the tests once and exit the karma process (helpful on Windows):
43 |
44 | ```bash
45 | $ yarn test --single-run
46 | ```
47 |
48 | To run the tests in an environment that does not have GPU support (such as Chrome Remote Desktop):
49 |
50 | ```bash
51 | $ yarn test --testEnv cpu
52 | ```
53 |
54 | Available test environments: cpu, webgl1, webgl2.
55 |
56 | #### Packaging (browser and npm)
57 |
58 | ```bash
59 | $ yarn build-npm
60 | > Stored standalone library at dist/tf-core(.min).js
61 | > Stored also tensorflow-tf-core-VERSION.tgz
62 | ```
63 |
64 | To install it locally, run `yarn add ./tensorflow-tf-core-VERSION.tgz`.
65 |
66 | > On Windows, use bash (available through git) to use the scripts above.
67 |
68 | Looking to contribute, and don't know where to start? Check out our "stat:contributions welcome" [issues](https://github.com/tensorflow/tfjs/labels/stat%3Acontributions%20welcome).
69 |
--------------------------------------------------------------------------------
/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | If you would like to get help from the community, we encourage using Stack Overflow and the [`tensorflow.js`](https://stackoverflow.com/questions/tagged/tensorflow.js) tag.
2 |
3 | GitHub issues for this repository are tracked in the [tfjs union repository](https://github.com/tensorflow/tfjs/issues).
4 |
5 | Please file your issue there, following the guidance in [that issue template](https://github.com/tensorflow/tfjs/blob/master/ISSUE_TEMPLATE.md).
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # This repository has been archived in favor of tensorflow/tfjs.
2 |
3 | This repo will remain around for some time to keep history but all future PRs should be sent to [tensorflow/tfjs](https://github.com/tensorflow/tfjs) inside the tfjs-core folder.
4 |
5 | All history and contributions have been preserved in the monorepo.
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
2 | http_archive(
3 | name = "build_bazel_rules_nodejs",
4 | sha256 = "88e5e579fb9edfbd19791b8a3c6bfbe16ae3444dba4b428e5efd36856db7cf16",
5 | urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/0.27.8/rules_nodejs-0.27.8.tar.gz"],
6 | )
7 |
8 | load("@build_bazel_rules_nodejs//:defs.bzl", "yarn_install")
9 | yarn_install(
10 | name = "npm",
11 | package_json = "//:package.json",
12 | yarn_lock = "//:yarn.lock",
13 | )
14 |
15 | load("@npm//:install_bazel_dependencies.bzl", "install_bazel_dependencies")
16 | install_bazel_dependencies()
17 |
18 | # Setup TypeScript toolchain
19 | load("@npm_bazel_typescript//:index.bzl", "ts_setup_workspace")
20 | ts_setup_workspace()
21 |
--------------------------------------------------------------------------------
/cloudbuild.yml:
--------------------------------------------------------------------------------
1 | steps:
2 | # Install top-level deps.
3 | - name: 'node:10'
4 | entrypoint: 'yarn'
5 | id: 'yarn'
6 | args: ['install']
7 |
8 | # Run diff to find modified files in each folder.
9 | - name: 'node:10'
10 | entrypoint: 'yarn'
11 | id: 'diff'
12 | args: ['diff']
13 | waitFor: ['yarn']
14 |
15 | # Core.
16 | - name: 'gcr.io/cloud-builders/gcloud'
17 | entrypoint: 'bash'
18 | id: 'core'
19 | args: ['./scripts/run-build.sh', 'tfjs-core']
20 | waitFor: ['diff']
21 |
22 | # WebGPU.
23 | - name: 'gcr.io/cloud-builders/gcloud'
24 | entrypoint: 'bash'
25 | id: 'webgpu'
26 | args: ['./scripts/run-build.sh', 'tfjs-webgpu']
27 | waitFor: ['diff']
28 |
29 | # React Native.
30 | - name: 'gcr.io/cloud-builders/gcloud'
31 | entrypoint: 'bash'
32 | id: 'react-native'
33 | args: ['./scripts/run-build.sh', 'tfjs-react-native']
34 | waitFor: ['diff']
35 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "devDependencies": {
3 | "clang-format": "~1.2.4",
4 | "shelljs": "~0.8.3"
5 | },
6 | "scripts": {
7 | "diff": "./scripts/diff.js"
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## This repository has been archived in favor of the monorepo at https://github.com/tensorflow/tfjs. Please send pull requests there.
--------------------------------------------------------------------------------
/scripts/diff.js:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | // Copyright 2019 Google LLC. All Rights Reserved.
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 | // =============================================================================
16 |
17 | const {exec} = require('./test-util');
18 | const {readdirSync, statSync, writeFileSync} = require('fs');
19 | const {join} = require('path');
20 |
21 | const CLONE_PATH = 'clone';
22 |
23 | const dirs = readdirSync('.').filter(f => {
24 | return f !== 'node_modules' && f !== '.git' && statSync(f).isDirectory();
25 | });
26 |
27 | exec(
28 | `git clone --depth=1 --single-branch ` +
29 | `https://github.com/tensorflow/tfjs-core.git ${CLONE_PATH}`);
30 |
31 |
32 | dirs.forEach(dir => {
33 | const diffCmd = `diff -rq ${CLONE_PATH}/${dir}/ ./${dir}/`;
34 | const diffOutput = exec(diffCmd, {silent: true}, true).stdout.trim();
35 |
36 | if (diffOutput !== '') {
37 | console.log(`${dir} has modified files.`);
38 | writeFileSync(join(dir, 'diff'), diffOutput);
39 | } else {
40 | console.log(`No modified files found in ${dir}`);
41 | }
42 | });
43 |
--------------------------------------------------------------------------------
/scripts/run-build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2019 Google LLC. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | DIR=$1
20 | if test -f "$DIR/diff"; then
21 | gcloud builds submit . --config=$DIR/cloudbuild.yml
22 | fi
23 |
--------------------------------------------------------------------------------
/scripts/test-util.js:
--------------------------------------------------------------------------------
1 | // Copyright 2019 Google LLC. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | // =============================================================================
15 |
16 | const shell = require('shelljs');
17 |
18 | function exec(command, opt, ignoreCode) {
19 | const res = shell.exec(command, opt);
20 | if (!ignoreCode && res.code !== 0) {
21 | shell.echo('command', command, 'returned code', res.code);
22 | process.exit(1);
23 | }
24 | return res;
25 | }
26 |
27 | exports.exec = exec;
28 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | // Place your settings in this file to overwrite default and user settings.
2 | {
3 | "search.exclude": {
4 | "**/node_modules": true,
5 | "**/coverage/": true,
6 | "**/dist/": true,
7 | "**/yarn.lock": true,
8 | "**/.rpt2_cache/": true,
9 | "**/.yalc/**/*": true
10 | },
11 | "tslint.configFile": "tslint.json",
12 | "files.trimTrailingWhitespace": true,
13 | "editor.tabSize": 2,
14 | "editor.insertSpaces": true,
15 | "[typescript]": {
16 | "editor.formatOnSave": true
17 | },
18 | "[javascript]": {
19 | "editor.formatOnSave": true
20 | },
21 | "editor.rulers": [80],
22 | "clang-format.style": "Google",
23 | "files.insertFinalNewline": true,
24 | "editor.detectIndentation": false,
25 | "editor.wrappingIndent": "none",
26 | "typescript.tsdk": "./node_modules/typescript/lib",
27 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
28 | }
29 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/README.md:
--------------------------------------------------------------------------------
1 | # Headless WebGL backend for TensorFlow.js via Node.js
2 |
3 | ** This project is under heavy development **
4 |
5 | This new backend will provide a light-weight headless WebGL runtime for TensorFlow.js running under Node.js. This new backend is powered by the [node-gles](https://github.com/google/node-gles) module which uses [ANGLE](https://github.com/google/angle) to provide an integration layer to system GL runtime. This package aims to provide a think acceleration engine for IoT, desktop, and Node.js applications where CUDA (size/OS compatibility) is not an option.
6 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/demo/README.md:
--------------------------------------------------------------------------------
1 | # MobileNet tfjs-backend-nodegl Demo
2 |
3 | *This is a very early demo to show how tfjs-backend-nodegl can be used for headless WebGL acceleration.*
4 |
5 | To run this demo, perform the following:
6 |
7 | 1. Move into `tfjs-backend-nodegl` (parent directory of this demo folder):
8 | ```sh
9 | $ cd tfjs-backend-nodegl
10 | ```
11 |
12 | 2. Build package and compile TypeScript:
13 | ```sh
14 | $ yarn && yarn tsc
15 | ```
16 |
17 | 3. Move into the demo directory:
18 | ```sh
19 | $ cd demo
20 | ```
21 |
22 | 4. Prep and build demo:
23 | ```sh
24 | $ yarn
25 | ```
26 |
27 | 5. Run demo:
28 | ```sh
29 | $ node run_mobilenet_inference.js dog.jpg
30 | ```
31 |
32 | Expected output:
33 | ```sh
34 | $ node run_mobilenet_inference.js dog.jpg
35 | Platform node has already been set. Overwriting the platform with [object Object].
36 | - gl.VERSION: OpenGL ES 3.0 (ANGLE 2.1.0.9512a0ef062a)
37 | - gl.RENDERER: ANGLE (Intel Inc., Intel(R) Iris(TM) Plus Graphics 640, OpenGL 4.1 core)
38 | - Loading model...
39 | - Mobilenet load: 6450.763924002647ms
40 | - Coldstarting model...
41 | - Mobilenet cold start: 297.92842200398445ms
42 | - Running inference (100x) ...
43 | - Mobilenet inference: (100x) : 35.75772546708584ms
44 | ```
45 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/demo/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/tfjs-core/1b88a535f7fa30166167463a16ebacb4cd40c797/tfjs-backend-nodegl/demo/dog.jpg
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/demo/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tfjs-backend-nodegl-demo",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "scripts": {
7 | "test": "echo \"Error: no test specified\" && exit 1"
8 | },
9 | "dependencies": {
10 | "@tensorflow-models/mobilenet": "^1.0.1",
11 | "@tensorflow/tfjs": "^1.2.2",
12 | "jpeg-js": "^0.3.5"
13 | },
14 | "author": "",
15 | "license": "ISC"
16 | }
17 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tfjs-backend-nodegl",
3 | "version": "0.0.1",
4 | "description": "",
5 | "main": "dist/index.js",
6 | "types": "dist/index.d.ts",
7 | "scripts": {
8 | "build": "tsc",
9 | "lint": "tslint -p . -t verbose",
10 | "test": "ts-node src/run_tests.ts"
11 | },
12 | "author": "",
13 | "license": "Apache-2.0",
14 | "dependencies": {
15 | "@tensorflow/tfjs-core": "1.2.2",
16 | "node-gles": "^0.0.13"
17 | },
18 | "devDependencies": {
19 | "@types/jasmine": "~2.8.6",
20 | "@types/node": "^10.5.1",
21 | "@types/rimraf": "~2.0.2",
22 | "clang-format": "^1.2.4",
23 | "jasmine": "~3.1.0",
24 | "ts-node": "^8.1.0",
25 | "tslint": "~5.9.1",
26 | "typescript": "^3.4.5"
27 | }
28 | }
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/src/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs-core';
19 | // TODO(kreeger): Fix binding definition in node-gles before switching to a
20 | // non-require import style.
21 | // tslint:disable-next-line:no-require-imports
22 | const nodeGles = require('node-gles');
23 |
24 | const nodeGl = nodeGles.binding.createWebGLRenderingContext();
25 |
26 | // TODO(kreeger): These are hard-coded GL integration flags. These need to be
27 | // updated to ensure they work on all systems with proper exception reporting.
28 | tf.ENV.set('WEBGL_VERSION', 2);
29 | tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', true);
30 | tf.ENV.set('WEBGL_DOWNLOAD_FLOAT_ENABLED', true);
31 | tf.ENV.set('WEBGL_FENCE_API_ENABLED', true); // OpenGL ES 3.0 and higher..
32 | tf.ENV.set(
33 | 'WEBGL_MAX_TEXTURE_SIZE', nodeGl.getParameter(nodeGl.MAX_TEXTURE_SIZE));
34 | tf.webgl.setWebGLContext(2, nodeGl);
35 |
36 | tf.registerBackend('headless-nodegl', () => {
37 | // TODO(kreeger): Consider moving all GL creation here. However, weak-ref to
38 | // GL context tends to cause an issue when running unit tests:
39 | // https://github.com/tensorflow/tfjs/issues/1732
40 | return new tf.webgl.MathBackendWebGL(new tf.webgl.GPGPUContext(nodeGl));
41 | }, 3 /* priority */);
42 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/src/version.ts:
--------------------------------------------------------------------------------
1 | /** @license See the LICENSE file. */
2 |
3 | // This code is auto-generated, do not modify this file!
4 | const version = '0.0.1';
5 | export {version};
6 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0.0",
3 | "tasks": [
4 | {
5 | "command": "yarn",
6 | "label": "lint",
7 | "type": "shell",
8 | "args": [
9 | "lint"
10 | ],
11 | "problemMatcher": {
12 | "base": "$tslint5",
13 | "owner": "tslint-type-checked",
14 | "fileLocation": "absolute"
15 | }
16 | },
17 | {
18 | "command": "yarn",
19 | "label": "build",
20 | "type": "shell",
21 | "args": ["build", "--pretty", "false", "--noEmit"],
22 | "problemMatcher": [
23 | "$tsc"
24 | ]
25 | }
26 | ]
27 | }
28 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tsconfig",
3 | "include": [
4 | "src/"
5 | ],
6 | "exclude": [
7 | "node_modules/"
8 | ],
9 | "compilerOptions": {
10 | "outDir": "./dist"
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/tfjs-backend-nodegl/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tslint.json"
3 | }
4 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/.npmrc:
--------------------------------------------------------------------------------
1 | package-lock=false
2 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/.vscode/c_cpp_properties.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "Linux",
5 | "includePath": [
6 | "${workspaceFolder}/src/**",
7 | "~/emsdk/fastcomp/emscripten/system/include/**"
8 | ],
9 | "defines": [],
10 | "compilerPath": "/usr/bin/clang",
11 | "cStandard": "c11",
12 | "cppStandard": "c++11",
13 | "intelliSenseMode": "clang-x64"
14 | }
15 | ],
16 | "version": 4
17 | }
--------------------------------------------------------------------------------
/tfjs-backend-wasm/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | // Place your settings in this file to overwrite default and user settings.
2 | {
3 | "search.exclude": {
4 | "**/node_modules": true,
5 | "**/coverage/": true,
6 | "**/dist/": true,
7 | "**/yarn.lock": true,
8 | "**/.rpt2_cache/": true,
9 | "**/.yalc/**/*": true
10 | },
11 | "tslint.configFile": "tslint.json",
12 | "files.trimTrailingWhitespace": true,
13 | "editor.tabSize": 2,
14 | "editor.insertSpaces": true,
15 | "[typescript]": {
16 | "editor.formatOnSave": true
17 | },
18 | "[javascript]": {
19 | "editor.formatOnSave": true
20 | },
21 | "[cpp]": {
22 | "editor.formatOnSave": true
23 | },
24 | "editor.rulers": [80],
25 | "clang-format.style": "Google",
26 | "files.insertFinalNewline": true,
27 | "editor.detectIndentation": false,
28 | "editor.wrappingIndent": "none",
29 | "typescript.tsdk": "./node_modules/typescript/lib",
30 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format",
31 | "files.associations": {
32 | "vector": "cpp",
33 | "map": "cpp",
34 | "algorithm": "cpp"
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/README.md:
--------------------------------------------------------------------------------
1 | # Emscripten installation
2 |
3 | Install emscripten [here](https://emscripten.org/docs/getting_started/downloads.html)
4 |
5 |
6 | # Building
7 |
8 | ```sh
9 | yarn build
10 | ```
11 |
12 | # Testing
13 |
14 | ```sh
15 | yarn test
16 | ```
17 |
18 | # Deployment
19 | ```sh
20 | ./scripts/build-npm.sh
21 | npm publish
22 | ```
23 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@tensorflow/tfjs-backend-wasm",
3 | "version": "0.0.1",
4 | "main": "dist/index.js",
5 | "types": "dist/index.d.ts",
6 | "jsnext:main": "dist/tf-wasm.esm.js",
7 | "module": "dist/tf-wasm.esm.js",
8 | "unpkg": "dist/tf-wasm.min.js",
9 | "jsdelivr": "dist/tf-wasm.min.js",
10 | "scripts": {
11 | "build": "rimraf dist/ && ./scripts/build-wasm.sh && tsc && cp wasm-out/*.wasm dist/",
12 | "lint": "tslint -p . -t verbose",
13 | "test": "./scripts/build-wasm.sh && karma start"
14 | },
15 | "peerDependencies": {
16 | "@tensorflow/tfjs-core": "~1.2.7"
17 | },
18 | "devDependencies": {
19 | "@tensorflow/tfjs-core": "~1.2.7",
20 | "@types/emscripten": "~0.0.34",
21 | "clang-format": "^1.2.4",
22 | "jasmine-core": "~3.1.0",
23 | "karma": "~4.0.0",
24 | "karma-browserstack-launcher": "~1.4.0",
25 | "karma-chrome-launcher": "~2.2.0",
26 | "karma-firefox-launcher": "~1.1.0",
27 | "karma-jasmine": "~1.1.1",
28 | "karma-typescript": "~4.0.0",
29 | "rimraf": "~2.6.2",
30 | "rollup": "^1.17.0",
31 | "rollup-plugin-commonjs": "^10.0.1",
32 | "rollup-plugin-node-resolve": "^5.2.0",
33 | "rollup-plugin-terser": "^5.1.1",
34 | "rollup-plugin-typescript2": "^0.22.1",
35 | "tslint": "~5.11.0",
36 | "tslint-no-circular-imports": "^0.5.0",
37 | "typescript": "3.3.3333",
38 | "yalc": "~1.0.0-pre.21"
39 | },
40 | "license": "Apache-2.0"
41 | }
42 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/scripts/build-npm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2019 Google Inc. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | yarn rimraf dist/
20 | yarn
21 | yarn build
22 | yarn rollup -c
23 |
24 | echo "Stored standalone library at dist/tf-backend-wasm(.min).js"
25 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/scripts/build-wasm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2019 Google LLC. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | SRCS="src/lib.cc src/kernels.cc"
20 |
21 | # Add these optimization flags in production: -g0 -O3 --llvm-lto 3
22 |
23 | emcc $SRCS \
24 | -std=c++11 \
25 | -fno-rtti \
26 | -g \
27 | -fno-exceptions \
28 | -I./src/ \
29 | -o wasm-out/tfjs-backend-wasm.js \
30 | -s ALLOW_MEMORY_GROWTH=1 \
31 | -s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[] \
32 | -s DISABLE_EXCEPTION_CATCHING=1 \
33 | -s FILESYSTEM=0 \
34 | -s EXIT_RUNTIME=0 \
35 | -s EXPORTED_FUNCTIONS='["_malloc"]' \
36 | -s EXTRA_EXPORTED_RUNTIME_METHODS='["cwrap"]' \
37 | -s ENVIRONMENT=web \
38 | -s MODULARIZE=1 \
39 | -s EXPORT_NAME=WasmBackendModule \
40 | -s MALLOC=emmalloc
41 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/src/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {BackendWasm} from './backend_wasm';
19 | export {BackendWasm};
20 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/src/index_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs-core';
19 | import {test_util} from '@tensorflow/tfjs-core';
20 | import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';
21 |
22 | import {BackendWasm} from './index';
23 |
24 | /**
25 | * Tests specific to the wasm backend. The name of these tests must start with
26 | * 'wasm' so that they are always included in the test runner. See
27 | * `env.specFilter` in `setup_test.ts` for details.
28 | */
29 | describeWithFlags('wasm', ALL_ENVS, () => {
30 | it('write and read values', async () => {
31 | const x = tf.tensor1d([1, 2, 3]);
32 | test_util.expectArraysClose([1, 2, 3], await x.data());
33 | });
34 |
35 | it('allocate repetitively and confirm reuse of heap space', () => {
36 | const backend = tf.backend() as BackendWasm;
37 | const size = 100;
38 | // Allocate for the first time, record the memory offset and dispose.
39 | const t1 = tf.zeros([size]);
40 | const memOffset1 = backend.getMemoryOffset(t1.dataId);
41 | t1.dispose();
42 |
43 | // Allocate again and make sure the offset is the same (memory was reused).
44 | const t2 = tf.zeros([size]);
45 | const memOffset2 = backend.getMemoryOffset(t2.dataId);
46 | // This should fail in case of a memory leak.
47 | expect(memOffset1).toBe(memOffset2);
48 | });
49 | });
50 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/src/kernels.cc:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 Google Inc. All Rights Reserved.
2 | * Licensed under the Apache License, Version 2.0 (the "License");
3 | * you may not use this file except in compliance with the License.
4 | * You may obtain a copy of the License at
5 | *
6 | * http://www.apache.org/licenses/LICENSE-2.0
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS,
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | * See the License for the specific language governing permissions and
12 | * limitations under the License.
13 | * ===========================================================================*/
14 |
15 | #include "kernels.h"
16 |
17 | #include
18 |
19 | namespace tfjs {
20 | namespace kernels {
21 |
22 | // TODO(smilkov): Consider inlining small methods.
23 |
24 | template
25 | void add(T* a_buf, int a_size, T* b_buf, int b_size, T* out_buf) {
26 | int size = std::max(a_size, b_size);
27 | for (int i = 0; i < size; ++i) {
28 | out_buf[i] = a_buf[i % a_size] + b_buf[i % b_size];
29 | }
30 | }
31 |
32 | // Templates need explicit instantiation when implemented in a .cc file.
33 | template void add(float* a_buf, int a_size, float* b_buf, int b_size,
34 | float* out_buf);
35 | template void add(int* a_buf, int a_size, int* b_buf, int b_size,
36 | int* out_buf);
37 | template void add(bool* a_buf, int a_size, bool* b_buf, int b_size,
38 | bool* out_buf);
39 |
40 | } // namespace kernels
41 | } // namespace tfjs
42 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/src/kernels.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 Google Inc. All Rights Reserved.
2 | * Licensed under the Apache License, Version 2.0 (the "License");
3 | * you may not use this file except in compliance with the License.
4 | * You may obtain a copy of the License at
5 | *
6 | * http://www.apache.org/licenses/LICENSE-2.0
7 | *
8 | * Unless required by applicable law or agreed to in writing, software
9 | * distributed under the License is distributed on an "AS IS" BASIS,
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | * See the License for the specific language governing permissions and
12 | * limitations under the License.
13 | * ===========================================================================*/
14 |
15 | #ifndef TFJS_WASM_KERNELS_H_
16 | #define TFJS_WASM_KERNELS_H_
17 |
18 | namespace tfjs {
19 | namespace kernels {
20 |
21 | template
22 | // Element-wise add of two tensors.
23 | void add(T* a_buf, int a_size, T* b_buf, int b_size, T* out_buf);
24 |
25 | } // namespace kernels
26 | } // namespace tfjs
27 |
28 | #endif // TFJS_WASM_KERNELS_H_
29 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tsconfig",
3 | "include": [
4 | "src/"
5 | ],
6 | "exclude": [
7 | "node_modules/"
8 | ],
9 | "compilerOptions": {
10 | "outDir": "./dist"
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tslint.json"
3 | }
4 |
--------------------------------------------------------------------------------
/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | export type BackendWasmModule = EmscriptenModule&{
19 | onRuntimeInitialized: () => void;
20 | // Using the tfjs namespace to avoid conflict with emscripten's API.
21 | tfjs: {
22 | registerTensor(
23 | dataId: number, shape: Uint8Array, shapeLength: number, dtype: number,
24 | memoryOffset: number): void;
25 | // Disposes the data behind the data bucket.
26 | disposeData(dataId: number): void;
27 | // Disposes the backend and all of its associated data.
28 | dispose(): void;
29 |
30 | // Kernels.
31 | add(aId: number, bId: number, outId: number): void;
32 | }
33 | };
34 |
35 | declare var moduleFactory: () => BackendWasmModule;
36 | export default moduleFactory;
37 |
--------------------------------------------------------------------------------
/tfjs-core/.bazelignore:
--------------------------------------------------------------------------------
1 | node_modules
2 |
--------------------------------------------------------------------------------
/tfjs-core/.bazelrc:
--------------------------------------------------------------------------------
1 | build --symlink_prefix=dist/
2 |
--------------------------------------------------------------------------------
/tfjs-core/.npmignore:
--------------------------------------------------------------------------------
1 | # Ignore other packages in the mono-repo.
2 | tfjs-webgpu/
3 | tfjs-backend-nodegl/
4 | tfjs-react-native/
5 |
6 | .vscode/
7 | .rpt2_cache/
8 | src/**/*_test.ts
9 | integration_tests/
10 |
11 | dist/backends/**/*_test.js
12 | models/
13 | coverage/
14 | package/
15 | **/node_modules/
16 | karma.conf.js
17 | *.tgz
18 | *.log
19 | .travis.yml
20 | CONTRIBUTING.md
21 | tslint.json
22 | yarn.lock
23 | DEVELOPMENT.md
24 | ISSUE_TEMPLATE.md
25 | PULL_REQUEST_TEMPLATE.md
26 | rollup.config.js
27 | tsconfig.json
28 | .yalc/
29 | yalc.lock
30 | tfjs-react-native/
31 | tfjs-backend-nodegl/
32 |
--------------------------------------------------------------------------------
/tfjs-core/.npmrc:
--------------------------------------------------------------------------------
1 | package-lock=false
2 |
--------------------------------------------------------------------------------
/tfjs-core/.vscode/c_cpp_properties.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "Mac",
5 | "includePath": [
6 | "${workspaceFolder}/**",
7 | "~/emsdk/fastcomp/emscripten/system/include/**"
8 | ],
9 | "defines": [],
10 | "macFrameworkPath": [
11 | "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.14.sdk/System/Library/Frameworks"
12 | ],
13 | "compilerPath": "/usr/bin/clang",
14 | "cStandard": "c11",
15 | "cppStandard": "c++11",
16 | "intelliSenseMode": "clang-x64"
17 | }
18 | ],
19 | "version": 4
20 | }
21 |
--------------------------------------------------------------------------------
/tfjs-core/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "type": "chrome",
9 | "request": "attach",
10 | "name": "Attach Karma Chrome",
11 | "address": "localhost",
12 | "port": 9333,
13 | "pathMapping": {
14 | "/": "${workspaceRoot}",
15 | "/base/": "${workspaceRoot}/"
16 | }
17 | }
18 | ]
19 | }
20 |
--------------------------------------------------------------------------------
/tfjs-core/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | // Place your settings in this file to overwrite default and user settings.
2 | {
3 | "search.exclude": {
4 | "**/node_modules": true,
5 | "**/coverage/": true,
6 | "**/dist/": true,
7 | "**/yarn.lock": true,
8 | "**/.rpt2_cache/": true,
9 | "**/.yalc/**/*": true
10 | },
11 | "tslint.configFile": "tslint.json",
12 | "files.trimTrailingWhitespace": true,
13 | "editor.tabSize": 2,
14 | "editor.insertSpaces": true,
15 | "[typescript]": {
16 | "editor.formatOnSave": true
17 | },
18 | "[javascript]": {
19 | "editor.formatOnSave": true
20 | },
21 | "editor.defaultFormatter": "xaver.clang-format",
22 | "editor.rulers": [80],
23 | "clang-format.style": "Google",
24 | "files.insertFinalNewline": true,
25 | "editor.detectIndentation": false,
26 | "editor.wrappingIndent": "none",
27 | "typescript.tsdk": "./node_modules/typescript/lib",
28 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
29 | }
30 |
--------------------------------------------------------------------------------
/tfjs-core/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0.0",
3 | "tasks": [
4 | {
5 | "command": "yarn",
6 | "label": "lint",
7 | "type": "shell",
8 | "args": [
9 | "lint"
10 | ],
11 | "problemMatcher": {
12 | "base": "$tslint5",
13 | "owner": "tslint-type-checked",
14 | "fileLocation": "absolute"
15 | }
16 | },
17 | {
18 | "command": "yarn",
19 | "label": "build",
20 | "type": "shell",
21 | "args": ["build", "--pretty", "false", "--noEmit"],
22 | "problemMatcher": [
23 | "$tsc"
24 | ]
25 | }
26 | ]
27 | }
28 |
--------------------------------------------------------------------------------
/tfjs-core/BUILD.bazel:
--------------------------------------------------------------------------------
1 | # Allow typescript rules in any package to reference this file
2 | exports_files(["tsconfig.json"])
3 |
--------------------------------------------------------------------------------
/tfjs-core/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow.js Core API
2 |
3 | A part of the TensorFlow.js ecosystem, this repo hosts `@tensorflow/tfjs-core`,
4 | the TensorFlow.js Core API, which provides low-level, hardware-accelerated
5 | linear algebra operations and an eager API for automatic differentiation.
6 |
7 | Check out [js.tensorflow.org](https://js.tensorflow.org) for more
8 | information about the library, tutorials and API docs.
9 |
10 | To keep track of issues we use the [tensorflow/tfjs](https://github.com/tensorflow/tfjs) Github repo.
11 |
12 | ## Importing
13 |
14 | You can install TensorFlow.js via yarn or npm. We recommend using the
15 | [@tensorflow/tfjs](https://www.npmjs.com/package/@tensorflow/tfjs) npm package,
16 | which gives you both this Core API and the higher-level
17 | [Layers API](https://github.com/tensorflow/tfjs-layers):
18 |
19 | ```js
20 | import * as tf from '@tensorflow/tfjs';
21 | // You have the Core API: tf.matMul(), tf.softmax(), ...
22 | // You also have Layers API: tf.model(), tf.layers.dense(), ...
23 | ```
24 |
25 | On the other hand, if you care about the bundle size and you do not use the
26 | Layers API, you can import only the Core API:
27 |
28 | ```js
29 | import * as tfc from '@tensorflow/tfjs-core';
30 | // You have the Core API: tfc.matMul(), tfc.softmax(), ...
31 | // No Layers API.
32 | ```
33 |
34 | For info about development, check out [DEVELOPMENT.md](./DEVELOPMENT.md).
35 |
36 | ## For more information
37 |
38 | - [TensorFlow.js API documentation](https://js.tensorflow.org/api/latest/)
39 | - [TensorFlow.js Tutorials](https://js.tensorflow.org/tutorials/)
40 |
41 | Thanks BrowserStack for providing testing support.
42 |
--------------------------------------------------------------------------------
/tfjs-core/cloudbuild.yml:
--------------------------------------------------------------------------------
1 | steps:
2 | # Install common dependencies.
3 | - name: 'node:10'
4 | id: 'yarn-common'
5 | entrypoint: 'yarn'
6 | args: ['install']
7 |
8 | # Install tfjs-core dependencies.
9 | - name: 'node:10'
10 | dir: 'tfjs-core'
11 | id: 'yarn'
12 | entrypoint: 'yarn'
13 | args: ['install']
14 | waitFor: ['yarn-common']
15 |
16 | # Build
17 | - name: 'node:10'
18 | dir: 'tfjs-core'
19 | id: 'build'
20 | entrypoint: 'yarn'
21 | args: ['build-ci']
22 | waitFor: ['yarn']
23 |
24 | # Run unit tests.
25 | - name: 'node:10'
26 | dir: 'tfjs-core'
27 | id: 'test'
28 | entrypoint: 'yarn'
29 | args: ['test-ci']
30 | waitFor: ['build']
31 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1']
32 | secretEnv: ['BROWSERSTACK_KEY']
33 |
34 | # Run integration tests of other packages against core.
35 | - name: 'node:10'
36 | dir: 'tfjs-core'
37 | id: 'test-integration'
38 | entrypoint: 'yarn'
39 | args: ['test-integration']
40 | waitFor: ['build']
41 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1', 'NIGHTLY=$_NIGHTLY']
42 | secretEnv: ['BROWSERSTACK_KEY']
43 |
44 | # bundle size check
45 | - name: 'node:10'
46 | dir: 'tfjs-core'
47 | id: 'test-bundle-size'
48 | entrypoint: 'yarn'
49 | args: ['test-bundle-size']
50 | waitFor: ['yarn']
51 |
52 | # test doc snippets
53 | - name: 'node:10'
54 | dir: 'tfjs-core'
55 | id: 'test-snippets'
56 | entrypoint: 'yarn'
57 | args: ['test-snippets']
58 | waitFor: ['yarn']
59 |
60 | # test Async backends
61 | - name: 'node:10'
62 | dir: 'tfjs-core'
63 | id: 'test-async-backends'
64 | entrypoint: 'yarn'
65 | args: ['test-async-backends-ci']
66 | waitFor: ['build']
67 |
68 | # General configuration
69 | secrets:
70 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc
71 | secretEnv:
72 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA=
73 | timeout: 1800s
74 | logsBucket: 'gs://tfjs-build-logs'
75 | substitutions:
76 | _NIGHTLY: ''
77 | options:
78 | logStreamingOption: 'STREAM_ON'
79 | machineType: 'N1_HIGHCPU_8'
80 | substitution_option: 'ALLOW_LOOSE'
81 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/build-npm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2017 Google Inc. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | yarn rimraf dist/
20 | yarn
21 |
22 | yarn build
23 | yarn build-test-snippets
24 | yarn rollup -c --visualize
25 |
26 | # Use minified files for miniprogram
27 | mkdir dist/miniprogram
28 | cp dist/tf-core.min.js dist/miniprogram/index.js
29 | cp dist/tf-core.min.js.map dist/miniprogram/index.js.map
30 |
31 | echo "Stored standalone library at dist/tf-core(.min).js"
32 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/README.md:
--------------------------------------------------------------------------------
1 | This directory contains the following Google Cloud Functions.
2 |
3 | ### `trigger_nightly`
4 | Programatically triggers a Cloud Build on master. This function is called by the Cloud Scheduler at 3am EST every day (configurable via the Cloud Scheduler UI).
5 | You can also trigger the function manually via the Cloud UI.
6 |
7 | Command to re-deploy:
8 | ```sh
9 | gcloud functions deploy nightly \
10 | --runtime nodejs8 \
11 | --trigger-topic nightly
12 | ```
13 |
14 | If a build was triggered by nightly, there is a substitution variable `_NIGHTLY=true`.
15 | You can forward the substitution as the `NIGHTLY` environment variable so the scripts can use it, by specifying `env: ['NIGHTLY=$_NIGHTLY']` in `cloudbuild.yml`. E.g. `test-integration` uses the `NIGHTLY` bit to always run on nightly.
16 |
17 | ### `send_email`
18 | Sends an email and a chat message with the nightly build status. Every build sends a message to the `cloud-builds` topic with its build information. The `send_email` function is subscribed to that topic and ignores all builds (e.g. builds triggered by pull requests) **except** for the nightly build and sends an email to an internal mailing list with its build status around 3:10am.
19 |
20 | Command to re-deploy:
21 |
22 | ```sh
23 | gcloud functions deploy send_email \
24 | --runtime nodejs8 \
25 | --stage-bucket learnjs-174218_cloudbuild \
26 | --trigger-topic cloud-builds \
27 | --set-env-vars MAILGUN_API_KEY="[API_KEY_HERE]",HANGOUTS_URL="[URL_HERE]"
28 | ```
29 |
30 | ### The pipeline
31 |
32 | The pipeline looks like this:
33 |
34 | 1) At 3am, Cloud Scheduler writes to `nightly` topic
35 | 2) That triggers the `nightly` function, which starts a build programatically
36 | 3) That build runs and writes its status to `cloud-builds` topic
37 | 4) That triggers the `send_email` function, which sends email and chat with the build status.
38 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/send_email/.gcloudignore:
--------------------------------------------------------------------------------
1 | # This file specifies files that are *not* uploaded to Google Cloud Platform
2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of
3 | # "#!include" directives (which insert the entries of the given .gitignore-style
4 | # file at that point).
5 | #
6 | # For more information, run:
7 | # $ gcloud topic gcloudignore
8 | #
9 | .gcloudignore
10 | # If you would like to upload your .git directory, .gitignore file or files
11 | # from your .gitignore file, remove the corresponding line
12 | # below:
13 | .git
14 | .gitignore
15 |
16 | node_modules
17 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/send_email/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "MAILGUN_DOMAIN":"sandbox497e7af39a1b4dee92fb92d9bfe5a686.mailgun.org",
3 | "MAILGUN_FROM":"Cloud Build ",
4 | "MAILGUN_TO":"tfjs-builds@google.com"
5 | }
6 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/send_email/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "cloudbuild-email",
3 | "version": "0.0.1",
4 | "description": "Email integration for Google Cloud Build, using Google Cloud Functions",
5 | "main": "index.js",
6 | "dependencies": {
7 | "humanize-duration": "3.10.0",
8 | "mailgun-js": "^0.22.0",
9 | "node-fetch": "^2.5.0",
10 | "request": "^2.88.0",
11 | "request-promise-native": "^1.0.7"
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/trigger_nightly/.gcloudignore:
--------------------------------------------------------------------------------
1 | # This file specifies files that are *not* uploaded to Google Cloud Platform
2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of
3 | # "#!include" directives (which insert the entries of the given .gitignore-style
4 | # file at that point).
5 | #
6 | # For more information, run:
7 | # $ gcloud topic gcloudignore
8 | #
9 | .gcloudignore
10 | # If you would like to upload your .git directory, .gitignore file or files
11 | # from your .gitignore file, remove the corresponding line
12 | # below:
13 | .git
14 | .gitignore
15 |
16 | node_modules
17 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/trigger_nightly/index.js:
--------------------------------------------------------------------------------
1 | // Copyright 2019 Google LLC. All Rights Reserved.
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 | // =============================================================================
15 |
16 | const {google} = require('googleapis');
17 |
18 | module.exports.nightly = async data => {
19 | const cloudbuild = google.cloudbuild('v1');
20 | const auth = await google.auth.getClient(
21 | {scopes: ['https://www.googleapis.com/auth/cloud-platform']});
22 | google.options({auth});
23 | const resp = await cloudbuild.projects.triggers.run({
24 | 'projectId': 'learnjs-174218',
25 | 'triggerId': '7423c985-2fd2-40f3-abe7-94d4c353eed0',
26 | 'resource': {'branchName': 'master'}
27 | });
28 | console.log(resp);
29 | };
30 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/cloud_funcs/trigger_nightly/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "cloudbuild",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "dependencies": {
7 | "googleapis": "^39.2.0"
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/make-version:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | // Copyright 2018 Google LLC. All Rights Reserved.
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 | // =============================================================================
16 |
17 |
18 | // Run this script from the base directory (not the script directory):
19 | // ./scripts/make-version
20 |
21 | const fs = require('fs');
22 | const version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version;
23 |
24 | const versionCode =
25 | `/** @license See the LICENSE file. */
26 |
27 | // This code is auto-generated, do not modify this file!
28 | const version = '${version}';
29 | export {version};
30 | `
31 |
32 | fs.writeFile('src/version.ts', versionCode, err => {
33 | if (err) {
34 | throw new Error(`Could not save version file ${version}: ${err}`);
35 | }
36 | console.log(`Version file for version ${version} saved sucessfully.`);
37 | });
38 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/publish-npm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2017 Google Inc. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | # Before you run this script, do this:
18 | # 1) Update the version in package.json
19 | # 2) Run ./scripts/make-version from the base dir of the project.
20 | # 3) Run `yarn` to update `yarn.lock`, in case you updated dependencies
21 | # 4) Commit to the master branch.
22 |
23 | # Then:
24 | # 5) Checkout the master branch of this repo.
25 | # 6) Run this script as `./scripts/publish-npm.sh` from the project base dir.
26 |
27 | set -e
28 |
29 | BRANCH=`git rev-parse --abbrev-ref HEAD`
30 | ORIGIN=`git config --get remote.origin.url`
31 | CHANGES=`git status --porcelain`
32 |
33 | if [ "$BRANCH" != "master" ]; then
34 | echo "Error: Switch to the master branch before publishing."
35 | exit
36 | fi
37 |
38 | if ! [[ "$ORIGIN" =~ tensorflow/tfjs-core ]]; then
39 | echo "Error: Switch to the main repo (tensorflow/tfjs-core) before publishing."
40 | exit
41 | fi
42 |
43 | if [ ! -z "$CHANGES" ];
44 | then
45 | echo "Make sure the master branch is clean. Found changes:"
46 | echo $CHANGES
47 | exit 1
48 | fi
49 |
50 | yarn build-npm
51 | ./scripts/make-version # This is for safety in case you forgot to do 2).
52 | ./scripts/tag-version.js
53 | npm publish
54 | echo 'Yay! Published a new package to npm.'
55 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/tag-version.js:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | // Copyright 2018 Google LLC. All Rights Reserved.
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 | // =============================================================================
16 |
17 | // Run this script from the base directory (not the script directory):
18 | // ./scripts/tag-version.js
19 |
20 | var fs = require('fs');
21 | var exec = require('child_process').exec;
22 |
23 | var version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version;
24 | var tag = `v${version}`;
25 |
26 | exec(`git tag ${tag}`, (err, stdout, stderr) => {
27 | console.log('\x1b[36m%s\x1b[0m', 'git tag command stdout:');
28 | console.log(stdout);
29 | console.log('\x1b[31m%s\x1b[0m', 'git tag command stderr:');
30 | console.log(stderr);
31 |
32 | if (err) {
33 | throw new Error(`Could not git tag with ${tag}: ${err.message}.`);
34 | }
35 | console.log(`Successfully tagged with ${tag}.`);
36 | });
37 |
38 | exec(`git push --tags`, (err, stdout, stderr) => {
39 | console.log('\x1b[36m%s\x1b[0m', 'git push tags command stdout:');
40 | console.log(stdout);
41 | console.log('\x1b[41m%s\x1b[0m', 'git push tags command stderr:');
42 | console.log(stderr);
43 |
44 | if (err) {
45 | throw new Error(`Could not push git tags: ${err.message}.`);
46 | }
47 | console.log(`Successfully pushed tags.`);
48 | });
49 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/test-ci.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2018 Google LLC. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | yarn lint
20 | # Test in node (headless environment).
21 | yarn test-node-ci
22 |
23 | # Run the first karma separately so it can download the BrowserStack binary
24 | # without conflicting with others.
25 | yarn run-browserstack --browsers=bs_safari_mac,bs_ios_11 --testEnv webgl1 --flags '{"WEBGL_CPU_FORWARD": false, "WEBGL_SIZE_UPLOAD_UNIFORM": 0}'
26 |
27 | # Run the rest of the karma tests in parallel. These runs will reuse the
28 | # already downloaded binary.
29 | npm-run-all -p -c --aggregate-output \
30 | "run-browserstack --browsers=bs_safari_mac,bs_ios_11 --flags '{\"HAS_WEBGL\": false}' --testEnv cpu" \
31 | "run-browserstack --browsers=bs_firefox_mac,bs_chrome_mac" \
32 | "run-browserstack --browsers=bs_chrome_mac,win_10_chrome --testEnv webgl2 --flags '{\"WEBGL_CPU_FORWARD\": false, \"WEBGL_SIZE_UPLOAD_UNIFORM\": 0}'"
33 |
34 | ### The next section tests TF.js in a webworker.
35 | # Make a dist/tf-core.min.js file to be imported by the web worker.
36 | yarn rollup -c --ci
37 | # Safari doesn't have offscreen canvas so test cpu in a webworker.
38 | # Chrome has offscreen canvas, so test webgl in a webworker.
39 | yarn test-webworker --browsers=bs_safari_mac,bs_chrome_mac
40 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/test-integration.js:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | // Copyright 2019 Google LLC. All Rights Reserved.
3 | //
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 | // =============================================================================
16 |
17 | const {exec} = require('../../scripts/test-util');
18 |
19 | const dirName = 'tfjs-core-integration';
20 |
21 | let shouldRunIntegration = false;
22 | if (process.env.NIGHTLY === 'true') {
23 | shouldRunIntegration = true;
24 | } else {
25 | exec(
26 | `git clone --depth=1 --single-branch ` +
27 | `https://github.com/tensorflow/tfjs-core.git ${dirName}`);
28 | const res = exec(
29 | `git diff --name-only --diff-filter=M --no-index ${dirName}/src/ src/`,
30 | {silent: true}, true);
31 | let files = res.stdout.trim().split('\n');
32 | files.forEach(file => {
33 | if (file === 'src/version.ts') {
34 | shouldRunIntegration = true;
35 | }
36 | });
37 | }
38 | if (shouldRunIntegration) {
39 | exec('./scripts/test-integration.sh');
40 | }
41 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/test_snippets/test_snippets.ts:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | /**
3 | * @license
4 | * Copyright 2019 Google LLC. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | * =============================================================================
17 | */
18 | import * as tf from '../../src/index';
19 | import {parseAndEvaluateSnippets} from './util';
20 |
21 | parseAndEvaluateSnippets(tf);
22 |
--------------------------------------------------------------------------------
/tfjs-core/scripts/test_snippets/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../../tsconfig.json",
3 | "compilerOptions": {
4 | "outDir": "../../dist/scripts/test_snippets/"
5 | },
6 | "include": [
7 | "util.ts"
8 | ]
9 | }
10 |
--------------------------------------------------------------------------------
/tfjs-core/src/BUILD.bazel:
--------------------------------------------------------------------------------
1 | load("@npm_bazel_typescript//:defs.bzl", "ts_library")
2 |
3 | TEST_SRCS = [
4 | "jasmine_*",
5 | "test_*",
6 | "**/*_test.ts"
7 | ]
8 |
9 | ts_library(
10 | name = "src",
11 | srcs = glob(
12 | ["**/*.ts"],
13 | # exclude = TEST_SRCS,
14 | ),
15 | deps = [
16 | "@npm//@types",
17 | ]
18 | )
19 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/backend_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 | import * as tf from '../index';
18 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
19 | import {EPSILON_FLOAT16, EPSILON_FLOAT32} from './backend';
20 |
21 | describeWithFlags('epsilon', ALL_ENVS, () => {
22 | it('Epsilon is a function of float precision', () => {
23 | const epsilonValue = tf.backend().floatPrecision() === 32 ?
24 | EPSILON_FLOAT32 :
25 | EPSILON_FLOAT16;
26 | expect(tf.backend().epsilon()).toBe(epsilonValue);
27 | });
28 |
29 | it('abs(epsilon) > 0', async () => {
30 | expect(await tf.abs(tf.backend().epsilon()).array()).toBeGreaterThan(0);
31 | });
32 | });
33 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/cpu/backend_cpu_test_registry.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Constraints, registerTestEnv} from '../../jasmine_util';
19 |
20 | export const CPU_ENVS: Constraints = {
21 | predicate: testEnv => testEnv.backendName === 'cpu'
22 | };
23 |
24 | registerTestEnv({name: 'cpu', backendName: 'cpu', isDataSync: true});
25 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/packing_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | export function getVecChannels(name: string, rank: number): string[] {
19 | return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
20 | }
21 |
22 | export function getChannels(name: string, rank: number): string[] {
23 | if (rank === 1) {
24 | return [name];
25 | }
26 | return getVecChannels(name, rank);
27 | }
28 |
29 | export function getSourceCoords(rank: number, dims: string[]): string {
30 | if (rank === 1) {
31 | return 'rc';
32 | }
33 |
34 | let coords = '';
35 | for (let i = 0; i < rank; i++) {
36 | coords += dims[i];
37 | if (i < rank - 1) {
38 | coords += ',';
39 | }
40 | }
41 | return coords;
42 | }
--------------------------------------------------------------------------------
/tfjs-core/src/backends/split_shared.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Tensor} from '../tensor';
19 |
20 | /** Shared implementation of the split kernel across WebGL and CPU. */
21 | export function split(
22 | x: T, sizeSplits: number[], axis: number): T[] {
23 | const begin = new Array(x.rank).fill(0);
24 | const size = x.shape.slice();
25 | return sizeSplits.map(s => {
26 | size[axis] = s;
27 | const slice = x.slice(begin, size);
28 | begin[axis] += s;
29 | return slice;
30 | });
31 | }
32 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/tile_impl.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /**
19 | * An implementation of the tile kernel shared between webgl and cpu for string
20 | * tensors only.
21 | */
22 |
23 | import {buffer} from '../ops/array_ops';
24 | import {Tensor, TensorBuffer} from '../tensor';
25 | import {DataType, Rank} from '../types';
26 |
27 | export function tile(
28 | xBuf: TensorBuffer, reps: number[]): Tensor {
29 | const newShape: number[] = new Array(xBuf.rank);
30 | for (let i = 0; i < newShape.length; i++) {
31 | newShape[i] = xBuf.shape[i] * reps[i];
32 | }
33 | const result = buffer(newShape, xBuf.dtype);
34 | for (let i = 0; i < result.values.length; ++i) {
35 | const newLoc = result.indexToLoc(i);
36 |
37 | const originalLoc: number[] = new Array(xBuf.rank);
38 | for (let i = 0; i < originalLoc.length; i++) {
39 | originalLoc[i] = newLoc[i] % xBuf.shape[i];
40 | }
41 |
42 | const originalIndex = xBuf.locToIndex(originalLoc);
43 |
44 | result.values[i] = xBuf.values[originalIndex];
45 | }
46 | return result.toTensor() as Tensor;
47 | }
48 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/addn_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export class AddNProgram implements GPGPUProgram {
21 | variableNames: string[];
22 | outputShape: number[] = [];
23 | userCode: string;
24 |
25 | constructor(outputShape: number[], shapes: number[][]) {
26 | this.outputShape = outputShape;
27 | this.variableNames = shapes.map((_, i) => `T${i}`);
28 |
29 | const snippets: string[] = [];
30 | // Get target elements from every input tensor.
31 | this.variableNames.forEach(variable => {
32 | snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
33 | });
34 |
35 | // Calculate the sum of all elements.
36 | const operation = this.variableNames
37 | .map(variable => {
38 | return `v${variable}`;
39 | })
40 | .join(' + ');
41 |
42 | this.userCode = `
43 | void main() {
44 | ${snippets.join('\n ')}
45 |
46 | float result = ${operation};
47 | setOutput(result);
48 | }
49 | `;
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/addn_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export class AddNPackedProgram implements GPGPUProgram {
21 | variableNames: string[];
22 | outputShape: number[] = [];
23 | userCode: string;
24 | usesPackedTextures = true;
25 |
26 | constructor(outputShape: number[], shapes: number[][]) {
27 | this.outputShape = outputShape;
28 | this.variableNames = shapes.map((_, i) => `T${i}`);
29 |
30 | const snippets: string[] = [];
31 | // Get target elements from every input tensor.
32 | this.variableNames.forEach(variable => {
33 | snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
34 | });
35 |
36 | // Calculate the sum of all elements.
37 | const operation = this.variableNames
38 | .map(variable => {
39 | return `v${variable}`;
40 | })
41 | .join(' + ');
42 |
43 | this.userCode = `
44 | void main() {
45 | ${snippets.join('\n ')}
46 |
47 | vec4 result = ${operation};
48 | setOutput(result);
49 | }
50 | `;
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/argminmax_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ReduceInfo} from '../../ops/reduce_util';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class ArgMinMaxProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | outputShape: number[];
24 | userCode: string;
25 |
26 | constructor(reduceInfo: ReduceInfo, op: 'max'|'min', firstPass: boolean) {
27 | const windowSize = reduceInfo.windowSize;
28 | const batchSize = reduceInfo.batchSize;
29 | const inSize = reduceInfo.inSize;
30 | const outSize = Math.ceil(inSize / windowSize);
31 | if (!firstPass) {
32 | this.variableNames.push('bestIndicesA');
33 | }
34 | this.outputShape = [batchSize, outSize];
35 | const compOp = (op === 'max') ? '>' : '<';
36 | const indexSnippet = firstPass ?
37 | 'inOffset + i;' :
38 | 'round(getBestIndicesA(batch, inOffset + i));';
39 |
40 | this.userCode = `
41 | void main() {
42 | ivec2 coords = getOutputCoords();
43 | int batch = coords[0];
44 | int outIdx = coords[1];
45 | int inOffset = outIdx * ${windowSize};
46 |
47 | int bestIndex = inOffset;
48 | float bestValue = getA(batch, bestIndex);
49 |
50 | for (int i = 0; i < ${windowSize}; i++) {
51 | int inIdx = ${indexSnippet};
52 | float candidate = getA(batch, inIdx);
53 | if (candidate ${compOp} bestValue) {
54 | bestValue = candidate;
55 | bestIndex = inIdx;
56 | }
57 | }
58 | setOutput(float(bestIndex));
59 | }
60 | `;
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/backend_webgl_test_registry.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Constraints, registerTestEnv} from '../../jasmine_util';
19 |
20 | export const WEBGL_ENVS: Constraints = {
21 | predicate: testEnv => testEnv.backendName === 'webgl'
22 | };
23 | export const PACKED_ENVS: Constraints = {
24 | flags: {'WEBGL_PACK': true}
25 | };
26 |
27 | registerTestEnv({
28 | name: 'webgl1',
29 | backendName: 'webgl',
30 | flags: {
31 | 'WEBGL_VERSION': 1,
32 | 'WEBGL_CPU_FORWARD': false,
33 | 'WEBGL_SIZE_UPLOAD_UNIFORM': 0
34 | },
35 | isDataSync: true
36 | });
37 |
38 | registerTestEnv({
39 | name: 'webgl2',
40 | backendName: 'webgl',
41 | flags: {
42 | 'WEBGL_VERSION': 2,
43 | 'WEBGL_CPU_FORWARD': false,
44 | 'WEBGL_SIZE_UPLOAD_UNIFORM': 0
45 | },
46 | isDataSync: true
47 | });
48 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/binaryop_complex_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as broadcast_util from '../../ops/broadcast_util';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | // (Ar + Ai)(Br + Bi) =
22 | // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
23 | // Yr = ArBr - AB
24 | // Yi = ArBi + AiBr
25 | export const COMPLEX_MULTIPLY = {
26 | REAL: 'return areal * breal - aimag * bimag;',
27 | IMAG: 'return areal * bimag + aimag * breal;'
28 | };
29 |
30 | export class BinaryOpComplexProgram implements GPGPUProgram {
31 | variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
32 | userCode: string;
33 | outputShape: number[];
34 |
35 | constructor(op: string, aShape: number[], bShape: number[]) {
36 | this.outputShape =
37 | broadcast_util.assertAndGetBroadcastShape(aShape, bShape);
38 |
39 | this.userCode = `
40 | float binaryOpComplex(
41 | float areal, float aimag, float breal, float bimag) {
42 | ${op}
43 | }
44 |
45 | void main() {
46 | float areal = getARealAtOutCoords();
47 | float aimag = getAImagAtOutCoords();
48 | float breal = getBRealAtOutCoords();
49 | float bimag = getBImagAtOutCoords();
50 | setOutput(binaryOpComplex(areal, aimag, breal, bimag));
51 | }
52 | `;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/canvas_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ENV} from '../../environment';
19 | import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util';
20 |
21 | import {getWebGLContext} from './canvas_util';
22 |
23 | describeWithFlags('canvas_util', BROWSER_ENVS, () => {
24 | it('Returns a valid canvas', () => {
25 | const canvas = getWebGLContext(ENV.getNumber('WEBGL_VERSION')).canvas as (
26 | HTMLCanvasElement | OffscreenCanvas);
27 | expect(
28 | (canvas instanceof HTMLCanvasElement) ||
29 | (canvas instanceof OffscreenCanvas))
30 | .toBe(true);
31 | });
32 |
33 | it('Returns a valid gl context', () => {
34 | const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION'));
35 | expect(gl.isContextLost()).toBe(false);
36 | });
37 | });
38 |
39 | describeWithFlags('canvas_util webgl2', {flags: {WEBGL_VERSION: 2}}, () => {
40 | it('is ok when the user requests webgl 1 canvas', () => {
41 | const canvas = getWebGLContext(1).canvas;
42 | expect((canvas instanceof HTMLCanvasElement)).toBe(true);
43 | });
44 | });
45 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/clip_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUContext} from './gpgpu_context';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class ClipProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | userCode: string;
24 | outputShape: number[];
25 |
26 | // Caching uniform locations for speed.
27 | minLoc: WebGLUniformLocation;
28 | maxLoc: WebGLUniformLocation;
29 |
30 | constructor(aShape: number[]) {
31 | this.outputShape = aShape;
32 | this.userCode = `
33 | uniform float minVal;
34 | uniform float maxVal;
35 |
36 | void main() {
37 | float value = getAAtOutCoords();
38 | if (isnan(value)) {
39 | setOutput(value);
40 | return;
41 | }
42 |
43 | setOutput(clamp(value, minVal, maxVal));
44 | }
45 | `;
46 | }
47 |
48 | getCustomSetupFunc(min: number, max: number) {
49 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
50 | if (this.minLoc == null) {
51 | this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
52 | this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
53 | }
54 | gpgpu.gl.uniform1f(this.minLoc, min);
55 | gpgpu.gl.uniform1f(this.maxLoc, max);
56 | };
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/clip_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUContext} from './gpgpu_context';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class ClipPackedProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | usesPackedTextures = true;
24 | userCode: string;
25 | outputShape: number[];
26 |
27 | // Caching uniform locations for speed.
28 | minLoc: WebGLUniformLocation;
29 | maxLoc: WebGLUniformLocation;
30 |
31 | constructor(aShape: number[]) {
32 | this.outputShape = aShape;
33 | this.userCode = `
34 | uniform float minVal;
35 | uniform float maxVal;
36 |
37 | void main() {
38 | vec4 value = getAAtOutCoords();
39 |
40 | if (any(isnan(value))) {
41 | setOutput(value);
42 | return;
43 | }
44 |
45 | setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
46 | }
47 | `;
48 | }
49 |
50 | getCustomSetupFunc(min: number, max: number) {
51 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
52 | if (this.minLoc == null) {
53 | this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
54 | this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
55 | }
56 | gpgpu.gl.uniform1f(this.minLoc, min);
57 | gpgpu.gl.uniform1f(this.maxLoc, max);
58 | };
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/complex_abs_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export class ComplexAbsProgram implements GPGPUProgram {
21 | variableNames = ['real', 'imag'];
22 | userCode: string;
23 | outputShape: number[];
24 |
25 | constructor(shape: number[]) {
26 | this.outputShape = shape;
27 | this.userCode = `
28 | void main() {
29 | float re = abs(getRealAtOutCoords());
30 | float im = abs(getImagAtOutCoords());
31 | float mx = max(re, im);
32 |
33 | // sadly the length function in glsl is not underflow-safe
34 | // (at least not on Intel GPUs). So the safe solution is
35 | // to ensure underflow-safety in all cases.
36 | setOutput(
37 | mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
38 | );
39 | }
40 | `;
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/concat_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as concat_util from '../../ops/concat_util';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class ConcatProgram implements GPGPUProgram {
22 | variableNames: string[];
23 | outputShape: number[] = [];
24 | userCode: string;
25 |
26 | // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
27 | constructor(shapes: Array<[number, number]>) {
28 | this.outputShape = concat_util.computeOutShape(shapes, 1 /* axis */);
29 | this.variableNames = shapes.map((_, i) => `T${i}`);
30 |
31 | const offsets: number[] = new Array(shapes.length - 1);
32 | offsets[0] = shapes[0][1];
33 | for (let i = 1; i < offsets.length; i++) {
34 | offsets[i] = offsets[i - 1] + shapes[i][1];
35 | }
36 |
37 | const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
38 | for (let i = 1; i < offsets.length; i++) {
39 | const shift = offsets[i - 1];
40 | snippets.push(
41 | `else if (yC < ${offsets[i]}) ` +
42 | `setOutput(getT${i}(yR, yC-${shift}));`);
43 | }
44 | const lastIndex = offsets.length;
45 | const lastShift = offsets[offsets.length - 1];
46 | snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
47 |
48 | this.userCode = `
49 | void main() {
50 | ivec2 coords = getOutputCoords();
51 | int yR = coords.x;
52 | int yC = coords.y;
53 |
54 | ${snippets.join('\n ')}
55 | }
56 | `;
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/decode_matrix_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 | import * as shader_util from './shader_compiler_util';
21 |
22 | export class DecodeMatrixProgram implements GPGPUProgram {
23 | variableNames = ['A'];
24 | userCode: string;
25 | outputShape: [number, number, number];
26 |
27 | constructor(outputShape: [number, number, number], texShape: [
28 | number, number
29 | ]) {
30 | const glsl = getGlslDifferences();
31 | this.outputShape = outputShape;
32 |
33 | this.userCode = `
34 | ivec3 outCoordsFromFlatIndex(int index) {
35 | ${
36 | shader_util.getLogicalCoordinatesFromFlatIndex(
37 | ['r', 'c', 'd'], outputShape)}
38 | return ivec3(r, c, d);
39 | }
40 |
41 | void main() {
42 | ivec2 resTexRC = ivec2(resultUV.yx *
43 | vec2(${texShape[0]}, ${texShape[1]}));
44 | int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
45 |
46 | vec4 result = vec4(0.);
47 |
48 | for (int i=0; i<4; i++) {
49 | int flatIndex = index + i;
50 | ivec3 rc = outCoordsFromFlatIndex(flatIndex);
51 | result[i] = getA(rc.x, rc.y, rc.z);
52 | }
53 |
54 | ${glsl.output} = result;
55 | }
56 | `;
57 | }
58 | }
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/decode_matrix_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 | import * as shader_util from './shader_compiler_util';
21 |
22 | export class DecodeMatrixPackedProgram implements GPGPUProgram {
23 | variableNames = ['A'];
24 | userCode: string;
25 | usesPackedTextures = true;
26 | outputShape: [number, number, number];
27 |
28 | constructor(outputShape: [number, number, number], texShape: [
29 | number, number
30 | ]) {
31 | const glsl = getGlslDifferences();
32 | this.outputShape = outputShape;
33 |
34 | this.userCode = `
35 | ivec3 outCoordsFromFlatIndex(int index) {
36 | ${
37 | shader_util.getLogicalCoordinatesFromFlatIndex(
38 | ['r', 'c', 'd'], outputShape)}
39 | return ivec3(r, c, d);
40 | }
41 |
42 | void main() {
43 | ivec2 resTexRC = ivec2(resultUV.yx *
44 | vec2(${texShape[0]}, ${texShape[1]}));
45 | int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y);
46 |
47 | vec4 result = vec4(0.);
48 |
49 | for (int i=0; i<4; i++) {
50 | int flatIndex = index + i;
51 | ivec3 rc = outCoordsFromFlatIndex(flatIndex);
52 | result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
53 | }
54 |
55 | ${glsl.output} = result;
56 | }
57 | `;
58 | }
59 | }
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/diag_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export class DiagProgram implements GPGPUProgram {
21 | variableNames = ['X'];
22 | outputShape: number[];
23 | userCode: string;
24 |
25 | constructor(size: number) {
26 | this.outputShape = [size, size];
27 | this.userCode = `
28 | void main() {
29 | ivec2 coords = getOutputCoords();
30 | float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
31 | setOutput(val);
32 | }
33 | `;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/encode_float_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 | import {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util';
21 |
22 | export class EncodeFloatProgram implements GPGPUProgram {
23 | variableNames = ['A'];
24 | userCode: string;
25 | outputShape: number[];
26 |
27 | constructor(outputShape: number[]) {
28 | const glsl = getGlslDifferences();
29 | this.outputShape = outputShape;
30 | this.userCode = `
31 | ${ENCODE_FLOAT_SNIPPET}
32 |
33 | void main() {
34 | float x = getAAtOutCoords();
35 | ${glsl.output} = encode_float(x);
36 | }
37 | `;
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/encode_float_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 | import {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util';
21 |
22 | export class EncodeFloatPackedProgram implements GPGPUProgram {
23 | variableNames = ['A'];
24 | userCode: string;
25 | outputShape: number[];
26 | usesPackedTextures = true;
27 |
28 | constructor(outputShape: [number, number, number]) {
29 | const glsl = getGlslDifferences();
30 | this.outputShape = outputShape;
31 | this.userCode = `
32 | ${ENCODE_FLOAT_SNIPPET}
33 |
34 | void main() {
35 | ivec3 coords = getOutputCoords();
36 | float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
37 | ${glsl.output} = encode_float(x);
38 | }
39 | `;
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/fill_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUContext} from './gpgpu_context';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class FillProgram implements GPGPUProgram {
22 | variableNames: string[];
23 | outputShape: number[] = [];
24 | userCode: string;
25 |
26 | valueLoc: WebGLUniformLocation;
27 |
28 | constructor(shape: number[], value: number) {
29 | this.variableNames = ['x'];
30 | this.outputShape = shape;
31 |
32 | this.userCode = `
33 | uniform float value;
34 | void main() {
35 | // Input can be obtained from uniform value.
36 | setOutput(value);
37 | }
38 | `;
39 | }
40 |
41 | getCustomSetupFunc(value: number) {
42 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
43 | if (this.valueLoc == null) {
44 | this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
45 | }
46 | gpgpu.gl.uniform1f(this.valueLoc, value);
47 | };
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/from_pixels_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class FromPixelsProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | userCode: string;
24 | outputShape: number[];
25 |
26 | constructor(outputShape: number[]) {
27 | const glsl = getGlslDifferences();
28 | const [height, width, ] = outputShape;
29 | this.outputShape = outputShape;
30 | this.userCode = `
31 | void main() {
32 | ivec3 coords = getOutputCoords();
33 | int texR = coords[0];
34 | int texC = coords[1];
35 | int depth = coords[2];
36 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
37 |
38 | vec4 values = ${glsl.texture2D}(A, uv);
39 | float value;
40 | if (depth == 0) {
41 | value = values.r;
42 | } else if (depth == 1) {
43 | value = values.g;
44 | } else if (depth == 2) {
45 | value = values.b;
46 | } else if (depth == 3) {
47 | value = values.a;
48 | }
49 |
50 | setOutput(floor(value * 255.0 + 0.5));
51 | }
52 | `;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/from_pixels_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getGlslDifferences} from './glsl_version';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class FromPixelsPackedProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | userCode: string;
24 | outputShape: number[];
25 |
26 | constructor(outputShape: number[]) {
27 | const glsl = getGlslDifferences();
28 | const [height, width, ] = outputShape;
29 | this.outputShape = outputShape;
30 | this.userCode = `
31 | void main() {
32 | ivec3 coords = getOutputCoords();
33 | int texR = coords[0];
34 | int texC = coords[1];
35 | int depth = coords[2];
36 |
37 | vec4 result = vec4(0.);
38 |
39 | for(int row=0; row<=1; row++) {
40 | for(int col=0; col<=1; col++) {
41 | texC = coords[1] + row;
42 | depth = coords[2] + col;
43 |
44 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${
45 | height}.0);
46 | vec4 values = ${glsl.texture2D}(A, uv);
47 | float value;
48 | if (depth == 0) {
49 | value = values.r;
50 | } else if (depth == 1) {
51 | value = values.g;
52 | } else if (depth == 2) {
53 | value = values.b;
54 | } else if (depth == 3) {
55 | value = values.a;
56 | }
57 |
58 | result[row * 2 + col] = floor(value * 255.0 + 0.5);
59 | }
60 | }
61 |
62 | ${glsl.output} = result;
63 | }
64 | `;
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/gather_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class GatherProgram implements GPGPUProgram {
22 | variableNames = ['A', 'indices'];
23 | outputShape: number[];
24 | userCode: string;
25 | rank: number;
26 |
27 | constructor(aShape: number[], indicesLength: number, axis: number) {
28 | const outputShape: number[] = aShape.slice();
29 | outputShape[axis] = indicesLength;
30 | this.outputShape = outputShape;
31 | this.rank = outputShape.length;
32 | const dtype = getCoordsDataType(this.rank);
33 | const sourceCoords = getSourceCoords(aShape, axis);
34 |
35 | this.userCode = `
36 | void main() {
37 | ${dtype} resRC = getOutputCoords();
38 | setOutput(getA(${sourceCoords}));
39 | }
40 | `;
41 | }
42 | }
43 |
44 | function getSourceCoords(aShape: number[], axis: number): string {
45 | const rank = aShape.length;
46 | if (rank > 4) {
47 | throw Error(`Gather for rank ${rank} is not yet supported`);
48 | }
49 | if (rank === 1) {
50 | return `int(getIndices(resRC))`;
51 | }
52 |
53 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
54 |
55 | const sourceCoords = [];
56 | for (let i = 0; i < aShape.length; i++) {
57 | if (i === axis) {
58 | sourceCoords.push(`int(getIndices(${currentCoords[i]}))`);
59 | } else {
60 | sourceCoords.push(`${currentCoords[i]}`);
61 | }
62 | }
63 | return sourceCoords.join();
64 | }
65 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/gather_nd_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 | import {GPGPUProgram} from './gpgpu_math';
18 | import {getCoordsDataType} from './shader_compiler';
19 |
20 | export class GatherNDProgram implements GPGPUProgram {
21 | variableNames = ['x', 'indices'];
22 | outputShape: number[];
23 | userCode: string;
24 | constructor(
25 | private sliceDim: number, private strides: number[], shape: number[]) {
26 | this.outputShape = shape;
27 | const stridesType = getCoordsDataType(strides.length);
28 | const dtype = getCoordsDataType(shape.length);
29 | const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
30 | this.userCode = `
31 | ${stridesType} strides = ${stridesType}(${this.strides});
32 | void main() {
33 | ${dtype} coords = getOutputCoords();
34 | int flattenIndex = 0;
35 | for (int j = 0; j < ${this.sliceDim}; j++) {
36 | int index = round(getIndices(coords[0], j));
37 | flattenIndex += index * ${strideString};
38 | }
39 | setOutput(getX(flattenIndex, coords[1]));
40 | }
41 | `;
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/multinomial_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUContext} from './gpgpu_context';
19 | import {GPGPUProgram} from './gpgpu_math';
20 |
21 | export class MultinomialProgram implements GPGPUProgram {
22 | variableNames = ['probs'];
23 | outputShape: number[];
24 | userCode: string;
25 |
26 | // Caching uniform location for speed.
27 | seedLoc: WebGLUniformLocation;
28 |
29 | constructor(batchSize: number, numOutcomes: number, numSamples: number) {
30 | this.outputShape = [batchSize, numSamples];
31 |
32 | this.userCode = `
33 | uniform float seed;
34 |
35 | void main() {
36 | ivec2 coords = getOutputCoords();
37 | int batch = coords[0];
38 |
39 | float r = random(seed);
40 | float cdf = 0.0;
41 |
42 | for (int i = 0; i < ${numOutcomes - 1}; i++) {
43 | cdf += getProbs(batch, i);
44 |
45 | if (r < cdf) {
46 | setOutput(float(i));
47 | return;
48 | }
49 | }
50 |
51 | // If no other event happened, last event happened.
52 | setOutput(float(${numOutcomes - 1}));
53 | }
54 | `;
55 | }
56 |
57 | getCustomSetupFunc(seed: number) {
58 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => {
59 | if (this.seedLoc == null) {
60 | this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed');
61 | }
62 | gpgpu.gl.uniform1f(this.seedLoc, seed);
63 | };
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/onehot_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export class OneHotProgram implements GPGPUProgram {
21 | variableNames = ['indices'];
22 | outputShape: number[];
23 | userCode: string;
24 |
25 | // Caching uniform location for speed.
26 | seedLoc: WebGLUniformLocation;
27 |
28 | constructor(
29 | numIndices: number, depth: number, onValue: number, offValue: number) {
30 | this.outputShape = [numIndices, depth];
31 |
32 | this.userCode = `
33 | void main() {
34 | ivec2 coords = getOutputCoords();
35 | int index = round(getIndices(coords.x));
36 | setOutput(mix(float(${offValue}), float(${onValue}),
37 | float(index == coords.y)));
38 | }
39 | `;
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/reverse_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class ReverseProgram implements GPGPUProgram {
22 | variableNames = ['x'];
23 | outputShape: number[];
24 | userCode: string;
25 |
26 | constructor(xShape: number[], axis: number[]) {
27 | const rank = xShape.length;
28 | if (rank > 4) {
29 | throw new Error(
30 | `WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
31 | }
32 | this.outputShape = xShape;
33 |
34 | if (rank === 1) {
35 | this.userCode = `
36 | void main() {
37 | int coord = getOutputCoords();
38 | setOutput(getX(${xShape[0]} - coord - 1));
39 | }
40 | `;
41 | return;
42 | }
43 | const getInCoord = (i: number) => {
44 | if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
45 | return `${xShape[i]} - coords[${i}] - 1`;
46 | }
47 | return `coords[${i}]`;
48 | };
49 | const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
50 | const type = getCoordsDataType(rank);
51 |
52 | this.userCode = `
53 | void main() {
54 | ${type} coords = getOutputCoords();
55 | setOutput(getX(${inCoords}));
56 | }
57 | `;
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/select_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class SelectProgram implements GPGPUProgram {
22 | variableNames = ['c', 'a', 'b'];
23 | outputShape: number[];
24 | userCode: string;
25 |
26 | constructor(cRank: number, shape: number[], rank: number) {
27 | this.outputShape = shape;
28 |
29 | let cCoords;
30 | let abCoords;
31 | if (rank > 4) {
32 | throw Error(`Where for rank ${rank} is not yet supported`);
33 | }
34 |
35 | if (rank === 1) {
36 | abCoords = `resRC`;
37 | cCoords = `resRC`;
38 | } else {
39 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
40 | const cCoordVars = [];
41 | const abCoordVars = [];
42 | for (let i = 0; i < shape.length; i++) {
43 | abCoordVars.push(`${currentCoords[i]}`);
44 | if (i < cRank) {
45 | cCoordVars.push(`${currentCoords[i]}`);
46 | }
47 | }
48 | cCoords = cCoordVars.join();
49 | abCoords = abCoordVars.join();
50 | }
51 |
52 | const dtype = getCoordsDataType(rank);
53 |
54 | this.userCode = `
55 | void main() {
56 | ${dtype} resRC = getOutputCoords();
57 | float cVal = getC(${cCoords});
58 | if (cVal >= 1.0) {
59 | setOutput(getA(${abCoords}));
60 | } else {
61 | setOutput(getB(${abCoords}));
62 | }
63 | }
64 | `;
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/shader_compiler_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {describeWithFlags} from '../../jasmine_util';
19 | import {WEBGL_ENVS} from './backend_webgl_test_registry';
20 | import {dotify, getLogicalCoordinatesFromFlatIndex} from './shader_compiler_util';
21 |
22 | describeWithFlags('shader compiler', WEBGL_ENVS, () => {
23 | it('dotify takes two arrays of coordinates and produces' +
24 | 'the glsl that finds the dot product of those coordinates',
25 | () => {
26 | const coords1 = ['r', 'g', 'b', 'a'];
27 | const coords2 = ['x', 'y', 'z', 'w'];
28 |
29 | expect(dotify(coords1, coords2))
30 | .toEqual('dot(vec4(r,g,b,a), vec4(x,y,z,w))');
31 | });
32 |
33 | it('dotify should split up arrays into increments of vec4s', () => {
34 | const coords1 = ['a', 'b', 'c', 'd', 'e', 'f', 'g'];
35 | const coords2 = ['h', 'i', 'j', 'k', 'l', 'm', 'n'];
36 |
37 | expect(dotify(coords1, coords2))
38 | .toEqual(
39 | 'dot(vec4(a,b,c,d), vec4(h,i,j,k))+dot(vec3(e,f,g), vec3(l,m,n))');
40 | });
41 |
42 | it('getLogicalCoordinatesFromFlatIndex produces glsl that takes' +
43 | 'a flat index and finds its coordinates within that shape',
44 | () => {
45 | const coords = ['r', 'c', 'd'];
46 | const shape = [1, 2, 3];
47 |
48 | expect(getLogicalCoordinatesFromFlatIndex(coords, shape))
49 | .toEqual(
50 | 'int r = index / 6; index -= r * 6;' +
51 | 'int c = index / 3; int d = index - c * 3;');
52 | });
53 | });
54 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/strided_slice_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class StridedSliceProgram implements GPGPUProgram {
22 | variableNames = ['x'];
23 | outputShape: number[];
24 | userCode: string;
25 |
26 | constructor(
27 | begin: number[], strides: number[], size: number[],
28 | shrinkAxis: number[]) {
29 | const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
30 | this.outputShape = shape;
31 | const rank = size.length;
32 | const inputDtype = getCoordsDataType(size.length);
33 | const dtype = getCoordsDataType(shape.length);
34 |
35 | let newCoords = '';
36 | if (rank === 1) {
37 | newCoords = 'coords * strides + begin';
38 | } else {
39 | let outputAxis = 0;
40 | newCoords =
41 | size.map((_, i) => {
42 | if (shrinkAxis.indexOf(i) === -1) {
43 | outputAxis++;
44 | return shape.length === 1 ?
45 | `coords * strides[${i}] + begin[${i}]` :
46 | `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
47 | } else {
48 | return `begin[${i}]`;
49 | }
50 | })
51 | .join(',');
52 | }
53 |
54 | this.userCode = `
55 | ${inputDtype} begin = ${inputDtype}(${begin});
56 | ${inputDtype} strides = ${inputDtype}(${strides});
57 |
58 | void main() {
59 | ${dtype} coords = getOutputCoords();
60 | setOutput(getX(${newCoords}));
61 | }
62 | `;
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/tile_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class TileProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | outputShape: number[];
24 | userCode: string;
25 | rank: number;
26 |
27 | constructor(aShape: number[], reps: number[]) {
28 | const outputShape: number[] = new Array(aShape.length);
29 | for (let i = 0; i < outputShape.length; i++) {
30 | outputShape[i] = aShape[i] * reps[i];
31 | }
32 | this.outputShape = outputShape;
33 | this.rank = outputShape.length;
34 | const dtype = getCoordsDataType(this.rank);
35 | const sourceCoords = getSourceCoords(aShape);
36 |
37 | this.userCode = `
38 | void main() {
39 | ${dtype} resRC = getOutputCoords();
40 | setOutput(getA(${sourceCoords}));
41 | }
42 | `;
43 | }
44 | }
45 |
46 | function getSourceCoords(aShape: number[]): string {
47 | const rank = aShape.length;
48 | if (rank > 5) {
49 | throw Error(`Tile for rank ${rank} is not yet supported`);
50 | }
51 | if (rank === 1) {
52 | return `imod(resRC, ${aShape[0]})`;
53 | }
54 |
55 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
56 |
57 | const sourceCoords = [];
58 | for (let i = 0; i < aShape.length; i++) {
59 | sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
60 | }
61 | return sourceCoords.join();
62 | }
63 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/transpose_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 | import {getCoordsDataType} from './shader_compiler';
20 |
21 | export class TransposeProgram implements GPGPUProgram {
22 | variableNames = ['A'];
23 | outputShape: number[];
24 | userCode: string;
25 | rank: number;
26 |
27 | constructor(aShape: number[], newDim: number[]) {
28 | const outputShape: number[] = new Array(aShape.length);
29 | for (let i = 0; i < outputShape.length; i++) {
30 | outputShape[i] = aShape[newDim[i]];
31 | }
32 | this.outputShape = outputShape;
33 | this.rank = outputShape.length;
34 | const dtype = getCoordsDataType(this.rank);
35 | const switched = getSwitchedCoords(newDim);
36 |
37 | this.userCode = `
38 | void main() {
39 | ${dtype} resRC = getOutputCoords();
40 | setOutput(getA(${switched}));
41 | }
42 | `;
43 | }
44 | }
45 |
46 | function getSwitchedCoords(newDim: number[]): string {
47 | const rank = newDim.length;
48 | if (rank > 6) {
49 | throw Error(`Transpose for rank ${rank} is not yet supported`);
50 | }
51 | const originalOrder =
52 | ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
53 | const switchedCoords = new Array(rank);
54 | for (let i = 0; i < newDim.length; i++) {
55 | switchedCoords[newDim[i]] = originalOrder[i];
56 | }
57 | return switchedCoords.join();
58 | }
59 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/unaryop_packed_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {GPGPUProgram} from './gpgpu_math';
19 |
20 | export const LINEAR = `return x;`;
21 |
22 | export const LOG = `
23 | vec4 result = log(x);
24 | vec4 isNaN = vec4(lessThan(x, vec4(0.0)));
25 | result.r = isNaN.r == 1.0 ? NAN : result.r;
26 | result.g = isNaN.g == 1.0 ? NAN : result.g;
27 | result.b = isNaN.b == 1.0 ? NAN : result.b;
28 | result.a = isNaN.a == 1.0 ? NAN : result.a;
29 |
30 | return result;
31 | `;
32 |
33 | export const RELU = `
34 | vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
35 | bvec4 isNaN = isnan(x);
36 |
37 | result.r = isNaN.r ? x.r : result.r;
38 | result.g = isNaN.g ? x.g : result.g;
39 | result.b = isNaN.b ? x.b : result.b;
40 | result.a = isNaN.a ? x.a : result.a;
41 |
42 | return result;
43 | `;
44 |
45 | export class UnaryOpPackedProgram implements GPGPUProgram {
46 | variableNames = ['A'];
47 | userCode: string;
48 | outputShape: number[];
49 | usesPackedTextures = true;
50 |
51 | constructor(aShape: number[], opSnippet: string) {
52 | this.outputShape = aShape;
53 | this.userCode = `
54 | vec4 unaryOperation(vec4 x) {
55 | ${opSnippet}
56 | }
57 |
58 | void main() {
59 | vec4 x = getAAtOutCoords();
60 | vec4 y = unaryOperation(x);
61 |
62 | setOutput(y);
63 | }
64 | `;
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/unpack_gpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {getChannels, getSourceCoords} from '../packing_util';
19 |
20 | import {GPGPUProgram} from './gpgpu_math';
21 | import {getCoordsDataType} from './shader_compiler';
22 |
23 | export class UnpackProgram implements GPGPUProgram {
24 | variableNames = ['A'];
25 | usesPackedTextures = true;
26 | outputShape: number[];
27 | userCode: string;
28 |
29 | constructor(outputShape: number[]) {
30 | this.outputShape = outputShape;
31 | const rank = outputShape.length;
32 |
33 | const channels = getChannels('rc', rank);
34 | const dtype = getCoordsDataType(rank);
35 | const sourceCoords = getSourceCoords(rank, channels);
36 | const innerDims = channels.slice(-2);
37 | const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
38 |
39 | this.userCode = `
40 | void main() {
41 | ${dtype} rc = getOutputCoords();
42 | vec4 packedInput = getA(${sourceCoords});
43 |
44 | setOutput(getChannel(packedInput, ${coords}));
45 | }
46 | `;
47 | }
48 | }
--------------------------------------------------------------------------------
/tfjs-core/src/backends/webgl/webgl_types.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // TODO(nsthorat): Move these to the webgl official typings.
19 | export interface WebGL2DisjointQueryTimerExtension {
20 | TIME_ELAPSED_EXT: number;
21 | GPU_DISJOINT_EXT: number;
22 | }
23 |
24 | export interface WebGL1DisjointQueryTimerExtension {
25 | TIME_ELAPSED_EXT: number;
26 | QUERY_RESULT_AVAILABLE_EXT: number;
27 | GPU_DISJOINT_EXT: number;
28 | QUERY_RESULT_EXT: number;
29 | createQueryEXT: () => {};
30 | beginQueryEXT: (ext: number, query: WebGLQuery) => void;
31 | endQueryEXT: (ext: number) => void;
32 | deleteQueryEXT: (query: WebGLQuery) => void;
33 | isQueryEXT: (query: WebGLQuery) => boolean;
34 | getQueryObjectEXT:
35 | (query: WebGLQuery, queryResultAvailableExt: number) => number;
36 | }
37 |
38 | export interface WebGLContextAttributes {
39 | alpha?: boolean;
40 | antialias?: boolean;
41 | premultipliedAlpha?: boolean;
42 | preserveDrawingBuffer?: boolean;
43 | depth?: boolean;
44 | stencil?: boolean;
45 | failIfMajorPerformanceCaveat?: boolean;
46 | }
47 |
--------------------------------------------------------------------------------
/tfjs-core/src/backends/where_impl.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /** An implementation of the Where kernel shared between cpu and webgl */
19 |
20 | import {buffer} from '../ops/array_ops';
21 | import {Tensor2D} from '../tensor';
22 | import {TypedArray} from '../types';
23 |
24 | export function whereImpl(condShape: number[], condVals: TypedArray): Tensor2D {
25 | const indices = [];
26 | for (let i = 0; i < condVals.length; i++) {
27 | if (condVals[i]) {
28 | indices.push(i);
29 | }
30 | }
31 |
32 | const inBuffer = buffer(condShape, 'int32');
33 |
34 | const out = buffer([indices.length, condShape.length], 'int32');
35 | for (let i = 0; i < indices.length; i++) {
36 | const loc = inBuffer.indexToLoc(indices[i]);
37 | const offset = i * condShape.length;
38 | out.values.set(loc, offset);
39 | }
40 | return out.toTensor() as Tensor2D;
41 | }
42 |
--------------------------------------------------------------------------------
/tfjs-core/src/browser_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | const delayCallback: Function = (() => {
19 | if (typeof requestAnimationFrame !== 'undefined') {
20 | return requestAnimationFrame;
21 | } else if (typeof setImmediate !== 'undefined') {
22 | return setImmediate;
23 | }
24 | return (f: Function) => f(); // no delays
25 | })();
26 |
27 | /**
28 | * Returns a promise that resolve when a requestAnimationFrame has completed.
29 | *
30 | * On Node.js this uses setImmediate instead of requestAnimationFrame.
31 | *
32 | * This is simply a sugar method so that users can do the following:
33 | * `await tf.nextFrame();`
34 | */
35 | /** @doc {heading: 'Performance', subheading: 'Timing'} */
36 | function nextFrame(): Promise {
37 | return new Promise(resolve => delayCallback(() => resolve()));
38 | }
39 |
40 | export {nextFrame};
41 |
--------------------------------------------------------------------------------
/tfjs-core/src/browser_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from './index';
19 | import {ALL_ENVS, describeWithFlags} from './jasmine_util';
20 |
21 | describeWithFlags('nextFrame', ALL_ENVS, () => {
22 | it('basic usage', async () => {
23 | const t0 = tf.util.now();
24 | await tf.nextFrame();
25 | const t1 = tf.util.now();
26 | // tf.util.now should give sufficient accuracy on all supported envs.
27 | expect(t1 > t0);
28 | });
29 |
30 | it('does not block timers', async () => {
31 | let flag = false;
32 | setTimeout(() => {
33 | flag = true;
34 | }, 50);
35 | const t0 = tf.util.now();
36 | expect(flag).toBe(false);
37 | while (tf.util.now() - t0 < 1000 && !flag) {
38 | await tf.nextFrame();
39 | }
40 | expect(flag).toBe(true);
41 | });
42 | });
43 |
--------------------------------------------------------------------------------
/tfjs-core/src/log.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ENV} from './environment';
19 |
20 | export function warn(...msg: Array<{}>): void {
21 | if (!ENV.getBool('IS_TEST')) {
22 | console.warn(...msg);
23 | }
24 | }
25 |
26 | export function log(...msg: Array<{}>): void {
27 | if (!ENV.getBool('IS_TEST')) {
28 | console.log(...msg);
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/tfjs-core/src/math.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /**
19 | * Exports under the tf.math.* namespace.
20 | */
21 |
22 | import {confusionMatrix} from './ops/confusion_matrix';
23 |
24 | export {confusionMatrix};
25 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/clone_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '../index';
19 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20 | import {expectArraysClose} from '../test_util';
21 |
22 | describeWithFlags('clone', ALL_ENVS, () => {
23 | it('returns a tensor with the same shape and value', async () => {
24 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]);
25 | const aPrime = tf.clone(a);
26 | expect(aPrime.shape).toEqual(a.shape);
27 | expectArraysClose(await aPrime.data(), await a.data());
28 | expect(aPrime.shape).toEqual(a.shape);
29 | });
30 |
31 | it('accepts a tensor-like object', async () => {
32 | const res = tf.clone([[1, 2, 3], [4, 5, 6]]);
33 | expect(res.dtype).toBe('float32');
34 | expect(res.shape).toEqual([2, 3]);
35 | expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]);
36 | });
37 | });
38 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/concat_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as util from '../util';
19 |
20 | export function assertParamsConsistent(shapes: number[][], axis: number) {
21 | const rank = shapes[0].length;
22 | shapes.forEach((shape, i) => {
23 | util.assert(
24 | shape.length === rank,
25 | () =>
26 | `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
27 | `as the rank of the rest (${rank})`);
28 | });
29 |
30 | util.assert(
31 | axis >= 0 && axis < rank,
32 | () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
33 |
34 | const firstShape = shapes[0];
35 | shapes.forEach((shape, i) => {
36 | for (let r = 0; r < rank; r++) {
37 | util.assert(
38 | (r === axis) || (shape[r] === firstShape[r]),
39 | () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
40 | `does not match the shape of the rest (${firstShape}) ` +
41 | `along the non-concatenated axis ${i}.`);
42 | }
43 | });
44 | }
45 |
46 | export function computeOutShape(shapes: number[][], axis: number): number[] {
47 | const outputShape = shapes[0].slice();
48 | for (let i = 1; i < shapes.length; i++) {
49 | outputShape[axis] += shapes[i][axis];
50 | }
51 | return outputShape;
52 | }
53 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/diag.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ENGINE} from '../engine';
19 | import {Tensor} from '../tensor';
20 | import {convertToTensor} from '../tensor_util_env';
21 | import {op} from './operation';
22 |
23 | /**
24 | * Returns a diagonal tensor with a given diagonal values.
25 | *
26 | * Given a diagonal, this operation returns a tensor with the diagonal and
27 | * everything else padded with zeros.
28 | *
29 | * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
30 | * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
31 | *
32 | * ```js
33 | * const x = tf.tensor1d([1, 2, 3, 4]);
34 | *
35 | * tf.diag(x).print()
36 | * ```
37 | * ```js
38 | * const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
39 | *
40 | * tf.diag(x).print()
41 | * ```
42 | * @param x The input tensor.
43 | */
44 | function diag_(x: Tensor): Tensor {
45 | const $x = convertToTensor(x, 'x', 'diag').flatten();
46 | const outShape = [...x.shape, ...x.shape];
47 | return ENGINE.runKernel(backend => backend.diag($x), {$x}).reshape(outShape);
48 | }
49 |
50 | export const diag = op({diag_});
51 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/dropout_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Tensor} from '../tensor';
19 | import * as util from '../util';
20 |
21 | /**
22 | * Normalize noise shape based on provided tensor and noise shape.
23 | *
24 | * @param x Tensor.
25 | * @param noiseShape The shape for the randomly generated keep/drop flags, as
26 | * an array of numbers. Optional.
27 | * @returns Normalized noise shape.
28 | */
29 | export function getNoiseShape(x: Tensor, noiseShape?: number[]): number[] {
30 | if (noiseShape == null) {
31 | return x.shape.slice();
32 | }
33 | if (util.arraysEqual(x.shape, noiseShape)) {
34 | return noiseShape;
35 | }
36 | if (x.shape.length === noiseShape.length) {
37 | const newDimension: number[] = [];
38 | for (let i = 0; i < x.shape.length; i++) {
39 | if (noiseShape[i] == null && x.shape[i] != null) {
40 | newDimension.push(x.shape[i]);
41 | } else {
42 | newDimension.push(noiseShape[i]);
43 | }
44 | }
45 | return newDimension;
46 | }
47 |
48 | return noiseShape;
49 | }
50 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/dropout_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '../index';
19 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20 | import {getNoiseShape} from './dropout_util';
21 |
22 | describeWithFlags('getNoiseShape', ALL_ENVS, () => {
23 | it('x.shape == noiseShape', async () => {
24 | const x = tf.ones([2, 3]);
25 | const noiseShape = [2, 3];
26 | const shape = getNoiseShape(x, noiseShape);
27 | expect(shape).toEqual([2, 3]);
28 | });
29 |
30 | it('x.shape and noiseShape have same length, different value', async () => {
31 | const x = tf.ones([2, 3]);
32 | const noiseShape = [2, 1];
33 | const shape = getNoiseShape(x, noiseShape);
34 | expect(shape).toEqual([2, 1]);
35 | });
36 |
37 | it('noiseShape has null value', async () => {
38 | const x = tf.ones([2, 3]);
39 | const noiseShape = [2, null];
40 | const shape = getNoiseShape(x, noiseShape);
41 | expect(shape).toEqual([2, 3]);
42 | });
43 |
44 | it('x.shape and noiseShape has different length', async () => {
45 | const x = tf.ones([2, 3, 4]);
46 | const noiseShape = [2, 3];
47 | const shape = getNoiseShape(x, noiseShape);
48 | expect(shape).toEqual([2, 3]);
49 | });
50 |
51 | it('noiseShape is null', async () => {
52 | const x = tf.ones([2, 3]);
53 | const shape = getNoiseShape(x, null);
54 | expect(shape).toEqual([2, 3]);
55 | });
56 | });
57 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/erf_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | export const ERF_P = 0.3275911;
19 | export const ERF_A1 = 0.254829592;
20 | export const ERF_A2 = -0.284496736;
21 | export const ERF_A3 = 1.421413741;
22 | export const ERF_A4 = -1.453152027;
23 | export const ERF_A5 = 1.061405429;
24 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/fused_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Tensor, Tensor3D} from '../tensor';
19 |
20 | export type Activation = 'linear'|'relu'|'prelu';
21 |
22 | export type FusedBatchMatMulConfig = {
23 | a: Tensor3D,
24 | b: Tensor3D,
25 | transposeA: boolean,
26 | transposeB: boolean,
27 | bias?: Tensor,
28 | activation?: Activation,
29 | preluActivationWeights?: Tensor
30 | };
31 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/operation.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 | import {ENGINE} from '../engine';
18 |
19 | /**
20 | * Used for wrapping functions that perform math operations on
21 | * Tensors. The function will be wrapped in a named scope that cleans all
22 | * memory usage after the function is done.
23 | */
24 | export function op(f: {[name: string]: T}): T {
25 | const keys = Object.keys(f);
26 | if (keys.length !== 1) {
27 | throw new Error(
28 | `Please provide an object with a single key ` +
29 | `(operation name) mapping to a function. Got an object with ` +
30 | `${keys.length} keys.`);
31 | }
32 |
33 | let opName = keys[0];
34 | const fn = f[opName];
35 |
36 | // Strip the underscore from the end of the function name.
37 | if (opName.endsWith('_')) {
38 | opName = opName.substring(0, opName.length - 1);
39 | }
40 |
41 | // tslint:disable-next-line:no-any
42 | const f2 = (...args: any[]) => {
43 | ENGINE.startScope(opName);
44 | try {
45 | const result = fn(...args);
46 | if (result instanceof Promise) {
47 | console.error('Cannot return a Promise inside of tidy.');
48 | }
49 | ENGINE.endScope(result);
50 | return result;
51 | } catch (ex) {
52 | ENGINE.endScope(null);
53 | throw ex;
54 | }
55 | };
56 | Object.defineProperty(f2, 'name', {value: opName, configurable: true});
57 |
58 | // tslint:disable-next-line:no-any
59 | return f2 as any as T;
60 | }
61 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/operation_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
18 | import {op} from './operation';
19 |
20 | describeWithFlags('operation', ALL_ENVS, () => {
21 | it('executes and preserves function name', () => {
22 | const f = () => 2;
23 | const opfn = op({'opName': f});
24 |
25 | expect(opfn.name).toBe('opName');
26 | expect(opfn()).toBe(2);
27 | });
28 |
29 | it('executes, preserves function name, strips underscore', () => {
30 | const f = () => 2;
31 | const opfn = op({'opName_': f});
32 |
33 | expect(opfn.name).toBe('opName');
34 | expect(opfn()).toBe(2);
35 | });
36 |
37 | it('throws when passing an object with multiple keys', () => {
38 | const f = () => 2;
39 | expect(() => op({'opName_': f, 'opName2_': f}))
40 | .toThrowError(/Please provide an object with a single key/);
41 | });
42 | });
43 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/ops.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | export * from './batchnorm';
19 | export * from './boolean_mask';
20 | export * from './complex_ops';
21 | export * from './concat_split';
22 | export * from './conv';
23 | export * from './matmul';
24 | export * from './reverse';
25 | export * from './pool';
26 | export * from './slice';
27 | export * from './unary_ops';
28 | export * from './reduction_ops';
29 | export * from './compare';
30 | export * from './binary_ops';
31 | export * from './relu_ops';
32 | export * from './logical_ops';
33 | export * from './array_ops';
34 | export * from './tensor_ops';
35 | export * from './transpose';
36 | export * from './softmax';
37 | export * from './lrn';
38 | export * from './norm';
39 | export * from './segment_ops';
40 | export * from './lstm';
41 | export * from './moving_average';
42 | export * from './strided_slice';
43 | export * from './topk';
44 | export * from './scatter_nd';
45 | export * from './spectral_ops';
46 | export * from './sparse_to_dense';
47 | export * from './gather_nd';
48 | export * from './diag';
49 | export * from './dropout';
50 | export * from './signal_ops';
51 | export * from './in_top_k';
52 |
53 | export {op} from './operation';
54 |
55 | // Second level exports.
56 | import * as losses from './loss_ops';
57 | import * as linalg from './linalg_ops';
58 | import * as image from './image_ops';
59 | import * as spectral from './spectral_ops';
60 | import * as fused from './fused_ops';
61 | import * as signal from './signal_ops';
62 |
63 | export {image, linalg, losses, spectral, fused, signal};
64 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/reduce_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /**
19 | * Inputs of size above this threshold will be parallelized by calling multiple
20 | * shader programs.
21 | */
22 | import {nearestDivisor} from '../util';
23 |
24 | export const PARALLELIZE_THRESHOLD = 30;
25 |
26 | export interface ReduceInfo {
27 | windowSize: number;
28 | batchSize: number;
29 | inSize: number;
30 | }
31 |
32 | export function computeOptimalWindowSize(inSize: number): number {
33 | if (inSize <= PARALLELIZE_THRESHOLD) {
34 | return inSize;
35 | }
36 | return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
37 | }
38 |
--------------------------------------------------------------------------------
/tfjs-core/src/ops/selu_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | export const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
19 | export const SELU_SCALE = 1.0507009873554804934193349852946;
20 |
--------------------------------------------------------------------------------
/tfjs-core/src/platforms/platform.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {RequestDetails} from '../io/types';
19 |
20 | /**
21 | * At any given time a single platform is active and represents and
22 | * implementation of this interface. In practice, a platform is an environment
23 | * where TensorFlow.js can be executed, e.g. the browser or Node.js.
24 | */
25 | export interface Platform {
26 | /**
27 | * Makes an HTTP request.
28 | * @param path The URL path to make a request to
29 | * @param init The request init. See init here:
30 | * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
31 | */
32 | fetch(path: string, requestInits?: RequestInit, options?: RequestDetails):
33 | Promise;
34 |
35 | /**
36 | * Returns the current high-resolution time in milliseconds relative to an
37 | * arbitrary time in the past. It works across different platforms (node.js,
38 | * browsers).
39 | */
40 | now(): number;
41 |
42 | /**
43 | * Encode the provided string into an array of bytes using the provided
44 | * encoding.
45 | */
46 | encode(text: string, encoding: string): Uint8Array;
47 | /** Decode the provided bytes into a string using the provided encoding. */
48 | decode(bytes: Uint8Array, encoding: string): string;
49 | }
50 |
--------------------------------------------------------------------------------
/tfjs-core/src/platforms/platform_browser.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 | import {ENV} from '../environment';
18 | import {Platform} from './platform';
19 |
20 | export class PlatformBrowser implements Platform {
21 | private textEncoder: TextEncoder;
22 |
23 | constructor() {
24 | // According to the spec, the built-in encoder can do only UTF-8 encoding.
25 | // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
26 | this.textEncoder = new TextEncoder();
27 | }
28 |
29 | fetch(path: string, init?: RequestInit): Promise {
30 | return fetch(path, init);
31 | }
32 |
33 | now(): number {
34 | return performance.now();
35 | }
36 |
37 | encode(text: string, encoding: string): Uint8Array {
38 | if (encoding !== 'utf-8' && encoding !== 'utf8') {
39 | throw new Error(
40 | `Browser's encoder only supports utf-8, but got ${encoding}`);
41 | }
42 | return this.textEncoder.encode(text);
43 | }
44 | decode(bytes: Uint8Array, encoding: string): string {
45 | return new TextDecoder(encoding).decode(bytes);
46 | }
47 | }
48 |
49 | if (ENV.get('IS_BROWSER')) {
50 | ENV.setPlatform('browser', new PlatformBrowser());
51 | }
52 |
--------------------------------------------------------------------------------
/tfjs-core/src/setup_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | /**
19 | * This file is necessary so we register all test environments before we start
20 | * executing tests.
21 | */
22 | import './backends/cpu/backend_cpu_test_registry';
23 | import './backends/webgl/backend_webgl_test_registry';
24 |
25 | import {parseTestEnvFromKarmaFlags, setTestEnvs, TEST_ENVS} from './jasmine_util';
26 |
27 | // tslint:disable-next-line:no-any
28 | declare let __karma__: any;
29 | if (typeof __karma__ !== 'undefined') {
30 | const testEnv = parseTestEnvFromKarmaFlags(__karma__.config.args, TEST_ENVS);
31 | if (testEnv != null) {
32 | setTestEnvs([testEnv]);
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/tfjs-core/src/tensor_types.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {Tensor, Variable} from './tensor';
19 | import {DataType} from './types';
20 |
21 | /** @docalias {[name: string]: Tensor} */
22 | export type NamedTensorMap = {
23 | [name: string]: Tensor;
24 | };
25 |
26 | export interface NamedTensor {
27 | name: string;
28 | tensor: Tensor;
29 | }
30 |
31 | export type NamedVariableMap = {
32 | [name: string]: Variable;
33 | };
34 |
35 | export type GradSaveFunc = (save: Tensor[]) => void;
36 |
37 | /**
38 | * @docalias void|number|string|TypedArray|Tensor|Tensor[]|{[key:
39 | * string]:Tensor|number|string}
40 | */
41 | export type TensorContainer =
42 | void|Tensor|string|number|boolean|TensorContainerObject|
43 | TensorContainerArray|Float32Array|Int32Array|Uint8Array;
44 | export interface TensorContainerObject {
45 | [x: string]: TensorContainer;
46 | }
47 | export interface TensorContainerArray extends Array {}
48 |
49 | export interface TensorInfo {
50 | // Name of the tensor.
51 | name: string;
52 | // Tensor shape information, Optional.
53 | shape?: number[];
54 | // Data type of the tensor.
55 | dtype: DataType;
56 | }
57 |
--------------------------------------------------------------------------------
/tfjs-core/src/test_node.ts:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | /**
3 | * @license
4 | * Copyright 2019 Google LLC. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | * =============================================================================
17 | */
18 |
19 | import {setTestEnvs} from './jasmine_util';
20 |
21 | // tslint:disable-next-line:no-require-imports
22 | const jasmine = require('jasmine');
23 |
24 | process.on('unhandledRejection', e => {
25 | throw e;
26 | });
27 |
28 | setTestEnvs([{name: 'node', backendName: 'cpu'}]);
29 |
30 | const runner = new jasmine();
31 | runner.loadConfig({spec_files: ['dist/**/**_test.js'], random: false});
32 | runner.execute();
33 |
--------------------------------------------------------------------------------
/tfjs-core/src/train.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // So typings can propagate.
19 | import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
20 | import {AdagradOptimizer} from './optimizers/adagrad_optimizer';
21 | import {AdamOptimizer} from './optimizers/adam_optimizer';
22 | import {AdamaxOptimizer} from './optimizers/adamax_optimizer';
23 | import {MomentumOptimizer} from './optimizers/momentum_optimizer';
24 | import {OptimizerConstructors} from './optimizers/optimizer_constructors';
25 | import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
26 | import {SGDOptimizer} from './optimizers/sgd_optimizer';
27 |
28 | // tslint:disable-next-line:no-unused-expression
29 | [MomentumOptimizer, SGDOptimizer, AdadeltaOptimizer, AdagradOptimizer,
30 | RMSPropOptimizer, AdamaxOptimizer, AdamOptimizer];
31 |
32 | export const train = {
33 | sgd: OptimizerConstructors.sgd,
34 | momentum: OptimizerConstructors.momentum,
35 | adadelta: OptimizerConstructors.adadelta,
36 | adagrad: OptimizerConstructors.adagrad,
37 | rmsprop: OptimizerConstructors.rmsprop,
38 | adamax: OptimizerConstructors.adamax,
39 | adam: OptimizerConstructors.adam
40 | };
41 |
--------------------------------------------------------------------------------
/tfjs-core/src/types_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2018 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {upcastType} from './types';
19 |
20 | describe('upcastType', () => {
21 | it('upcasts bool to bool', () => {
22 | expect(upcastType('bool', 'bool')).toBe('bool');
23 | });
24 |
25 | it('upcasts bool/int32 to int32', () => {
26 | expect(upcastType('bool', 'int32')).toBe('int32');
27 | expect(upcastType('int32', 'int32')).toBe('int32');
28 | });
29 |
30 | it('upcasts bool/int32/float32 to float32', () => {
31 | expect(upcastType('bool', 'float32')).toBe('float32');
32 | expect(upcastType('int32', 'float32')).toBe('float32');
33 | expect(upcastType('float32', 'float32')).toBe('float32');
34 | });
35 |
36 | it('upcasts bool/int32/float32/complex64 to complex64', () => {
37 | expect(upcastType('bool', 'complex64')).toBe('complex64');
38 | expect(upcastType('int32', 'complex64')).toBe('complex64');
39 | expect(upcastType('float32', 'complex64')).toBe('complex64');
40 | expect(upcastType('complex64', 'complex64')).toBe('complex64');
41 | });
42 |
43 | it('fails to upcast anything other than string with string', () => {
44 | expect(() => upcastType('bool', 'string')).toThrowError();
45 | expect(() => upcastType('int32', 'string')).toThrowError();
46 | expect(() => upcastType('float32', 'string')).toThrowError();
47 | expect(() => upcastType('complex64', 'string')).toThrowError();
48 | // Ok upcasting string to string.
49 | expect(upcastType('string', 'string')).toBe('string');
50 | });
51 | });
52 |
--------------------------------------------------------------------------------
/tfjs-core/src/version.ts:
--------------------------------------------------------------------------------
1 | /** @license See the LICENSE file. */
2 |
3 | // This code is auto-generated, do not modify this file!
4 | const version = '1.2.7';
5 | export {version};
6 |
--------------------------------------------------------------------------------
/tfjs-core/src/version_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2017 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {version_core} from './index';
19 |
20 | describe('version', () => {
21 | it('version is contained', () => {
22 | // tslint:disable-next-line:no-require-imports
23 | const expected = require('../package.json').version;
24 | expect(version_core).toBe(expected);
25 | });
26 | });
27 |
--------------------------------------------------------------------------------
/tfjs-core/src/webgl.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as gpgpu_util from './backends/webgl/gpgpu_util';
19 | import * as webgl_util from './backends/webgl/webgl_util';
20 |
21 | export {MathBackendWebGL, WebGLMemoryInfo, WebGLTimingInfo} from './backends/webgl/backend_webgl';
22 | export {setWebGLContext} from './backends/webgl/canvas_util';
23 | export {GPGPUContext} from './backends/webgl/gpgpu_context';
24 | export {GPGPUProgram} from './backends/webgl/gpgpu_math';
25 | // WebGL specific utils.
26 | export {gpgpu_util, webgl_util};
27 |
--------------------------------------------------------------------------------
/tfjs-core/src/worker_node_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {describeWithFlags, HAS_NODE_WORKER} from './jasmine_util';
19 | import {expectArraysClose} from './test_util';
20 | // tslint:disable:no-require-imports
21 |
22 | const fn2String = (fn: Function): string => {
23 | const funcStr = '(' + fn.toString() + ')()';
24 | return funcStr;
25 | };
26 |
27 | // The source code of a web worker.
28 | const workerTestNode = () => {
29 | // Web worker scripts in node live relative to the CWD, not to the dir of the
30 | // file that spawned them.
31 | const tf = require('./dist/index.js');
32 | const {parentPort} = require('worker_threads');
33 | let a = tf.tensor1d([1, 2, 3]);
34 | const b = tf.tensor1d([3, 2, 1]);
35 | a = a.add(b);
36 | parentPort.postMessage({data: a.dataSync()});
37 | };
38 |
39 | describeWithFlags('computation in worker (node env)', HAS_NODE_WORKER, () => {
40 | it('tensor in worker', (done) => {
41 | const {Worker} = require('worker_threads');
42 | const worker = new Worker(fn2String(workerTestNode), {eval: true});
43 | // tslint:disable-next-line:no-any
44 | worker.on('message', (msg: any) => {
45 | const data = msg.data;
46 | expectArraysClose(data, [4, 4, 4]);
47 | done();
48 | });
49 | });
50 | });
51 |
--------------------------------------------------------------------------------
/tfjs-core/src/worker_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from './index';
19 | import {describeWithFlags, HAS_WORKER} from './jasmine_util';
20 | import {expectArraysClose} from './test_util';
21 |
22 | const fn2workerURL = (fn: Function): string => {
23 | const blob =
24 | new Blob(['(' + fn.toString() + ')()'], {type: 'application/javascript'});
25 | return URL.createObjectURL(blob);
26 | };
27 |
28 | // The source code of a web worker.
29 | const workerTest = () => {
30 | //@ts-ignore
31 | importScripts('http://bs-local.com:12345/base/dist/tf-core.min.js');
32 | let a = tf.tensor1d([1, 2, 3]);
33 | const b = tf.tensor1d([3, 2, 1]);
34 | a = a.add(b);
35 | //@ts-ignore
36 | self.postMessage({data: a.dataSync()});
37 | };
38 |
39 | describeWithFlags('computation in worker', HAS_WORKER, () => {
40 | it('tensor in worker', (done) => {
41 | const worker = new Worker(fn2workerURL(workerTest));
42 | worker.onmessage = (msg) => {
43 | const data = msg.data.data;
44 | expectArraysClose(data, [4, 4, 4]);
45 | done();
46 | };
47 | });
48 | });
49 |
--------------------------------------------------------------------------------
/tfjs-core/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tsconfig",
3 | "include": [
4 | "src/"
5 | ],
6 | "exclude": [
7 | "node_modules/"
8 | ],
9 | "compilerOptions": {
10 | "outDir": "./dist"
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/tfjs-core/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tslint.json"
3 | }
4 |
--------------------------------------------------------------------------------
/tfjs-react-native/.npmignore:
--------------------------------------------------------------------------------
1 | .babelrc
2 | .DS_Store
3 | .idea/
4 | .rpt2_cache
5 | .travis.yml
6 | .vscode
7 | *.tgz
8 | *.txt
9 | **.yalc
10 | **yalc.lock
11 | cloudbuild.yml
12 | coverage/
13 | demo/
14 | dist/**/*_test.d.ts
15 | dist/**/*_test.js
16 | karma.conf.js
17 | node_modules/
18 | npm-debug.log
19 | package-lock.json
20 | package/
21 | rollup.config.js
22 | scripts/
23 | src/**/*_test.ts
24 | tsconfig.json
25 | tslint.json
26 | yarn-error.log
27 | yarn.lock
28 |
--------------------------------------------------------------------------------
/tfjs-react-native/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "search.exclude": {
3 | "**/node_modules": true,
4 | "coverage/": true,
5 | "dist/": true,
6 | "**/yarn.lock": true,
7 | ".rpt2_cache/": true,
8 | ".yalc/": true
9 | },
10 | "tslint.configFile": "tslint.json",
11 | "files.trimTrailingWhitespace": true,
12 | "editor.tabSize": 2,
13 | "editor.insertSpaces": true,
14 | "[typescript]": {
15 | "editor.formatOnSave": true
16 | },
17 | "[javascript]": {
18 | "editor.formatOnSave": true
19 | },
20 | "editor.rulers": [80],
21 | "clang-format.style": "Google",
22 | "files.insertFinalNewline": true,
23 | "editor.detectIndentation": false,
24 | "editor.wrappingIndent": "none",
25 | "typescript.tsdk": "./node_modules/typescript/lib",
26 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
27 | }
28 |
--------------------------------------------------------------------------------
/tfjs-react-native/README.md:
--------------------------------------------------------------------------------
1 | # Platform Adapter for React Native
2 |
3 | Status: Early development.
4 |
--------------------------------------------------------------------------------
/tfjs-react-native/cloudbuild.yml:
--------------------------------------------------------------------------------
1 | steps:
2 | # Install common dependencies.
3 | - name: 'node:10'
4 | id: 'yarn-common'
5 | entrypoint: 'yarn'
6 | args: ['install']
7 |
8 | # Install react native dependencies.
9 | - name: 'node:10'
10 | dir: 'tfjs-react-native'
11 | entrypoint: 'yarn'
12 | id: 'test-react-native'
13 | args: ['test-ci']
14 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1']
15 | secretEnv: ['BROWSERSTACK_KEY']
16 | waitFor: ['yarn-common']
17 |
18 | # General configuration
19 | secrets:
20 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc
21 | secretEnv:
22 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA=
23 | timeout: 1800s
24 | logsBucket: 'gs://tfjs-build-logs'
25 | substitutions:
26 | _NIGHTLY: ''
27 | options:
28 | logStreamingOption: 'STREAM_ON'
29 | substitution_option: 'ALLOW_LOOSE'
30 |
--------------------------------------------------------------------------------
/tfjs-react-native/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@tensorflow/tfjs-platform-react-native",
3 | "version": "0.1.0",
4 | "description": "TensorFlow.js platform implementation for React Native",
5 | "main": "dist/index.js",
6 | "types": "dist/index.d.ts",
7 | "jsnext:main": "dist/tf-react-native.esm.js",
8 | "module": "dist/tf-react-native.esm.js",
9 | "unpkg": "dist/tf-react-native.min.js",
10 | "jsdelivr": "dist/tf-react-native.min.js",
11 | "license": "Apache-2.0",
12 | "private": true,
13 | "scripts": {
14 | "publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push",
15 | "build": "tsc",
16 | "link-local": "yalc link",
17 | "unlink-local": "yalc remove",
18 | "lint": "tslint -p . -t verbose",
19 | "test": "karma start",
20 | "test-ci": "./scripts/test-ci.sh"
21 | },
22 | "devDependencies": {
23 | "@react-native-community/async-storage": "^1.4.2",
24 | "@tensorflow/tfjs-core": "^1.2.7",
25 | "@types/base64-js": "^1.2.5",
26 | "@types/jasmine": "~2.5.53",
27 | "@types/react-native": "^0.60.2",
28 | "clang-format": "~1.2.2",
29 | "expo-gl": "^5.0.1",
30 | "jasmine": "~3.1.0",
31 | "jasmine-core": "~3.1.0",
32 | "karma": "~4.2.0",
33 | "karma-browserstack-launcher": "~1.4.0",
34 | "karma-chrome-launcher": "~2.2.0",
35 | "karma-jasmine": "~1.1.0",
36 | "karma-typescript": "~4.1.1",
37 | "karma-verbose-reporter": "^0.0.6",
38 | "rimraf": "~2.6.2",
39 | "rollup": "^0.58.2",
40 | "rollup-plugin-commonjs": "9.1.3",
41 | "rollup-plugin-node-resolve": "3.3.0",
42 | "rollup-plugin-typescript2": "0.13.0",
43 | "rollup-plugin-uglify": "~3.0.0",
44 | "tslint": "~5.11.0",
45 | "tslint-no-circular-imports": "^0.5.0",
46 | "typescript": "3.3.3333",
47 | "yalc": "^1.0.0-pre.32"
48 | },
49 | "dependencies": {
50 | "base64-js": "^1.3.0",
51 | "buffer": "^5.2.1"
52 | },
53 | "peerDependencies": {
54 | "@react-native-community/async-storage": "^1.4.2",
55 | "@tensorflow/tfjs-core": ">=1.2.7",
56 | "expo-gl": "^5.0.1",
57 | "react-native": ">=0.59.0"
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/tfjs-react-native/scripts/test-ci.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2019 Google LLC. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | yarn
20 | yarn lint
21 | yarn build
22 | karma start --browserstack --browsers=bs_chrome_mac
23 |
24 |
--------------------------------------------------------------------------------
/tfjs-react-native/src/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import './platform_react_native';
19 |
20 | export {asyncStorageIO} from './async_storage_io';
21 | export {bundleResourceIO} from './bundle_resource_io';
22 | export {fetch} from './platform_react_native';
23 |
--------------------------------------------------------------------------------
/tfjs-react-native/src/platform_react_native_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs-core';
19 | import {PlatformReactNative} from './platform_react_native';
20 |
21 | describe('PlatformReactNative', () => {
22 | it('tf.util.fetch calls platform.fetch', async () => {
23 | const platform = new PlatformReactNative();
24 | tf.setPlatform('rn-test-platform', platform);
25 |
26 | spyOn(platform, 'fetch');
27 |
28 | await tf.util.fetch('test/url', {method: 'GET'});
29 | expect(platform.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'});
30 | });
31 | });
32 |
--------------------------------------------------------------------------------
/tfjs-react-native/src/test_utils/async_storage_mock.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // We mock this library as it cannot be loaded in a browser yet we do want
19 | // to do JS only unit tests.
20 |
21 | // @ts-ignore (use of window)
22 | const localStorage = window.localStorage;
23 | // Use default export to match the library we are mocking.
24 | // tslint:disable-next-line
25 | export default localStorage;
26 |
--------------------------------------------------------------------------------
/tfjs-react-native/src/test_utils/gl_view_mock.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // Mock gl-view to export nothing as we don't test it in unit tests.
19 | // We mock this library as it cannot be loaded in a browser yet we do want
20 | // to do JS only unit tests.
21 | // tslint:disable-next-line
22 | export default {};
23 |
--------------------------------------------------------------------------------
/tfjs-react-native/src/test_utils/react_native_mock.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // We mock this library as it cannot be loaded in a browser yet we do want
19 | // to do JS only unit tests.
20 |
21 | interface ImageResolvedAssetSource {
22 | uri: string;
23 | }
24 |
25 | // tslint:disable-next-line
26 | export const Image = {
27 | resolveAssetSource: (resourceId: string|number): ImageResolvedAssetSource => {
28 | return {
29 | uri: `http://localhost/assets/${resourceId}`,
30 | };
31 | }
32 | };
33 |
--------------------------------------------------------------------------------
/tfjs-react-native/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tsconfig",
3 | "compilerOptions": {
4 | "target": "es2017",
5 | "lib": [
6 | "es2017"
7 | ],
8 | "outDir": "./dist/",
9 | "skipLibCheck": true,
10 | },
11 | "include": [
12 | "src/",
13 | ],
14 | "exclude": [
15 | "node_modules"
16 | ],
17 | }
18 |
--------------------------------------------------------------------------------
/tfjs-react-native/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tslint.json"
3 | }
4 |
--------------------------------------------------------------------------------
/tfjs-webgpu/.npmignore:
--------------------------------------------------------------------------------
1 | .babelrc
2 | .DS_Store
3 | .idea/
4 | .rpt2_cache
5 | .travis.yml
6 | .vscode
7 | *.tgz
8 | *.txt
9 | **.yalc
10 | **yalc.lock
11 | cloudbuild.yml
12 | coverage/
13 | demo/
14 | dist/**/*_test.d.ts
15 | dist/**/*_test.js
16 | karma.conf.js
17 | node_modules/
18 | npm-debug.log
19 | package-lock.json
20 | package/
21 | rollup.config.js
22 | scripts/
23 | src/**/*_test.ts
24 | tsconfig.json
25 | tslint.json
26 | yarn-error.log
27 | yarn.lock
28 |
--------------------------------------------------------------------------------
/tfjs-webgpu/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "search.exclude": {
3 | "**/node_modules": true,
4 | "coverage/": true,
5 | "dist/": true,
6 | "**/yarn.lock": true,
7 | ".rpt2_cache/": true,
8 | ".yalc/": true
9 | },
10 | "tslint.configFile": "tslint.json",
11 | "files.trimTrailingWhitespace": true,
12 | "editor.tabSize": 2,
13 | "editor.insertSpaces": true,
14 | "[typescript]": {
15 | "editor.formatOnSave": true
16 | },
17 | "[javascript]": {
18 | "editor.formatOnSave": true
19 | },
20 | "editor.rulers": [80],
21 | "clang-format.style": "Google",
22 | "files.insertFinalNewline": true,
23 | "editor.detectIndentation": false,
24 | "editor.wrappingIndent": "none",
25 | "typescript.tsdk": "./node_modules/typescript/lib",
26 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
27 | }
28 |
--------------------------------------------------------------------------------
/tfjs-webgpu/README.md:
--------------------------------------------------------------------------------
1 | This is an experimental backend.
2 |
3 | ```
4 | $ yarn # to install dependencies
5 | $ yarn test
6 | ```
7 |
8 | # To run the test suite:
9 | The `$CHROME_BIN` environment variable must be set to the location of the Chrome Canary application.
10 |
11 | e.g. in `~/.bash_profile`:
12 |
13 | `export CHROME_BIN="$HOME/Documents/PROJECTS/tfjs-core-wrapper/Google Chrome Canary.app/Contents/MacOS/Google Chrome Canary"`
14 |
--------------------------------------------------------------------------------
/tfjs-webgpu/cloudbuild.yml:
--------------------------------------------------------------------------------
1 | steps:
2 | # Install common dependencies.
3 | - name: 'node:10'
4 | id: 'yarn-common'
5 | entrypoint: 'yarn'
6 | args: ['install']
7 |
8 | # Install webgpu dependencies.
9 | - name: 'node:10'
10 | dir: 'tfjs-webgpu'
11 | entrypoint: 'yarn'
12 | id: 'test-webgpu'
13 | args: ['test-ci']
14 | waitFor: ['yarn-common']
15 |
16 | # General configuration
17 | timeout: 1800s
18 | logsBucket: 'gs://tfjs-build-logs'
19 | substitutions:
20 | _NIGHTLY: ''
21 | options:
22 | logStreamingOption: 'STREAM_ON'
23 | substitution_option: 'ALLOW_LOOSE'
24 |
--------------------------------------------------------------------------------
/tfjs-webgpu/karma.conf.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | const karmaTypescriptConfig = {
19 | tsconfig: 'tsconfig.json',
20 | // Disable coverage reports and instrumentation by default for tests
21 | coverageOptions: {instrumentation: false},
22 | reports: {}
23 | };
24 |
25 | module.exports = function(config) {
26 | const args = [];
27 | if (config.grep) {
28 | args.push('--grep', config.grep);
29 | }
30 | if (config.flags) {
31 | args.push('--flags', config.flags);
32 | }
33 | config.set({
34 | basePath: '',
35 | frameworks: ['jasmine', 'karma-typescript'],
36 | files: [
37 | 'src/setup_test.ts', // Setup the environment for the tests.
38 | {pattern: 'src/**/*.ts'}, // Import all tests.
39 | ],
40 | preprocessors: {'**/*.ts': ['karma-typescript']},
41 | karmaTypescriptConfig,
42 | reporters: ['progress', 'karma-typescript'],
43 | port: 9876,
44 | colors: true,
45 | autoWatch: false,
46 | browsers: ['Chrome', 'chrome_webgpu'],
47 | singleRun: true,
48 | customLaunchers: {
49 | chrome_webgpu: {
50 | base: 'Chrome',
51 | flags: ['--enable-unsafe-webgpu'],
52 | }
53 | },
54 | client: {jasmine: {random: false}, args: args}
55 | })
56 | }
57 |
--------------------------------------------------------------------------------
/tfjs-webgpu/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "@tensorflow/tfjs-backend-webgpu",
3 | "version": "0.0.1",
4 | "main": "dist/index.js",
5 | "types": "dist/index.d.ts",
6 | "jsnext:main": "dist/tf-webgpu.esm.js",
7 | "module": "dist/tf-webgpu.esm.js",
8 | "unpkg": "dist/tf-webgpu.min.js",
9 | "jsdelivr": "dist/tf-webgpu.min.js",
10 | "scripts": {
11 | "publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push",
12 | "build": "tsc",
13 | "link-local": "yalc link",
14 | "unlink-local": "yalc remove",
15 | "lint": "tslint -p . -t verbose",
16 | "test": "karma start --browsers='chrome_webgpu'",
17 | "test-ci": "./scripts/test-ci.sh"
18 | },
19 | "license": "Apache-2.0",
20 | "devDependencies": {
21 | "@tensorflow/tfjs-core": "1.2.1",
22 | "@types/jasmine": "~2.5.53",
23 | "clang-format": "~1.2.2",
24 | "http-server": "~0.10.0",
25 | "jasmine-core": "~3.1.0",
26 | "karma": "~4.0.0",
27 | "karma-browserstack-launcher": "~1.4.0",
28 | "karma-chrome-launcher": "~2.2.0",
29 | "karma-firefox-launcher": "~1.1.0",
30 | "karma-jasmine": "~1.1.1",
31 | "karma-typescript": "~4.1.1",
32 | "rimraf": "~2.6.2",
33 | "rollup": "^0.58.2",
34 | "rollup-plugin-commonjs": "9.1.3",
35 | "rollup-plugin-node-resolve": "3.3.0",
36 | "rollup-plugin-typescript2": "0.13.0",
37 | "rollup-plugin-uglify": "~3.0.0",
38 | "tslint": "~5.11.0",
39 | "tslint-no-circular-imports": "^0.5.0",
40 | "typescript": "3.3.3333",
41 | "yalc": "~1.0.0-pre.21"
42 | },
43 | "dependencies": {
44 | "@webgpu/shaderc": "0.0.6",
45 | "@webgpu/types": "0.0.6"
46 | },
47 | "peerDependencies": {
48 | "@tensorflow/tfjs-core": "1.2.1"
49 | }
50 | }
--------------------------------------------------------------------------------
/tfjs-webgpu/scripts/test-ci.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Copyright 2019 Google LLC. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # =============================================================================
16 |
17 | set -e
18 |
19 | yarn
20 | yarn lint
21 | yarn build
22 |
23 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/flags_webgpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google Inc. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ENV} from '@tensorflow/tfjs-core';
19 |
20 | /** Whether we submit commands to the device queue immediately. */
21 | ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true);
22 |
23 | /**
24 | * Thread register block size for matmul kernel. If 0, we use the version of
25 | * matMul without register blocking.
26 | */
27 | ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4);
28 |
29 | /**
30 | * -1: conv2d_naive
31 | * 0: conv2d_mm with matmul without register blocking
32 | * >0: conv2d_mm with matmul_packed with WPT=this
33 | */
34 | ENV.registerFlag('WEBGPU_CONV2D_WORK_PER_THREAD', () => 2);
35 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/index.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import * as tf from '@tensorflow/tfjs-core';
19 | import * as Shaderc from '@webgpu/shaderc';
20 |
21 | import {WebGPUBackend} from './backend_webgpu';
22 |
23 | tf.registerBackend('webgpu', async () => {
24 | const shaderc = await Shaderc.instantiate();
25 | // @ts-ignore navigator.gpu is required
26 | const adapter = await navigator.gpu.requestAdapter({});
27 | const device = await adapter.requestDevice({});
28 | return new WebGPUBackend(device, shaderc);
29 | }, 3 /*priority*/);
30 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/kernels/binary_op_webgpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {backend_util} from '@tensorflow/tfjs-core';
19 |
20 | import {computeDispatch} from '../webgpu_util';
21 |
22 | import {WebGPUProgram} from './webgpu_program';
23 |
24 | export const MUL = 'return a * b;';
25 | export const ADD = 'return a + b;';
26 | export const SUB = 'return a - b;';
27 |
28 | export const INT_DIV = `
29 | float s = sign(a) * sign(b);
30 | int ia = int(round(a));
31 | int ib = int(round(b));
32 | return float(idiv(ia, ib, s));
33 | `;
34 |
35 | export class BinaryOpProgram implements WebGPUProgram {
36 | outputShape: number[];
37 | userCode: string;
38 | dispatchLayout: {x: number[]};
39 | dispatch: [number, number, number];
40 | variableNames = ['A', 'B'];
41 |
42 | constructor(op: string, aShape: number[], bShape: number[]) {
43 | this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape);
44 |
45 | this.dispatchLayout = {x: this.outputShape.map((d, i) => i)};
46 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
47 |
48 | this.userCode = `
49 | float binaryOperation(float a, float b) {
50 | ${op}
51 | }
52 |
53 | void main() {
54 | uint index = gl_GlobalInvocationID.x;
55 | float a = getAAtOutCoords();
56 | float b = getBAtOutCoords();
57 | setOutput(index, binaryOperation(a, b));
58 | }
59 | `;
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/kernels/from_pixels_webgpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {computeDispatch, flatDispatchLayout} from '../webgpu_util';
19 |
20 | import {WebGPUProgram} from './webgpu_program';
21 |
22 | export class FromPixelsProgram implements WebGPUProgram {
23 | outputShape: number[];
24 | userCode: string;
25 | variableNames = ['A'];
26 | dispatchLayout: {x: number[]};
27 | dispatch: [number, number, number];
28 |
29 | constructor(outputShape: number[]) {
30 | const [height, width, ] = outputShape;
31 | this.outputShape = outputShape;
32 | this.dispatchLayout = flatDispatchLayout(this.outputShape);
33 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
34 |
35 | this.userCode = `
36 | void main() {
37 | ivec3 coords = getOutputCoords();
38 | int texR = coords[0];
39 | int texC = coords[1];
40 | int depth = coords[2];
41 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
42 |
43 | vec4 values = texelFetch(A, uv);
44 | float value;
45 | if (depth == 0) {
46 | value = values.r;
47 | } else if (depth == 1) {
48 | value = values.g;
49 | } else if (depth == 2) {
50 | value = values.b;
51 | } else if (depth == 3) {
52 | value = values.a;
53 | }
54 |
55 | setOutput(floor(value * 255.0 + 0.5));
56 | }
57 | `;
58 | }
59 | }
--------------------------------------------------------------------------------
/tfjs-webgpu/src/kernels/unary_op_webgpu.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {computeDispatch} from '../webgpu_util';
19 | import {WebGPUProgram} from './webgpu_program';
20 |
21 | export const RELU = 'return max(a, 0.0);';
22 |
23 | export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
24 |
25 | export class UnaryOpProgram implements WebGPUProgram {
26 | outputShape: number[];
27 | userCode: string;
28 | dispatchLayout: {x: number[]};
29 | dispatch: [number, number, number];
30 | variableNames = ['A'];
31 |
32 | constructor(outputShape: number[], op: string) {
33 | this.outputShape = outputShape;
34 | this.dispatchLayout = {x: this.outputShape.map((d, i) => i)};
35 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
36 |
37 | this.userCode = `
38 | float unaryOperation(float a) {
39 | ${op}
40 | }
41 |
42 | void main() {
43 | uint index = gl_GlobalInvocationID.x;
44 | float a = getAAtOutCoords();
45 | setOutput(index, unaryOperation(a));
46 | }
47 | `;
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/shader_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | // Generates GLSL that computes strides.
19 | export function symbolicallyComputeStrides(
20 | indicesArr: number[], variableName: string): string[] {
21 | if (Math.max(...indicesArr) > 3) {
22 | throw new Error('Cannot symbolically compute strides for rank > 4 tensor.');
23 | }
24 |
25 | const numCoords = indicesArr.length;
26 | const shape = indicesArr.map(d => `${variableName}[${d}]`);
27 | const strides = new Array(numCoords - 1);
28 | strides[numCoords - 2] = shape[numCoords - 1];
29 | for (let i = numCoords - 3; i >= 0; --i) {
30 | strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
31 | }
32 |
33 | return strides;
34 | }
--------------------------------------------------------------------------------
/tfjs-webgpu/src/shader_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {symbolicallyComputeStrides} from './shader_util';
19 |
20 | describe('shader util', () => {
21 | it('symbolicallyComputeStrides takes in array of dimensions ' +
22 | 'and returns GLSL to compute strides for those dimensions',
23 | () => {
24 | const layout = [0, 2, 1];
25 | const strides = symbolicallyComputeStrides(layout, 'output');
26 | expect(strides[0]).toEqual('(output[1] * output[2])');
27 | expect(strides[1]).toEqual('output[1]');
28 | });
29 |
30 | it('symbolicallyComputeStrides throws if given a dimension ' +
31 | 'that cannot be accessed from a GLSL data type',
32 | () => {
33 | const layout = [0, 5, 2];
34 | expect(() => symbolicallyComputeStrides(layout, 'output'))
35 | .toThrowError();
36 | });
37 | });
--------------------------------------------------------------------------------
/tfjs-webgpu/src/test_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {ALL_ENVS, describeWithFlags, TestEnv} from '@tensorflow/tfjs-core/dist/jasmine_util';
19 |
20 | export function describeWebGPU(name: string, tests: (env: TestEnv) => void) {
21 | describeWithFlags('webgpu ' + name, ALL_ENVS, tests);
22 | }
23 |
--------------------------------------------------------------------------------
/tfjs-webgpu/src/webgpu_util.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | const arrayProduct = (arr: number[]) => {
19 | let product = 1;
20 | for (let i = 0; i < arr.length; i++) {
21 | product *= arr[i];
22 | }
23 | return product;
24 | };
25 |
26 | // Computes dispatch geometry based on layout of output dimensions and
27 | // workGroupSize.
28 | export function computeDispatch(
29 | layout: {x: number[], y?: number[], z?: number[]}, outputShape: number[],
30 | workGroupSize: [number, number, number] = [1, 1, 1],
31 | elementsPerThread: [number, number, number] =
32 | [1, 1, 1]): [number, number, number] {
33 | return [
34 | Math.ceil(
35 | arrayProduct(layout.x.map(d => outputShape[d])) /
36 | (workGroupSize[0] * elementsPerThread[0])),
37 | layout.y ? Math.ceil(
38 | arrayProduct(layout.y.map(d => outputShape[d])) /
39 | (workGroupSize[1] * elementsPerThread[1])) :
40 | 1,
41 | layout.z ? Math.ceil(
42 | arrayProduct(layout.z.map(d => outputShape[d])) /
43 | (workGroupSize[2] * elementsPerThread[2])) :
44 | 1
45 | ];
46 | }
47 |
48 | export function flatDispatchLayout(shape: number[]) {
49 | return {x: shape.map((d, i) => i)};
50 | }
--------------------------------------------------------------------------------
/tfjs-webgpu/src/webgpu_util_test.ts:
--------------------------------------------------------------------------------
1 | /**
2 | * @license
3 | * Copyright 2019 Google LLC. All Rights Reserved.
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =============================================================================
16 | */
17 |
18 | import {computeDispatch} from './webgpu_util';
19 |
20 | describe('webgpu util', () => {
21 | it('computeDispatch returns dispatch dimensions based on layout of ' +
22 | 'output dimensions and workGroupSize.',
23 | () => {
24 | const layout = {x: [0], y: [1], z: [2, 3]};
25 | const outputShape = [1, 2, 3, 2];
26 |
27 | const workGroupSize = [2, 2, 1] as [number, number, number];
28 |
29 | const dispatch = computeDispatch(layout, outputShape, workGroupSize);
30 | expect(dispatch).toEqual([1, 1, 6]);
31 | });
32 |
33 | it('computeDispatch returns dispatch dimensions based on layout of ' +
34 | 'output dimensions, workGroupSize, and elementsPerThread.',
35 | () => {
36 | const layout = {x: [0], y: [1], z: [2, 3]};
37 | const outputShape = [4, 8, 12, 2];
38 |
39 | const workGroupSize = [2, 1, 1] as [number, number, number];
40 | const elementsPerThread = [2, 2, 3] as [number, number, number];
41 |
42 | const dispatch = computeDispatch(
43 | layout, outputShape, workGroupSize, elementsPerThread);
44 | expect(dispatch).toEqual([1, 4, 8]);
45 | });
46 | });
47 |
--------------------------------------------------------------------------------
/tfjs-webgpu/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tsconfig",
3 | "include": [
4 | "src/"
5 | ],
6 | "exclude": [
7 | "node_modules/"
8 | ],
9 | "compilerOptions": {
10 | "outDir": "./dist"
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/tfjs-webgpu/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": "../tslint.json"
3 | }
4 |
--------------------------------------------------------------------------------
/tfjs.code-workspace:
--------------------------------------------------------------------------------
1 | {
2 | "folders": [
3 | {
4 | "name": "Common",
5 | "path": "."
6 | },
7 | {
8 | "name": "Core",
9 | "path": "tfjs-core"
10 | },
11 | {
12 | "name": "Wasm",
13 | "path": "tfjs-backend-wasm"
14 | },
15 | {
16 | "name": "NodeGL",
17 | "path": "tfjs-backend-nodegl"
18 | },
19 | {
20 | "name": "ReactNative",
21 | "path": "tfjs-react-native"
22 | },
23 | {
24 | "name": "WebGPU",
25 | "path": "tfjs-webgpu"
26 | }
27 | ],
28 | "settings": {
29 | "search.exclude": {
30 | "**/node_modules": true,
31 | "**/coverage/": true,
32 | "**/dist/": true,
33 | "**/yarn.lock": true,
34 | "**/.rpt2_cache/": true,
35 | "**/.yalc/**/*": true
36 | },
37 | "tslint.configFile": "tslint.json",
38 | "files.trimTrailingWhitespace": true,
39 | "editor.tabSize": 2,
40 | "editor.insertSpaces": true,
41 | "[typescript]": {
42 | "editor.formatOnSave": true
43 | },
44 | "[javascript]": {
45 | "editor.formatOnSave": true
46 | },
47 | "[cpp]": {
48 | "editor.formatOnSave": true
49 | },
50 | "editor.defaultFormatter": "xaver.clang-format",
51 | "editor.rulers": [
52 | 80
53 | ],
54 | "clang-format.style": "Google",
55 | "files.insertFinalNewline": true,
56 | "editor.detectIndentation": false,
57 | "editor.wrappingIndent": "none",
58 | "typescript.tsdk": "./node_modules/typescript/lib",
59 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format"
60 | },
61 | "extensions": {
62 | "recommendations": [
63 | // Formats typescript, javascript and c++ code.
64 | "xaver.clang-format",
65 | // Lints typescipt code.
66 | "ms-vscode.vscode-typescript-tslint-plugin"
67 | ]
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "module": "commonjs",
4 | "moduleResolution": "node",
5 | "noImplicitAny": true,
6 | "sourceMap": true,
7 | "removeComments": false,
8 | "preserveConstEnums": true,
9 | "declaration": true,
10 | "target": "es5",
11 | "lib": [
12 | "es2015",
13 | "dom"
14 | ],
15 | "outDir": "./dist",
16 | "noUnusedLocals": true,
17 | "noImplicitReturns": true,
18 | "noImplicitThis": true,
19 | "alwaysStrict": true,
20 | "noUnusedParameters": false,
21 | "pretty": true,
22 | "noFallthroughCasesInSwitch": true,
23 | "allowUnreachableCode": false
24 | },
25 | "include": [
26 | "src/"
27 | ],
28 | "exclude": [
29 | "node_modules/"
30 | ]
31 | }
32 |
--------------------------------------------------------------------------------
/tslint.json:
--------------------------------------------------------------------------------
1 | {
2 | "extends": ["tslint-no-circular-imports"],
3 | "rules": {
4 | "array-type": [true, "array-simple"],
5 | "arrow-return-shorthand": true,
6 | "ban": [true,
7 | ["fit"],
8 | ["fdescribe"],
9 | ["xit"],
10 | ["xdescribe"],
11 | ["fitAsync"],
12 | ["xitAsync"],
13 | ["fitFakeAsync"],
14 | ["xitFakeAsync"]
15 | ],
16 | "ban-types": [true,
17 | ["Object", "Use {} instead."],
18 | ["String", "Use 'string' instead."],
19 | ["Number", "Use 'number' instead."],
20 | ["Boolean", "Use 'boolean' instead."]
21 | ],
22 | "no-return-await": true,
23 | "class-name": true,
24 | "curly": true,
25 | "interface-name": [true, "never-prefix"],
26 | "jsdoc-format": true,
27 | "forin": false,
28 | "label-position": true,
29 | "max-line-length": {
30 | "options": {"limit": 80, "ignore-pattern": "^import |^export "}
31 | },
32 | "new-parens": true,
33 | "no-angle-bracket-type-assertion": true,
34 | "no-any": true,
35 | "no-construct": true,
36 | "no-consecutive-blank-lines": true,
37 | "no-debugger": true,
38 | "no-default-export": true,
39 | "no-inferrable-types": true,
40 | "no-namespace": [true, "allow-declarations"],
41 | "no-reference": true,
42 | "no-require-imports": true,
43 | "no-string-throw": true,
44 | "no-unused-expression": true,
45 | "no-var-keyword": true,
46 | "object-literal-shorthand": true,
47 | "only-arrow-functions": [true, "allow-declarations", "allow-named-functions"],
48 | "prefer-const": true,
49 | "quotemark": [true, "single"],
50 | "radix": true,
51 | "restrict-plus-operands": true,
52 | "semicolon": [true, "always", "ignore-bound-class-methods"],
53 | "switch-default": true,
54 | "triple-equals": [true, "allow-null-check"],
55 | "use-isnan": true,
56 | "use-default-type-parameter": true,
57 | "variable-name": [
58 | true,
59 | "check-format",
60 | "ban-keywords",
61 | "allow-leading-underscore",
62 | "allow-trailing-underscore"
63 | ]
64 | }
65 | }
66 |
--------------------------------------------------------------------------------