├── .gitignore ├── CONTRIBUTING.md ├── DEVELOPMENT.md ├── ISSUE_TEMPLATE.md ├── LICENSE ├── README.md ├── WORKSPACE ├── cloudbuild.yml ├── package.json ├── pull_request_template.md ├── scripts ├── diff.js ├── run-build.sh └── test-util.js ├── tfjs-backend-nodegl ├── .vscode │ └── settings.json ├── README.md ├── demo │ ├── README.md │ ├── dog.jpg │ ├── package.json │ ├── run_mobilenet_inference.js │ └── yarn.lock ├── package.json ├── src │ ├── index.ts │ ├── run_tests.ts │ └── version.ts ├── tasks.json ├── tsconfig.json ├── tslint.json └── yarn.lock ├── tfjs-backend-wasm ├── .npmrc ├── .vscode │ ├── c_cpp_properties.json │ └── settings.json ├── README.md ├── karma.conf.js ├── package.json ├── rollup.config.js ├── scripts │ ├── build-npm.sh │ └── build-wasm.sh ├── src │ ├── backend_wasm.ts │ ├── index.ts │ ├── index_test.ts │ ├── kernels.cc │ ├── kernels.h │ ├── lib.cc │ ├── setup_test.ts │ └── util.h ├── tsconfig.json ├── tslint.json ├── wasm-out │ └── tfjs-backend-wasm.d.ts └── yarn.lock ├── tfjs-core ├── .bazelignore ├── .bazelrc ├── .npmignore ├── .npmrc ├── .vscode │ ├── c_cpp_properties.json │ ├── launch.json │ ├── settings.json │ └── tasks.json ├── BUILD.bazel ├── README.md ├── benchmarks │ └── index.html ├── cloudbuild.yml ├── karma.conf.js ├── package.json ├── rollup.config.js ├── scripts │ ├── build-npm.sh │ ├── cloud_funcs │ │ ├── README.md │ │ ├── send_email │ │ │ ├── .gcloudignore │ │ │ ├── config.json │ │ │ ├── index.js │ │ │ ├── package-lock.json │ │ │ └── package.json │ │ └── trigger_nightly │ │ │ ├── .gcloudignore │ │ │ ├── index.js │ │ │ ├── package-lock.json │ │ │ └── package.json │ ├── enumerate-tests.js │ ├── make-version │ ├── publish-npm.sh │ ├── tag-version.js │ ├── test-bundle-size.js │ ├── test-ci.sh │ ├── test-integration.js │ ├── test-integration.sh │ └── test_snippets │ │ ├── test_snippets.ts │ │ ├── tsconfig.json │ │ └── util.ts ├── src │ ├── BUILD.bazel │ ├── backends │ │ ├── backend.ts │ │ ├── backend_test.ts │ │ ├── backend_util.ts │ │ ├── complex_util.ts │ │ ├── complex_util_test.ts │ │ ├── cpu │ │ │ ├── backend_cpu.ts │ │ │ ├── backend_cpu_test.ts │ │ │ └── backend_cpu_test_registry.ts │ │ ├── non_max_suppression_impl.ts │ │ ├── packing_util.ts │ │ ├── split_shared.ts │ │ ├── tile_impl.ts │ │ ├── topk_impl.ts │ │ ├── webgl │ │ │ ├── addn_gpu.ts │ │ │ ├── addn_packed_gpu.ts │ │ │ ├── argminmax_gpu.ts │ │ │ ├── argminmax_packed_gpu.ts │ │ │ ├── avg_pool_backprop_gpu.ts │ │ │ ├── backend_webgl.ts │ │ │ ├── backend_webgl_test.ts │ │ │ ├── backend_webgl_test_registry.ts │ │ │ ├── batchnorm_gpu.ts │ │ │ ├── batchnorm_packed_gpu.ts │ │ │ ├── binaryop_complex_gpu.ts │ │ │ ├── binaryop_gpu.ts │ │ │ ├── binaryop_packed_gpu.ts │ │ │ ├── canvas_util.ts │ │ │ ├── canvas_util_test.ts │ │ │ ├── clip_gpu.ts │ │ │ ├── clip_packed_gpu.ts │ │ │ ├── complex_abs_gpu.ts │ │ │ ├── concat_gpu.ts │ │ │ ├── concat_packed_gpu.ts │ │ │ ├── conv_backprop_gpu.ts │ │ │ ├── conv_backprop_gpu_depthwise.ts │ │ │ ├── conv_gpu.ts │ │ │ ├── conv_gpu_depthwise.ts │ │ │ ├── conv_packed_gpu_depthwise.ts │ │ │ ├── crop_and_resize_gpu.ts │ │ │ ├── cumsum_gpu.ts │ │ │ ├── decode_matrix_gpu.ts │ │ │ ├── decode_matrix_packed_gpu.ts │ │ │ ├── depth_to_space_gpu.ts │ │ │ ├── diag_gpu.ts │ │ │ ├── encode_float_gpu.ts │ │ │ ├── encode_float_packed_gpu.ts │ │ │ ├── encode_matrix_gpu.ts │ │ │ ├── encode_matrix_packed_gpu.ts │ │ │ ├── fft_gpu.ts │ │ │ ├── fill_gpu.ts │ │ │ ├── flags_webgl.ts │ │ │ ├── flags_webgl_test.ts │ │ │ ├── from_pixels_gpu.ts │ │ │ ├── from_pixels_packed_gpu.ts │ │ │ ├── gather_gpu.ts │ │ │ ├── gather_nd_gpu.ts │ │ │ ├── glsl_version.ts │ │ │ ├── gpgpu_context.ts │ │ │ ├── gpgpu_context_test.ts │ │ │ ├── gpgpu_math.ts │ │ │ ├── gpgpu_util.ts │ │ │ ├── gpgpu_util_test.ts │ │ │ ├── im2col_packed_gpu.ts │ │ │ ├── lrn_gpu.ts │ │ │ ├── lrn_grad_gpu.ts │ │ │ ├── lrn_packed_gpu.ts │ │ │ ├── max_pool_backprop_gpu.ts │ │ │ ├── mulmat_packed_gpu.ts │ │ │ ├── multinomial_gpu.ts │ │ │ ├── onehot_gpu.ts │ │ │ ├── pack_gpu.ts │ │ │ ├── pad_gpu.ts │ │ │ ├── pad_packed_gpu.ts │ │ │ ├── pool_gpu.ts │ │ │ ├── reduce_gpu.ts │ │ │ ├── reshape_packed_gpu.ts │ │ │ ├── reshape_packed_test.ts │ │ │ ├── resize_bilinear_backprop_gpu.ts │ │ │ ├── resize_bilinear_gpu.ts │ │ │ ├── resize_bilinear_packed_gpu.ts │ │ │ ├── resize_nearest_neighbor_backprop_gpu.ts │ │ │ ├── resize_nearest_neighbor_gpu.ts │ │ │ ├── reverse_gpu.ts │ │ │ ├── reverse_packed_gpu.ts │ │ │ ├── scatter_gpu.ts │ │ │ ├── segment_gpu.ts │ │ │ ├── select_gpu.ts │ │ │ ├── shader_compiler.ts │ │ │ ├── shader_compiler_util.ts │ │ │ ├── shader_compiler_util_test.ts │ │ │ ├── slice_gpu.ts │ │ │ ├── slice_packed_gpu.ts │ │ │ ├── strided_slice_gpu.ts │ │ │ ├── tex_util.ts │ │ │ ├── tex_util_test.ts │ │ │ ├── texture_manager.ts │ │ │ ├── tile_gpu.ts │ │ │ ├── transpose_gpu.ts │ │ │ ├── transpose_packed_gpu.ts │ │ │ ├── unaryop_gpu.ts │ │ │ ├── unaryop_packed_gpu.ts │ │ │ ├── unpack_gpu.ts │ │ │ ├── webgl_batchnorm_test.ts │ │ │ ├── webgl_custom_op_test.ts │ │ │ ├── webgl_ops_test.ts │ │ │ ├── webgl_types.ts │ │ │ ├── webgl_util.ts │ │ │ └── webgl_util_test.ts │ │ └── where_impl.ts │ ├── browser_util.ts │ ├── browser_util_test.ts │ ├── buffer_test.ts │ ├── debug_mode_test.ts │ ├── device_util.ts │ ├── engine.ts │ ├── engine_test.ts │ ├── environment.ts │ ├── environment_test.ts │ ├── flags.ts │ ├── flags_test.ts │ ├── globals.ts │ ├── globals_test.ts │ ├── gradients.ts │ ├── gradients_test.ts │ ├── index.ts │ ├── io │ │ ├── browser_files.ts │ │ ├── browser_files_test.ts │ │ ├── http.ts │ │ ├── http_test.ts │ │ ├── indexed_db.ts │ │ ├── indexed_db_test.ts │ │ ├── io.ts │ │ ├── io_utils.ts │ │ ├── io_utils_test.ts │ │ ├── local_storage.ts │ │ ├── local_storage_test.ts │ │ ├── model_management.ts │ │ ├── model_management_test.ts │ │ ├── passthrough.ts │ │ ├── passthrough_test.ts │ │ ├── progress.ts │ │ ├── progress_test.ts │ │ ├── router_registry.ts │ │ ├── router_registry_test.ts │ │ ├── types.ts │ │ ├── weights_loader.ts │ │ └── weights_loader_test.ts │ ├── jasmine_util.ts │ ├── jasmine_util_test.ts │ ├── log.ts │ ├── math.ts │ ├── model_types.ts │ ├── ops │ │ ├── arithmetic_test.ts │ │ ├── array_ops.ts │ │ ├── array_ops_test.ts │ │ ├── array_ops_util.ts │ │ ├── axis_util.ts │ │ ├── axis_util_test.ts │ │ ├── batchnorm.ts │ │ ├── batchnorm_test.ts │ │ ├── binary_ops.ts │ │ ├── binary_ops_test.ts │ │ ├── boolean_mask.ts │ │ ├── boolean_mask_test.ts │ │ ├── broadcast_util.ts │ │ ├── broadcast_util_test.ts │ │ ├── browser.ts │ │ ├── clone_test.ts │ │ ├── compare.ts │ │ ├── compare_ops_test.ts │ │ ├── complex_ops.ts │ │ ├── complex_ops_test.ts │ │ ├── concat_split.ts │ │ ├── concat_test.ts │ │ ├── concat_util.ts │ │ ├── concat_util_test.ts │ │ ├── confusion_matrix.ts │ │ ├── confusion_matrix_test.ts │ │ ├── conv.ts │ │ ├── conv1d_test.ts │ │ ├── conv2d_depthwise_test.ts │ │ ├── conv2d_separable_test.ts │ │ ├── conv2d_test.ts │ │ ├── conv2d_transpose_test.ts │ │ ├── conv3d_test.ts │ │ ├── conv3d_transpose_test.ts │ │ ├── conv_util.ts │ │ ├── conv_util_test.ts │ │ ├── diag.ts │ │ ├── diag_test.ts │ │ ├── dropout.ts │ │ ├── dropout_test.ts │ │ ├── dropout_util.ts │ │ ├── dropout_util_test.ts │ │ ├── erf_util.ts │ │ ├── fused_ops.ts │ │ ├── fused_test.ts │ │ ├── fused_util.ts │ │ ├── gather_nd.ts │ │ ├── gather_nd_test.ts │ │ ├── gather_nd_util.ts │ │ ├── image_ops.ts │ │ ├── image_ops_test.ts │ │ ├── in_top_k.ts │ │ ├── in_top_k_test.ts │ │ ├── linalg_ops.ts │ │ ├── linalg_ops_test.ts │ │ ├── logical_ops.ts │ │ ├── logical_ops_test.ts │ │ ├── loss_ops.ts │ │ ├── loss_ops_test.ts │ │ ├── lrn.ts │ │ ├── lrn_test.ts │ │ ├── lstm.ts │ │ ├── lstm_test.ts │ │ ├── matmul.ts │ │ ├── matmul_test.ts │ │ ├── moving_average.ts │ │ ├── moving_average_test.ts │ │ ├── multinomial_test.ts │ │ ├── norm.ts │ │ ├── operation.ts │ │ ├── operation_test.ts │ │ ├── ops.ts │ │ ├── pad_test.ts │ │ ├── pool.ts │ │ ├── pool_test.ts │ │ ├── rand.ts │ │ ├── rand_test.ts │ │ ├── rand_util.ts │ │ ├── reduce_util.ts │ │ ├── reduction_ops.ts │ │ ├── reduction_ops_test.ts │ │ ├── relu_ops.ts │ │ ├── resize_bilinear_test.ts │ │ ├── resize_nearest_neighbor_test.ts │ │ ├── reverse.ts │ │ ├── reverse_test.ts │ │ ├── scatter_nd.ts │ │ ├── scatter_nd_test.ts │ │ ├── scatter_nd_util.ts │ │ ├── segment_ops.ts │ │ ├── segment_ops_test.ts │ │ ├── segment_util.ts │ │ ├── selu_util.ts │ │ ├── signal_ops.ts │ │ ├── signal_ops_test.ts │ │ ├── slice.ts │ │ ├── slice_test.ts │ │ ├── slice_util.ts │ │ ├── slice_util_test.ts │ │ ├── softmax.ts │ │ ├── softmax_test.ts │ │ ├── sparse_to_dense.ts │ │ ├── sparse_to_dense_test.ts │ │ ├── sparse_to_dense_util.ts │ │ ├── spectral_ops.ts │ │ ├── spectral_ops_test.ts │ │ ├── strided_slice.ts │ │ ├── strided_slice_test.ts │ │ ├── tensor_ops.ts │ │ ├── topk.ts │ │ ├── topk_test.ts │ │ ├── transpose.ts │ │ ├── transpose_test.ts │ │ ├── unary_ops.ts │ │ └── unary_ops_test.ts │ ├── optimizers │ │ ├── adadelta_optimizer.ts │ │ ├── adadelta_optimizer_test.ts │ │ ├── adagrad_optimizer.ts │ │ ├── adagrad_optimizer_test.ts │ │ ├── adam_optimizer.ts │ │ ├── adam_optimizer_test.ts │ │ ├── adamax_optimizer.ts │ │ ├── adamax_optimizer_test.ts │ │ ├── momentum_optimizer.ts │ │ ├── momentum_optimizer_test.ts │ │ ├── optimizer.ts │ │ ├── optimizer_constructors.ts │ │ ├── optimizer_test.ts │ │ ├── rmsprop_optimizer.ts │ │ ├── rmsprop_optimizer_test.ts │ │ ├── sgd_optimizer.ts │ │ └── sgd_optimizer_test.ts │ ├── platforms │ │ ├── platform.ts │ │ ├── platform_browser.ts │ │ ├── platform_browser_test.ts │ │ ├── platform_node.ts │ │ └── platform_node_test.ts │ ├── profiler.ts │ ├── profiler_test.ts │ ├── serialization.ts │ ├── serialization_test.ts │ ├── setup_test.ts │ ├── tape.ts │ ├── tape_test.ts │ ├── tensor.ts │ ├── tensor_format.ts │ ├── tensor_test.ts │ ├── tensor_types.ts │ ├── tensor_util.ts │ ├── tensor_util_env.ts │ ├── tensor_util_test.ts │ ├── test_async_backends.ts │ ├── test_node.ts │ ├── test_util.ts │ ├── test_util_test.ts │ ├── tests.ts │ ├── train.ts │ ├── types.ts │ ├── types_test.ts │ ├── util.ts │ ├── util_test.ts │ ├── variable_test.ts │ ├── version.ts │ ├── version_test.ts │ ├── webgl.ts │ ├── worker_node_test.ts │ └── worker_test.ts ├── tsconfig.json ├── tslint.json └── yarn.lock ├── tfjs-react-native ├── .npmignore ├── .vscode │ └── settings.json ├── README.md ├── cloudbuild.yml ├── karma.conf.js ├── package.json ├── rollup.config.js ├── scripts │ └── test-ci.sh ├── src │ ├── async_storage_io.ts │ ├── async_storage_io_test.ts │ ├── bundle_resource_io.ts │ ├── bundle_resource_io_test.ts │ ├── index.ts │ ├── platform_react_native.ts │ ├── platform_react_native_test.ts │ └── test_utils │ │ ├── async_storage_mock.ts │ │ ├── gl_view_mock.ts │ │ └── react_native_mock.ts ├── tsconfig.json ├── tslint.json └── yarn.lock ├── tfjs-webgpu ├── .npmignore ├── .vscode │ └── settings.json ├── README.md ├── cloudbuild.yml ├── karma.conf.js ├── package.json ├── rollup.config.js ├── scripts │ └── test-ci.sh ├── src │ ├── backend_webgpu.ts │ ├── backend_webgpu_test.ts │ ├── benchmark_ops_test.ts │ ├── buffer_manager.ts │ ├── flags_webgpu.ts │ ├── index.ts │ ├── kernels │ │ ├── argminmax_webgpu.ts │ │ ├── binary_op_webgpu.ts │ │ ├── concat_webgpu.ts │ │ ├── conv2d_mm_webgpu.ts │ │ ├── conv2d_naive_webgpu.ts │ │ ├── from_pixels_webgpu.ts │ │ ├── matmul_packed_webgpu.ts │ │ ├── matmul_webgpu.ts │ │ ├── maxpool_webgpu.ts │ │ ├── pad_webgpu.ts │ │ ├── resize_bilinear_webgpu.ts │ │ ├── transpose_webgpu.ts │ │ ├── unary_op_webgpu.ts │ │ └── webgpu_program.ts │ ├── matmul_test.ts │ ├── setup_test.ts │ ├── shader_preprocessor.ts │ ├── shader_util.ts │ ├── shader_util_test.ts │ ├── test_util.ts │ ├── webgpu_util.ts │ └── webgpu_util_test.ts ├── tsconfig.json ├── tslint.json └── yarn.lock ├── tfjs.code-workspace ├── tsconfig.json ├── tslint.json └── yarn.lock /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | coverage/ 3 | npm-debug.log 4 | yarn-error.log 5 | local.log 6 | .DS_Store 7 | dist/ 8 | bazel-out/ 9 | .idea/ 10 | *.tgz 11 | **/*.pyc 12 | .yalc/ 13 | yalc.lock 14 | .rpt2_cache/ 15 | package/ 16 | */diff 17 | 18 | tfjs-backend-wasm/dist 19 | tfjs-backend-wasm/wasm-out 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Adding functionality 19 | 20 | One way to ensure that your PR will be accepted is to add functionality that 21 | has been requested in Github issues. If there is something you think is 22 | important and we're missing it but does not show up in Github issues, it would 23 | be good to file an issue there first so we can have the discussion before 24 | sending us a PR. 25 | 26 | In general, we're trying to add functionality when driven by use-cases instead of 27 | adding functionality for the sake of parity with TensorFlow python. This will 28 | help us keep the bundle size smaller and have less to maintain especially as we 29 | add new backends. 30 | 31 | ### Adding an op 32 | 33 | When adding ops to the library and deciding whether to write a kernel 34 | implementation in [backend.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/backends/backend.ts), 35 | be sure to check out the TensorFlow ops list [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/ops.pbtxt). 36 | This list shows the kernels available for the TensorFlow C API. To ensure that 37 | we can bind to this with node.js, we should ensure that our backend.ts 38 | interface matches ops in the TensorFlow C API. 39 | 40 | ## Code reviews 41 | 42 | All submissions, including submissions by project members, require review. We 43 | use GitHub pull requests for this purpose. Consult 44 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 45 | information on using pull requests. 46 | 47 | We require unit tests for most code, instructions for running our unit test 48 | suites are in the documentation. 49 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | ## Development 2 | 3 | To build **TensorFlow.js Core API** from source, we need to clone the project and prepare 4 | the dev environment: 5 | 6 | ```bash 7 | $ git clone https://github.com/tensorflow/tfjs-core.git 8 | $ cd tfjs-core 9 | $ yarn # Installs dependencies. 10 | ``` 11 | 12 | #### Yarn 13 | We use yarn, and if you are adding or removing dependencies you should use yarn 14 | to keep the `yarn.lock` file up to date. 15 | 16 | #### Code editor 17 | We recommend using [Visual Studio Code](https://code.visualstudio.com/) for 18 | development. Make sure to install 19 | [TSLint VSCode extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.vscode-typescript-tslint-plugin) 20 | and the npm [clang-format](https://github.com/angular/clang-format) `1.2.2` or later 21 | with the 22 | [Clang-Format VSCode extension](https://marketplace.visualstudio.com/items?itemName=xaver.clang-format) 23 | for auto-formatting. 24 | 25 | #### Testing 26 | Before submitting a pull request, make sure the code passes all the tests and is clean of lint errors: 27 | 28 | ```bash 29 | $ yarn test 30 | $ yarn lint 31 | ``` 32 | 33 | To run a subset of tests and/or on a specific browser: 34 | 35 | ```bash 36 | $ yarn test --browsers=Chrome --grep='multinomial' 37 |   38 | > ... 39 | > Chrome 62.0.3202 (Mac OS X 10.12.6): Executed 28 of 1891 (skipped 1863) SUCCESS (6.914 secs / 0.634 secs) 40 | ``` 41 | 42 | To run the tests once and exit the karma process (helpful on Windows): 43 | 44 | ```bash 45 | $ yarn test --single-run 46 | ``` 47 | 48 | To run the tests in an environment that does not have GPU support (such as Chrome Remote Desktop): 49 | 50 | ```bash 51 | $ yarn test --testEnv cpu 52 | ``` 53 | 54 | Available test environments: cpu, webgl1, webgl2. 55 | 56 | #### Packaging (browser and npm) 57 | 58 | ```bash 59 | $ yarn build-npm 60 | > Stored standalone library at dist/tf-core(.min).js 61 | > Stored also tensorflow-tf-core-VERSION.tgz 62 | ``` 63 | 64 | To install it locally, run `yarn add ./tensorflow-tf-core-VERSION.tgz`. 65 | 66 | > On Windows, use bash (available through git) to use the scripts above. 67 | 68 | Looking to contribute, and don't know where to start? Check out our "stat:contributions welcome" [issues](https://github.com/tensorflow/tfjs/labels/stat%3Acontributions%20welcome). 69 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | If you would like to get help from the community, we encourage using Stack Overflow and the [`tensorflow.js`](https://stackoverflow.com/questions/tagged/tensorflow.js) tag. 2 | 3 | GitHub issues for this repository are tracked in the [tfjs union repository](https://github.com/tensorflow/tfjs/issues). 4 | 5 | Please file your issue there, following the guidance in [that issue template](https://github.com/tensorflow/tfjs/blob/master/ISSUE_TEMPLATE.md). 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository has been archived in favor of tensorflow/tfjs. 2 | 3 | This repo will remain around for some time to keep history but all future PRs should be sent to [tensorflow/tfjs](https://github.com/tensorflow/tfjs) inside the tfjs-core folder. 4 | 5 | All history and contributions have been preserved in the monorepo. -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | http_archive( 3 | name = "build_bazel_rules_nodejs", 4 | sha256 = "88e5e579fb9edfbd19791b8a3c6bfbe16ae3444dba4b428e5efd36856db7cf16", 5 | urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/0.27.8/rules_nodejs-0.27.8.tar.gz"], 6 | ) 7 | 8 | load("@build_bazel_rules_nodejs//:defs.bzl", "yarn_install") 9 | yarn_install( 10 | name = "npm", 11 | package_json = "//:package.json", 12 | yarn_lock = "//:yarn.lock", 13 | ) 14 | 15 | load("@npm//:install_bazel_dependencies.bzl", "install_bazel_dependencies") 16 | install_bazel_dependencies() 17 | 18 | # Setup TypeScript toolchain 19 | load("@npm_bazel_typescript//:index.bzl", "ts_setup_workspace") 20 | ts_setup_workspace() 21 | -------------------------------------------------------------------------------- /cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | # Install top-level deps. 3 | - name: 'node:10' 4 | entrypoint: 'yarn' 5 | id: 'yarn' 6 | args: ['install'] 7 | 8 | # Run diff to find modified files in each folder. 9 | - name: 'node:10' 10 | entrypoint: 'yarn' 11 | id: 'diff' 12 | args: ['diff'] 13 | waitFor: ['yarn'] 14 | 15 | # Core. 16 | - name: 'gcr.io/cloud-builders/gcloud' 17 | entrypoint: 'bash' 18 | id: 'core' 19 | args: ['./scripts/run-build.sh', 'tfjs-core'] 20 | waitFor: ['diff'] 21 | 22 | # WebGPU. 23 | - name: 'gcr.io/cloud-builders/gcloud' 24 | entrypoint: 'bash' 25 | id: 'webgpu' 26 | args: ['./scripts/run-build.sh', 'tfjs-webgpu'] 27 | waitFor: ['diff'] 28 | 29 | # React Native. 30 | - name: 'gcr.io/cloud-builders/gcloud' 31 | entrypoint: 'bash' 32 | id: 'react-native' 33 | args: ['./scripts/run-build.sh', 'tfjs-react-native'] 34 | waitFor: ['diff'] 35 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "devDependencies": { 3 | "clang-format": "~1.2.4", 4 | "shelljs": "~0.8.3" 5 | }, 6 | "scripts": { 7 | "diff": "./scripts/diff.js" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## This repository has been archived in favor of the monorepo at https://github.com/tensorflow/tfjs. Please send pull requests there. -------------------------------------------------------------------------------- /scripts/diff.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2019 Google LLC. All Rights Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | 17 | const {exec} = require('./test-util'); 18 | const {readdirSync, statSync, writeFileSync} = require('fs'); 19 | const {join} = require('path'); 20 | 21 | const CLONE_PATH = 'clone'; 22 | 23 | const dirs = readdirSync('.').filter(f => { 24 | return f !== 'node_modules' && f !== '.git' && statSync(f).isDirectory(); 25 | }); 26 | 27 | exec( 28 | `git clone --depth=1 --single-branch ` + 29 | `https://github.com/tensorflow/tfjs-core.git ${CLONE_PATH}`); 30 | 31 | 32 | dirs.forEach(dir => { 33 | const diffCmd = `diff -rq ${CLONE_PATH}/${dir}/ ./${dir}/`; 34 | const diffOutput = exec(diffCmd, {silent: true}, true).stdout.trim(); 35 | 36 | if (diffOutput !== '') { 37 | console.log(`${dir} has modified files.`); 38 | writeFileSync(join(dir, 'diff'), diffOutput); 39 | } else { 40 | console.log(`No modified files found in ${dir}`); 41 | } 42 | }); 43 | -------------------------------------------------------------------------------- /scripts/run-build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | DIR=$1 20 | if test -f "$DIR/diff"; then 21 | gcloud builds submit . --config=$DIR/cloudbuild.yml 22 | fi 23 | -------------------------------------------------------------------------------- /scripts/test-util.js: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Google LLC. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | const shell = require('shelljs'); 17 | 18 | function exec(command, opt, ignoreCode) { 19 | const res = shell.exec(command, opt); 20 | if (!ignoreCode && res.code !== 0) { 21 | shell.echo('command', command, 'returned code', res.code); 22 | process.exit(1); 23 | } 24 | return res; 25 | } 26 | 27 | exports.exec = exec; 28 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | // Place your settings in this file to overwrite default and user settings. 2 | { 3 | "search.exclude": { 4 | "**/node_modules": true, 5 | "**/coverage/": true, 6 | "**/dist/": true, 7 | "**/yarn.lock": true, 8 | "**/.rpt2_cache/": true, 9 | "**/.yalc/**/*": true 10 | }, 11 | "tslint.configFile": "tslint.json", 12 | "files.trimTrailingWhitespace": true, 13 | "editor.tabSize": 2, 14 | "editor.insertSpaces": true, 15 | "[typescript]": { 16 | "editor.formatOnSave": true 17 | }, 18 | "[javascript]": { 19 | "editor.formatOnSave": true 20 | }, 21 | "editor.rulers": [80], 22 | "clang-format.style": "Google", 23 | "files.insertFinalNewline": true, 24 | "editor.detectIndentation": false, 25 | "editor.wrappingIndent": "none", 26 | "typescript.tsdk": "./node_modules/typescript/lib", 27 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format" 28 | } 29 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/README.md: -------------------------------------------------------------------------------- 1 | # Headless WebGL backend for TensorFlow.js via Node.js 2 | 3 | ** This project is under heavy development ** 4 | 5 | This new backend will provide a light-weight headless WebGL runtime for TensorFlow.js running under Node.js. This new backend is powered by the [node-gles](https://github.com/google/node-gles) module which uses [ANGLE](https://github.com/google/angle) to provide an integration layer to system GL runtime. This package aims to provide a think acceleration engine for IoT, desktop, and Node.js applications where CUDA (size/OS compatibility) is not an option. 6 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/demo/README.md: -------------------------------------------------------------------------------- 1 | # MobileNet tfjs-backend-nodegl Demo 2 | 3 | *This is a very early demo to show how tfjs-backend-nodegl can be used for headless WebGL acceleration.* 4 | 5 | To run this demo, perform the following: 6 | 7 | 1. Move into `tfjs-backend-nodegl` (parent directory of this demo folder): 8 | ```sh 9 | $ cd tfjs-backend-nodegl 10 | ``` 11 | 12 | 2. Build package and compile TypeScript: 13 | ```sh 14 | $ yarn && yarn tsc 15 | ``` 16 | 17 | 3. Move into the demo directory: 18 | ```sh 19 | $ cd demo 20 | ``` 21 | 22 | 4. Prep and build demo: 23 | ```sh 24 | $ yarn 25 | ``` 26 | 27 | 5. Run demo: 28 | ```sh 29 | $ node run_mobilenet_inference.js dog.jpg 30 | ``` 31 | 32 | Expected output: 33 | ```sh 34 | $ node run_mobilenet_inference.js dog.jpg 35 | Platform node has already been set. Overwriting the platform with [object Object]. 36 | - gl.VERSION: OpenGL ES 3.0 (ANGLE 2.1.0.9512a0ef062a) 37 | - gl.RENDERER: ANGLE (Intel Inc., Intel(R) Iris(TM) Plus Graphics 640, OpenGL 4.1 core) 38 | - Loading model... 39 | - Mobilenet load: 6450.763924002647ms 40 | - Coldstarting model... 41 | - Mobilenet cold start: 297.92842200398445ms 42 | - Running inference (100x) ... 43 | - Mobilenet inference: (100x) : 35.75772546708584ms 44 | ``` 45 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/demo/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/tfjs-core/1b88a535f7fa30166167463a16ebacb4cd40c797/tfjs-backend-nodegl/demo/dog.jpg -------------------------------------------------------------------------------- /tfjs-backend-nodegl/demo/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tfjs-backend-nodegl-demo", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1" 8 | }, 9 | "dependencies": { 10 | "@tensorflow-models/mobilenet": "^1.0.1", 11 | "@tensorflow/tfjs": "^1.2.2", 12 | "jpeg-js": "^0.3.5" 13 | }, 14 | "author": "", 15 | "license": "ISC" 16 | } 17 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tfjs-backend-nodegl", 3 | "version": "0.0.1", 4 | "description": "", 5 | "main": "dist/index.js", 6 | "types": "dist/index.d.ts", 7 | "scripts": { 8 | "build": "tsc", 9 | "lint": "tslint -p . -t verbose", 10 | "test": "ts-node src/run_tests.ts" 11 | }, 12 | "author": "", 13 | "license": "Apache-2.0", 14 | "dependencies": { 15 | "@tensorflow/tfjs-core": "1.2.2", 16 | "node-gles": "^0.0.13" 17 | }, 18 | "devDependencies": { 19 | "@types/jasmine": "~2.8.6", 20 | "@types/node": "^10.5.1", 21 | "@types/rimraf": "~2.0.2", 22 | "clang-format": "^1.2.4", 23 | "jasmine": "~3.1.0", 24 | "ts-node": "^8.1.0", 25 | "tslint": "~5.9.1", 26 | "typescript": "^3.4.5" 27 | } 28 | } -------------------------------------------------------------------------------- /tfjs-backend-nodegl/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs-core'; 19 | // TODO(kreeger): Fix binding definition in node-gles before switching to a 20 | // non-require import style. 21 | // tslint:disable-next-line:no-require-imports 22 | const nodeGles = require('node-gles'); 23 | 24 | const nodeGl = nodeGles.binding.createWebGLRenderingContext(); 25 | 26 | // TODO(kreeger): These are hard-coded GL integration flags. These need to be 27 | // updated to ensure they work on all systems with proper exception reporting. 28 | tf.ENV.set('WEBGL_VERSION', 2); 29 | tf.ENV.set('WEBGL_RENDER_FLOAT32_ENABLED', true); 30 | tf.ENV.set('WEBGL_DOWNLOAD_FLOAT_ENABLED', true); 31 | tf.ENV.set('WEBGL_FENCE_API_ENABLED', true); // OpenGL ES 3.0 and higher.. 32 | tf.ENV.set( 33 | 'WEBGL_MAX_TEXTURE_SIZE', nodeGl.getParameter(nodeGl.MAX_TEXTURE_SIZE)); 34 | tf.webgl.setWebGLContext(2, nodeGl); 35 | 36 | tf.registerBackend('headless-nodegl', () => { 37 | // TODO(kreeger): Consider moving all GL creation here. However, weak-ref to 38 | // GL context tends to cause an issue when running unit tests: 39 | // https://github.com/tensorflow/tfjs/issues/1732 40 | return new tf.webgl.MathBackendWebGL(new tf.webgl.GPGPUContext(nodeGl)); 41 | }, 3 /* priority */); 42 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/src/version.ts: -------------------------------------------------------------------------------- 1 | /** @license See the LICENSE file. */ 2 | 3 | // This code is auto-generated, do not modify this file! 4 | const version = '0.0.1'; 5 | export {version}; 6 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "command": "yarn", 6 | "label": "lint", 7 | "type": "shell", 8 | "args": [ 9 | "lint" 10 | ], 11 | "problemMatcher": { 12 | "base": "$tslint5", 13 | "owner": "tslint-type-checked", 14 | "fileLocation": "absolute" 15 | } 16 | }, 17 | { 18 | "command": "yarn", 19 | "label": "build", 20 | "type": "shell", 21 | "args": ["build", "--pretty", "false", "--noEmit"], 22 | "problemMatcher": [ 23 | "$tsc" 24 | ] 25 | } 26 | ] 27 | } 28 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tsconfig", 3 | "include": [ 4 | "src/" 5 | ], 6 | "exclude": [ 7 | "node_modules/" 8 | ], 9 | "compilerOptions": { 10 | "outDir": "./dist" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /tfjs-backend-nodegl/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tslint.json" 3 | } 4 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/.npmrc: -------------------------------------------------------------------------------- 1 | package-lock=false 2 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/src/**", 7 | "~/emsdk/fastcomp/emscripten/system/include/**" 8 | ], 9 | "defines": [], 10 | "compilerPath": "/usr/bin/clang", 11 | "cStandard": "c11", 12 | "cppStandard": "c++11", 13 | "intelliSenseMode": "clang-x64" 14 | } 15 | ], 16 | "version": 4 17 | } -------------------------------------------------------------------------------- /tfjs-backend-wasm/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | // Place your settings in this file to overwrite default and user settings. 2 | { 3 | "search.exclude": { 4 | "**/node_modules": true, 5 | "**/coverage/": true, 6 | "**/dist/": true, 7 | "**/yarn.lock": true, 8 | "**/.rpt2_cache/": true, 9 | "**/.yalc/**/*": true 10 | }, 11 | "tslint.configFile": "tslint.json", 12 | "files.trimTrailingWhitespace": true, 13 | "editor.tabSize": 2, 14 | "editor.insertSpaces": true, 15 | "[typescript]": { 16 | "editor.formatOnSave": true 17 | }, 18 | "[javascript]": { 19 | "editor.formatOnSave": true 20 | }, 21 | "[cpp]": { 22 | "editor.formatOnSave": true 23 | }, 24 | "editor.rulers": [80], 25 | "clang-format.style": "Google", 26 | "files.insertFinalNewline": true, 27 | "editor.detectIndentation": false, 28 | "editor.wrappingIndent": "none", 29 | "typescript.tsdk": "./node_modules/typescript/lib", 30 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format", 31 | "files.associations": { 32 | "vector": "cpp", 33 | "map": "cpp", 34 | "algorithm": "cpp" 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/README.md: -------------------------------------------------------------------------------- 1 | # Emscripten installation 2 | 3 | Install emscripten [here](https://emscripten.org/docs/getting_started/downloads.html) 4 | 5 | 6 | # Building 7 | 8 | ```sh 9 | yarn build 10 | ``` 11 | 12 | # Testing 13 | 14 | ```sh 15 | yarn test 16 | ``` 17 | 18 | # Deployment 19 | ```sh 20 | ./scripts/build-npm.sh 21 | npm publish 22 | ``` 23 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@tensorflow/tfjs-backend-wasm", 3 | "version": "0.0.1", 4 | "main": "dist/index.js", 5 | "types": "dist/index.d.ts", 6 | "jsnext:main": "dist/tf-wasm.esm.js", 7 | "module": "dist/tf-wasm.esm.js", 8 | "unpkg": "dist/tf-wasm.min.js", 9 | "jsdelivr": "dist/tf-wasm.min.js", 10 | "scripts": { 11 | "build": "rimraf dist/ && ./scripts/build-wasm.sh && tsc && cp wasm-out/*.wasm dist/", 12 | "lint": "tslint -p . -t verbose", 13 | "test": "./scripts/build-wasm.sh && karma start" 14 | }, 15 | "peerDependencies": { 16 | "@tensorflow/tfjs-core": "~1.2.7" 17 | }, 18 | "devDependencies": { 19 | "@tensorflow/tfjs-core": "~1.2.7", 20 | "@types/emscripten": "~0.0.34", 21 | "clang-format": "^1.2.4", 22 | "jasmine-core": "~3.1.0", 23 | "karma": "~4.0.0", 24 | "karma-browserstack-launcher": "~1.4.0", 25 | "karma-chrome-launcher": "~2.2.0", 26 | "karma-firefox-launcher": "~1.1.0", 27 | "karma-jasmine": "~1.1.1", 28 | "karma-typescript": "~4.0.0", 29 | "rimraf": "~2.6.2", 30 | "rollup": "^1.17.0", 31 | "rollup-plugin-commonjs": "^10.0.1", 32 | "rollup-plugin-node-resolve": "^5.2.0", 33 | "rollup-plugin-terser": "^5.1.1", 34 | "rollup-plugin-typescript2": "^0.22.1", 35 | "tslint": "~5.11.0", 36 | "tslint-no-circular-imports": "^0.5.0", 37 | "typescript": "3.3.3333", 38 | "yalc": "~1.0.0-pre.21" 39 | }, 40 | "license": "Apache-2.0" 41 | } 42 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/scripts/build-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn rimraf dist/ 20 | yarn 21 | yarn build 22 | yarn rollup -c 23 | 24 | echo "Stored standalone library at dist/tf-backend-wasm(.min).js" 25 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/scripts/build-wasm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | SRCS="src/lib.cc src/kernels.cc" 20 | 21 | # Add these optimization flags in production: -g0 -O3 --llvm-lto 3 22 | 23 | emcc $SRCS \ 24 | -std=c++11 \ 25 | -fno-rtti \ 26 | -g \ 27 | -fno-exceptions \ 28 | -I./src/ \ 29 | -o wasm-out/tfjs-backend-wasm.js \ 30 | -s ALLOW_MEMORY_GROWTH=1 \ 31 | -s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[] \ 32 | -s DISABLE_EXCEPTION_CATCHING=1 \ 33 | -s FILESYSTEM=0 \ 34 | -s EXIT_RUNTIME=0 \ 35 | -s EXPORTED_FUNCTIONS='["_malloc"]' \ 36 | -s EXTRA_EXPORTED_RUNTIME_METHODS='["cwrap"]' \ 37 | -s ENVIRONMENT=web \ 38 | -s MODULARIZE=1 \ 39 | -s EXPORT_NAME=WasmBackendModule \ 40 | -s MALLOC=emmalloc 41 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {BackendWasm} from './backend_wasm'; 19 | export {BackendWasm}; 20 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/src/index_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs-core'; 19 | import {test_util} from '@tensorflow/tfjs-core'; 20 | import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; 21 | 22 | import {BackendWasm} from './index'; 23 | 24 | /** 25 | * Tests specific to the wasm backend. The name of these tests must start with 26 | * 'wasm' so that they are always included in the test runner. See 27 | * `env.specFilter` in `setup_test.ts` for details. 28 | */ 29 | describeWithFlags('wasm', ALL_ENVS, () => { 30 | it('write and read values', async () => { 31 | const x = tf.tensor1d([1, 2, 3]); 32 | test_util.expectArraysClose([1, 2, 3], await x.data()); 33 | }); 34 | 35 | it('allocate repetitively and confirm reuse of heap space', () => { 36 | const backend = tf.backend() as BackendWasm; 37 | const size = 100; 38 | // Allocate for the first time, record the memory offset and dispose. 39 | const t1 = tf.zeros([size]); 40 | const memOffset1 = backend.getMemoryOffset(t1.dataId); 41 | t1.dispose(); 42 | 43 | // Allocate again and make sure the offset is the same (memory was reused). 44 | const t2 = tf.zeros([size]); 45 | const memOffset2 = backend.getMemoryOffset(t2.dataId); 46 | // This should fail in case of a memory leak. 47 | expect(memOffset1).toBe(memOffset2); 48 | }); 49 | }); 50 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/src/kernels.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google Inc. All Rights Reserved. 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | * ===========================================================================*/ 14 | 15 | #include "kernels.h" 16 | 17 | #include 18 | 19 | namespace tfjs { 20 | namespace kernels { 21 | 22 | // TODO(smilkov): Consider inlining small methods. 23 | 24 | template 25 | void add(T* a_buf, int a_size, T* b_buf, int b_size, T* out_buf) { 26 | int size = std::max(a_size, b_size); 27 | for (int i = 0; i < size; ++i) { 28 | out_buf[i] = a_buf[i % a_size] + b_buf[i % b_size]; 29 | } 30 | } 31 | 32 | // Templates need explicit instantiation when implemented in a .cc file. 33 | template void add(float* a_buf, int a_size, float* b_buf, int b_size, 34 | float* out_buf); 35 | template void add(int* a_buf, int a_size, int* b_buf, int b_size, 36 | int* out_buf); 37 | template void add(bool* a_buf, int a_size, bool* b_buf, int b_size, 38 | bool* out_buf); 39 | 40 | } // namespace kernels 41 | } // namespace tfjs 42 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/src/kernels.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 Google Inc. All Rights Reserved. 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | * ===========================================================================*/ 14 | 15 | #ifndef TFJS_WASM_KERNELS_H_ 16 | #define TFJS_WASM_KERNELS_H_ 17 | 18 | namespace tfjs { 19 | namespace kernels { 20 | 21 | template 22 | // Element-wise add of two tensors. 23 | void add(T* a_buf, int a_size, T* b_buf, int b_size, T* out_buf); 24 | 25 | } // namespace kernels 26 | } // namespace tfjs 27 | 28 | #endif // TFJS_WASM_KERNELS_H_ 29 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tsconfig", 3 | "include": [ 4 | "src/" 5 | ], 6 | "exclude": [ 7 | "node_modules/" 8 | ], 9 | "compilerOptions": { 10 | "outDir": "./dist" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tslint.json" 3 | } 4 | -------------------------------------------------------------------------------- /tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.d.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | export type BackendWasmModule = EmscriptenModule&{ 19 | onRuntimeInitialized: () => void; 20 | // Using the tfjs namespace to avoid conflict with emscripten's API. 21 | tfjs: { 22 | registerTensor( 23 | dataId: number, shape: Uint8Array, shapeLength: number, dtype: number, 24 | memoryOffset: number): void; 25 | // Disposes the data behind the data bucket. 26 | disposeData(dataId: number): void; 27 | // Disposes the backend and all of its associated data. 28 | dispose(): void; 29 | 30 | // Kernels. 31 | add(aId: number, bId: number, outId: number): void; 32 | } 33 | }; 34 | 35 | declare var moduleFactory: () => BackendWasmModule; 36 | export default moduleFactory; 37 | -------------------------------------------------------------------------------- /tfjs-core/.bazelignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | -------------------------------------------------------------------------------- /tfjs-core/.bazelrc: -------------------------------------------------------------------------------- 1 | build --symlink_prefix=dist/ 2 | -------------------------------------------------------------------------------- /tfjs-core/.npmignore: -------------------------------------------------------------------------------- 1 | # Ignore other packages in the mono-repo. 2 | tfjs-webgpu/ 3 | tfjs-backend-nodegl/ 4 | tfjs-react-native/ 5 | 6 | .vscode/ 7 | .rpt2_cache/ 8 | src/**/*_test.ts 9 | integration_tests/ 10 | 11 | dist/backends/**/*_test.js 12 | models/ 13 | coverage/ 14 | package/ 15 | **/node_modules/ 16 | karma.conf.js 17 | *.tgz 18 | *.log 19 | .travis.yml 20 | CONTRIBUTING.md 21 | tslint.json 22 | yarn.lock 23 | DEVELOPMENT.md 24 | ISSUE_TEMPLATE.md 25 | PULL_REQUEST_TEMPLATE.md 26 | rollup.config.js 27 | tsconfig.json 28 | .yalc/ 29 | yalc.lock 30 | tfjs-react-native/ 31 | tfjs-backend-nodegl/ 32 | -------------------------------------------------------------------------------- /tfjs-core/.npmrc: -------------------------------------------------------------------------------- 1 | package-lock=false 2 | -------------------------------------------------------------------------------- /tfjs-core/.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Mac", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "~/emsdk/fastcomp/emscripten/system/include/**" 8 | ], 9 | "defines": [], 10 | "macFrameworkPath": [ 11 | "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.14.sdk/System/Library/Frameworks" 12 | ], 13 | "compilerPath": "/usr/bin/clang", 14 | "cStandard": "c11", 15 | "cppStandard": "c++11", 16 | "intelliSenseMode": "clang-x64" 17 | } 18 | ], 19 | "version": 4 20 | } 21 | -------------------------------------------------------------------------------- /tfjs-core/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "chrome", 9 | "request": "attach", 10 | "name": "Attach Karma Chrome", 11 | "address": "localhost", 12 | "port": 9333, 13 | "pathMapping": { 14 | "/": "${workspaceRoot}", 15 | "/base/": "${workspaceRoot}/" 16 | } 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /tfjs-core/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | // Place your settings in this file to overwrite default and user settings. 2 | { 3 | "search.exclude": { 4 | "**/node_modules": true, 5 | "**/coverage/": true, 6 | "**/dist/": true, 7 | "**/yarn.lock": true, 8 | "**/.rpt2_cache/": true, 9 | "**/.yalc/**/*": true 10 | }, 11 | "tslint.configFile": "tslint.json", 12 | "files.trimTrailingWhitespace": true, 13 | "editor.tabSize": 2, 14 | "editor.insertSpaces": true, 15 | "[typescript]": { 16 | "editor.formatOnSave": true 17 | }, 18 | "[javascript]": { 19 | "editor.formatOnSave": true 20 | }, 21 | "editor.defaultFormatter": "xaver.clang-format", 22 | "editor.rulers": [80], 23 | "clang-format.style": "Google", 24 | "files.insertFinalNewline": true, 25 | "editor.detectIndentation": false, 26 | "editor.wrappingIndent": "none", 27 | "typescript.tsdk": "./node_modules/typescript/lib", 28 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format" 29 | } 30 | -------------------------------------------------------------------------------- /tfjs-core/.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "command": "yarn", 6 | "label": "lint", 7 | "type": "shell", 8 | "args": [ 9 | "lint" 10 | ], 11 | "problemMatcher": { 12 | "base": "$tslint5", 13 | "owner": "tslint-type-checked", 14 | "fileLocation": "absolute" 15 | } 16 | }, 17 | { 18 | "command": "yarn", 19 | "label": "build", 20 | "type": "shell", 21 | "args": ["build", "--pretty", "false", "--noEmit"], 22 | "problemMatcher": [ 23 | "$tsc" 24 | ] 25 | } 26 | ] 27 | } 28 | -------------------------------------------------------------------------------- /tfjs-core/BUILD.bazel: -------------------------------------------------------------------------------- 1 | # Allow typescript rules in any package to reference this file 2 | exports_files(["tsconfig.json"]) 3 | -------------------------------------------------------------------------------- /tfjs-core/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow.js Core API 2 | 3 | A part of the TensorFlow.js ecosystem, this repo hosts `@tensorflow/tfjs-core`, 4 | the TensorFlow.js Core API, which provides low-level, hardware-accelerated 5 | linear algebra operations and an eager API for automatic differentiation. 6 | 7 | Check out [js.tensorflow.org](https://js.tensorflow.org) for more 8 | information about the library, tutorials and API docs. 9 | 10 | To keep track of issues we use the [tensorflow/tfjs](https://github.com/tensorflow/tfjs) Github repo. 11 | 12 | ## Importing 13 | 14 | You can install TensorFlow.js via yarn or npm. We recommend using the 15 | [@tensorflow/tfjs](https://www.npmjs.com/package/@tensorflow/tfjs) npm package, 16 | which gives you both this Core API and the higher-level 17 | [Layers API](https://github.com/tensorflow/tfjs-layers): 18 | 19 | ```js 20 | import * as tf from '@tensorflow/tfjs'; 21 | // You have the Core API: tf.matMul(), tf.softmax(), ... 22 | // You also have Layers API: tf.model(), tf.layers.dense(), ... 23 | ``` 24 | 25 | On the other hand, if you care about the bundle size and you do not use the 26 | Layers API, you can import only the Core API: 27 | 28 | ```js 29 | import * as tfc from '@tensorflow/tfjs-core'; 30 | // You have the Core API: tfc.matMul(), tfc.softmax(), ... 31 | // No Layers API. 32 | ``` 33 | 34 | For info about development, check out [DEVELOPMENT.md](./DEVELOPMENT.md). 35 | 36 | ## For more information 37 | 38 | - [TensorFlow.js API documentation](https://js.tensorflow.org/api/latest/) 39 | - [TensorFlow.js Tutorials](https://js.tensorflow.org/tutorials/) 40 | 41 | Thanks BrowserStack for providing testing support. 42 | -------------------------------------------------------------------------------- /tfjs-core/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | # Install common dependencies. 3 | - name: 'node:10' 4 | id: 'yarn-common' 5 | entrypoint: 'yarn' 6 | args: ['install'] 7 | 8 | # Install tfjs-core dependencies. 9 | - name: 'node:10' 10 | dir: 'tfjs-core' 11 | id: 'yarn' 12 | entrypoint: 'yarn' 13 | args: ['install'] 14 | waitFor: ['yarn-common'] 15 | 16 | # Build 17 | - name: 'node:10' 18 | dir: 'tfjs-core' 19 | id: 'build' 20 | entrypoint: 'yarn' 21 | args: ['build-ci'] 22 | waitFor: ['yarn'] 23 | 24 | # Run unit tests. 25 | - name: 'node:10' 26 | dir: 'tfjs-core' 27 | id: 'test' 28 | entrypoint: 'yarn' 29 | args: ['test-ci'] 30 | waitFor: ['build'] 31 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1'] 32 | secretEnv: ['BROWSERSTACK_KEY'] 33 | 34 | # Run integration tests of other packages against core. 35 | - name: 'node:10' 36 | dir: 'tfjs-core' 37 | id: 'test-integration' 38 | entrypoint: 'yarn' 39 | args: ['test-integration'] 40 | waitFor: ['build'] 41 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1', 'NIGHTLY=$_NIGHTLY'] 42 | secretEnv: ['BROWSERSTACK_KEY'] 43 | 44 | # bundle size check 45 | - name: 'node:10' 46 | dir: 'tfjs-core' 47 | id: 'test-bundle-size' 48 | entrypoint: 'yarn' 49 | args: ['test-bundle-size'] 50 | waitFor: ['yarn'] 51 | 52 | # test doc snippets 53 | - name: 'node:10' 54 | dir: 'tfjs-core' 55 | id: 'test-snippets' 56 | entrypoint: 'yarn' 57 | args: ['test-snippets'] 58 | waitFor: ['yarn'] 59 | 60 | # test Async backends 61 | - name: 'node:10' 62 | dir: 'tfjs-core' 63 | id: 'test-async-backends' 64 | entrypoint: 'yarn' 65 | args: ['test-async-backends-ci'] 66 | waitFor: ['build'] 67 | 68 | # General configuration 69 | secrets: 70 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc 71 | secretEnv: 72 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA= 73 | timeout: 1800s 74 | logsBucket: 'gs://tfjs-build-logs' 75 | substitutions: 76 | _NIGHTLY: '' 77 | options: 78 | logStreamingOption: 'STREAM_ON' 79 | machineType: 'N1_HIGHCPU_8' 80 | substitution_option: 'ALLOW_LOOSE' 81 | -------------------------------------------------------------------------------- /tfjs-core/scripts/build-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn rimraf dist/ 20 | yarn 21 | 22 | yarn build 23 | yarn build-test-snippets 24 | yarn rollup -c --visualize 25 | 26 | # Use minified files for miniprogram 27 | mkdir dist/miniprogram 28 | cp dist/tf-core.min.js dist/miniprogram/index.js 29 | cp dist/tf-core.min.js.map dist/miniprogram/index.js.map 30 | 31 | echo "Stored standalone library at dist/tf-core(.min).js" 32 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the following Google Cloud Functions. 2 | 3 | ### `trigger_nightly` 4 | Programatically triggers a Cloud Build on master. This function is called by the Cloud Scheduler at 3am EST every day (configurable via the Cloud Scheduler UI). 5 | You can also trigger the function manually via the Cloud UI. 6 | 7 | Command to re-deploy: 8 | ```sh 9 | gcloud functions deploy nightly \ 10 | --runtime nodejs8 \ 11 | --trigger-topic nightly 12 | ``` 13 | 14 | If a build was triggered by nightly, there is a substitution variable `_NIGHTLY=true`. 15 | You can forward the substitution as the `NIGHTLY` environment variable so the scripts can use it, by specifying `env: ['NIGHTLY=$_NIGHTLY']` in `cloudbuild.yml`. E.g. `test-integration` uses the `NIGHTLY` bit to always run on nightly. 16 | 17 | ### `send_email` 18 | Sends an email and a chat message with the nightly build status. Every build sends a message to the `cloud-builds` topic with its build information. The `send_email` function is subscribed to that topic and ignores all builds (e.g. builds triggered by pull requests) **except** for the nightly build and sends an email to an internal mailing list with its build status around 3:10am. 19 | 20 | Command to re-deploy: 21 | 22 | ```sh 23 | gcloud functions deploy send_email \ 24 | --runtime nodejs8 \ 25 | --stage-bucket learnjs-174218_cloudbuild \ 26 | --trigger-topic cloud-builds \ 27 | --set-env-vars MAILGUN_API_KEY="[API_KEY_HERE]",HANGOUTS_URL="[URL_HERE]" 28 | ``` 29 | 30 | ### The pipeline 31 | 32 | The pipeline looks like this: 33 | 34 | 1) At 3am, Cloud Scheduler writes to `nightly` topic 35 | 2) That triggers the `nightly` function, which starts a build programatically 36 | 3) That build runs and writes its status to `cloud-builds` topic 37 | 4) That triggers the `send_email` function, which sends email and chat with the build status. 38 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/send_email/.gcloudignore: -------------------------------------------------------------------------------- 1 | # This file specifies files that are *not* uploaded to Google Cloud Platform 2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of 3 | # "#!include" directives (which insert the entries of the given .gitignore-style 4 | # file at that point). 5 | # 6 | # For more information, run: 7 | # $ gcloud topic gcloudignore 8 | # 9 | .gcloudignore 10 | # If you would like to upload your .git directory, .gitignore file or files 11 | # from your .gitignore file, remove the corresponding line 12 | # below: 13 | .git 14 | .gitignore 15 | 16 | node_modules 17 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/send_email/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "MAILGUN_DOMAIN":"sandbox497e7af39a1b4dee92fb92d9bfe5a686.mailgun.org", 3 | "MAILGUN_FROM":"Cloud Build ", 4 | "MAILGUN_TO":"tfjs-builds@google.com" 5 | } 6 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/send_email/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "cloudbuild-email", 3 | "version": "0.0.1", 4 | "description": "Email integration for Google Cloud Build, using Google Cloud Functions", 5 | "main": "index.js", 6 | "dependencies": { 7 | "humanize-duration": "3.10.0", 8 | "mailgun-js": "^0.22.0", 9 | "node-fetch": "^2.5.0", 10 | "request": "^2.88.0", 11 | "request-promise-native": "^1.0.7" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/trigger_nightly/.gcloudignore: -------------------------------------------------------------------------------- 1 | # This file specifies files that are *not* uploaded to Google Cloud Platform 2 | # using gcloud. It follows the same syntax as .gitignore, with the addition of 3 | # "#!include" directives (which insert the entries of the given .gitignore-style 4 | # file at that point). 5 | # 6 | # For more information, run: 7 | # $ gcloud topic gcloudignore 8 | # 9 | .gcloudignore 10 | # If you would like to upload your .git directory, .gitignore file or files 11 | # from your .gitignore file, remove the corresponding line 12 | # below: 13 | .git 14 | .gitignore 15 | 16 | node_modules 17 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/trigger_nightly/index.js: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Google LLC. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | const {google} = require('googleapis'); 17 | 18 | module.exports.nightly = async data => { 19 | const cloudbuild = google.cloudbuild('v1'); 20 | const auth = await google.auth.getClient( 21 | {scopes: ['https://www.googleapis.com/auth/cloud-platform']}); 22 | google.options({auth}); 23 | const resp = await cloudbuild.projects.triggers.run({ 24 | 'projectId': 'learnjs-174218', 25 | 'triggerId': '7423c985-2fd2-40f3-abe7-94d4c353eed0', 26 | 'resource': {'branchName': 'master'} 27 | }); 28 | console.log(resp); 29 | }; 30 | -------------------------------------------------------------------------------- /tfjs-core/scripts/cloud_funcs/trigger_nightly/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "cloudbuild", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "dependencies": { 7 | "googleapis": "^39.2.0" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tfjs-core/scripts/make-version: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2018 Google LLC. All Rights Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | 17 | 18 | // Run this script from the base directory (not the script directory): 19 | // ./scripts/make-version 20 | 21 | const fs = require('fs'); 22 | const version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version; 23 | 24 | const versionCode = 25 | `/** @license See the LICENSE file. */ 26 | 27 | // This code is auto-generated, do not modify this file! 28 | const version = '${version}'; 29 | export {version}; 30 | ` 31 | 32 | fs.writeFile('src/version.ts', versionCode, err => { 33 | if (err) { 34 | throw new Error(`Could not save version file ${version}: ${err}`); 35 | } 36 | console.log(`Version file for version ${version} saved sucessfully.`); 37 | }); 38 | -------------------------------------------------------------------------------- /tfjs-core/scripts/publish-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | # Before you run this script, do this: 18 | # 1) Update the version in package.json 19 | # 2) Run ./scripts/make-version from the base dir of the project. 20 | # 3) Run `yarn` to update `yarn.lock`, in case you updated dependencies 21 | # 4) Commit to the master branch. 22 | 23 | # Then: 24 | # 5) Checkout the master branch of this repo. 25 | # 6) Run this script as `./scripts/publish-npm.sh` from the project base dir. 26 | 27 | set -e 28 | 29 | BRANCH=`git rev-parse --abbrev-ref HEAD` 30 | ORIGIN=`git config --get remote.origin.url` 31 | CHANGES=`git status --porcelain` 32 | 33 | if [ "$BRANCH" != "master" ]; then 34 | echo "Error: Switch to the master branch before publishing." 35 | exit 36 | fi 37 | 38 | if ! [[ "$ORIGIN" =~ tensorflow/tfjs-core ]]; then 39 | echo "Error: Switch to the main repo (tensorflow/tfjs-core) before publishing." 40 | exit 41 | fi 42 | 43 | if [ ! -z "$CHANGES" ]; 44 | then 45 | echo "Make sure the master branch is clean. Found changes:" 46 | echo $CHANGES 47 | exit 1 48 | fi 49 | 50 | yarn build-npm 51 | ./scripts/make-version # This is for safety in case you forgot to do 2). 52 | ./scripts/tag-version.js 53 | npm publish 54 | echo 'Yay! Published a new package to npm.' 55 | -------------------------------------------------------------------------------- /tfjs-core/scripts/tag-version.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2018 Google LLC. All Rights Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | 17 | // Run this script from the base directory (not the script directory): 18 | // ./scripts/tag-version.js 19 | 20 | var fs = require('fs'); 21 | var exec = require('child_process').exec; 22 | 23 | var version = JSON.parse(fs.readFileSync('package.json', 'utf8')).version; 24 | var tag = `v${version}`; 25 | 26 | exec(`git tag ${tag}`, (err, stdout, stderr) => { 27 | console.log('\x1b[36m%s\x1b[0m', 'git tag command stdout:'); 28 | console.log(stdout); 29 | console.log('\x1b[31m%s\x1b[0m', 'git tag command stderr:'); 30 | console.log(stderr); 31 | 32 | if (err) { 33 | throw new Error(`Could not git tag with ${tag}: ${err.message}.`); 34 | } 35 | console.log(`Successfully tagged with ${tag}.`); 36 | }); 37 | 38 | exec(`git push --tags`, (err, stdout, stderr) => { 39 | console.log('\x1b[36m%s\x1b[0m', 'git push tags command stdout:'); 40 | console.log(stdout); 41 | console.log('\x1b[41m%s\x1b[0m', 'git push tags command stderr:'); 42 | console.log(stderr); 43 | 44 | if (err) { 45 | throw new Error(`Could not push git tags: ${err.message}.`); 46 | } 47 | console.log(`Successfully pushed tags.`); 48 | }); 49 | -------------------------------------------------------------------------------- /tfjs-core/scripts/test-ci.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn lint 20 | # Test in node (headless environment). 21 | yarn test-node-ci 22 | 23 | # Run the first karma separately so it can download the BrowserStack binary 24 | # without conflicting with others. 25 | yarn run-browserstack --browsers=bs_safari_mac,bs_ios_11 --testEnv webgl1 --flags '{"WEBGL_CPU_FORWARD": false, "WEBGL_SIZE_UPLOAD_UNIFORM": 0}' 26 | 27 | # Run the rest of the karma tests in parallel. These runs will reuse the 28 | # already downloaded binary. 29 | npm-run-all -p -c --aggregate-output \ 30 | "run-browserstack --browsers=bs_safari_mac,bs_ios_11 --flags '{\"HAS_WEBGL\": false}' --testEnv cpu" \ 31 | "run-browserstack --browsers=bs_firefox_mac,bs_chrome_mac" \ 32 | "run-browserstack --browsers=bs_chrome_mac,win_10_chrome --testEnv webgl2 --flags '{\"WEBGL_CPU_FORWARD\": false, \"WEBGL_SIZE_UPLOAD_UNIFORM\": 0}'" 33 | 34 | ### The next section tests TF.js in a webworker. 35 | # Make a dist/tf-core.min.js file to be imported by the web worker. 36 | yarn rollup -c --ci 37 | # Safari doesn't have offscreen canvas so test cpu in a webworker. 38 | # Chrome has offscreen canvas, so test webgl in a webworker. 39 | yarn test-webworker --browsers=bs_safari_mac,bs_chrome_mac 40 | -------------------------------------------------------------------------------- /tfjs-core/scripts/test-integration.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | // Copyright 2019 Google LLC. All Rights Reserved. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | // ============================================================================= 16 | 17 | const {exec} = require('../../scripts/test-util'); 18 | 19 | const dirName = 'tfjs-core-integration'; 20 | 21 | let shouldRunIntegration = false; 22 | if (process.env.NIGHTLY === 'true') { 23 | shouldRunIntegration = true; 24 | } else { 25 | exec( 26 | `git clone --depth=1 --single-branch ` + 27 | `https://github.com/tensorflow/tfjs-core.git ${dirName}`); 28 | const res = exec( 29 | `git diff --name-only --diff-filter=M --no-index ${dirName}/src/ src/`, 30 | {silent: true}, true); 31 | let files = res.stdout.trim().split('\n'); 32 | files.forEach(file => { 33 | if (file === 'src/version.ts') { 34 | shouldRunIntegration = true; 35 | } 36 | }); 37 | } 38 | if (shouldRunIntegration) { 39 | exec('./scripts/test-integration.sh'); 40 | } 41 | -------------------------------------------------------------------------------- /tfjs-core/scripts/test_snippets/test_snippets.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | /** 3 | * @license 4 | * Copyright 2019 Google LLC. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * ============================================================================= 17 | */ 18 | import * as tf from '../../src/index'; 19 | import {parseAndEvaluateSnippets} from './util'; 20 | 21 | parseAndEvaluateSnippets(tf); 22 | -------------------------------------------------------------------------------- /tfjs-core/scripts/test_snippets/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../../tsconfig.json", 3 | "compilerOptions": { 4 | "outDir": "../../dist/scripts/test_snippets/" 5 | }, 6 | "include": [ 7 | "util.ts" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /tfjs-core/src/BUILD.bazel: -------------------------------------------------------------------------------- 1 | load("@npm_bazel_typescript//:defs.bzl", "ts_library") 2 | 3 | TEST_SRCS = [ 4 | "jasmine_*", 5 | "test_*", 6 | "**/*_test.ts" 7 | ] 8 | 9 | ts_library( 10 | name = "src", 11 | srcs = glob( 12 | ["**/*.ts"], 13 | # exclude = TEST_SRCS, 14 | ), 15 | deps = [ 16 | "@npm//@types", 17 | ] 18 | ) 19 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/backend_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import * as tf from '../index'; 18 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; 19 | import {EPSILON_FLOAT16, EPSILON_FLOAT32} from './backend'; 20 | 21 | describeWithFlags('epsilon', ALL_ENVS, () => { 22 | it('Epsilon is a function of float precision', () => { 23 | const epsilonValue = tf.backend().floatPrecision() === 32 ? 24 | EPSILON_FLOAT32 : 25 | EPSILON_FLOAT16; 26 | expect(tf.backend().epsilon()).toBe(epsilonValue); 27 | }); 28 | 29 | it('abs(epsilon) > 0', async () => { 30 | expect(await tf.abs(tf.backend().epsilon()).array()).toBeGreaterThan(0); 31 | }); 32 | }); 33 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/cpu/backend_cpu_test_registry.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Constraints, registerTestEnv} from '../../jasmine_util'; 19 | 20 | export const CPU_ENVS: Constraints = { 21 | predicate: testEnv => testEnv.backendName === 'cpu' 22 | }; 23 | 24 | registerTestEnv({name: 'cpu', backendName: 'cpu', isDataSync: true}); 25 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/packing_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | export function getVecChannels(name: string, rank: number): string[] { 19 | return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`); 20 | } 21 | 22 | export function getChannels(name: string, rank: number): string[] { 23 | if (rank === 1) { 24 | return [name]; 25 | } 26 | return getVecChannels(name, rank); 27 | } 28 | 29 | export function getSourceCoords(rank: number, dims: string[]): string { 30 | if (rank === 1) { 31 | return 'rc'; 32 | } 33 | 34 | let coords = ''; 35 | for (let i = 0; i < rank; i++) { 36 | coords += dims[i]; 37 | if (i < rank - 1) { 38 | coords += ','; 39 | } 40 | } 41 | return coords; 42 | } -------------------------------------------------------------------------------- /tfjs-core/src/backends/split_shared.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Tensor} from '../tensor'; 19 | 20 | /** Shared implementation of the split kernel across WebGL and CPU. */ 21 | export function split( 22 | x: T, sizeSplits: number[], axis: number): T[] { 23 | const begin = new Array(x.rank).fill(0); 24 | const size = x.shape.slice(); 25 | return sizeSplits.map(s => { 26 | size[axis] = s; 27 | const slice = x.slice(begin, size); 28 | begin[axis] += s; 29 | return slice; 30 | }); 31 | } 32 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/tile_impl.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /** 19 | * An implementation of the tile kernel shared between webgl and cpu for string 20 | * tensors only. 21 | */ 22 | 23 | import {buffer} from '../ops/array_ops'; 24 | import {Tensor, TensorBuffer} from '../tensor'; 25 | import {DataType, Rank} from '../types'; 26 | 27 | export function tile( 28 | xBuf: TensorBuffer, reps: number[]): Tensor { 29 | const newShape: number[] = new Array(xBuf.rank); 30 | for (let i = 0; i < newShape.length; i++) { 31 | newShape[i] = xBuf.shape[i] * reps[i]; 32 | } 33 | const result = buffer(newShape, xBuf.dtype); 34 | for (let i = 0; i < result.values.length; ++i) { 35 | const newLoc = result.indexToLoc(i); 36 | 37 | const originalLoc: number[] = new Array(xBuf.rank); 38 | for (let i = 0; i < originalLoc.length; i++) { 39 | originalLoc[i] = newLoc[i] % xBuf.shape[i]; 40 | } 41 | 42 | const originalIndex = xBuf.locToIndex(originalLoc); 43 | 44 | result.values[i] = xBuf.values[originalIndex]; 45 | } 46 | return result.toTensor() as Tensor; 47 | } 48 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/addn_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export class AddNProgram implements GPGPUProgram { 21 | variableNames: string[]; 22 | outputShape: number[] = []; 23 | userCode: string; 24 | 25 | constructor(outputShape: number[], shapes: number[][]) { 26 | this.outputShape = outputShape; 27 | this.variableNames = shapes.map((_, i) => `T${i}`); 28 | 29 | const snippets: string[] = []; 30 | // Get target elements from every input tensor. 31 | this.variableNames.forEach(variable => { 32 | snippets.push(`float v${variable} = get${variable}AtOutCoords();`); 33 | }); 34 | 35 | // Calculate the sum of all elements. 36 | const operation = this.variableNames 37 | .map(variable => { 38 | return `v${variable}`; 39 | }) 40 | .join(' + '); 41 | 42 | this.userCode = ` 43 | void main() { 44 | ${snippets.join('\n ')} 45 | 46 | float result = ${operation}; 47 | setOutput(result); 48 | } 49 | `; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/addn_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export class AddNPackedProgram implements GPGPUProgram { 21 | variableNames: string[]; 22 | outputShape: number[] = []; 23 | userCode: string; 24 | usesPackedTextures = true; 25 | 26 | constructor(outputShape: number[], shapes: number[][]) { 27 | this.outputShape = outputShape; 28 | this.variableNames = shapes.map((_, i) => `T${i}`); 29 | 30 | const snippets: string[] = []; 31 | // Get target elements from every input tensor. 32 | this.variableNames.forEach(variable => { 33 | snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`); 34 | }); 35 | 36 | // Calculate the sum of all elements. 37 | const operation = this.variableNames 38 | .map(variable => { 39 | return `v${variable}`; 40 | }) 41 | .join(' + '); 42 | 43 | this.userCode = ` 44 | void main() { 45 | ${snippets.join('\n ')} 46 | 47 | vec4 result = ${operation}; 48 | setOutput(result); 49 | } 50 | `; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/argminmax_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ReduceInfo} from '../../ops/reduce_util'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class ArgMinMaxProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | outputShape: number[]; 24 | userCode: string; 25 | 26 | constructor(reduceInfo: ReduceInfo, op: 'max'|'min', firstPass: boolean) { 27 | const windowSize = reduceInfo.windowSize; 28 | const batchSize = reduceInfo.batchSize; 29 | const inSize = reduceInfo.inSize; 30 | const outSize = Math.ceil(inSize / windowSize); 31 | if (!firstPass) { 32 | this.variableNames.push('bestIndicesA'); 33 | } 34 | this.outputShape = [batchSize, outSize]; 35 | const compOp = (op === 'max') ? '>' : '<'; 36 | const indexSnippet = firstPass ? 37 | 'inOffset + i;' : 38 | 'round(getBestIndicesA(batch, inOffset + i));'; 39 | 40 | this.userCode = ` 41 | void main() { 42 | ivec2 coords = getOutputCoords(); 43 | int batch = coords[0]; 44 | int outIdx = coords[1]; 45 | int inOffset = outIdx * ${windowSize}; 46 | 47 | int bestIndex = inOffset; 48 | float bestValue = getA(batch, bestIndex); 49 | 50 | for (int i = 0; i < ${windowSize}; i++) { 51 | int inIdx = ${indexSnippet}; 52 | float candidate = getA(batch, inIdx); 53 | if (candidate ${compOp} bestValue) { 54 | bestValue = candidate; 55 | bestIndex = inIdx; 56 | } 57 | } 58 | setOutput(float(bestIndex)); 59 | } 60 | `; 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/backend_webgl_test_registry.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Constraints, registerTestEnv} from '../../jasmine_util'; 19 | 20 | export const WEBGL_ENVS: Constraints = { 21 | predicate: testEnv => testEnv.backendName === 'webgl' 22 | }; 23 | export const PACKED_ENVS: Constraints = { 24 | flags: {'WEBGL_PACK': true} 25 | }; 26 | 27 | registerTestEnv({ 28 | name: 'webgl1', 29 | backendName: 'webgl', 30 | flags: { 31 | 'WEBGL_VERSION': 1, 32 | 'WEBGL_CPU_FORWARD': false, 33 | 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 34 | }, 35 | isDataSync: true 36 | }); 37 | 38 | registerTestEnv({ 39 | name: 'webgl2', 40 | backendName: 'webgl', 41 | flags: { 42 | 'WEBGL_VERSION': 2, 43 | 'WEBGL_CPU_FORWARD': false, 44 | 'WEBGL_SIZE_UPLOAD_UNIFORM': 0 45 | }, 46 | isDataSync: true 47 | }); 48 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/binaryop_complex_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as broadcast_util from '../../ops/broadcast_util'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | // (Ar + Ai)(Br + Bi) = 22 | // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr 23 | // Yr = ArBr - AB 24 | // Yi = ArBi + AiBr 25 | export const COMPLEX_MULTIPLY = { 26 | REAL: 'return areal * breal - aimag * bimag;', 27 | IMAG: 'return areal * bimag + aimag * breal;' 28 | }; 29 | 30 | export class BinaryOpComplexProgram implements GPGPUProgram { 31 | variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; 32 | userCode: string; 33 | outputShape: number[]; 34 | 35 | constructor(op: string, aShape: number[], bShape: number[]) { 36 | this.outputShape = 37 | broadcast_util.assertAndGetBroadcastShape(aShape, bShape); 38 | 39 | this.userCode = ` 40 | float binaryOpComplex( 41 | float areal, float aimag, float breal, float bimag) { 42 | ${op} 43 | } 44 | 45 | void main() { 46 | float areal = getARealAtOutCoords(); 47 | float aimag = getAImagAtOutCoords(); 48 | float breal = getBRealAtOutCoords(); 49 | float bimag = getBImagAtOutCoords(); 50 | setOutput(binaryOpComplex(areal, aimag, breal, bimag)); 51 | } 52 | `; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/canvas_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ENV} from '../../environment'; 19 | import {BROWSER_ENVS, describeWithFlags} from '../../jasmine_util'; 20 | 21 | import {getWebGLContext} from './canvas_util'; 22 | 23 | describeWithFlags('canvas_util', BROWSER_ENVS, () => { 24 | it('Returns a valid canvas', () => { 25 | const canvas = getWebGLContext(ENV.getNumber('WEBGL_VERSION')).canvas as ( 26 | HTMLCanvasElement | OffscreenCanvas); 27 | expect( 28 | (canvas instanceof HTMLCanvasElement) || 29 | (canvas instanceof OffscreenCanvas)) 30 | .toBe(true); 31 | }); 32 | 33 | it('Returns a valid gl context', () => { 34 | const gl = getWebGLContext(ENV.getNumber('WEBGL_VERSION')); 35 | expect(gl.isContextLost()).toBe(false); 36 | }); 37 | }); 38 | 39 | describeWithFlags('canvas_util webgl2', {flags: {WEBGL_VERSION: 2}}, () => { 40 | it('is ok when the user requests webgl 1 canvas', () => { 41 | const canvas = getWebGLContext(1).canvas; 42 | expect((canvas instanceof HTMLCanvasElement)).toBe(true); 43 | }); 44 | }); 45 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/clip_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUContext} from './gpgpu_context'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class ClipProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | userCode: string; 24 | outputShape: number[]; 25 | 26 | // Caching uniform locations for speed. 27 | minLoc: WebGLUniformLocation; 28 | maxLoc: WebGLUniformLocation; 29 | 30 | constructor(aShape: number[]) { 31 | this.outputShape = aShape; 32 | this.userCode = ` 33 | uniform float minVal; 34 | uniform float maxVal; 35 | 36 | void main() { 37 | float value = getAAtOutCoords(); 38 | if (isnan(value)) { 39 | setOutput(value); 40 | return; 41 | } 42 | 43 | setOutput(clamp(value, minVal, maxVal)); 44 | } 45 | `; 46 | } 47 | 48 | getCustomSetupFunc(min: number, max: number) { 49 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { 50 | if (this.minLoc == null) { 51 | this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); 52 | this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); 53 | } 54 | gpgpu.gl.uniform1f(this.minLoc, min); 55 | gpgpu.gl.uniform1f(this.maxLoc, max); 56 | }; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/clip_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUContext} from './gpgpu_context'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class ClipPackedProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | usesPackedTextures = true; 24 | userCode: string; 25 | outputShape: number[]; 26 | 27 | // Caching uniform locations for speed. 28 | minLoc: WebGLUniformLocation; 29 | maxLoc: WebGLUniformLocation; 30 | 31 | constructor(aShape: number[]) { 32 | this.outputShape = aShape; 33 | this.userCode = ` 34 | uniform float minVal; 35 | uniform float maxVal; 36 | 37 | void main() { 38 | vec4 value = getAAtOutCoords(); 39 | 40 | if (any(isnan(value))) { 41 | setOutput(value); 42 | return; 43 | } 44 | 45 | setOutput(clamp(value, vec4(minVal), vec4(maxVal))); 46 | } 47 | `; 48 | } 49 | 50 | getCustomSetupFunc(min: number, max: number) { 51 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { 52 | if (this.minLoc == null) { 53 | this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); 54 | this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); 55 | } 56 | gpgpu.gl.uniform1f(this.minLoc, min); 57 | gpgpu.gl.uniform1f(this.maxLoc, max); 58 | }; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/complex_abs_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export class ComplexAbsProgram implements GPGPUProgram { 21 | variableNames = ['real', 'imag']; 22 | userCode: string; 23 | outputShape: number[]; 24 | 25 | constructor(shape: number[]) { 26 | this.outputShape = shape; 27 | this.userCode = ` 28 | void main() { 29 | float re = abs(getRealAtOutCoords()); 30 | float im = abs(getImagAtOutCoords()); 31 | float mx = max(re, im); 32 | 33 | // sadly the length function in glsl is not underflow-safe 34 | // (at least not on Intel GPUs). So the safe solution is 35 | // to ensure underflow-safety in all cases. 36 | setOutput( 37 | mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx)) 38 | ); 39 | } 40 | `; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/concat_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as concat_util from '../../ops/concat_util'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class ConcatProgram implements GPGPUProgram { 22 | variableNames: string[]; 23 | outputShape: number[] = []; 24 | userCode: string; 25 | 26 | // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat(). 27 | constructor(shapes: Array<[number, number]>) { 28 | this.outputShape = concat_util.computeOutShape(shapes, 1 /* axis */); 29 | this.variableNames = shapes.map((_, i) => `T${i}`); 30 | 31 | const offsets: number[] = new Array(shapes.length - 1); 32 | offsets[0] = shapes[0][1]; 33 | for (let i = 1; i < offsets.length; i++) { 34 | offsets[i] = offsets[i - 1] + shapes[i][1]; 35 | } 36 | 37 | const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`]; 38 | for (let i = 1; i < offsets.length; i++) { 39 | const shift = offsets[i - 1]; 40 | snippets.push( 41 | `else if (yC < ${offsets[i]}) ` + 42 | `setOutput(getT${i}(yR, yC-${shift}));`); 43 | } 44 | const lastIndex = offsets.length; 45 | const lastShift = offsets[offsets.length - 1]; 46 | snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`); 47 | 48 | this.userCode = ` 49 | void main() { 50 | ivec2 coords = getOutputCoords(); 51 | int yR = coords.x; 52 | int yC = coords.y; 53 | 54 | ${snippets.join('\n ')} 55 | } 56 | `; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/decode_matrix_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | import * as shader_util from './shader_compiler_util'; 21 | 22 | export class DecodeMatrixProgram implements GPGPUProgram { 23 | variableNames = ['A']; 24 | userCode: string; 25 | outputShape: [number, number, number]; 26 | 27 | constructor(outputShape: [number, number, number], texShape: [ 28 | number, number 29 | ]) { 30 | const glsl = getGlslDifferences(); 31 | this.outputShape = outputShape; 32 | 33 | this.userCode = ` 34 | ivec3 outCoordsFromFlatIndex(int index) { 35 | ${ 36 | shader_util.getLogicalCoordinatesFromFlatIndex( 37 | ['r', 'c', 'd'], outputShape)} 38 | return ivec3(r, c, d); 39 | } 40 | 41 | void main() { 42 | ivec2 resTexRC = ivec2(resultUV.yx * 43 | vec2(${texShape[0]}, ${texShape[1]})); 44 | int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y); 45 | 46 | vec4 result = vec4(0.); 47 | 48 | for (int i=0; i<4; i++) { 49 | int flatIndex = index + i; 50 | ivec3 rc = outCoordsFromFlatIndex(flatIndex); 51 | result[i] = getA(rc.x, rc.y, rc.z); 52 | } 53 | 54 | ${glsl.output} = result; 55 | } 56 | `; 57 | } 58 | } -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/decode_matrix_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | import * as shader_util from './shader_compiler_util'; 21 | 22 | export class DecodeMatrixPackedProgram implements GPGPUProgram { 23 | variableNames = ['A']; 24 | userCode: string; 25 | usesPackedTextures = true; 26 | outputShape: [number, number, number]; 27 | 28 | constructor(outputShape: [number, number, number], texShape: [ 29 | number, number 30 | ]) { 31 | const glsl = getGlslDifferences(); 32 | this.outputShape = outputShape; 33 | 34 | this.userCode = ` 35 | ivec3 outCoordsFromFlatIndex(int index) { 36 | ${ 37 | shader_util.getLogicalCoordinatesFromFlatIndex( 38 | ['r', 'c', 'd'], outputShape)} 39 | return ivec3(r, c, d); 40 | } 41 | 42 | void main() { 43 | ivec2 resTexRC = ivec2(resultUV.yx * 44 | vec2(${texShape[0]}, ${texShape[1]})); 45 | int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y); 46 | 47 | vec4 result = vec4(0.); 48 | 49 | for (int i=0; i<4; i++) { 50 | int flatIndex = index + i; 51 | ivec3 rc = outCoordsFromFlatIndex(flatIndex); 52 | result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z)); 53 | } 54 | 55 | ${glsl.output} = result; 56 | } 57 | `; 58 | } 59 | } -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/diag_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export class DiagProgram implements GPGPUProgram { 21 | variableNames = ['X']; 22 | outputShape: number[]; 23 | userCode: string; 24 | 25 | constructor(size: number) { 26 | this.outputShape = [size, size]; 27 | this.userCode = ` 28 | void main() { 29 | ivec2 coords = getOutputCoords(); 30 | float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0; 31 | setOutput(val); 32 | } 33 | `; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/encode_float_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | import {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util'; 21 | 22 | export class EncodeFloatProgram implements GPGPUProgram { 23 | variableNames = ['A']; 24 | userCode: string; 25 | outputShape: number[]; 26 | 27 | constructor(outputShape: number[]) { 28 | const glsl = getGlslDifferences(); 29 | this.outputShape = outputShape; 30 | this.userCode = ` 31 | ${ENCODE_FLOAT_SNIPPET} 32 | 33 | void main() { 34 | float x = getAAtOutCoords(); 35 | ${glsl.output} = encode_float(x); 36 | } 37 | `; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/encode_float_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | import {ENCODE_FLOAT_SNIPPET} from './shader_compiler_util'; 21 | 22 | export class EncodeFloatPackedProgram implements GPGPUProgram { 23 | variableNames = ['A']; 24 | userCode: string; 25 | outputShape: number[]; 26 | usesPackedTextures = true; 27 | 28 | constructor(outputShape: [number, number, number]) { 29 | const glsl = getGlslDifferences(); 30 | this.outputShape = outputShape; 31 | this.userCode = ` 32 | ${ENCODE_FLOAT_SNIPPET} 33 | 34 | void main() { 35 | ivec3 coords = getOutputCoords(); 36 | float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z)); 37 | ${glsl.output} = encode_float(x); 38 | } 39 | `; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/fill_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUContext} from './gpgpu_context'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class FillProgram implements GPGPUProgram { 22 | variableNames: string[]; 23 | outputShape: number[] = []; 24 | userCode: string; 25 | 26 | valueLoc: WebGLUniformLocation; 27 | 28 | constructor(shape: number[], value: number) { 29 | this.variableNames = ['x']; 30 | this.outputShape = shape; 31 | 32 | this.userCode = ` 33 | uniform float value; 34 | void main() { 35 | // Input can be obtained from uniform value. 36 | setOutput(value); 37 | } 38 | `; 39 | } 40 | 41 | getCustomSetupFunc(value: number) { 42 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { 43 | if (this.valueLoc == null) { 44 | this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value'); 45 | } 46 | gpgpu.gl.uniform1f(this.valueLoc, value); 47 | }; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/from_pixels_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class FromPixelsProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | userCode: string; 24 | outputShape: number[]; 25 | 26 | constructor(outputShape: number[]) { 27 | const glsl = getGlslDifferences(); 28 | const [height, width, ] = outputShape; 29 | this.outputShape = outputShape; 30 | this.userCode = ` 31 | void main() { 32 | ivec3 coords = getOutputCoords(); 33 | int texR = coords[0]; 34 | int texC = coords[1]; 35 | int depth = coords[2]; 36 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); 37 | 38 | vec4 values = ${glsl.texture2D}(A, uv); 39 | float value; 40 | if (depth == 0) { 41 | value = values.r; 42 | } else if (depth == 1) { 43 | value = values.g; 44 | } else if (depth == 2) { 45 | value = values.b; 46 | } else if (depth == 3) { 47 | value = values.a; 48 | } 49 | 50 | setOutput(floor(value * 255.0 + 0.5)); 51 | } 52 | `; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/from_pixels_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getGlslDifferences} from './glsl_version'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class FromPixelsPackedProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | userCode: string; 24 | outputShape: number[]; 25 | 26 | constructor(outputShape: number[]) { 27 | const glsl = getGlslDifferences(); 28 | const [height, width, ] = outputShape; 29 | this.outputShape = outputShape; 30 | this.userCode = ` 31 | void main() { 32 | ivec3 coords = getOutputCoords(); 33 | int texR = coords[0]; 34 | int texC = coords[1]; 35 | int depth = coords[2]; 36 | 37 | vec4 result = vec4(0.); 38 | 39 | for(int row=0; row<=1; row++) { 40 | for(int col=0; col<=1; col++) { 41 | texC = coords[1] + row; 42 | depth = coords[2] + col; 43 | 44 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${ 45 | height}.0); 46 | vec4 values = ${glsl.texture2D}(A, uv); 47 | float value; 48 | if (depth == 0) { 49 | value = values.r; 50 | } else if (depth == 1) { 51 | value = values.g; 52 | } else if (depth == 2) { 53 | value = values.b; 54 | } else if (depth == 3) { 55 | value = values.a; 56 | } 57 | 58 | result[row * 2 + col] = floor(value * 255.0 + 0.5); 59 | } 60 | } 61 | 62 | ${glsl.output} = result; 63 | } 64 | `; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/gather_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class GatherProgram implements GPGPUProgram { 22 | variableNames = ['A', 'indices']; 23 | outputShape: number[]; 24 | userCode: string; 25 | rank: number; 26 | 27 | constructor(aShape: number[], indicesLength: number, axis: number) { 28 | const outputShape: number[] = aShape.slice(); 29 | outputShape[axis] = indicesLength; 30 | this.outputShape = outputShape; 31 | this.rank = outputShape.length; 32 | const dtype = getCoordsDataType(this.rank); 33 | const sourceCoords = getSourceCoords(aShape, axis); 34 | 35 | this.userCode = ` 36 | void main() { 37 | ${dtype} resRC = getOutputCoords(); 38 | setOutput(getA(${sourceCoords})); 39 | } 40 | `; 41 | } 42 | } 43 | 44 | function getSourceCoords(aShape: number[], axis: number): string { 45 | const rank = aShape.length; 46 | if (rank > 4) { 47 | throw Error(`Gather for rank ${rank} is not yet supported`); 48 | } 49 | if (rank === 1) { 50 | return `int(getIndices(resRC))`; 51 | } 52 | 53 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; 54 | 55 | const sourceCoords = []; 56 | for (let i = 0; i < aShape.length; i++) { 57 | if (i === axis) { 58 | sourceCoords.push(`int(getIndices(${currentCoords[i]}))`); 59 | } else { 60 | sourceCoords.push(`${currentCoords[i]}`); 61 | } 62 | } 63 | return sourceCoords.join(); 64 | } 65 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/gather_nd_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import {GPGPUProgram} from './gpgpu_math'; 18 | import {getCoordsDataType} from './shader_compiler'; 19 | 20 | export class GatherNDProgram implements GPGPUProgram { 21 | variableNames = ['x', 'indices']; 22 | outputShape: number[]; 23 | userCode: string; 24 | constructor( 25 | private sliceDim: number, private strides: number[], shape: number[]) { 26 | this.outputShape = shape; 27 | const stridesType = getCoordsDataType(strides.length); 28 | const dtype = getCoordsDataType(shape.length); 29 | const strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides'; 30 | this.userCode = ` 31 | ${stridesType} strides = ${stridesType}(${this.strides}); 32 | void main() { 33 | ${dtype} coords = getOutputCoords(); 34 | int flattenIndex = 0; 35 | for (int j = 0; j < ${this.sliceDim}; j++) { 36 | int index = round(getIndices(coords[0], j)); 37 | flattenIndex += index * ${strideString}; 38 | } 39 | setOutput(getX(flattenIndex, coords[1])); 40 | } 41 | `; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/multinomial_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUContext} from './gpgpu_context'; 19 | import {GPGPUProgram} from './gpgpu_math'; 20 | 21 | export class MultinomialProgram implements GPGPUProgram { 22 | variableNames = ['probs']; 23 | outputShape: number[]; 24 | userCode: string; 25 | 26 | // Caching uniform location for speed. 27 | seedLoc: WebGLUniformLocation; 28 | 29 | constructor(batchSize: number, numOutcomes: number, numSamples: number) { 30 | this.outputShape = [batchSize, numSamples]; 31 | 32 | this.userCode = ` 33 | uniform float seed; 34 | 35 | void main() { 36 | ivec2 coords = getOutputCoords(); 37 | int batch = coords[0]; 38 | 39 | float r = random(seed); 40 | float cdf = 0.0; 41 | 42 | for (int i = 0; i < ${numOutcomes - 1}; i++) { 43 | cdf += getProbs(batch, i); 44 | 45 | if (r < cdf) { 46 | setOutput(float(i)); 47 | return; 48 | } 49 | } 50 | 51 | // If no other event happened, last event happened. 52 | setOutput(float(${numOutcomes - 1})); 53 | } 54 | `; 55 | } 56 | 57 | getCustomSetupFunc(seed: number) { 58 | return (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => { 59 | if (this.seedLoc == null) { 60 | this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed'); 61 | } 62 | gpgpu.gl.uniform1f(this.seedLoc, seed); 63 | }; 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/onehot_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export class OneHotProgram implements GPGPUProgram { 21 | variableNames = ['indices']; 22 | outputShape: number[]; 23 | userCode: string; 24 | 25 | // Caching uniform location for speed. 26 | seedLoc: WebGLUniformLocation; 27 | 28 | constructor( 29 | numIndices: number, depth: number, onValue: number, offValue: number) { 30 | this.outputShape = [numIndices, depth]; 31 | 32 | this.userCode = ` 33 | void main() { 34 | ivec2 coords = getOutputCoords(); 35 | int index = round(getIndices(coords.x)); 36 | setOutput(mix(float(${offValue}), float(${onValue}), 37 | float(index == coords.y))); 38 | } 39 | `; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/reverse_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class ReverseProgram implements GPGPUProgram { 22 | variableNames = ['x']; 23 | outputShape: number[]; 24 | userCode: string; 25 | 26 | constructor(xShape: number[], axis: number[]) { 27 | const rank = xShape.length; 28 | if (rank > 4) { 29 | throw new Error( 30 | `WebGL backend: Reverse of rank-${rank} tensor is not yet supported`); 31 | } 32 | this.outputShape = xShape; 33 | 34 | if (rank === 1) { 35 | this.userCode = ` 36 | void main() { 37 | int coord = getOutputCoords(); 38 | setOutput(getX(${xShape[0]} - coord - 1)); 39 | } 40 | `; 41 | return; 42 | } 43 | const getInCoord = (i: number) => { 44 | if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { 45 | return `${xShape[i]} - coords[${i}] - 1`; 46 | } 47 | return `coords[${i}]`; 48 | }; 49 | const inCoords = xShape.map((_, i) => getInCoord(i)).join(','); 50 | const type = getCoordsDataType(rank); 51 | 52 | this.userCode = ` 53 | void main() { 54 | ${type} coords = getOutputCoords(); 55 | setOutput(getX(${inCoords})); 56 | } 57 | `; 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/select_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class SelectProgram implements GPGPUProgram { 22 | variableNames = ['c', 'a', 'b']; 23 | outputShape: number[]; 24 | userCode: string; 25 | 26 | constructor(cRank: number, shape: number[], rank: number) { 27 | this.outputShape = shape; 28 | 29 | let cCoords; 30 | let abCoords; 31 | if (rank > 4) { 32 | throw Error(`Where for rank ${rank} is not yet supported`); 33 | } 34 | 35 | if (rank === 1) { 36 | abCoords = `resRC`; 37 | cCoords = `resRC`; 38 | } else { 39 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; 40 | const cCoordVars = []; 41 | const abCoordVars = []; 42 | for (let i = 0; i < shape.length; i++) { 43 | abCoordVars.push(`${currentCoords[i]}`); 44 | if (i < cRank) { 45 | cCoordVars.push(`${currentCoords[i]}`); 46 | } 47 | } 48 | cCoords = cCoordVars.join(); 49 | abCoords = abCoordVars.join(); 50 | } 51 | 52 | const dtype = getCoordsDataType(rank); 53 | 54 | this.userCode = ` 55 | void main() { 56 | ${dtype} resRC = getOutputCoords(); 57 | float cVal = getC(${cCoords}); 58 | if (cVal >= 1.0) { 59 | setOutput(getA(${abCoords})); 60 | } else { 61 | setOutput(getB(${abCoords})); 62 | } 63 | } 64 | `; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/shader_compiler_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {describeWithFlags} from '../../jasmine_util'; 19 | import {WEBGL_ENVS} from './backend_webgl_test_registry'; 20 | import {dotify, getLogicalCoordinatesFromFlatIndex} from './shader_compiler_util'; 21 | 22 | describeWithFlags('shader compiler', WEBGL_ENVS, () => { 23 | it('dotify takes two arrays of coordinates and produces' + 24 | 'the glsl that finds the dot product of those coordinates', 25 | () => { 26 | const coords1 = ['r', 'g', 'b', 'a']; 27 | const coords2 = ['x', 'y', 'z', 'w']; 28 | 29 | expect(dotify(coords1, coords2)) 30 | .toEqual('dot(vec4(r,g,b,a), vec4(x,y,z,w))'); 31 | }); 32 | 33 | it('dotify should split up arrays into increments of vec4s', () => { 34 | const coords1 = ['a', 'b', 'c', 'd', 'e', 'f', 'g']; 35 | const coords2 = ['h', 'i', 'j', 'k', 'l', 'm', 'n']; 36 | 37 | expect(dotify(coords1, coords2)) 38 | .toEqual( 39 | 'dot(vec4(a,b,c,d), vec4(h,i,j,k))+dot(vec3(e,f,g), vec3(l,m,n))'); 40 | }); 41 | 42 | it('getLogicalCoordinatesFromFlatIndex produces glsl that takes' + 43 | 'a flat index and finds its coordinates within that shape', 44 | () => { 45 | const coords = ['r', 'c', 'd']; 46 | const shape = [1, 2, 3]; 47 | 48 | expect(getLogicalCoordinatesFromFlatIndex(coords, shape)) 49 | .toEqual( 50 | 'int r = index / 6; index -= r * 6;' + 51 | 'int c = index / 3; int d = index - c * 3;'); 52 | }); 53 | }); 54 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/strided_slice_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class StridedSliceProgram implements GPGPUProgram { 22 | variableNames = ['x']; 23 | outputShape: number[]; 24 | userCode: string; 25 | 26 | constructor( 27 | begin: number[], strides: number[], size: number[], 28 | shrinkAxis: number[]) { 29 | const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1); 30 | this.outputShape = shape; 31 | const rank = size.length; 32 | const inputDtype = getCoordsDataType(size.length); 33 | const dtype = getCoordsDataType(shape.length); 34 | 35 | let newCoords = ''; 36 | if (rank === 1) { 37 | newCoords = 'coords * strides + begin'; 38 | } else { 39 | let outputAxis = 0; 40 | newCoords = 41 | size.map((_, i) => { 42 | if (shrinkAxis.indexOf(i) === -1) { 43 | outputAxis++; 44 | return shape.length === 1 ? 45 | `coords * strides[${i}] + begin[${i}]` : 46 | `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`; 47 | } else { 48 | return `begin[${i}]`; 49 | } 50 | }) 51 | .join(','); 52 | } 53 | 54 | this.userCode = ` 55 | ${inputDtype} begin = ${inputDtype}(${begin}); 56 | ${inputDtype} strides = ${inputDtype}(${strides}); 57 | 58 | void main() { 59 | ${dtype} coords = getOutputCoords(); 60 | setOutput(getX(${newCoords})); 61 | } 62 | `; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/tile_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class TileProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | outputShape: number[]; 24 | userCode: string; 25 | rank: number; 26 | 27 | constructor(aShape: number[], reps: number[]) { 28 | const outputShape: number[] = new Array(aShape.length); 29 | for (let i = 0; i < outputShape.length; i++) { 30 | outputShape[i] = aShape[i] * reps[i]; 31 | } 32 | this.outputShape = outputShape; 33 | this.rank = outputShape.length; 34 | const dtype = getCoordsDataType(this.rank); 35 | const sourceCoords = getSourceCoords(aShape); 36 | 37 | this.userCode = ` 38 | void main() { 39 | ${dtype} resRC = getOutputCoords(); 40 | setOutput(getA(${sourceCoords})); 41 | } 42 | `; 43 | } 44 | } 45 | 46 | function getSourceCoords(aShape: number[]): string { 47 | const rank = aShape.length; 48 | if (rank > 5) { 49 | throw Error(`Tile for rank ${rank} is not yet supported`); 50 | } 51 | if (rank === 1) { 52 | return `imod(resRC, ${aShape[0]})`; 53 | } 54 | 55 | const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u']; 56 | 57 | const sourceCoords = []; 58 | for (let i = 0; i < aShape.length; i++) { 59 | sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`); 60 | } 61 | return sourceCoords.join(); 62 | } 63 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/transpose_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | import {getCoordsDataType} from './shader_compiler'; 20 | 21 | export class TransposeProgram implements GPGPUProgram { 22 | variableNames = ['A']; 23 | outputShape: number[]; 24 | userCode: string; 25 | rank: number; 26 | 27 | constructor(aShape: number[], newDim: number[]) { 28 | const outputShape: number[] = new Array(aShape.length); 29 | for (let i = 0; i < outputShape.length; i++) { 30 | outputShape[i] = aShape[newDim[i]]; 31 | } 32 | this.outputShape = outputShape; 33 | this.rank = outputShape.length; 34 | const dtype = getCoordsDataType(this.rank); 35 | const switched = getSwitchedCoords(newDim); 36 | 37 | this.userCode = ` 38 | void main() { 39 | ${dtype} resRC = getOutputCoords(); 40 | setOutput(getA(${switched})); 41 | } 42 | `; 43 | } 44 | } 45 | 46 | function getSwitchedCoords(newDim: number[]): string { 47 | const rank = newDim.length; 48 | if (rank > 6) { 49 | throw Error(`Transpose for rank ${rank} is not yet supported`); 50 | } 51 | const originalOrder = 52 | ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v']; 53 | const switchedCoords = new Array(rank); 54 | for (let i = 0; i < newDim.length; i++) { 55 | switchedCoords[newDim[i]] = originalOrder[i]; 56 | } 57 | return switchedCoords.join(); 58 | } 59 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/unaryop_packed_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {GPGPUProgram} from './gpgpu_math'; 19 | 20 | export const LINEAR = `return x;`; 21 | 22 | export const LOG = ` 23 | vec4 result = log(x); 24 | vec4 isNaN = vec4(lessThan(x, vec4(0.0))); 25 | result.r = isNaN.r == 1.0 ? NAN : result.r; 26 | result.g = isNaN.g == 1.0 ? NAN : result.g; 27 | result.b = isNaN.b == 1.0 ? NAN : result.b; 28 | result.a = isNaN.a == 1.0 ? NAN : result.a; 29 | 30 | return result; 31 | `; 32 | 33 | export const RELU = ` 34 | vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0))); 35 | bvec4 isNaN = isnan(x); 36 | 37 | result.r = isNaN.r ? x.r : result.r; 38 | result.g = isNaN.g ? x.g : result.g; 39 | result.b = isNaN.b ? x.b : result.b; 40 | result.a = isNaN.a ? x.a : result.a; 41 | 42 | return result; 43 | `; 44 | 45 | export class UnaryOpPackedProgram implements GPGPUProgram { 46 | variableNames = ['A']; 47 | userCode: string; 48 | outputShape: number[]; 49 | usesPackedTextures = true; 50 | 51 | constructor(aShape: number[], opSnippet: string) { 52 | this.outputShape = aShape; 53 | this.userCode = ` 54 | vec4 unaryOperation(vec4 x) { 55 | ${opSnippet} 56 | } 57 | 58 | void main() { 59 | vec4 x = getAAtOutCoords(); 60 | vec4 y = unaryOperation(x); 61 | 62 | setOutput(y); 63 | } 64 | `; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/unpack_gpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {getChannels, getSourceCoords} from '../packing_util'; 19 | 20 | import {GPGPUProgram} from './gpgpu_math'; 21 | import {getCoordsDataType} from './shader_compiler'; 22 | 23 | export class UnpackProgram implements GPGPUProgram { 24 | variableNames = ['A']; 25 | usesPackedTextures = true; 26 | outputShape: number[]; 27 | userCode: string; 28 | 29 | constructor(outputShape: number[]) { 30 | this.outputShape = outputShape; 31 | const rank = outputShape.length; 32 | 33 | const channels = getChannels('rc', rank); 34 | const dtype = getCoordsDataType(rank); 35 | const sourceCoords = getSourceCoords(rank, channels); 36 | const innerDims = channels.slice(-2); 37 | const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`; 38 | 39 | this.userCode = ` 40 | void main() { 41 | ${dtype} rc = getOutputCoords(); 42 | vec4 packedInput = getA(${sourceCoords}); 43 | 44 | setOutput(getChannel(packedInput, ${coords})); 45 | } 46 | `; 47 | } 48 | } -------------------------------------------------------------------------------- /tfjs-core/src/backends/webgl/webgl_types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // TODO(nsthorat): Move these to the webgl official typings. 19 | export interface WebGL2DisjointQueryTimerExtension { 20 | TIME_ELAPSED_EXT: number; 21 | GPU_DISJOINT_EXT: number; 22 | } 23 | 24 | export interface WebGL1DisjointQueryTimerExtension { 25 | TIME_ELAPSED_EXT: number; 26 | QUERY_RESULT_AVAILABLE_EXT: number; 27 | GPU_DISJOINT_EXT: number; 28 | QUERY_RESULT_EXT: number; 29 | createQueryEXT: () => {}; 30 | beginQueryEXT: (ext: number, query: WebGLQuery) => void; 31 | endQueryEXT: (ext: number) => void; 32 | deleteQueryEXT: (query: WebGLQuery) => void; 33 | isQueryEXT: (query: WebGLQuery) => boolean; 34 | getQueryObjectEXT: 35 | (query: WebGLQuery, queryResultAvailableExt: number) => number; 36 | } 37 | 38 | export interface WebGLContextAttributes { 39 | alpha?: boolean; 40 | antialias?: boolean; 41 | premultipliedAlpha?: boolean; 42 | preserveDrawingBuffer?: boolean; 43 | depth?: boolean; 44 | stencil?: boolean; 45 | failIfMajorPerformanceCaveat?: boolean; 46 | } 47 | -------------------------------------------------------------------------------- /tfjs-core/src/backends/where_impl.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /** An implementation of the Where kernel shared between cpu and webgl */ 19 | 20 | import {buffer} from '../ops/array_ops'; 21 | import {Tensor2D} from '../tensor'; 22 | import {TypedArray} from '../types'; 23 | 24 | export function whereImpl(condShape: number[], condVals: TypedArray): Tensor2D { 25 | const indices = []; 26 | for (let i = 0; i < condVals.length; i++) { 27 | if (condVals[i]) { 28 | indices.push(i); 29 | } 30 | } 31 | 32 | const inBuffer = buffer(condShape, 'int32'); 33 | 34 | const out = buffer([indices.length, condShape.length], 'int32'); 35 | for (let i = 0; i < indices.length; i++) { 36 | const loc = inBuffer.indexToLoc(indices[i]); 37 | const offset = i * condShape.length; 38 | out.values.set(loc, offset); 39 | } 40 | return out.toTensor() as Tensor2D; 41 | } 42 | -------------------------------------------------------------------------------- /tfjs-core/src/browser_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const delayCallback: Function = (() => { 19 | if (typeof requestAnimationFrame !== 'undefined') { 20 | return requestAnimationFrame; 21 | } else if (typeof setImmediate !== 'undefined') { 22 | return setImmediate; 23 | } 24 | return (f: Function) => f(); // no delays 25 | })(); 26 | 27 | /** 28 | * Returns a promise that resolve when a requestAnimationFrame has completed. 29 | * 30 | * On Node.js this uses setImmediate instead of requestAnimationFrame. 31 | * 32 | * This is simply a sugar method so that users can do the following: 33 | * `await tf.nextFrame();` 34 | */ 35 | /** @doc {heading: 'Performance', subheading: 'Timing'} */ 36 | function nextFrame(): Promise { 37 | return new Promise(resolve => delayCallback(() => resolve())); 38 | } 39 | 40 | export {nextFrame}; 41 | -------------------------------------------------------------------------------- /tfjs-core/src/browser_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from './index'; 19 | import {ALL_ENVS, describeWithFlags} from './jasmine_util'; 20 | 21 | describeWithFlags('nextFrame', ALL_ENVS, () => { 22 | it('basic usage', async () => { 23 | const t0 = tf.util.now(); 24 | await tf.nextFrame(); 25 | const t1 = tf.util.now(); 26 | // tf.util.now should give sufficient accuracy on all supported envs. 27 | expect(t1 > t0); 28 | }); 29 | 30 | it('does not block timers', async () => { 31 | let flag = false; 32 | setTimeout(() => { 33 | flag = true; 34 | }, 50); 35 | const t0 = tf.util.now(); 36 | expect(flag).toBe(false); 37 | while (tf.util.now() - t0 < 1000 && !flag) { 38 | await tf.nextFrame(); 39 | } 40 | expect(flag).toBe(true); 41 | }); 42 | }); 43 | -------------------------------------------------------------------------------- /tfjs-core/src/log.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ENV} from './environment'; 19 | 20 | export function warn(...msg: Array<{}>): void { 21 | if (!ENV.getBool('IS_TEST')) { 22 | console.warn(...msg); 23 | } 24 | } 25 | 26 | export function log(...msg: Array<{}>): void { 27 | if (!ENV.getBool('IS_TEST')) { 28 | console.log(...msg); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /tfjs-core/src/math.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /** 19 | * Exports under the tf.math.* namespace. 20 | */ 21 | 22 | import {confusionMatrix} from './ops/confusion_matrix'; 23 | 24 | export {confusionMatrix}; 25 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/clone_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '../index'; 19 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; 20 | import {expectArraysClose} from '../test_util'; 21 | 22 | describeWithFlags('clone', ALL_ENVS, () => { 23 | it('returns a tensor with the same shape and value', async () => { 24 | const a = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]); 25 | const aPrime = tf.clone(a); 26 | expect(aPrime.shape).toEqual(a.shape); 27 | expectArraysClose(await aPrime.data(), await a.data()); 28 | expect(aPrime.shape).toEqual(a.shape); 29 | }); 30 | 31 | it('accepts a tensor-like object', async () => { 32 | const res = tf.clone([[1, 2, 3], [4, 5, 6]]); 33 | expect(res.dtype).toBe('float32'); 34 | expect(res.shape).toEqual([2, 3]); 35 | expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); 36 | }); 37 | }); 38 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/concat_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as util from '../util'; 19 | 20 | export function assertParamsConsistent(shapes: number[][], axis: number) { 21 | const rank = shapes[0].length; 22 | shapes.forEach((shape, i) => { 23 | util.assert( 24 | shape.length === rank, 25 | () => 26 | `Error in concat${rank}D: rank of tensors[${i}] must be the same ` + 27 | `as the rank of the rest (${rank})`); 28 | }); 29 | 30 | util.assert( 31 | axis >= 0 && axis < rank, 32 | () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`); 33 | 34 | const firstShape = shapes[0]; 35 | shapes.forEach((shape, i) => { 36 | for (let r = 0; r < rank; r++) { 37 | util.assert( 38 | (r === axis) || (shape[r] === firstShape[r]), 39 | () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` + 40 | `does not match the shape of the rest (${firstShape}) ` + 41 | `along the non-concatenated axis ${i}.`); 42 | } 43 | }); 44 | } 45 | 46 | export function computeOutShape(shapes: number[][], axis: number): number[] { 47 | const outputShape = shapes[0].slice(); 48 | for (let i = 1; i < shapes.length; i++) { 49 | outputShape[axis] += shapes[i][axis]; 50 | } 51 | return outputShape; 52 | } 53 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/diag.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ENGINE} from '../engine'; 19 | import {Tensor} from '../tensor'; 20 | import {convertToTensor} from '../tensor_util_env'; 21 | import {op} from './operation'; 22 | 23 | /** 24 | * Returns a diagonal tensor with a given diagonal values. 25 | * 26 | * Given a diagonal, this operation returns a tensor with the diagonal and 27 | * everything else padded with zeros. 28 | * 29 | * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor 30 | * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]` 31 | * 32 | * ```js 33 | * const x = tf.tensor1d([1, 2, 3, 4]); 34 | * 35 | * tf.diag(x).print() 36 | * ``` 37 | * ```js 38 | * const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2]) 39 | * 40 | * tf.diag(x).print() 41 | * ``` 42 | * @param x The input tensor. 43 | */ 44 | function diag_(x: Tensor): Tensor { 45 | const $x = convertToTensor(x, 'x', 'diag').flatten(); 46 | const outShape = [...x.shape, ...x.shape]; 47 | return ENGINE.runKernel(backend => backend.diag($x), {$x}).reshape(outShape); 48 | } 49 | 50 | export const diag = op({diag_}); 51 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/dropout_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Tensor} from '../tensor'; 19 | import * as util from '../util'; 20 | 21 | /** 22 | * Normalize noise shape based on provided tensor and noise shape. 23 | * 24 | * @param x Tensor. 25 | * @param noiseShape The shape for the randomly generated keep/drop flags, as 26 | * an array of numbers. Optional. 27 | * @returns Normalized noise shape. 28 | */ 29 | export function getNoiseShape(x: Tensor, noiseShape?: number[]): number[] { 30 | if (noiseShape == null) { 31 | return x.shape.slice(); 32 | } 33 | if (util.arraysEqual(x.shape, noiseShape)) { 34 | return noiseShape; 35 | } 36 | if (x.shape.length === noiseShape.length) { 37 | const newDimension: number[] = []; 38 | for (let i = 0; i < x.shape.length; i++) { 39 | if (noiseShape[i] == null && x.shape[i] != null) { 40 | newDimension.push(x.shape[i]); 41 | } else { 42 | newDimension.push(noiseShape[i]); 43 | } 44 | } 45 | return newDimension; 46 | } 47 | 48 | return noiseShape; 49 | } 50 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/dropout_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '../index'; 19 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; 20 | import {getNoiseShape} from './dropout_util'; 21 | 22 | describeWithFlags('getNoiseShape', ALL_ENVS, () => { 23 | it('x.shape == noiseShape', async () => { 24 | const x = tf.ones([2, 3]); 25 | const noiseShape = [2, 3]; 26 | const shape = getNoiseShape(x, noiseShape); 27 | expect(shape).toEqual([2, 3]); 28 | }); 29 | 30 | it('x.shape and noiseShape have same length, different value', async () => { 31 | const x = tf.ones([2, 3]); 32 | const noiseShape = [2, 1]; 33 | const shape = getNoiseShape(x, noiseShape); 34 | expect(shape).toEqual([2, 1]); 35 | }); 36 | 37 | it('noiseShape has null value', async () => { 38 | const x = tf.ones([2, 3]); 39 | const noiseShape = [2, null]; 40 | const shape = getNoiseShape(x, noiseShape); 41 | expect(shape).toEqual([2, 3]); 42 | }); 43 | 44 | it('x.shape and noiseShape has different length', async () => { 45 | const x = tf.ones([2, 3, 4]); 46 | const noiseShape = [2, 3]; 47 | const shape = getNoiseShape(x, noiseShape); 48 | expect(shape).toEqual([2, 3]); 49 | }); 50 | 51 | it('noiseShape is null', async () => { 52 | const x = tf.ones([2, 3]); 53 | const shape = getNoiseShape(x, null); 54 | expect(shape).toEqual([2, 3]); 55 | }); 56 | }); 57 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/erf_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | export const ERF_P = 0.3275911; 19 | export const ERF_A1 = 0.254829592; 20 | export const ERF_A2 = -0.284496736; 21 | export const ERF_A3 = 1.421413741; 22 | export const ERF_A4 = -1.453152027; 23 | export const ERF_A5 = 1.061405429; 24 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/fused_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Tensor, Tensor3D} from '../tensor'; 19 | 20 | export type Activation = 'linear'|'relu'|'prelu'; 21 | 22 | export type FusedBatchMatMulConfig = { 23 | a: Tensor3D, 24 | b: Tensor3D, 25 | transposeA: boolean, 26 | transposeB: boolean, 27 | bias?: Tensor, 28 | activation?: Activation, 29 | preluActivationWeights?: Tensor 30 | }; 31 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/operation.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import {ENGINE} from '../engine'; 18 | 19 | /** 20 | * Used for wrapping functions that perform math operations on 21 | * Tensors. The function will be wrapped in a named scope that cleans all 22 | * memory usage after the function is done. 23 | */ 24 | export function op(f: {[name: string]: T}): T { 25 | const keys = Object.keys(f); 26 | if (keys.length !== 1) { 27 | throw new Error( 28 | `Please provide an object with a single key ` + 29 | `(operation name) mapping to a function. Got an object with ` + 30 | `${keys.length} keys.`); 31 | } 32 | 33 | let opName = keys[0]; 34 | const fn = f[opName]; 35 | 36 | // Strip the underscore from the end of the function name. 37 | if (opName.endsWith('_')) { 38 | opName = opName.substring(0, opName.length - 1); 39 | } 40 | 41 | // tslint:disable-next-line:no-any 42 | const f2 = (...args: any[]) => { 43 | ENGINE.startScope(opName); 44 | try { 45 | const result = fn(...args); 46 | if (result instanceof Promise) { 47 | console.error('Cannot return a Promise inside of tidy.'); 48 | } 49 | ENGINE.endScope(result); 50 | return result; 51 | } catch (ex) { 52 | ENGINE.endScope(null); 53 | throw ex; 54 | } 55 | }; 56 | Object.defineProperty(f2, 'name', {value: opName, configurable: true}); 57 | 58 | // tslint:disable-next-line:no-any 59 | return f2 as any as T; 60 | } 61 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/operation_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; 18 | import {op} from './operation'; 19 | 20 | describeWithFlags('operation', ALL_ENVS, () => { 21 | it('executes and preserves function name', () => { 22 | const f = () => 2; 23 | const opfn = op({'opName': f}); 24 | 25 | expect(opfn.name).toBe('opName'); 26 | expect(opfn()).toBe(2); 27 | }); 28 | 29 | it('executes, preserves function name, strips underscore', () => { 30 | const f = () => 2; 31 | const opfn = op({'opName_': f}); 32 | 33 | expect(opfn.name).toBe('opName'); 34 | expect(opfn()).toBe(2); 35 | }); 36 | 37 | it('throws when passing an object with multiple keys', () => { 38 | const f = () => 2; 39 | expect(() => op({'opName_': f, 'opName2_': f})) 40 | .toThrowError(/Please provide an object with a single key/); 41 | }); 42 | }); 43 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/ops.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | export * from './batchnorm'; 19 | export * from './boolean_mask'; 20 | export * from './complex_ops'; 21 | export * from './concat_split'; 22 | export * from './conv'; 23 | export * from './matmul'; 24 | export * from './reverse'; 25 | export * from './pool'; 26 | export * from './slice'; 27 | export * from './unary_ops'; 28 | export * from './reduction_ops'; 29 | export * from './compare'; 30 | export * from './binary_ops'; 31 | export * from './relu_ops'; 32 | export * from './logical_ops'; 33 | export * from './array_ops'; 34 | export * from './tensor_ops'; 35 | export * from './transpose'; 36 | export * from './softmax'; 37 | export * from './lrn'; 38 | export * from './norm'; 39 | export * from './segment_ops'; 40 | export * from './lstm'; 41 | export * from './moving_average'; 42 | export * from './strided_slice'; 43 | export * from './topk'; 44 | export * from './scatter_nd'; 45 | export * from './spectral_ops'; 46 | export * from './sparse_to_dense'; 47 | export * from './gather_nd'; 48 | export * from './diag'; 49 | export * from './dropout'; 50 | export * from './signal_ops'; 51 | export * from './in_top_k'; 52 | 53 | export {op} from './operation'; 54 | 55 | // Second level exports. 56 | import * as losses from './loss_ops'; 57 | import * as linalg from './linalg_ops'; 58 | import * as image from './image_ops'; 59 | import * as spectral from './spectral_ops'; 60 | import * as fused from './fused_ops'; 61 | import * as signal from './signal_ops'; 62 | 63 | export {image, linalg, losses, spectral, fused, signal}; 64 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/reduce_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /** 19 | * Inputs of size above this threshold will be parallelized by calling multiple 20 | * shader programs. 21 | */ 22 | import {nearestDivisor} from '../util'; 23 | 24 | export const PARALLELIZE_THRESHOLD = 30; 25 | 26 | export interface ReduceInfo { 27 | windowSize: number; 28 | batchSize: number; 29 | inSize: number; 30 | } 31 | 32 | export function computeOptimalWindowSize(inSize: number): number { 33 | if (inSize <= PARALLELIZE_THRESHOLD) { 34 | return inSize; 35 | } 36 | return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); 37 | } 38 | -------------------------------------------------------------------------------- /tfjs-core/src/ops/selu_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | export const SELU_SCALEALPHA = 1.7580993408473768599402175208123; 19 | export const SELU_SCALE = 1.0507009873554804934193349852946; 20 | -------------------------------------------------------------------------------- /tfjs-core/src/platforms/platform.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {RequestDetails} from '../io/types'; 19 | 20 | /** 21 | * At any given time a single platform is active and represents and 22 | * implementation of this interface. In practice, a platform is an environment 23 | * where TensorFlow.js can be executed, e.g. the browser or Node.js. 24 | */ 25 | export interface Platform { 26 | /** 27 | * Makes an HTTP request. 28 | * @param path The URL path to make a request to 29 | * @param init The request init. See init here: 30 | * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request 31 | */ 32 | fetch(path: string, requestInits?: RequestInit, options?: RequestDetails): 33 | Promise; 34 | 35 | /** 36 | * Returns the current high-resolution time in milliseconds relative to an 37 | * arbitrary time in the past. It works across different platforms (node.js, 38 | * browsers). 39 | */ 40 | now(): number; 41 | 42 | /** 43 | * Encode the provided string into an array of bytes using the provided 44 | * encoding. 45 | */ 46 | encode(text: string, encoding: string): Uint8Array; 47 | /** Decode the provided bytes into a string using the provided encoding. */ 48 | decode(bytes: Uint8Array, encoding: string): string; 49 | } 50 | -------------------------------------------------------------------------------- /tfjs-core/src/platforms/platform_browser.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | import {ENV} from '../environment'; 18 | import {Platform} from './platform'; 19 | 20 | export class PlatformBrowser implements Platform { 21 | private textEncoder: TextEncoder; 22 | 23 | constructor() { 24 | // According to the spec, the built-in encoder can do only UTF-8 encoding. 25 | // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder 26 | this.textEncoder = new TextEncoder(); 27 | } 28 | 29 | fetch(path: string, init?: RequestInit): Promise { 30 | return fetch(path, init); 31 | } 32 | 33 | now(): number { 34 | return performance.now(); 35 | } 36 | 37 | encode(text: string, encoding: string): Uint8Array { 38 | if (encoding !== 'utf-8' && encoding !== 'utf8') { 39 | throw new Error( 40 | `Browser's encoder only supports utf-8, but got ${encoding}`); 41 | } 42 | return this.textEncoder.encode(text); 43 | } 44 | decode(bytes: Uint8Array, encoding: string): string { 45 | return new TextDecoder(encoding).decode(bytes); 46 | } 47 | } 48 | 49 | if (ENV.get('IS_BROWSER')) { 50 | ENV.setPlatform('browser', new PlatformBrowser()); 51 | } 52 | -------------------------------------------------------------------------------- /tfjs-core/src/setup_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | /** 19 | * This file is necessary so we register all test environments before we start 20 | * executing tests. 21 | */ 22 | import './backends/cpu/backend_cpu_test_registry'; 23 | import './backends/webgl/backend_webgl_test_registry'; 24 | 25 | import {parseTestEnvFromKarmaFlags, setTestEnvs, TEST_ENVS} from './jasmine_util'; 26 | 27 | // tslint:disable-next-line:no-any 28 | declare let __karma__: any; 29 | if (typeof __karma__ !== 'undefined') { 30 | const testEnv = parseTestEnvFromKarmaFlags(__karma__.config.args, TEST_ENVS); 31 | if (testEnv != null) { 32 | setTestEnvs([testEnv]); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /tfjs-core/src/tensor_types.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {Tensor, Variable} from './tensor'; 19 | import {DataType} from './types'; 20 | 21 | /** @docalias {[name: string]: Tensor} */ 22 | export type NamedTensorMap = { 23 | [name: string]: Tensor; 24 | }; 25 | 26 | export interface NamedTensor { 27 | name: string; 28 | tensor: Tensor; 29 | } 30 | 31 | export type NamedVariableMap = { 32 | [name: string]: Variable; 33 | }; 34 | 35 | export type GradSaveFunc = (save: Tensor[]) => void; 36 | 37 | /** 38 | * @docalias void|number|string|TypedArray|Tensor|Tensor[]|{[key: 39 | * string]:Tensor|number|string} 40 | */ 41 | export type TensorContainer = 42 | void|Tensor|string|number|boolean|TensorContainerObject| 43 | TensorContainerArray|Float32Array|Int32Array|Uint8Array; 44 | export interface TensorContainerObject { 45 | [x: string]: TensorContainer; 46 | } 47 | export interface TensorContainerArray extends Array {} 48 | 49 | export interface TensorInfo { 50 | // Name of the tensor. 51 | name: string; 52 | // Tensor shape information, Optional. 53 | shape?: number[]; 54 | // Data type of the tensor. 55 | dtype: DataType; 56 | } 57 | -------------------------------------------------------------------------------- /tfjs-core/src/test_node.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | /** 3 | * @license 4 | * Copyright 2019 Google LLC. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * ============================================================================= 17 | */ 18 | 19 | import {setTestEnvs} from './jasmine_util'; 20 | 21 | // tslint:disable-next-line:no-require-imports 22 | const jasmine = require('jasmine'); 23 | 24 | process.on('unhandledRejection', e => { 25 | throw e; 26 | }); 27 | 28 | setTestEnvs([{name: 'node', backendName: 'cpu'}]); 29 | 30 | const runner = new jasmine(); 31 | runner.loadConfig({spec_files: ['dist/**/**_test.js'], random: false}); 32 | runner.execute(); 33 | -------------------------------------------------------------------------------- /tfjs-core/src/train.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // So typings can propagate. 19 | import {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; 20 | import {AdagradOptimizer} from './optimizers/adagrad_optimizer'; 21 | import {AdamOptimizer} from './optimizers/adam_optimizer'; 22 | import {AdamaxOptimizer} from './optimizers/adamax_optimizer'; 23 | import {MomentumOptimizer} from './optimizers/momentum_optimizer'; 24 | import {OptimizerConstructors} from './optimizers/optimizer_constructors'; 25 | import {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; 26 | import {SGDOptimizer} from './optimizers/sgd_optimizer'; 27 | 28 | // tslint:disable-next-line:no-unused-expression 29 | [MomentumOptimizer, SGDOptimizer, AdadeltaOptimizer, AdagradOptimizer, 30 | RMSPropOptimizer, AdamaxOptimizer, AdamOptimizer]; 31 | 32 | export const train = { 33 | sgd: OptimizerConstructors.sgd, 34 | momentum: OptimizerConstructors.momentum, 35 | adadelta: OptimizerConstructors.adadelta, 36 | adagrad: OptimizerConstructors.adagrad, 37 | rmsprop: OptimizerConstructors.rmsprop, 38 | adamax: OptimizerConstructors.adamax, 39 | adam: OptimizerConstructors.adam 40 | }; 41 | -------------------------------------------------------------------------------- /tfjs-core/src/types_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2018 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {upcastType} from './types'; 19 | 20 | describe('upcastType', () => { 21 | it('upcasts bool to bool', () => { 22 | expect(upcastType('bool', 'bool')).toBe('bool'); 23 | }); 24 | 25 | it('upcasts bool/int32 to int32', () => { 26 | expect(upcastType('bool', 'int32')).toBe('int32'); 27 | expect(upcastType('int32', 'int32')).toBe('int32'); 28 | }); 29 | 30 | it('upcasts bool/int32/float32 to float32', () => { 31 | expect(upcastType('bool', 'float32')).toBe('float32'); 32 | expect(upcastType('int32', 'float32')).toBe('float32'); 33 | expect(upcastType('float32', 'float32')).toBe('float32'); 34 | }); 35 | 36 | it('upcasts bool/int32/float32/complex64 to complex64', () => { 37 | expect(upcastType('bool', 'complex64')).toBe('complex64'); 38 | expect(upcastType('int32', 'complex64')).toBe('complex64'); 39 | expect(upcastType('float32', 'complex64')).toBe('complex64'); 40 | expect(upcastType('complex64', 'complex64')).toBe('complex64'); 41 | }); 42 | 43 | it('fails to upcast anything other than string with string', () => { 44 | expect(() => upcastType('bool', 'string')).toThrowError(); 45 | expect(() => upcastType('int32', 'string')).toThrowError(); 46 | expect(() => upcastType('float32', 'string')).toThrowError(); 47 | expect(() => upcastType('complex64', 'string')).toThrowError(); 48 | // Ok upcasting string to string. 49 | expect(upcastType('string', 'string')).toBe('string'); 50 | }); 51 | }); 52 | -------------------------------------------------------------------------------- /tfjs-core/src/version.ts: -------------------------------------------------------------------------------- 1 | /** @license See the LICENSE file. */ 2 | 3 | // This code is auto-generated, do not modify this file! 4 | const version = '1.2.7'; 5 | export {version}; 6 | -------------------------------------------------------------------------------- /tfjs-core/src/version_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2017 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {version_core} from './index'; 19 | 20 | describe('version', () => { 21 | it('version is contained', () => { 22 | // tslint:disable-next-line:no-require-imports 23 | const expected = require('../package.json').version; 24 | expect(version_core).toBe(expected); 25 | }); 26 | }); 27 | -------------------------------------------------------------------------------- /tfjs-core/src/webgl.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as gpgpu_util from './backends/webgl/gpgpu_util'; 19 | import * as webgl_util from './backends/webgl/webgl_util'; 20 | 21 | export {MathBackendWebGL, WebGLMemoryInfo, WebGLTimingInfo} from './backends/webgl/backend_webgl'; 22 | export {setWebGLContext} from './backends/webgl/canvas_util'; 23 | export {GPGPUContext} from './backends/webgl/gpgpu_context'; 24 | export {GPGPUProgram} from './backends/webgl/gpgpu_math'; 25 | // WebGL specific utils. 26 | export {gpgpu_util, webgl_util}; 27 | -------------------------------------------------------------------------------- /tfjs-core/src/worker_node_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {describeWithFlags, HAS_NODE_WORKER} from './jasmine_util'; 19 | import {expectArraysClose} from './test_util'; 20 | // tslint:disable:no-require-imports 21 | 22 | const fn2String = (fn: Function): string => { 23 | const funcStr = '(' + fn.toString() + ')()'; 24 | return funcStr; 25 | }; 26 | 27 | // The source code of a web worker. 28 | const workerTestNode = () => { 29 | // Web worker scripts in node live relative to the CWD, not to the dir of the 30 | // file that spawned them. 31 | const tf = require('./dist/index.js'); 32 | const {parentPort} = require('worker_threads'); 33 | let a = tf.tensor1d([1, 2, 3]); 34 | const b = tf.tensor1d([3, 2, 1]); 35 | a = a.add(b); 36 | parentPort.postMessage({data: a.dataSync()}); 37 | }; 38 | 39 | describeWithFlags('computation in worker (node env)', HAS_NODE_WORKER, () => { 40 | it('tensor in worker', (done) => { 41 | const {Worker} = require('worker_threads'); 42 | const worker = new Worker(fn2String(workerTestNode), {eval: true}); 43 | // tslint:disable-next-line:no-any 44 | worker.on('message', (msg: any) => { 45 | const data = msg.data; 46 | expectArraysClose(data, [4, 4, 4]); 47 | done(); 48 | }); 49 | }); 50 | }); 51 | -------------------------------------------------------------------------------- /tfjs-core/src/worker_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from './index'; 19 | import {describeWithFlags, HAS_WORKER} from './jasmine_util'; 20 | import {expectArraysClose} from './test_util'; 21 | 22 | const fn2workerURL = (fn: Function): string => { 23 | const blob = 24 | new Blob(['(' + fn.toString() + ')()'], {type: 'application/javascript'}); 25 | return URL.createObjectURL(blob); 26 | }; 27 | 28 | // The source code of a web worker. 29 | const workerTest = () => { 30 | //@ts-ignore 31 | importScripts('http://bs-local.com:12345/base/dist/tf-core.min.js'); 32 | let a = tf.tensor1d([1, 2, 3]); 33 | const b = tf.tensor1d([3, 2, 1]); 34 | a = a.add(b); 35 | //@ts-ignore 36 | self.postMessage({data: a.dataSync()}); 37 | }; 38 | 39 | describeWithFlags('computation in worker', HAS_WORKER, () => { 40 | it('tensor in worker', (done) => { 41 | const worker = new Worker(fn2workerURL(workerTest)); 42 | worker.onmessage = (msg) => { 43 | const data = msg.data.data; 44 | expectArraysClose(data, [4, 4, 4]); 45 | done(); 46 | }; 47 | }); 48 | }); 49 | -------------------------------------------------------------------------------- /tfjs-core/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tsconfig", 3 | "include": [ 4 | "src/" 5 | ], 6 | "exclude": [ 7 | "node_modules/" 8 | ], 9 | "compilerOptions": { 10 | "outDir": "./dist" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /tfjs-core/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tslint.json" 3 | } 4 | -------------------------------------------------------------------------------- /tfjs-react-native/.npmignore: -------------------------------------------------------------------------------- 1 | .babelrc 2 | .DS_Store 3 | .idea/ 4 | .rpt2_cache 5 | .travis.yml 6 | .vscode 7 | *.tgz 8 | *.txt 9 | **.yalc 10 | **yalc.lock 11 | cloudbuild.yml 12 | coverage/ 13 | demo/ 14 | dist/**/*_test.d.ts 15 | dist/**/*_test.js 16 | karma.conf.js 17 | node_modules/ 18 | npm-debug.log 19 | package-lock.json 20 | package/ 21 | rollup.config.js 22 | scripts/ 23 | src/**/*_test.ts 24 | tsconfig.json 25 | tslint.json 26 | yarn-error.log 27 | yarn.lock 28 | -------------------------------------------------------------------------------- /tfjs-react-native/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "search.exclude": { 3 | "**/node_modules": true, 4 | "coverage/": true, 5 | "dist/": true, 6 | "**/yarn.lock": true, 7 | ".rpt2_cache/": true, 8 | ".yalc/": true 9 | }, 10 | "tslint.configFile": "tslint.json", 11 | "files.trimTrailingWhitespace": true, 12 | "editor.tabSize": 2, 13 | "editor.insertSpaces": true, 14 | "[typescript]": { 15 | "editor.formatOnSave": true 16 | }, 17 | "[javascript]": { 18 | "editor.formatOnSave": true 19 | }, 20 | "editor.rulers": [80], 21 | "clang-format.style": "Google", 22 | "files.insertFinalNewline": true, 23 | "editor.detectIndentation": false, 24 | "editor.wrappingIndent": "none", 25 | "typescript.tsdk": "./node_modules/typescript/lib", 26 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format" 27 | } 28 | -------------------------------------------------------------------------------- /tfjs-react-native/README.md: -------------------------------------------------------------------------------- 1 | # Platform Adapter for React Native 2 | 3 | Status: Early development. 4 | -------------------------------------------------------------------------------- /tfjs-react-native/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | # Install common dependencies. 3 | - name: 'node:10' 4 | id: 'yarn-common' 5 | entrypoint: 'yarn' 6 | args: ['install'] 7 | 8 | # Install react native dependencies. 9 | - name: 'node:10' 10 | dir: 'tfjs-react-native' 11 | entrypoint: 'yarn' 12 | id: 'test-react-native' 13 | args: ['test-ci'] 14 | env: ['BROWSERSTACK_USERNAME=deeplearnjs1'] 15 | secretEnv: ['BROWSERSTACK_KEY'] 16 | waitFor: ['yarn-common'] 17 | 18 | # General configuration 19 | secrets: 20 | - kmsKeyName: projects/learnjs-174218/locations/global/keyRings/tfjs/cryptoKeys/enc 21 | secretEnv: 22 | BROWSERSTACK_KEY: CiQAkwyoIW0LcnxymzotLwaH4udVTQFBEN4AEA5CA+a3+yflL2ASPQAD8BdZnGARf78MhH5T9rQqyz9HNODwVjVIj64CTkFlUCGrP1B2HX9LXHWHLmtKutEGTeFFX9XhuBzNExA= 23 | timeout: 1800s 24 | logsBucket: 'gs://tfjs-build-logs' 25 | substitutions: 26 | _NIGHTLY: '' 27 | options: 28 | logStreamingOption: 'STREAM_ON' 29 | substitution_option: 'ALLOW_LOOSE' 30 | -------------------------------------------------------------------------------- /tfjs-react-native/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@tensorflow/tfjs-platform-react-native", 3 | "version": "0.1.0", 4 | "description": "TensorFlow.js platform implementation for React Native", 5 | "main": "dist/index.js", 6 | "types": "dist/index.d.ts", 7 | "jsnext:main": "dist/tf-react-native.esm.js", 8 | "module": "dist/tf-react-native.esm.js", 9 | "unpkg": "dist/tf-react-native.min.js", 10 | "jsdelivr": "dist/tf-react-native.min.js", 11 | "license": "Apache-2.0", 12 | "private": true, 13 | "scripts": { 14 | "publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push", 15 | "build": "tsc", 16 | "link-local": "yalc link", 17 | "unlink-local": "yalc remove", 18 | "lint": "tslint -p . -t verbose", 19 | "test": "karma start", 20 | "test-ci": "./scripts/test-ci.sh" 21 | }, 22 | "devDependencies": { 23 | "@react-native-community/async-storage": "^1.4.2", 24 | "@tensorflow/tfjs-core": "^1.2.7", 25 | "@types/base64-js": "^1.2.5", 26 | "@types/jasmine": "~2.5.53", 27 | "@types/react-native": "^0.60.2", 28 | "clang-format": "~1.2.2", 29 | "expo-gl": "^5.0.1", 30 | "jasmine": "~3.1.0", 31 | "jasmine-core": "~3.1.0", 32 | "karma": "~4.2.0", 33 | "karma-browserstack-launcher": "~1.4.0", 34 | "karma-chrome-launcher": "~2.2.0", 35 | "karma-jasmine": "~1.1.0", 36 | "karma-typescript": "~4.1.1", 37 | "karma-verbose-reporter": "^0.0.6", 38 | "rimraf": "~2.6.2", 39 | "rollup": "^0.58.2", 40 | "rollup-plugin-commonjs": "9.1.3", 41 | "rollup-plugin-node-resolve": "3.3.0", 42 | "rollup-plugin-typescript2": "0.13.0", 43 | "rollup-plugin-uglify": "~3.0.0", 44 | "tslint": "~5.11.0", 45 | "tslint-no-circular-imports": "^0.5.0", 46 | "typescript": "3.3.3333", 47 | "yalc": "^1.0.0-pre.32" 48 | }, 49 | "dependencies": { 50 | "base64-js": "^1.3.0", 51 | "buffer": "^5.2.1" 52 | }, 53 | "peerDependencies": { 54 | "@react-native-community/async-storage": "^1.4.2", 55 | "@tensorflow/tfjs-core": ">=1.2.7", 56 | "expo-gl": "^5.0.1", 57 | "react-native": ">=0.59.0" 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /tfjs-react-native/scripts/test-ci.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn 20 | yarn lint 21 | yarn build 22 | karma start --browserstack --browsers=bs_chrome_mac 23 | 24 | -------------------------------------------------------------------------------- /tfjs-react-native/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import './platform_react_native'; 19 | 20 | export {asyncStorageIO} from './async_storage_io'; 21 | export {bundleResourceIO} from './bundle_resource_io'; 22 | export {fetch} from './platform_react_native'; 23 | -------------------------------------------------------------------------------- /tfjs-react-native/src/platform_react_native_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs-core'; 19 | import {PlatformReactNative} from './platform_react_native'; 20 | 21 | describe('PlatformReactNative', () => { 22 | it('tf.util.fetch calls platform.fetch', async () => { 23 | const platform = new PlatformReactNative(); 24 | tf.setPlatform('rn-test-platform', platform); 25 | 26 | spyOn(platform, 'fetch'); 27 | 28 | await tf.util.fetch('test/url', {method: 'GET'}); 29 | expect(platform.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'}); 30 | }); 31 | }); 32 | -------------------------------------------------------------------------------- /tfjs-react-native/src/test_utils/async_storage_mock.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // We mock this library as it cannot be loaded in a browser yet we do want 19 | // to do JS only unit tests. 20 | 21 | // @ts-ignore (use of window) 22 | const localStorage = window.localStorage; 23 | // Use default export to match the library we are mocking. 24 | // tslint:disable-next-line 25 | export default localStorage; 26 | -------------------------------------------------------------------------------- /tfjs-react-native/src/test_utils/gl_view_mock.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // Mock gl-view to export nothing as we don't test it in unit tests. 19 | // We mock this library as it cannot be loaded in a browser yet we do want 20 | // to do JS only unit tests. 21 | // tslint:disable-next-line 22 | export default {}; 23 | -------------------------------------------------------------------------------- /tfjs-react-native/src/test_utils/react_native_mock.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // We mock this library as it cannot be loaded in a browser yet we do want 19 | // to do JS only unit tests. 20 | 21 | interface ImageResolvedAssetSource { 22 | uri: string; 23 | } 24 | 25 | // tslint:disable-next-line 26 | export const Image = { 27 | resolveAssetSource: (resourceId: string|number): ImageResolvedAssetSource => { 28 | return { 29 | uri: `http://localhost/assets/${resourceId}`, 30 | }; 31 | } 32 | }; 33 | -------------------------------------------------------------------------------- /tfjs-react-native/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tsconfig", 3 | "compilerOptions": { 4 | "target": "es2017", 5 | "lib": [ 6 | "es2017" 7 | ], 8 | "outDir": "./dist/", 9 | "skipLibCheck": true, 10 | }, 11 | "include": [ 12 | "src/", 13 | ], 14 | "exclude": [ 15 | "node_modules" 16 | ], 17 | } 18 | -------------------------------------------------------------------------------- /tfjs-react-native/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tslint.json" 3 | } 4 | -------------------------------------------------------------------------------- /tfjs-webgpu/.npmignore: -------------------------------------------------------------------------------- 1 | .babelrc 2 | .DS_Store 3 | .idea/ 4 | .rpt2_cache 5 | .travis.yml 6 | .vscode 7 | *.tgz 8 | *.txt 9 | **.yalc 10 | **yalc.lock 11 | cloudbuild.yml 12 | coverage/ 13 | demo/ 14 | dist/**/*_test.d.ts 15 | dist/**/*_test.js 16 | karma.conf.js 17 | node_modules/ 18 | npm-debug.log 19 | package-lock.json 20 | package/ 21 | rollup.config.js 22 | scripts/ 23 | src/**/*_test.ts 24 | tsconfig.json 25 | tslint.json 26 | yarn-error.log 27 | yarn.lock 28 | -------------------------------------------------------------------------------- /tfjs-webgpu/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "search.exclude": { 3 | "**/node_modules": true, 4 | "coverage/": true, 5 | "dist/": true, 6 | "**/yarn.lock": true, 7 | ".rpt2_cache/": true, 8 | ".yalc/": true 9 | }, 10 | "tslint.configFile": "tslint.json", 11 | "files.trimTrailingWhitespace": true, 12 | "editor.tabSize": 2, 13 | "editor.insertSpaces": true, 14 | "[typescript]": { 15 | "editor.formatOnSave": true 16 | }, 17 | "[javascript]": { 18 | "editor.formatOnSave": true 19 | }, 20 | "editor.rulers": [80], 21 | "clang-format.style": "Google", 22 | "files.insertFinalNewline": true, 23 | "editor.detectIndentation": false, 24 | "editor.wrappingIndent": "none", 25 | "typescript.tsdk": "./node_modules/typescript/lib", 26 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format" 27 | } 28 | -------------------------------------------------------------------------------- /tfjs-webgpu/README.md: -------------------------------------------------------------------------------- 1 | This is an experimental backend. 2 | 3 | ``` 4 | $ yarn # to install dependencies 5 | $ yarn test 6 | ``` 7 | 8 | # To run the test suite: 9 | The `$CHROME_BIN` environment variable must be set to the location of the Chrome Canary application. 10 | 11 | e.g. in `~/.bash_profile`: 12 | 13 | `export CHROME_BIN="$HOME/Documents/PROJECTS/tfjs-core-wrapper/Google Chrome Canary.app/Contents/MacOS/Google Chrome Canary"` 14 | -------------------------------------------------------------------------------- /tfjs-webgpu/cloudbuild.yml: -------------------------------------------------------------------------------- 1 | steps: 2 | # Install common dependencies. 3 | - name: 'node:10' 4 | id: 'yarn-common' 5 | entrypoint: 'yarn' 6 | args: ['install'] 7 | 8 | # Install webgpu dependencies. 9 | - name: 'node:10' 10 | dir: 'tfjs-webgpu' 11 | entrypoint: 'yarn' 12 | id: 'test-webgpu' 13 | args: ['test-ci'] 14 | waitFor: ['yarn-common'] 15 | 16 | # General configuration 17 | timeout: 1800s 18 | logsBucket: 'gs://tfjs-build-logs' 19 | substitutions: 20 | _NIGHTLY: '' 21 | options: 22 | logStreamingOption: 'STREAM_ON' 23 | substitution_option: 'ALLOW_LOOSE' 24 | -------------------------------------------------------------------------------- /tfjs-webgpu/karma.conf.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const karmaTypescriptConfig = { 19 | tsconfig: 'tsconfig.json', 20 | // Disable coverage reports and instrumentation by default for tests 21 | coverageOptions: {instrumentation: false}, 22 | reports: {} 23 | }; 24 | 25 | module.exports = function(config) { 26 | const args = []; 27 | if (config.grep) { 28 | args.push('--grep', config.grep); 29 | } 30 | if (config.flags) { 31 | args.push('--flags', config.flags); 32 | } 33 | config.set({ 34 | basePath: '', 35 | frameworks: ['jasmine', 'karma-typescript'], 36 | files: [ 37 | 'src/setup_test.ts', // Setup the environment for the tests. 38 | {pattern: 'src/**/*.ts'}, // Import all tests. 39 | ], 40 | preprocessors: {'**/*.ts': ['karma-typescript']}, 41 | karmaTypescriptConfig, 42 | reporters: ['progress', 'karma-typescript'], 43 | port: 9876, 44 | colors: true, 45 | autoWatch: false, 46 | browsers: ['Chrome', 'chrome_webgpu'], 47 | singleRun: true, 48 | customLaunchers: { 49 | chrome_webgpu: { 50 | base: 'Chrome', 51 | flags: ['--enable-unsafe-webgpu'], 52 | } 53 | }, 54 | client: {jasmine: {random: false}, args: args} 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /tfjs-webgpu/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@tensorflow/tfjs-backend-webgpu", 3 | "version": "0.0.1", 4 | "main": "dist/index.js", 5 | "types": "dist/index.d.ts", 6 | "jsnext:main": "dist/tf-webgpu.esm.js", 7 | "module": "dist/tf-webgpu.esm.js", 8 | "unpkg": "dist/tf-webgpu.min.js", 9 | "jsdelivr": "dist/tf-webgpu.min.js", 10 | "scripts": { 11 | "publish-local": "rimraf dist/ && yarn build && rollup -c && yalc push", 12 | "build": "tsc", 13 | "link-local": "yalc link", 14 | "unlink-local": "yalc remove", 15 | "lint": "tslint -p . -t verbose", 16 | "test": "karma start --browsers='chrome_webgpu'", 17 | "test-ci": "./scripts/test-ci.sh" 18 | }, 19 | "license": "Apache-2.0", 20 | "devDependencies": { 21 | "@tensorflow/tfjs-core": "1.2.1", 22 | "@types/jasmine": "~2.5.53", 23 | "clang-format": "~1.2.2", 24 | "http-server": "~0.10.0", 25 | "jasmine-core": "~3.1.0", 26 | "karma": "~4.0.0", 27 | "karma-browserstack-launcher": "~1.4.0", 28 | "karma-chrome-launcher": "~2.2.0", 29 | "karma-firefox-launcher": "~1.1.0", 30 | "karma-jasmine": "~1.1.1", 31 | "karma-typescript": "~4.1.1", 32 | "rimraf": "~2.6.2", 33 | "rollup": "^0.58.2", 34 | "rollup-plugin-commonjs": "9.1.3", 35 | "rollup-plugin-node-resolve": "3.3.0", 36 | "rollup-plugin-typescript2": "0.13.0", 37 | "rollup-plugin-uglify": "~3.0.0", 38 | "tslint": "~5.11.0", 39 | "tslint-no-circular-imports": "^0.5.0", 40 | "typescript": "3.3.3333", 41 | "yalc": "~1.0.0-pre.21" 42 | }, 43 | "dependencies": { 44 | "@webgpu/shaderc": "0.0.6", 45 | "@webgpu/types": "0.0.6" 46 | }, 47 | "peerDependencies": { 48 | "@tensorflow/tfjs-core": "1.2.1" 49 | } 50 | } -------------------------------------------------------------------------------- /tfjs-webgpu/scripts/test-ci.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2019 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================= 16 | 17 | set -e 18 | 19 | yarn 20 | yarn lint 21 | yarn build 22 | 23 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/flags_webgpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google Inc. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ENV} from '@tensorflow/tfjs-core'; 19 | 20 | /** Whether we submit commands to the device queue immediately. */ 21 | ENV.registerFlag('WEBGPU_IMMEDIATE_EXECUTION_ENABLED', () => true); 22 | 23 | /** 24 | * Thread register block size for matmul kernel. If 0, we use the version of 25 | * matMul without register blocking. 26 | */ 27 | ENV.registerFlag('WEBGPU_MATMUL_WORK_PER_THREAD', () => 4); 28 | 29 | /** 30 | * -1: conv2d_naive 31 | * 0: conv2d_mm with matmul without register blocking 32 | * >0: conv2d_mm with matmul_packed with WPT=this 33 | */ 34 | ENV.registerFlag('WEBGPU_CONV2D_WORK_PER_THREAD', () => 2); 35 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/index.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import * as tf from '@tensorflow/tfjs-core'; 19 | import * as Shaderc from '@webgpu/shaderc'; 20 | 21 | import {WebGPUBackend} from './backend_webgpu'; 22 | 23 | tf.registerBackend('webgpu', async () => { 24 | const shaderc = await Shaderc.instantiate(); 25 | // @ts-ignore navigator.gpu is required 26 | const adapter = await navigator.gpu.requestAdapter({}); 27 | const device = await adapter.requestDevice({}); 28 | return new WebGPUBackend(device, shaderc); 29 | }, 3 /*priority*/); 30 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/kernels/binary_op_webgpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {backend_util} from '@tensorflow/tfjs-core'; 19 | 20 | import {computeDispatch} from '../webgpu_util'; 21 | 22 | import {WebGPUProgram} from './webgpu_program'; 23 | 24 | export const MUL = 'return a * b;'; 25 | export const ADD = 'return a + b;'; 26 | export const SUB = 'return a - b;'; 27 | 28 | export const INT_DIV = ` 29 | float s = sign(a) * sign(b); 30 | int ia = int(round(a)); 31 | int ib = int(round(b)); 32 | return float(idiv(ia, ib, s)); 33 | `; 34 | 35 | export class BinaryOpProgram implements WebGPUProgram { 36 | outputShape: number[]; 37 | userCode: string; 38 | dispatchLayout: {x: number[]}; 39 | dispatch: [number, number, number]; 40 | variableNames = ['A', 'B']; 41 | 42 | constructor(op: string, aShape: number[], bShape: number[]) { 43 | this.outputShape = backend_util.assertAndGetBroadcastShape(aShape, bShape); 44 | 45 | this.dispatchLayout = {x: this.outputShape.map((d, i) => i)}; 46 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape); 47 | 48 | this.userCode = ` 49 | float binaryOperation(float a, float b) { 50 | ${op} 51 | } 52 | 53 | void main() { 54 | uint index = gl_GlobalInvocationID.x; 55 | float a = getAAtOutCoords(); 56 | float b = getBAtOutCoords(); 57 | setOutput(index, binaryOperation(a, b)); 58 | } 59 | `; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/kernels/from_pixels_webgpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {computeDispatch, flatDispatchLayout} from '../webgpu_util'; 19 | 20 | import {WebGPUProgram} from './webgpu_program'; 21 | 22 | export class FromPixelsProgram implements WebGPUProgram { 23 | outputShape: number[]; 24 | userCode: string; 25 | variableNames = ['A']; 26 | dispatchLayout: {x: number[]}; 27 | dispatch: [number, number, number]; 28 | 29 | constructor(outputShape: number[]) { 30 | const [height, width, ] = outputShape; 31 | this.outputShape = outputShape; 32 | this.dispatchLayout = flatDispatchLayout(this.outputShape); 33 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape); 34 | 35 | this.userCode = ` 36 | void main() { 37 | ivec3 coords = getOutputCoords(); 38 | int texR = coords[0]; 39 | int texC = coords[1]; 40 | int depth = coords[2]; 41 | vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); 42 | 43 | vec4 values = texelFetch(A, uv); 44 | float value; 45 | if (depth == 0) { 46 | value = values.r; 47 | } else if (depth == 1) { 48 | value = values.g; 49 | } else if (depth == 2) { 50 | value = values.b; 51 | } else if (depth == 3) { 52 | value = values.a; 53 | } 54 | 55 | setOutput(floor(value * 255.0 + 0.5)); 56 | } 57 | `; 58 | } 59 | } -------------------------------------------------------------------------------- /tfjs-webgpu/src/kernels/unary_op_webgpu.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {computeDispatch} from '../webgpu_util'; 19 | import {WebGPUProgram} from './webgpu_program'; 20 | 21 | export const RELU = 'return max(a, 0.0);'; 22 | 23 | export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`; 24 | 25 | export class UnaryOpProgram implements WebGPUProgram { 26 | outputShape: number[]; 27 | userCode: string; 28 | dispatchLayout: {x: number[]}; 29 | dispatch: [number, number, number]; 30 | variableNames = ['A']; 31 | 32 | constructor(outputShape: number[], op: string) { 33 | this.outputShape = outputShape; 34 | this.dispatchLayout = {x: this.outputShape.map((d, i) => i)}; 35 | this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape); 36 | 37 | this.userCode = ` 38 | float unaryOperation(float a) { 39 | ${op} 40 | } 41 | 42 | void main() { 43 | uint index = gl_GlobalInvocationID.x; 44 | float a = getAAtOutCoords(); 45 | setOutput(index, unaryOperation(a)); 46 | } 47 | `; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/shader_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | // Generates GLSL that computes strides. 19 | export function symbolicallyComputeStrides( 20 | indicesArr: number[], variableName: string): string[] { 21 | if (Math.max(...indicesArr) > 3) { 22 | throw new Error('Cannot symbolically compute strides for rank > 4 tensor.'); 23 | } 24 | 25 | const numCoords = indicesArr.length; 26 | const shape = indicesArr.map(d => `${variableName}[${d}]`); 27 | const strides = new Array(numCoords - 1); 28 | strides[numCoords - 2] = shape[numCoords - 1]; 29 | for (let i = numCoords - 3; i >= 0; --i) { 30 | strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`; 31 | } 32 | 33 | return strides; 34 | } -------------------------------------------------------------------------------- /tfjs-webgpu/src/shader_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {symbolicallyComputeStrides} from './shader_util'; 19 | 20 | describe('shader util', () => { 21 | it('symbolicallyComputeStrides takes in array of dimensions ' + 22 | 'and returns GLSL to compute strides for those dimensions', 23 | () => { 24 | const layout = [0, 2, 1]; 25 | const strides = symbolicallyComputeStrides(layout, 'output'); 26 | expect(strides[0]).toEqual('(output[1] * output[2])'); 27 | expect(strides[1]).toEqual('output[1]'); 28 | }); 29 | 30 | it('symbolicallyComputeStrides throws if given a dimension ' + 31 | 'that cannot be accessed from a GLSL data type', 32 | () => { 33 | const layout = [0, 5, 2]; 34 | expect(() => symbolicallyComputeStrides(layout, 'output')) 35 | .toThrowError(); 36 | }); 37 | }); -------------------------------------------------------------------------------- /tfjs-webgpu/src/test_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {ALL_ENVS, describeWithFlags, TestEnv} from '@tensorflow/tfjs-core/dist/jasmine_util'; 19 | 20 | export function describeWebGPU(name: string, tests: (env: TestEnv) => void) { 21 | describeWithFlags('webgpu ' + name, ALL_ENVS, tests); 22 | } 23 | -------------------------------------------------------------------------------- /tfjs-webgpu/src/webgpu_util.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | const arrayProduct = (arr: number[]) => { 19 | let product = 1; 20 | for (let i = 0; i < arr.length; i++) { 21 | product *= arr[i]; 22 | } 23 | return product; 24 | }; 25 | 26 | // Computes dispatch geometry based on layout of output dimensions and 27 | // workGroupSize. 28 | export function computeDispatch( 29 | layout: {x: number[], y?: number[], z?: number[]}, outputShape: number[], 30 | workGroupSize: [number, number, number] = [1, 1, 1], 31 | elementsPerThread: [number, number, number] = 32 | [1, 1, 1]): [number, number, number] { 33 | return [ 34 | Math.ceil( 35 | arrayProduct(layout.x.map(d => outputShape[d])) / 36 | (workGroupSize[0] * elementsPerThread[0])), 37 | layout.y ? Math.ceil( 38 | arrayProduct(layout.y.map(d => outputShape[d])) / 39 | (workGroupSize[1] * elementsPerThread[1])) : 40 | 1, 41 | layout.z ? Math.ceil( 42 | arrayProduct(layout.z.map(d => outputShape[d])) / 43 | (workGroupSize[2] * elementsPerThread[2])) : 44 | 1 45 | ]; 46 | } 47 | 48 | export function flatDispatchLayout(shape: number[]) { 49 | return {x: shape.map((d, i) => i)}; 50 | } -------------------------------------------------------------------------------- /tfjs-webgpu/src/webgpu_util_test.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * @license 3 | * Copyright 2019 Google LLC. All Rights Reserved. 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ============================================================================= 16 | */ 17 | 18 | import {computeDispatch} from './webgpu_util'; 19 | 20 | describe('webgpu util', () => { 21 | it('computeDispatch returns dispatch dimensions based on layout of ' + 22 | 'output dimensions and workGroupSize.', 23 | () => { 24 | const layout = {x: [0], y: [1], z: [2, 3]}; 25 | const outputShape = [1, 2, 3, 2]; 26 | 27 | const workGroupSize = [2, 2, 1] as [number, number, number]; 28 | 29 | const dispatch = computeDispatch(layout, outputShape, workGroupSize); 30 | expect(dispatch).toEqual([1, 1, 6]); 31 | }); 32 | 33 | it('computeDispatch returns dispatch dimensions based on layout of ' + 34 | 'output dimensions, workGroupSize, and elementsPerThread.', 35 | () => { 36 | const layout = {x: [0], y: [1], z: [2, 3]}; 37 | const outputShape = [4, 8, 12, 2]; 38 | 39 | const workGroupSize = [2, 1, 1] as [number, number, number]; 40 | const elementsPerThread = [2, 2, 3] as [number, number, number]; 41 | 42 | const dispatch = computeDispatch( 43 | layout, outputShape, workGroupSize, elementsPerThread); 44 | expect(dispatch).toEqual([1, 4, 8]); 45 | }); 46 | }); 47 | -------------------------------------------------------------------------------- /tfjs-webgpu/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tsconfig", 3 | "include": [ 4 | "src/" 5 | ], 6 | "exclude": [ 7 | "node_modules/" 8 | ], 9 | "compilerOptions": { 10 | "outDir": "./dist" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /tfjs-webgpu/tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "../tslint.json" 3 | } 4 | -------------------------------------------------------------------------------- /tfjs.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "name": "Common", 5 | "path": "." 6 | }, 7 | { 8 | "name": "Core", 9 | "path": "tfjs-core" 10 | }, 11 | { 12 | "name": "Wasm", 13 | "path": "tfjs-backend-wasm" 14 | }, 15 | { 16 | "name": "NodeGL", 17 | "path": "tfjs-backend-nodegl" 18 | }, 19 | { 20 | "name": "ReactNative", 21 | "path": "tfjs-react-native" 22 | }, 23 | { 24 | "name": "WebGPU", 25 | "path": "tfjs-webgpu" 26 | } 27 | ], 28 | "settings": { 29 | "search.exclude": { 30 | "**/node_modules": true, 31 | "**/coverage/": true, 32 | "**/dist/": true, 33 | "**/yarn.lock": true, 34 | "**/.rpt2_cache/": true, 35 | "**/.yalc/**/*": true 36 | }, 37 | "tslint.configFile": "tslint.json", 38 | "files.trimTrailingWhitespace": true, 39 | "editor.tabSize": 2, 40 | "editor.insertSpaces": true, 41 | "[typescript]": { 42 | "editor.formatOnSave": true 43 | }, 44 | "[javascript]": { 45 | "editor.formatOnSave": true 46 | }, 47 | "[cpp]": { 48 | "editor.formatOnSave": true 49 | }, 50 | "editor.defaultFormatter": "xaver.clang-format", 51 | "editor.rulers": [ 52 | 80 53 | ], 54 | "clang-format.style": "Google", 55 | "files.insertFinalNewline": true, 56 | "editor.detectIndentation": false, 57 | "editor.wrappingIndent": "none", 58 | "typescript.tsdk": "./node_modules/typescript/lib", 59 | "clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format" 60 | }, 61 | "extensions": { 62 | "recommendations": [ 63 | // Formats typescript, javascript and c++ code. 64 | "xaver.clang-format", 65 | // Lints typescipt code. 66 | "ms-vscode.vscode-typescript-tslint-plugin" 67 | ] 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "moduleResolution": "node", 5 | "noImplicitAny": true, 6 | "sourceMap": true, 7 | "removeComments": false, 8 | "preserveConstEnums": true, 9 | "declaration": true, 10 | "target": "es5", 11 | "lib": [ 12 | "es2015", 13 | "dom" 14 | ], 15 | "outDir": "./dist", 16 | "noUnusedLocals": true, 17 | "noImplicitReturns": true, 18 | "noImplicitThis": true, 19 | "alwaysStrict": true, 20 | "noUnusedParameters": false, 21 | "pretty": true, 22 | "noFallthroughCasesInSwitch": true, 23 | "allowUnreachableCode": false 24 | }, 25 | "include": [ 26 | "src/" 27 | ], 28 | "exclude": [ 29 | "node_modules/" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["tslint-no-circular-imports"], 3 | "rules": { 4 | "array-type": [true, "array-simple"], 5 | "arrow-return-shorthand": true, 6 | "ban": [true, 7 | ["fit"], 8 | ["fdescribe"], 9 | ["xit"], 10 | ["xdescribe"], 11 | ["fitAsync"], 12 | ["xitAsync"], 13 | ["fitFakeAsync"], 14 | ["xitFakeAsync"] 15 | ], 16 | "ban-types": [true, 17 | ["Object", "Use {} instead."], 18 | ["String", "Use 'string' instead."], 19 | ["Number", "Use 'number' instead."], 20 | ["Boolean", "Use 'boolean' instead."] 21 | ], 22 | "no-return-await": true, 23 | "class-name": true, 24 | "curly": true, 25 | "interface-name": [true, "never-prefix"], 26 | "jsdoc-format": true, 27 | "forin": false, 28 | "label-position": true, 29 | "max-line-length": { 30 | "options": {"limit": 80, "ignore-pattern": "^import |^export "} 31 | }, 32 | "new-parens": true, 33 | "no-angle-bracket-type-assertion": true, 34 | "no-any": true, 35 | "no-construct": true, 36 | "no-consecutive-blank-lines": true, 37 | "no-debugger": true, 38 | "no-default-export": true, 39 | "no-inferrable-types": true, 40 | "no-namespace": [true, "allow-declarations"], 41 | "no-reference": true, 42 | "no-require-imports": true, 43 | "no-string-throw": true, 44 | "no-unused-expression": true, 45 | "no-var-keyword": true, 46 | "object-literal-shorthand": true, 47 | "only-arrow-functions": [true, "allow-declarations", "allow-named-functions"], 48 | "prefer-const": true, 49 | "quotemark": [true, "single"], 50 | "radix": true, 51 | "restrict-plus-operands": true, 52 | "semicolon": [true, "always", "ignore-bound-class-methods"], 53 | "switch-default": true, 54 | "triple-equals": [true, "allow-null-check"], 55 | "use-isnan": true, 56 | "use-default-type-parameter": true, 57 | "variable-name": [ 58 | true, 59 | "check-format", 60 | "ban-keywords", 61 | "allow-leading-underscore", 62 | "allow-trailing-underscore" 63 | ] 64 | } 65 | } 66 | --------------------------------------------------------------------------------