├── .clang-format ├── .gitattributes ├── .github └── workflows │ ├── github-actions-cpp-check.yml │ └── github-actions-python-check.yml ├── .gitignore ├── .pylintrc ├── .vscode └── c_cpp_properties.json ├── BUILD.md ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── KERNEL_COOKBOOK.md ├── LICENSE ├── NOTICE.txt ├── README.md ├── SECURITY.md ├── __init__.py ├── build.py ├── cmake └── dependencies.cmake ├── format.ps1 ├── format.sh ├── generate_op_defs_core.py ├── pipelines ├── ADOHelper.ps1 ├── SubmitFilesToBlobStorage.ps1 ├── build.yml ├── create_agent_summary.ps1 ├── create_report_email.ps1 ├── create_test_env.ps1 ├── create_test_env.sh ├── create_test_matrix.ps1 ├── create_test_summary_json.ps1 ├── create_test_summary_xml.ps1 ├── pipeline.yml ├── report.yml ├── setup_agent_linux.yml ├── setup_agent_windows.yml └── test.yml ├── pyproject.toml ├── requirements.txt ├── test ├── README.md ├── __init__.py ├── c │ └── c_api_tests.cc ├── ops │ ├── __init__.py │ ├── adagrad_test.py │ ├── aggregate_ops_test.py │ ├── argmax_op_test.py │ ├── array_ops_test.py │ ├── batch_matmul_op_test.py │ ├── batchtospace_op_test.py │ ├── bias_op_base.py │ ├── bias_op_deterministic_test.py │ ├── bias_op_test.py │ ├── bitcast_op_test.py │ ├── cast_op_test.py │ ├── concat_op_test.py │ ├── constant_op_test.py │ ├── conv1d_test.py │ ├── conv1d_transpose_test.py │ ├── conv2d_backprop_filter_grad_test.py │ ├── conv2d_transpose_test.py │ ├── conv3d_backprop_filter_v2_grad_test.py │ ├── conv3d_transpose_test.py │ ├── conv_ops_3d_test.py │ ├── conv_ops_test.py │ ├── cross_device_ops_test.py │ ├── cross_grad_test.py │ ├── cudnn_recurrent_test.py │ ├── cwise_ops_binary_test.py │ ├── cwise_ops_test.py │ ├── cwise_ops_unary_test.py │ ├── data_format_op_test.py │ ├── deepcopy_op_test.py │ ├── depthtospace_op_test.py │ ├── depthwise_conv_op_test.py │ ├── diag_op_test.py │ ├── dml_get_adapter_name.py │ ├── dml_test_util.py │ ├── dynamic_stitch_op_test.py │ ├── empty_op_test.py │ ├── extract_image_patches_grad_test.py │ ├── extract_image_patches_op_test.py │ ├── extract_volume_patches_grad_test.py │ ├── extract_volume_patches_op_test.py │ ├── fill_op_test.py │ ├── ftrl_test.py │ ├── gather_nd_op_test.py │ ├── gather_op_test.py │ ├── image_grad_deterministic_test.py │ ├── image_grad_test.py │ ├── image_grad_test_base.py │ ├── image_ops_test.py │ ├── in_topk_op_test.py │ ├── inplace_ops_test.py │ ├── l2loss_test.py │ ├── lrn_op_test.py │ ├── math_ops_test.py │ ├── matmul_op_test.py │ ├── matrix_band_part_op_test.py │ ├── nn_batchnorm_test.py │ ├── nn_fused_batchnorm_test.py │ ├── numerics_test.py │ ├── one_hot_op_test.py │ ├── ones_like_test.py │ ├── pad_op_test.py │ ├── pool_test.py │ ├── pooling_ops_3d_test.py │ ├── pooling_ops_test.py │ ├── proximal_adagrad_test.py │ ├── ragged_one_hot_op_test.py │ ├── ragged_reduce_op_test.py │ ├── random_ops_test.py │ ├── range_op_test.py │ ├── reduce_join_op_test.py │ ├── reduce_test.py │ ├── reduction_ops_test.py │ ├── reduction_ops_test_big.py │ ├── relu_op_test.py │ ├── resource_variable_ops_test.py │ ├── reverse_sequence_op_test.py │ ├── rnn_grad_test.py │ ├── roll_op_test.py │ ├── scan_ops_test.py │ ├── scatter_nd_ops_test.py │ ├── segment_reduction_ops_test.py │ ├── slice_op_test.py │ ├── spacetobatch_op_test.py │ ├── spacetodepth_op_test.py │ ├── sparse_xent_op_test.py │ ├── sparse_xent_op_test_base.py │ ├── split_op_test.py │ ├── stack_op_test.py │ ├── strings_reduce_join_op_test.py │ ├── testdata │ │ ├── bad_huffman.jpg │ │ ├── cat_q20.jpg │ │ ├── cat_q72.jpg │ │ ├── cat_q95.jpg │ │ ├── checkerboard1.png │ │ ├── checkerboard2.png │ │ ├── checkerboard3.png │ │ ├── corrupt.jpg │ │ ├── corrupt34_2.jpg │ │ ├── corrupt34_3.jpg │ │ ├── corrupt34_4.jpg │ │ ├── grayscale_small.bmp │ │ ├── grayscale_small_3channels.bmp │ │ ├── grayscale_small_4channels.bmp │ │ ├── jpeg_merge_test1.jpg │ │ ├── jpeg_merge_test1_cmyk.jpg │ │ ├── lena.bmp │ │ ├── lena.gif │ │ ├── lena_gray.png │ │ ├── lena_palette.png │ │ ├── lena_palette_trns.png │ │ ├── lena_rgba.png │ │ ├── medium.jpg │ │ ├── optimized.gif │ │ ├── palette_only.png │ │ ├── pendulum_sm.gif │ │ ├── pendulum_sm_frame0.png │ │ ├── pendulum_sm_frame1.png │ │ ├── pendulum_sm_frame2.png │ │ ├── red_black.gif │ │ ├── rgb_small.bmp │ │ ├── rgb_small_255.bmp │ │ ├── rgba_small.bmp │ │ ├── rgba_small_255.bmp │ │ ├── scan.gif │ │ ├── small.jpg │ │ └── squares.gif │ ├── topk_op_test.py │ ├── training_arrays_test.py │ ├── training_dataset_test.py │ ├── training_eager_test.py │ ├── training_generator_test.py │ ├── training_gpu_test.py │ ├── training_integration_test.py │ ├── training_ops_test.py │ ├── training_test.py │ ├── training_util_test.py │ ├── training_utils_v1_test.py │ ├── transpose_op_test.py │ ├── unstack_op_test.py │ ├── where_op_test.py │ ├── xent_op_test.py │ ├── xent_op_test_base.py │ └── zeros_like_test.py ├── plugin │ ├── __init__.py │ ├── dml_multiple_devices_test.py │ ├── dml_visible_devices_test.py │ └── profiler_test.py ├── test.py ├── tests.json └── tests_schema.json ├── tfdml.def ├── tfdml ├── core │ ├── dml_adapter.cc │ ├── dml_adapter.h │ ├── dml_adapter_heuristics.h │ ├── dml_adapter_impl.cc │ ├── dml_adapter_impl.h │ ├── dml_bfc_allocator.cc │ ├── dml_bfc_allocator.h │ ├── dml_buffer.cc │ ├── dml_buffer.h │ ├── dml_buffer_region.cc │ ├── dml_buffer_region.h │ ├── dml_command_allocator_ring.h │ ├── dml_command_list.cc │ ├── dml_command_list.h │ ├── dml_command_queue.cc │ ├── dml_command_queue.h │ ├── dml_common.h │ ├── dml_descriptor_bfc_allocator.cc │ ├── dml_descriptor_bfc_allocator.h │ ├── dml_descriptor_heap_allocator.cc │ ├── dml_descriptor_heap_allocator.h │ ├── dml_descriptor_pool.cc │ ├── dml_descriptor_pool.h │ ├── dml_device.cc │ ├── dml_device.h │ ├── dml_device_cache.cc │ ├── dml_device_cache.h │ ├── dml_device_context.cc │ ├── dml_device_context.h │ ├── dml_device_manager.cc │ ├── dml_device_manager.h │ ├── dml_device_state.cc │ ├── dml_device_state.h │ ├── dml_dso_loader.cc │ ├── dml_dso_loader.h │ ├── dml_error_handling.cc │ ├── dml_error_handling.h │ ├── dml_event_queue.cc │ ├── dml_event_queue.h │ ├── dml_execution_context.cc │ ├── dml_execution_context.h │ ├── dml_gpu_event.h │ ├── dml_guids.cc │ ├── dml_guids.h │ ├── dml_heap_allocator.cc │ ├── dml_heap_allocator.h │ ├── dml_kernel_context.cc │ ├── dml_kernel_context.h │ ├── dml_kernel_definition.h │ ├── dml_kernel_key.cc │ ├── dml_kernel_key.h │ ├── dml_kernel_manager.cc │ ├── dml_kernel_manager.h │ ├── dml_kernel_wrapper.cc │ ├── dml_kernel_wrapper.h │ ├── dml_operator_helper.cc │ ├── dml_operator_helper.h │ ├── dml_ops_common.cc │ ├── dml_ops_common.h │ ├── dml_pooled_heap.cc │ ├── dml_pooled_heap.h │ ├── dml_readback_heap.cc │ ├── dml_readback_heap.h │ ├── dml_tagged_pointer.cc │ ├── dml_tagged_pointer.h │ ├── dml_tensor_desc.cc │ ├── dml_tensor_desc.h │ ├── dml_tracing.cc │ ├── dml_tracing.h │ ├── dml_upload_heap.cc │ ├── dml_upload_heap.h │ ├── dml_util.cc │ └── dml_util.h ├── kernels │ ├── dml_addn_op.cc │ ├── dml_assign_variable_op.cc │ ├── dml_batch_norm_ops.cc │ ├── dml_batch_to_space_op.cc │ ├── dml_bias_add_op.cc │ ├── dml_bitcast_op.cc │ ├── dml_broadcast_to_op.cc │ ├── dml_cast_op.cc │ ├── dml_check_numerics_op.cc │ ├── dml_concat_op.cc │ ├── dml_conv_ops.cc │ ├── dml_crop_and_resize_grad_boxes_op.cc │ ├── dml_crop_and_resize_grad_image_op.cc │ ├── dml_crop_and_resize_op.cc │ ├── dml_cross_op.cc │ ├── dml_cudnn_rnn_ops.cc │ ├── dml_cwise_ops.cc │ ├── dml_data_format_dim_map.cc │ ├── dml_data_format_vec_permute.cc │ ├── dml_deepcopy_op.cc │ ├── dml_diag_op.cc │ ├── dml_diag_part_op.cc │ ├── dml_dynamic_stitch_op.cc │ ├── dml_empty_op.cc │ ├── dml_extract_image_patches_op.cc │ ├── dml_extract_patches_helpers.cc │ ├── dml_extract_patches_helpers.h │ ├── dml_extract_volume_patches_op.cc │ ├── dml_fill_op.cc │ ├── dml_gather_nd_op.cc │ ├── dml_gather_op.cc │ ├── dml_gru_ops.cc │ ├── dml_image_ops.cc │ ├── dml_in_topk_op.cc │ ├── dml_inplace_ops.cc │ ├── dml_l2loss_op.cc │ ├── dml_lrn_ops.cc │ ├── dml_lstm_helpers.h │ ├── dml_lstm_ops.cc │ ├── dml_matmul_op.cc │ ├── dml_matrix_band_part_ops.cc │ ├── dml_matrix_diag_helpers.cc │ ├── dml_matrix_diag_helpers.h │ ├── dml_matrix_diag_ops.cc │ ├── dml_matrix_diag_part_ops.cc │ ├── dml_matrix_set_diag_ops.cc │ ├── dml_mirror_pad_grad_op.cc │ ├── dml_one_hot_op.cc │ ├── dml_ones_like_op.cc │ ├── dml_pack_op.cc │ ├── dml_pad_op.cc │ ├── dml_parallel_concat_ops.cc │ ├── dml_pooling_ops.cc │ ├── dml_random_ops.cc │ ├── dml_range_op.cc │ ├── dml_reduce_ops.cc │ ├── dml_relu_ops.cc │ ├── dml_resize_grad_ops.cc │ ├── dml_resize_op.cc │ ├── dml_reverse_op.cc │ ├── dml_reverse_sequence_op.cc │ ├── dml_roll_op.cc │ ├── dml_scan_ops.cc │ ├── dml_scatter_nd_ops.cc │ ├── dml_scatter_ops.cc │ ├── dml_segment_reduction_ops.cc │ ├── dml_select_op.cc │ ├── dml_slice_op.cc │ ├── dml_snapshot_op.cc │ ├── dml_space_depth_ops.cc │ ├── dml_space_to_batch_op.cc │ ├── dml_sparse_xent_op.cc │ ├── dml_split_op.cc │ ├── dml_strided_slice_op.cc │ ├── dml_swapping_ops.cc │ ├── dml_tile_op.cc │ ├── dml_topk_op.cc │ ├── dml_training_ops.cc │ ├── dml_transpose_op.cc │ ├── dml_unpack_op.cc │ ├── dml_where_op.cc │ ├── dml_xent_op.cc │ ├── dml_zeros_like_op.cc │ └── pch.h ├── optimizer │ ├── byte_order.h │ ├── device_name_utils.cc │ ├── device_name_utils.h │ ├── device_type.cc │ ├── device_type.h │ ├── graph.h │ ├── graph_optimizer.cc │ ├── graph_optimizer.h │ ├── graph_properties.cc │ ├── graph_properties.h │ ├── graph_view.cc │ ├── graph_view.h │ ├── graph_view_internal.h │ ├── grappler_item.cc │ ├── grappler_item.h │ ├── hash.cc │ ├── hash.h │ ├── map_utils.h │ ├── op_registry.cc │ ├── op_registry.h │ ├── op_types.cc │ ├── op_types.h │ ├── optimizer_runner.cc │ ├── optimizer_runner.h │ ├── perm_utils.cc │ ├── perm_utils.h │ ├── proto_buffer_helpers.cc │ ├── proto_buffer_helpers.h │ ├── remapper.cc │ ├── remapper.h │ ├── tensor_id.cc │ ├── tensor_id.h │ ├── tensor_proto_util.cc │ ├── tensor_proto_util.h │ ├── transpose_remover.cc │ ├── transpose_remover.h │ ├── utils.cc │ └── utils.h ├── plugin │ ├── plugin_device.cc │ ├── plugin_kernel.cc │ ├── plugin_optimizer.cc │ ├── plugin_profiler.cc │ └── plugin_version.h ├── runtime_adapter │ ├── allocator.cc │ ├── allocator.h │ ├── allocator_retry.cc │ ├── allocator_retry.h │ ├── attribute.h │ ├── bcast.cc │ ├── bcast.h │ ├── bfc_allocator.cc │ ├── bfc_allocator.h │ ├── determinism.cc │ ├── determinism.h │ ├── device.cc │ ├── device.h │ ├── env.cc │ ├── env.h │ ├── env_var.cc │ ├── env_var.h │ ├── fused_eigen_output_kernels.cc │ ├── fused_eigen_output_kernels.h │ ├── guarded_philox_random.cc │ ├── guarded_philox_random.h │ ├── kernel_shape_util.cc │ ├── kernel_shape_util.h │ ├── macros.h │ ├── matmul_bcast.h │ ├── mirror_pad_mode.cc │ ├── mirror_pad_mode.h │ ├── node_def.h │ ├── numbers.cc │ ├── numbers.h │ ├── op_defs.h │ ├── op_defs_core.cc │ ├── op_defs_core.h │ ├── op_defs_dml.h │ ├── op_kernel.h │ ├── op_kernel_construction.cc │ ├── op_kernel_construction.h │ ├── op_kernel_context.cc │ ├── op_kernel_context.h │ ├── padding.cc │ ├── padding.h │ ├── path.cc │ ├── path.h │ ├── philox_random.h │ ├── random_ops_util.h │ ├── rng_alg.h │ ├── stateless_random_ops.cc │ ├── stateless_random_ops.h │ ├── status.cc │ ├── status.h │ ├── statusor.h │ ├── stream.h │ ├── tensor.cc │ ├── tensor.h │ ├── tensor_format.cc │ ├── tensor_format.h │ ├── tensor_shape.cc │ ├── tensor_shape.h │ ├── tensor_shape_utils.cc │ ├── tensor_shape_utils.h │ ├── tensor_types.h │ ├── training_op_helpers.cc │ ├── training_op_helpers.h │ ├── types.cc │ ├── types.h │ ├── variable_lock.cc │ ├── variable_lock.h │ ├── wide_char.h │ ├── xplane_builder.cc │ └── xplane_builder.h ├── tfdml.natvis └── wheel │ ├── MANIFEST.in │ ├── README │ ├── build_wheel.py │ ├── setup.py │ └── template_init.py └── third_party └── microsofttelemetry.h /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | # 4 | # To reformat all files using cmd.exe: 5 | # for /r %t in (*.cc *.h) do clang-format.exe -i -style=file "%t" 6 | # 7 | BasedOnStyle: Microsoft 8 | DerivePointerAlignment: false 9 | PointerAlignment: Left 10 | UseTab: Never 11 | IndentWidth: 4 12 | ColumnLimit: 80 13 | AlignAfterOpenBracket: AlwaysBreak 14 | BinPackArguments: false 15 | BinPackParameters: false 16 | AllowAllArgumentsOnNextLine: false 17 | AllowAllParametersOfDeclarationOnNextLine: false 18 | AllowAllConstructorInitializersOnNextLine: false 19 | BreakConstructorInitializers: BeforeColon 20 | AllowShortLambdasOnASingleLine: true 21 | AllowShortFunctionsOnASingleLine: true 22 | AllowShortCaseLabelsOnASingleLine: true 23 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 24 | BreakBeforeBraces: Custom 25 | BraceWrapping: 26 | BeforeLambdaBody: true 27 | AlwaysBreakTemplateDeclarations: Yes 28 | AllowShortIfStatementsOnASingleLine: WithoutElse 29 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.sh eol=lf -------------------------------------------------------------------------------- /.github/workflows/github-actions-cpp-check.yml: -------------------------------------------------------------------------------- 1 | name: C++ Code Check 2 | on: 3 | pull_request: 4 | types: [opened, synchronize, edited, reopened] 5 | branches: [main, release/*] 6 | concurrency: 7 | group: cpp-check-${{ github.ref }} 8 | cancel-in-progress: true 9 | jobs: 10 | cpp-code-check: 11 | name: Check C++ Code Formatting 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Run clang-format style check for C/C++/Protobuf programs. 16 | uses: jidicula/clang-format-action@v4.6.2 17 | with: 18 | clang-format-version: '13' 19 | exclude-regex: ^(\./build/|\./third_party/).*$ -------------------------------------------------------------------------------- /.github/workflows/github-actions-python-check.yml: -------------------------------------------------------------------------------- 1 | name: Python Code Check 2 | on: 3 | pull_request: 4 | types: [opened, synchronize, edited, reopened] 5 | branches: [main, release/*] 6 | concurrency: 7 | group: python-check-${{ github.ref }} 8 | cancel-in-progress: true 9 | jobs: 10 | python-code-check: 11 | name: Check Python Code Formatting And Lint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set Up Python 3.9 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: '3.9' 19 | - run: pip install -r requirements.txt 20 | - run: "black . --check --diff --verbose" 21 | - run: pylint *.py 22 | - run: pylint test 23 | - run: pylint tfdml 24 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | max-line-length=88 3 | ignored-modules=tensorflow 4 | ignore-paths=(test/ops/|build/) 5 | good-names=tensorflow-directml-plugin -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Win32", 5 | "includePath": [ 6 | "${workspaceFolder}", 7 | "${workspaceFolder}/build/_deps/abseil-src", 8 | "${workspaceFolder}/build/_deps/tensorflow_whl-src/tensorflow/include", 9 | "${workspaceFolder}/build/_deps/directx_headers-src/include", 10 | "${workspaceFolder}/build/_deps/directml_redist-src/include", 11 | "${workspaceFolder}/build/_deps/directmlx-src", 12 | "${workspaceFolder}/build/_deps/protobuf-src/src", 13 | "${workspaceFolder}/build/_deps/tensorflow_whl-build/proto", 14 | "${workspaceFolder}/build/_deps/pix_event_runtime-src/Include" 15 | ], 16 | "defines": [ 17 | "_DEBUG", 18 | "UNICODE", 19 | "_UNICODE", 20 | "DML_TARGET_VERSION=0x5000", 21 | "DMLX_USE_ABSEIL" 22 | ], 23 | "cppStandard": "c++14" 24 | } 25 | ], 26 | "version": 4 27 | } -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | 16 | ## Contributing Code 17 | 18 | We encourage contributions such as bug fixes, DirectML kernels, or general performance and stability improvements. For more substantial changes, we ask that you reach out first with GitHub issues or by contacting us directly at askdirectml@microsoft.com. This project's focus is currently on improving functional and performance parity with the official CUDA backend, so unrelated changes are less likely to be approved. 19 | 20 | Before creating a pull request, make sure to format your change in accordance with TensorFlow's coding style (see below). 21 | 22 | ### C++ coding style 23 | 24 | Changes to TensorFlow C++ code should conform to 25 | [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). 26 | 27 | Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: 28 | 29 | ```bash 30 | apt-get install -y clang-tidy 31 | ``` 32 | 33 | You can check a C/C++ file by doing: 34 | 35 | 36 | ```bash 37 | clang-format --style=google > /tmp/my_cc_file.cc 38 | diff /tmp/my_cc_file.cc 39 | ``` 40 | 41 | ### Formatting all files 42 | 43 | To automatically format all files in the repository and make sure that they conform to the guidelines, run `format.sh` on Ubuntu or `format.ps1` on Windows. 44 | 45 | ### Adding a new kernel 46 | 47 | To add a new kernel, follow the steps outlined in the [Kernel Cookbook](KERNEL_COOKBOOK.md) -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | TensorFlow-DirectML-Plugin 2 | Copyright (c) Microsoft Corporation 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/__init__.py -------------------------------------------------------------------------------- /format.ps1: -------------------------------------------------------------------------------- 1 | function ShouldFormat($FileName) 2 | { 3 | return $FileName -clike "*.h" ` 4 | -or $FileName -clike "*.c" ` 5 | -or $FileName -clike "*.cpp" ` 6 | -or $FileName -clike "*.cc" 7 | } 8 | 9 | function FormatFiles($Root) 10 | { 11 | if (Test-Path -Path $Root -PathType Leaf) 12 | { 13 | if (ShouldFormat($Root)) 14 | { 15 | Write-Output "Formatting $($Root)" 16 | clang-format -i --style=file $File.Path 17 | } 18 | continue 19 | } 20 | 21 | $Folder = (New-Object -Com Scripting.FileSystemObject).GetFolder($Root) 22 | $Files = $Folder.Files | Where-Object { ShouldFormat($_.Name) } 23 | foreach ($File in $Files) 24 | { 25 | Write-Output "Formatting $($File.Path)" 26 | clang-format -i --style=file $File.Path 27 | } 28 | foreach ($SubFolder in $Folder.Subfolders) 29 | { 30 | $CurrentItem = Get-Item $SubFolder.Path -ErrorAction SilentlyContinue 31 | if ($CurrentItem -and !$CurrentItem.Attributes.ToString().Contains("ReparsePoint")) 32 | { 33 | FormatFiles($SubFolder.Path) 34 | } 35 | } 36 | } 37 | 38 | $ErrorActionPreference = "Stop" 39 | 40 | foreach ($Item in Get-ChildItem $PSScriptRoot) 41 | { 42 | if ($Item.Name -ne "build" -and $Item.Name -ne "third_party") 43 | { 44 | FormatFiles($Item.FullName) 45 | } 46 | } 47 | 48 | black . 49 | pylint build.py 50 | pylint generate_op_defs_core.py 51 | pylint test/plugin 52 | pylint tfdml 53 | 54 | Write-Output "Done!" -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for f in $(find . -path ./build -prune -o -path ./third_party -prune -o -name '*.h' -or -name '*.c' -or -name '*.cpp' -or -name '*.cc'); do 4 | if [ ${f} != './build' -a ${f} != './third_party' ]; then 5 | echo "Formatting ${f}" 6 | clang-format -i --style=file ${f} 7 | fi 8 | done 9 | 10 | black . 11 | pylint build.py 12 | pylint generate_op_defs_core.py 13 | pylint test/plugin 14 | pylint tfdml 15 | 16 | echo "Done!" 17 | -------------------------------------------------------------------------------- /pipelines/SubmitFilesToBlobStorage.ps1: -------------------------------------------------------------------------------- 1 | <# 2 | .SYNOPSIS 3 | Submit a folder containing artifacts to azure blob storage 4 | #> 5 | Param( 6 | [Parameter(Position=1,mandatory=$true)] 7 | [string]$AzureBlobLink, 8 | [Parameter(Position=2,mandatory=$true)] 9 | [string]$AzureBlobLinkToken, 10 | [Parameter(Position=3,mandatory=$true)] 11 | [string]$SourcePath, 12 | [Parameter(Position=4,mandatory=$true)] 13 | [string]$TargetPath 14 | ) 15 | $azcopyPath = "./azcopy.exe" 16 | 17 | # download azcopy 18 | if(-not (Test-Path $azcopyPath)) { 19 | $tempZipPath = "./temp.zip" 20 | try { 21 | Invoke-WebRequest -Uri "https://aka.ms/downloadazcopy-v10-windows" -OutFile $tempZipPath 22 | } 23 | catch { 24 | Write-Host "Failed to fetch azcopy from https://aka.ms/downloadazcopy-v10-windows . Make sure that the link is working correctly." 25 | } 26 | 27 | # Expand the Zip file 28 | Expand-Archive $tempZipPath . -Force 29 | 30 | # Move to $InstallPath 31 | Get-ChildItem "./azcopy*windows*\azcopy.exe" | Move-Item -Destination "." 32 | 33 | # Clean up 34 | Remove-Item $tempZipPath -Force -Confirm:$false 35 | Remove-Item "./azcopy*windows*\" -Recurse 36 | } 37 | 38 | # Call azcopy and copy to blob storage 39 | try { 40 | Start-Process $azcopyPath -ArgumentList "copy $($SourcePath) $($AzureBlobLink)$($TargetPath)$($AzureBlobLinkToken)" -Wait -NoNewWindow 41 | } 42 | catch { 43 | Write-Host "Failed to upload $($SourcePath) to $($AzureBlobLink) please ensure that the token isn't expired." 44 | } 45 | 46 | # Clean up azcopy 47 | Remove-Item $azcopyPath -------------------------------------------------------------------------------- /pipelines/create_agent_summary.ps1: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | <# 5 | .SYNOPSIS 6 | Collates agent info into a master summary. 7 | #> 8 | param 9 | ( 10 | [string]$TestArtifactsPath, 11 | [string]$OutputPath = "$TestArtifactsPath\agent_summary.json" 12 | ) 13 | 14 | $AllResults = @{} 15 | 16 | $AgentPaths = Get-ChildItem $TestArtifactsPath -Directory 17 | foreach ($AgentPath in $AgentPaths.FullName) 18 | { 19 | $Result = @{} 20 | 21 | $DxDiagPath = "$AgentPath\dxdiag.xml" 22 | if (Test-Path $DxDiagPath) 23 | { 24 | $DxDiag = ([xml](Get-Content $DxDiagPath -Raw)).DxDiag 25 | 26 | $System = $DxDiag.SystemInformation 27 | $Adapter = @($DxDiag.DisplayDevices.DisplayDevice)[0] 28 | 29 | $Result.OperatingSystem = $System.OperatingSystem -replace '.*\((.*)\)$','$1' 30 | $Result.Processor = $System.Processor 31 | $Result.SystemManufacturer = $System.SystemManufacturer 32 | $Result.SystemModel = $System.SystemModel 33 | $Result.SystemDescription = "$($Result.SystemManufacturer) - $($Result.SystemModel)" 34 | $Result.DisplayAdapter = $Adapter.CardName 35 | $Result.DisplayDriver = $Adapter.DriverVersion 36 | $Result.DisplayDriverDate = $Adapter.DriverDate 37 | $Result.DisplayDriverModel = $Adapter.DriverModel 38 | 39 | } 40 | 41 | $EnvironmentVarsPath = "$AgentPath\environment_vars.json" 42 | if (Test-Path $EnvironmentVarsPath) 43 | { 44 | $AgentVars = Get-Content $EnvironmentVarsPath | ConvertFrom-Json 45 | } 46 | 47 | $AgentName = $AgentPath | Split-Path -Leaf 48 | $AllResults.$AgentName = $Result 49 | } 50 | 51 | New-Item -ItemType File -Path $OutputPath -Force 52 | ConvertTo-Json $AllResults -Depth 8 | Out-File $OutputPath -Encoding utf8 -------------------------------------------------------------------------------- /pipelines/create_test_env.ps1: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | param 4 | ( 5 | [string]$ArtifactsDirectory = $env:SYSTEM_ARTIFACTSDIRECTORY, 6 | [Parameter(Mandatory)][string]$TestArtifactPath, 7 | [Parameter(Mandatory)][string]$TensorFlowPackage, 8 | [Parameter(Mandatory)][string]$KerasPackage, 9 | [Parameter(Mandatory)][string]$SourcesDirectory 10 | ) 11 | 12 | $ErrorActionPreference = 'Stop' 13 | 14 | $InstallDir = Join-Path ($ArtifactsDirectory | Resolve-Path) "miniconda3" 15 | $PluginPackage = (Get-ChildItem "$TestArtifactPath/tensorflow_directml_plugin*.whl").FullName 16 | $TestEnvPath = "$ArtifactsDirectory/test_env" 17 | $TestArtifact = $TestArtifactPath | Split-Path -Leaf 18 | $PyVersionMajorDotMinor = $TestArtifact -replace '.*-cp(\d)(\d)', '$1.$2' 19 | 20 | Write-Host "Installing miniconda3 to $InstallDir" 21 | $Url = 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe' 22 | $DownloadPath = "$ArtifactsDirectory/miniconda.exe" 23 | (New-Object System.Net.WebClient).DownloadFile($Url, $DownloadPath) 24 | Start-Process $DownloadPath -ArgumentList '/NoRegistry=1', '/InstallationType=JustMe', '/RegisterPython=0', '/S', "/D=$InstallDir" -Wait 25 | & "$InstallDir/shell/condabin/conda-hook.ps1" 26 | 27 | conda create --prefix $TestEnvPath python=$PyVersionMajorDotMinor -y 28 | conda activate $TestEnvPath 29 | pip install $TensorFlowPackage 30 | pip install $KerasPackage 31 | pip install tensorboard_plugin_profile 32 | pip install $PluginPackage 33 | pip install portpicker 34 | pip list 35 | 36 | $ActivateCmd = "$InstallDir/shell/condabin/conda-hook.ps1; conda activate $TestEnvPath" 37 | echo "##vso[task.setVariable variable=activateCommand;isOutput=true]$ActivateCmd" 38 | 39 | # Extract the C Library API tests to the build folder 40 | # TODO: Make available on Windows once the TensorFlow C API exports all the necessary symbols 41 | # TF #40927951 42 | ls $TestArtifactPath 43 | $ApiTestsZip = (Get-ChildItem "$TestArtifactPath/tensorflow_directml_plugin-*-c-api-tests.zip").FullName 44 | Expand-Archive "$ApiTestsZip" -Destination "$SourcesDirectory/build" -------------------------------------------------------------------------------- /pipelines/create_test_env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | # Copyright (c) Microsoft Corporation. 4 | # Licensed under the MIT License. 5 | 6 | test_artifact_path=$1 7 | tensorflow_package=$2 8 | keras_package=$3 9 | sources_directory=$4 10 | 11 | # Windows agents use the agent artifacts directory for the conda installation, but 12 | # this is slow in WSL (filesystem networking overhead). Instead the agent will use 13 | # a temp directory in the native Linux filesystem. 14 | tmp_testing_root="/tmp/tfdml_plugin_pipeline" 15 | rm -rf "$tmp_testing_root" 16 | mkdir "$tmp_testing_root" 17 | cd $tmp_testing_root 18 | 19 | # Extract the C Library API tests to the build folder 20 | api_tests_zip=$(ls $test_artifact_path/tensorflow_directml_plugin-*-c-api-tests.zip) 21 | unzip $api_tests_zip -d $sources_directory/build 22 | 23 | install_dir="$tmp_testing_root/miniconda3" 24 | plugin_package=$(ls $test_artifact_path/tensorflow_directml_plugin*.whl) 25 | test_env_path="$tmp_testing_root/test_env" 26 | test_artifact=$(basename $test_artifact_path) 27 | py_version_major_dot_minor=$(echo $test_artifact | sed -E "s/.*-cp([0-9])([0-9])/\1.\2/") 28 | 29 | echo "Installing miniconda3 to $install_dir" 30 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 31 | bash Miniconda3-latest-Linux-x86_64.sh -b -p $install_dir 32 | eval "$($install_dir/bin/conda shell.bash hook)" 33 | conda create --prefix $test_env_path python=$py_version_major_dot_minor -y 34 | 35 | conda activate $test_env_path 36 | pip install $tensorflow_package 37 | pip install $keras_package 38 | pip install tensorboard_plugin_profile 39 | pip install $plugin_package 40 | pip install portpicker 41 | pip list 42 | 43 | activate_cmd="source $install_dir/bin/activate $test_env_path" 44 | echo "##vso[task.setVariable variable=activateCommand;isOutput=true]$activate_cmd" -------------------------------------------------------------------------------- /pipelines/create_test_summary_json.ps1: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | <# 5 | .SYNOPSIS 6 | Creates a summary of all test groups across all test environments. 7 | 8 | .DESCRIPTION 9 | Each tests runs on multiple environments, where an environment is defined by the agent (hardware), 10 | build architecture, and build configuration. For example, a run may produce 4 builds 11 | (e.g. x64.release, x64.debug, x86.release, x86.debug) that are tested by 8 agents for up to 12 | 8*4 = 32 environments. If there are 2000 tests, then there may be up to 2000*32 = 64,000 results. 13 | #> 14 | param 15 | ( 16 | [string]$TestArtifactsPath, 17 | [string]$OutputPath = "$TestArtifactsPath\test_summary.json" 18 | ) 19 | 20 | $Summary = @{} 21 | 22 | $TestMatrix = (Get-Content "$TestArtifactsPath/matrix.json" -Raw) | ConvertFrom-Json 23 | 24 | foreach ($Job in $TestMatrix) 25 | { 26 | foreach ($TestConfig in $Job.agentTestConfigs) 27 | { 28 | $Build, $TensorflowPackage, $Group = $TestConfig.split(':') 29 | $ResultsPath = "$TestArtifactsPath/$($Job.agentName)/$Build/$TensorflowPackage/summary.$Group.json" 30 | 31 | if (!$Summary.ContainsKey($Group)) 32 | { 33 | $Summary[$Group] = @() 34 | } 35 | 36 | $SummaryEntry = @{ 37 | "agent" = $Job.agentName; 38 | "build" = "${Build}-${TensorflowPackage}"; 39 | "agentWasOnline" = $Job.agentStatus -eq 'online'; 40 | "agentWasEnabled" = $Job.agentEnabled; 41 | "agentHasResults" = Test-Path $ResultsPath; 42 | } 43 | 44 | if (Test-Path $ResultsPath) 45 | { 46 | $EnvSummary = (Get-Content $ResultsPath -Raw) | ConvertFrom-Json 47 | $SummaryEntry["tests_total_count"] = $EnvSummary.tests_total_count 48 | $SummaryEntry["tests_passed_count"] = $EnvSummary.tests_passed_count 49 | $SummaryEntry["tests_failed_count"] = $EnvSummary.tests_failed_count 50 | $SummaryEntry["tests_skipped_count"] = $EnvSummary.tests_skipped_count 51 | $SummaryEntry["tests_timed_out_count"] = $EnvSummary.tests_timed_out_count 52 | $SummaryEntry["start_timestamp_seconds"] = $EnvSummary.start_timestamp_seconds 53 | $SummaryEntry["end_timestamp_seconds"] = $EnvSummary.end_timestamp_seconds 54 | $SummaryEntry["duration_seconds"] = $EnvSummary.duration_seconds 55 | } 56 | 57 | $Summary[$Group] += $SummaryEntry 58 | } 59 | } 60 | 61 | New-Item -ItemType File -Path $OutputPath -Force 62 | $Summary | ConvertTo-Json -Depth 8 -Compress | Out-File $OutputPath -Encoding utf8 -------------------------------------------------------------------------------- /pipelines/setup_agent_linux.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | steps: 5 | - script: | 6 | wget --output-document="$(Build.StagingDirectory)/cmake.sh" https://github.com/Kitware/CMake/releases/download/v3.22.1/cmake-3.22.1-linux-x86_64.sh 7 | mkdir "$(Build.StagingDirectory)/cmake" 8 | bash "$(Build.StagingDirectory)/cmake.sh" --skip-license --prefix="$(Build.StagingDirectory)/cmake" 9 | echo "##vso[task.prependpath]$(Build.StagingDirectory)/cmake/bin" 10 | displayName: Install CMake 11 | workingDirectory: $(Build.StagingDirectory) 12 | target: manylinux 13 | 14 | - script: | 15 | miniconda_path="miniconda3" 16 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 17 | bash Miniconda3-latest-Linux-x86_64.sh -b -p $miniconda_path 18 | eval "$($miniconda_path/bin/conda shell.bash hook)" 19 | conda create --name build python=$(vars.pyVersionMajorDotMinor) -y 20 | activate_cmd="source $(Build.StagingDirectory)/$miniconda_path/bin/activate build" 21 | source $(Build.StagingDirectory)/$miniconda_path/bin/activate build 22 | pip install wheel 23 | echo "##vso[task.setVariable variable=activateCommand;isOutput=true]$activate_cmd" 24 | displayName: Install Miniconda 25 | name: miniconda 26 | workingDirectory: $(Build.StagingDirectory) 27 | target: manylinux 28 | 29 | - script: | 30 | sudo apt update 31 | sudo apt install ninja-build -y 32 | displayName: Install Ninja 33 | workingDirectory: $(Build.StagingDirectory) 34 | target: manylinux -------------------------------------------------------------------------------- /pipelines/setup_agent_windows.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | steps: 5 | - powershell: | 6 | $Url = 'https://github.com/Kitware/CMake/releases/download/v3.22.1/cmake-3.22.1-windows-x86_64.zip' 7 | $DownloadPath = '$(Build.StagingDirectory)/cmake.zip' 8 | (New-Object System.Net.WebClient).DownloadFile($Url, $DownloadPath) 9 | Expand-Archive $DownloadPath -DestinationPath cmake 10 | echo "##vso[task.prependpath]$InstallDir/cmake-3.22.1-windows-x86_64/bin" 11 | displayName: Install CMake 12 | workingDirectory: $(Build.StagingDirectory) 13 | 14 | - powershell: | 15 | $Url = 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe' 16 | $DownloadPath = '$(Build.StagingDirectory)/miniconda.exe' 17 | (New-Object System.Net.WebClient).DownloadFile($Url, $DownloadPath) 18 | $InstallDir = '$(Build.StagingDirectory)\miniconda3' 19 | Start-Process $DownloadPath -ArgumentList '/NoRegistry=1', '/InstallationType=JustMe', '/RegisterPython=0', '/S', "/D=$InstallDir" -Wait 20 | & "$InstallDir/shell/condabin/conda-hook.ps1" 21 | conda create --name build python=$(vars.pyVersionMajorDotMinor) -y 22 | $ActivateCmd = "$InstallDir/shell/condabin/conda-hook.ps1; conda activate build" 23 | Invoke-Expression "$InstallDir/shell/condabin/conda-hook.ps1; conda activate build" 24 | pip install wheel vswhere 25 | Write-Host "##vso[task.setVariable variable=activateCommand;isOutput=true]$ActivateCmd" 26 | displayName: Install Miniconda 27 | name: miniconda 28 | workingDirectory: $(Build.StagingDirectory) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | extend-exclude = '(^/build/|^/test/ops/)' 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==22.6.0 2 | pylint==2.14.5 3 | tensorboard_plugin_profile==2.8.0 4 | tensorflow-cpu==2.12.0 5 | vswhere==1.4.0 ; sys_platform == 'win32' 6 | portpicker==1.5.2 7 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/__init__.py -------------------------------------------------------------------------------- /test/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/__init__.py -------------------------------------------------------------------------------- /test/ops/bias_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functional tests for BiasAdd.""" 16 | 17 | import bias_op_base 18 | from tensorflow.python.platform import test 19 | 20 | BiasAddTest = bias_op_base.BiasAddTestBase 21 | 22 | if __name__ == "__main__": 23 | test.main() -------------------------------------------------------------------------------- /test/ops/cross_grad_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.ops.nn_ops.Cross.""" 16 | 17 | from tensorflow.python.framework import test_util 18 | from tensorflow.python.ops import array_ops 19 | from tensorflow.python.ops import gradient_checker 20 | from tensorflow.python.ops import math_ops 21 | from tensorflow.python.platform import test 22 | 23 | 24 | class CrossOpTest(test.TestCase): 25 | 26 | @test_util.run_deprecated_v1 27 | def testGradientRandomValues(self): 28 | with self.cached_session(): 29 | us = [2, 3] 30 | u = array_ops.reshape( 31 | [0.854, -0.616, 0.767, 0.725, -0.927, 0.159], shape=us) 32 | v = array_ops.reshape( 33 | [-0.522, 0.755, 0.407, -0.652, 0.241, 0.247], shape=us) 34 | s = math_ops.cross(u, v) 35 | jacob_u, jacob_v = gradient_checker.compute_gradient([u, v], [us, us], s, 36 | us) 37 | 38 | self.assertAllClose(jacob_u[0], jacob_u[1], rtol=1e-3, atol=1e-3) 39 | self.assertAllClose(jacob_v[0], jacob_v[1], rtol=1e-3, atol=1e-3) 40 | 41 | 42 | if __name__ == "__main__": 43 | test.main() 44 | -------------------------------------------------------------------------------- /test/ops/deepcopy_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.ops.tf.deepcopy.""" 17 | 18 | import numpy as np 19 | 20 | from tensorflow.python.framework import test_util 21 | from tensorflow.python.ops import gen_array_ops 22 | from tensorflow.python.ops import math_ops 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.platform import test 25 | 26 | class DeepcopyTest(test.TestCase): 27 | 28 | @test_util.run_in_graph_and_eager_modes 29 | def testDeepcopy(self): 30 | with self.cached_session(): 31 | for type in [np.float32, np.int64]: 32 | x = np.array([[0, -1, 2, -3, 4], [-5, 6, -7, 8, -9]]).astype(type) 33 | x = ops.convert_to_tensor(x) 34 | y = gen_array_ops.DeepCopy(x=x) 35 | x = math_ops.abs(x) 36 | self.assertNotAllEqual(x, y) 37 | 38 | if __name__ == "__main__": 39 | test.main() -------------------------------------------------------------------------------- /test/ops/dml_get_adapter_name.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper that outputs the logs of DirectML device creation 3 | """ 4 | 5 | from tensorflow.python.client import device_lib 6 | 7 | if __name__ == "__main__": 8 | device_lib.list_local_devices() 9 | -------------------------------------------------------------------------------- /test/ops/dml_test_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers used by the tensorflow-directml-plugin python tests 3 | """ 4 | import re 5 | 6 | 7 | def should_skip_test(pattern, args): 8 | """ 9 | Returns whether the test should be skipped based on a regex that should match the 10 | adapter name 11 | """ 12 | for arg in args: 13 | if re.match(pattern, arg) is not None: 14 | return True 15 | return False 16 | -------------------------------------------------------------------------------- /test/ops/empty_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.ops.tf.empty.""" 17 | 18 | import numpy as np 19 | 20 | from tensorflow.python.framework import test_util 21 | from tensorflow.python.framework import ops 22 | from tensorflow.python.ops import gen_array_ops 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.platform import test 25 | from tensorflow.python.framework import dtypes 26 | 27 | class EmptyTest(test.TestCase): 28 | 29 | @test_util.run_in_graph_and_eager_modes 30 | def testEmpty(self): 31 | 32 | def empty_like(x, init=None): 33 | x = ops.convert_to_tensor(x) 34 | return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init) 35 | 36 | for dtype in [ 37 | dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64]: 38 | with self.cached_session(use_gpu=True): 39 | test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)] 40 | for shape in test_shapes: 41 | val = gen_array_ops.empty(shape, dtype) 42 | self.assertEqual(val.shape, shape) 43 | self.assertDTypeEqual(val, dtype) 44 | val = gen_array_ops.empty(shape, dtype, init=True) 45 | self.assertEqual(val.shape, shape) 46 | self.assertDTypeEqual(val, dtype) 47 | self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) 48 | val = empty_like(array_ops.zeros(shape, dtype)) 49 | self.assertEqual(val.shape, shape) 50 | self.assertDTypeEqual(val, dtype) 51 | val = empty_like( 52 | array_ops.zeros(shape, dtype), init=True) 53 | self.assertEqual(val.shape, shape) 54 | self.assertDTypeEqual(val, dtype) 55 | self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) 56 | 57 | 58 | if __name__ == "__main__": 59 | test.main() -------------------------------------------------------------------------------- /test/ops/fill_op_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.ops.tf.fill.""" 17 | 18 | from tensorflow.python.framework import test_util 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.platform import test 21 | 22 | class FillTest(test.TestCase): 23 | 24 | @test_util.run_in_graph_and_eager_modes 25 | def testFill(self): 26 | with self.cached_session(): 27 | fill = array_ops.fill([2, 3], 1.0) 28 | self.assertAllEqual([[1, 1, 1], [1, 1, 1]], fill) 29 | 30 | def testFillBadShape(self): 31 | with self.cached_session(): 32 | with self.assertRaisesOpError( 33 | r"dims must be a vector, got shape \[2,2\]"): 34 | array_ops.fill([[2, 3], [2, 3]], 1.0) 35 | 36 | def testFillBadValue(self): 37 | with self.cached_session(): 38 | with self.assertRaisesOpError( 39 | r"value must be a scalar, got shape \[2\]"): 40 | array_ops.fill([2, 3], [1.0, 2.0]) 41 | 42 | 43 | if __name__ == "__main__": 44 | test.main() -------------------------------------------------------------------------------- /test/ops/image_grad_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functional tests for Image Op Gradients.""" 16 | 17 | import image_grad_test_base as test_base 18 | from tensorflow.python.platform import test 19 | 20 | ResizeNearestNeighborOpTest = test_base.ResizeNearestNeighborOpTestBase 21 | ResizeBilinearOpTest = test_base.ResizeBilinearOpTestBase 22 | ResizeBicubicOpTest = test_base.ResizeBicubicOpTestBase 23 | ScaleAndTranslateOpTest = test_base.ScaleAndTranslateOpTestBase 24 | CropAndResizeOpTest = test_base.CropAndResizeOpTestBase 25 | RGBToHSVOpTest = test_base.RGBToHSVOpTestBase 26 | 27 | if __name__ == "__main__": 28 | test.main() 29 | -------------------------------------------------------------------------------- /test/ops/l2loss_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.nn.l2_loss.""" 17 | 18 | import tensorflow as tf 19 | import numpy as np 20 | 21 | from tensorflow.python.framework import constant_op 22 | from tensorflow.python.framework import dtypes 23 | from tensorflow.python.framework import test_util 24 | from tensorflow.python.platform import test 25 | 26 | class L2LossTest(test.TestCase): 27 | 28 | @test_util.run_in_graph_and_eager_modes 29 | def testL2Loss(self): 30 | with self.cached_session(): 31 | data = np.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]) 32 | expected_result = np.sum(np.square(data)) / 2 33 | for dtype in (dtypes.float16, dtypes.float32,): 34 | result = tf.nn.l2_loss(constant_op.constant(data, dtype=dtype)) 35 | self.assertAllCloseAccordingToType(result, expected_result) 36 | 37 | 38 | if __name__ == "__main__": 39 | test.main() -------------------------------------------------------------------------------- /test/ops/numerics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.ops.numerics.""" 16 | 17 | import numpy as np 18 | 19 | from tensorflow.python.framework import constant_op 20 | from tensorflow.python.framework import dtypes 21 | from tensorflow.python.framework import test_util 22 | from tensorflow.python.ops import numerics 23 | from tensorflow.python.platform import test 24 | 25 | 26 | class VerifyTensorAllFiniteTest(test.TestCase): 27 | 28 | def testVerifyTensorAllFiniteSucceeds(self): 29 | x_shape = [5, 4] 30 | x = np.random.random_sample(x_shape).astype(np.float32) 31 | with test_util.use_gpu(): 32 | t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 33 | t_verified = numerics.verify_tensor_all_finite(t, 34 | "Input is not a number.") 35 | self.assertAllClose(x, self.evaluate(t_verified)) 36 | 37 | def testVerifyTensorAllFiniteFails(self): 38 | x_shape = [5, 4] 39 | x = np.random.random_sample(x_shape).astype(np.float32) 40 | my_msg = "Input is not a number." 41 | 42 | # Test NaN. 43 | x[0] = np.nan 44 | with test_util.use_gpu(): 45 | with self.assertRaisesOpError(my_msg): 46 | t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 47 | t_verified = numerics.verify_tensor_all_finite(t, my_msg) 48 | self.evaluate(t_verified) 49 | 50 | # Test Inf. 51 | x[0] = np.inf 52 | with test_util.use_gpu(): 53 | with self.assertRaisesOpError(my_msg): 54 | t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 55 | t_verified = numerics.verify_tensor_all_finite(t, my_msg) 56 | self.evaluate(t_verified) 57 | 58 | 59 | if __name__ == "__main__": 60 | test.main() -------------------------------------------------------------------------------- /test/ops/ones_like_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.ops.tf.ones_like.""" 17 | 18 | import tensorflow as tf 19 | from tensorflow.python.framework import test_util 20 | from tensorflow.python.platform import test 21 | 22 | class OnesLikeTest(test.TestCase): 23 | 24 | @test_util.run_in_graph_and_eager_modes 25 | def testOnesLike(self): 26 | with self.cached_session(): 27 | for dtype in [tf.float16, tf.float32, tf.int64]: 28 | const_input = tf.constant([[1, 1, 3], [4, 5, 6]], dtype=dtype) 29 | result = tf.raw_ops.OnesLike(x=const_input) 30 | self.assertAllEqual([[1, 1, 1], [1, 1, 1]], result) 31 | 32 | 33 | if __name__ == "__main__": 34 | test.main() -------------------------------------------------------------------------------- /test/ops/testdata/bad_huffman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/bad_huffman.jpg -------------------------------------------------------------------------------- /test/ops/testdata/cat_q20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/cat_q20.jpg -------------------------------------------------------------------------------- /test/ops/testdata/cat_q72.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/cat_q72.jpg -------------------------------------------------------------------------------- /test/ops/testdata/cat_q95.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/cat_q95.jpg -------------------------------------------------------------------------------- /test/ops/testdata/checkerboard1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/checkerboard1.png -------------------------------------------------------------------------------- /test/ops/testdata/checkerboard2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/checkerboard2.png -------------------------------------------------------------------------------- /test/ops/testdata/checkerboard3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/checkerboard3.png -------------------------------------------------------------------------------- /test/ops/testdata/corrupt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/corrupt.jpg -------------------------------------------------------------------------------- /test/ops/testdata/corrupt34_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/corrupt34_2.jpg -------------------------------------------------------------------------------- /test/ops/testdata/corrupt34_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/corrupt34_3.jpg -------------------------------------------------------------------------------- /test/ops/testdata/corrupt34_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/corrupt34_4.jpg -------------------------------------------------------------------------------- /test/ops/testdata/grayscale_small.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/grayscale_small.bmp -------------------------------------------------------------------------------- /test/ops/testdata/grayscale_small_3channels.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/grayscale_small_3channels.bmp -------------------------------------------------------------------------------- /test/ops/testdata/grayscale_small_4channels.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/grayscale_small_4channels.bmp -------------------------------------------------------------------------------- /test/ops/testdata/jpeg_merge_test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/jpeg_merge_test1.jpg -------------------------------------------------------------------------------- /test/ops/testdata/jpeg_merge_test1_cmyk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/jpeg_merge_test1_cmyk.jpg -------------------------------------------------------------------------------- /test/ops/testdata/lena.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena.bmp -------------------------------------------------------------------------------- /test/ops/testdata/lena.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena.gif -------------------------------------------------------------------------------- /test/ops/testdata/lena_gray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena_gray.png -------------------------------------------------------------------------------- /test/ops/testdata/lena_palette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena_palette.png -------------------------------------------------------------------------------- /test/ops/testdata/lena_palette_trns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena_palette_trns.png -------------------------------------------------------------------------------- /test/ops/testdata/lena_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/lena_rgba.png -------------------------------------------------------------------------------- /test/ops/testdata/medium.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/medium.jpg -------------------------------------------------------------------------------- /test/ops/testdata/optimized.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/optimized.gif -------------------------------------------------------------------------------- /test/ops/testdata/palette_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/palette_only.png -------------------------------------------------------------------------------- /test/ops/testdata/pendulum_sm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/pendulum_sm.gif -------------------------------------------------------------------------------- /test/ops/testdata/pendulum_sm_frame0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/pendulum_sm_frame0.png -------------------------------------------------------------------------------- /test/ops/testdata/pendulum_sm_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/pendulum_sm_frame1.png -------------------------------------------------------------------------------- /test/ops/testdata/pendulum_sm_frame2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/pendulum_sm_frame2.png -------------------------------------------------------------------------------- /test/ops/testdata/red_black.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/red_black.gif -------------------------------------------------------------------------------- /test/ops/testdata/rgb_small.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/rgb_small.bmp -------------------------------------------------------------------------------- /test/ops/testdata/rgb_small_255.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/rgb_small_255.bmp -------------------------------------------------------------------------------- /test/ops/testdata/rgba_small.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/rgba_small.bmp -------------------------------------------------------------------------------- /test/ops/testdata/rgba_small_255.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/rgba_small_255.bmp -------------------------------------------------------------------------------- /test/ops/testdata/scan.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/scan.gif -------------------------------------------------------------------------------- /test/ops/testdata/small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/small.jpg -------------------------------------------------------------------------------- /test/ops/testdata/squares.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/ops/testdata/squares.gif -------------------------------------------------------------------------------- /test/ops/zeros_like_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for tensorflow.ops.tf.zeros_like.""" 17 | 18 | import tensorflow as tf 19 | from tensorflow.python.framework import test_util 20 | from tensorflow.python.platform import test 21 | 22 | class ZerosLikeTest(test.TestCase): 23 | 24 | @test_util.run_in_graph_and_eager_modes 25 | def testZerosLike(self): 26 | with self.cached_session(): 27 | for dtype in [tf.float16, tf.float32, tf.int64]: 28 | const_input = tf.constant([[1, 1, 3], [4, 5, 6]], dtype=dtype) 29 | result = tf.raw_ops.ZerosLike(x=const_input) 30 | self.assertAllEqual([[0, 0, 0], [0, 0, 0]], result) 31 | 32 | 33 | if __name__ == "__main__": 34 | test.main() -------------------------------------------------------------------------------- /test/plugin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/test/plugin/__init__.py -------------------------------------------------------------------------------- /test/plugin/dml_multiple_devices_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | """Tests computation and copies across multiple DML devices""" 19 | 20 | from absl.testing import absltest 21 | import tensorflow as tf 22 | 23 | 24 | class VisibleDevicesTest(absltest.TestCase): 25 | """Tests computation and copies across multiple DML devices""" 26 | 27 | def test(self): 28 | """Tests computation and copies across multiple DML devices""" 29 | 30 | gpu_devices = tf.config.list_physical_devices("GPU") 31 | # pylint: disable=duplicate-code 32 | dml_devices = list( 33 | filter( 34 | lambda x: tf.config.experimental.get_device_details(x)["device_name"] 35 | == "DML", 36 | gpu_devices, 37 | ) 38 | ) 39 | # pylint: enable=duplicate-code 40 | 41 | if len(dml_devices) < 2: 42 | self.skipTest( 43 | f"This test requires more than 1 DirectML GPU, but only " 44 | f"{len(dml_devices)} devices were found." 45 | ) 46 | 47 | with tf.device("GPU:0"): 48 | device1_tensor = tf.constant(1, dtype=tf.float32) 49 | 50 | with tf.device("GPU:1"): 51 | device2_tensor = tf.constant(2, dtype=tf.float32) 52 | 53 | actual = tf.math.add(device1_tensor, device2_tensor) 54 | expected = tf.constant(3, dtype=tf.float32) 55 | self.assertEqual(actual, expected) 56 | 57 | 58 | if __name__ == "__main__": 59 | absltest.main() 60 | -------------------------------------------------------------------------------- /test/plugin/dml_visible_devices_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | """Tests the DML device creation and visibility""" 19 | 20 | import os 21 | from absl.testing import absltest 22 | from absl import flags 23 | 24 | flags.DEFINE_string( 25 | "dml_visible_devices", "", "Value of DML_VISIBLE_DEVICES environment variable" 26 | ) 27 | 28 | 29 | class VisibleDevicesTest(absltest.TestCase): 30 | """Tests the visibility of DML devices""" 31 | 32 | def test(self): 33 | """Tests the visibility of DML devices""" 34 | os.environ["DML_VISIBLE_DEVICES"] = flags.FLAGS.dml_visible_devices 35 | 36 | # Tensorflow needs to be imported after the environment variable is set 37 | import tensorflow as tf # pylint:disable=import-outside-toplevel 38 | 39 | # See https://docs.microsoft.com/en-us/windows/ai/directml/gpu-faq 40 | # The value should be a comma-separated list of device IDs. 41 | # Any IDs appearing after -1 are invalid. 42 | valid_id_count = 0 43 | for device_id in flags.FLAGS.dml_visible_devices.split(","): 44 | if device_id == "-1": 45 | break 46 | valid_id_count += 1 47 | 48 | gpu_devices = tf.config.list_physical_devices("GPU") 49 | # pylint: disable=duplicate-code 50 | dml_devices = list( 51 | filter( 52 | lambda x: tf.config.experimental.get_device_details(x)["device_name"] 53 | == "DML", 54 | gpu_devices, 55 | ) 56 | ) 57 | # pylint: enable=duplicate-code 58 | 59 | # We can't guarantee the machine running this test has multiple 60 | # devices/adapters, but it must have at least one. 61 | if valid_id_count == 0: 62 | self.assertEmpty(dml_devices) 63 | else: 64 | self.assertBetween(len(dml_devices), 1, valid_id_count) 65 | 66 | 67 | if __name__ == "__main__": 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /test/tests_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "Tests Metadata", 3 | "description": "Describes test content for tensorflow-directml-plugin", 4 | "type": "object", 5 | "required": [ 6 | "groups" 7 | ], 8 | "properties": { 9 | "groups": { 10 | "type": "array", 11 | "description": "Test groups to execute", 12 | "items": { 13 | "$ref": "#/$defs/test_group" 14 | } 15 | } 16 | }, 17 | "$defs": { 18 | "test_group": { 19 | "type": "object", 20 | "properties": { 21 | "name": { 22 | "type": "string", 23 | "description": "Name of the test group" 24 | }, 25 | "tests": { 26 | "type": "array", 27 | "description": "Tests to execute in the test group", 28 | "items": { 29 | "$ref": "#/$defs/test" 30 | } 31 | }, 32 | "timeout_seconds": { 33 | "type": "number", 34 | "description": "Max number of seconds to wait for all tests in the group to complete.", 35 | "default": 300 36 | } 37 | }, 38 | "required": [ 39 | "name", 40 | "tests" 41 | ] 42 | }, 43 | "test": { 44 | "type": "object", 45 | "properties": { 46 | "type": { 47 | "type": "string", 48 | "enum": [ 49 | "py_abseil", 50 | "gtest" 51 | ], 52 | "description": "Type of test the file references." 53 | }, 54 | "file": { 55 | "type": "string", 56 | "description": "Path to test file to execute" 57 | }, 58 | "args": { 59 | "type": "array", 60 | "description": "Additional command-line arguments to use when executing the test file", 61 | "items": "string" 62 | }, 63 | "disabled": { 64 | "type": "boolean", 65 | "description": "Skip executing the test file", 66 | "default": false 67 | }, 68 | "timeout_seconds": { 69 | "type": "number", 70 | "description": "Max number of seconds to wait for the test to complete.", 71 | "default": 30 72 | } 73 | }, 74 | "required": [ 75 | "file" 76 | ] 77 | } 78 | } 79 | } -------------------------------------------------------------------------------- /tfdml.def: -------------------------------------------------------------------------------- 1 | EXPORTS 2 | SE_InitPlugin @1 3 | TF_InitKernel @2 4 | TF_InitGraph @3 5 | TF_InitProfiler @4 6 | -------------------------------------------------------------------------------- /tfdml/core/dml_adapter.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/core/dml_adapter.h" 17 | #include "tfdml/core/dml_adapter_impl.h" 18 | 19 | namespace tfdml 20 | { 21 | 22 | DmlAdapter::DmlAdapter(const DmlAdapterImpl& impl) 23 | : impl_(std::make_shared(impl)) 24 | { 25 | } 26 | 27 | DmlAdapter::~DmlAdapter() = default; 28 | 29 | DriverVersion DmlAdapter::DriverVersion() const 30 | { 31 | return impl_->DriverVersion(); 32 | } 33 | 34 | VendorID DmlAdapter::VendorID() const { return impl_->VendorID(); } 35 | uint32_t DmlAdapter::DeviceID() const { return impl_->DeviceID(); } 36 | const std::string& DmlAdapter::Name() const { return impl_->Name(); } 37 | bool DmlAdapter::IsComputeOnly() const { return impl_->IsComputeOnly(); } 38 | const LUID& DmlAdapter::AdapterLuid() const { return impl_->AdapterLuid(); } 39 | 40 | uint64_t DmlAdapter::GetTotalDedicatedMemory() const 41 | { 42 | return impl_->GetTotalDedicatedMemory(); 43 | } 44 | 45 | uint64_t DmlAdapter::GetTotalSharedMemory() const 46 | { 47 | return impl_->GetTotalSharedMemory(); 48 | } 49 | 50 | uint64_t DmlAdapter::QueryAvailableLocalMemory() const 51 | { 52 | return impl_->QueryAvailableLocalMemory(); 53 | } 54 | 55 | uint64_t DmlAdapter::QueryAvailableNonLocalMemory() const 56 | { 57 | return impl_->QueryAvailableNonLocalMemory(); 58 | } 59 | 60 | bool DmlAdapter::IsUmaAdapter() const { return impl_->IsUmaAdapter(); } 61 | 62 | const char* GetVendorName(VendorID id) 63 | { 64 | switch (id) 65 | { 66 | case VendorID::kAmd: return "AMD"; 67 | case VendorID::kNvidia: return "NVIDIA"; 68 | case VendorID::kMicrosoft: return "Microsoft"; 69 | case VendorID::kQualcomm: return "Qualcomm"; 70 | case VendorID::kIntel: return "Intel"; 71 | default: return "Unknown"; 72 | } 73 | } 74 | 75 | std::vector EnumerateAdapters() 76 | { 77 | auto impls = EnumerateAdapterImpls(); 78 | 79 | std::vector adapters; 80 | adapters.reserve(impls.size()); 81 | 82 | for (auto&& impl : impls) 83 | { 84 | adapters.push_back(DmlAdapter(std::move(impl))); 85 | } 86 | 87 | return adapters; 88 | } 89 | 90 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_adapter_heuristics.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | namespace tfdml 19 | { 20 | 21 | // We don't have an easy way to get detailed architecture-specific information 22 | // from a D3D adapter, so we set these properties to be roughly equivalent to 23 | // an NVIDIA GTX 1070 which we use as the archetype for an "average" GPU. 24 | struct DmlAdapterArchetype 25 | { 26 | // Core clock frequency, in Hz 27 | static constexpr int64_t kFrequency = 1700e6; // 1700MHz 28 | 29 | // Number of SMs/CUs/EUs 30 | static constexpr int64_t kNumCores = 15; 31 | 32 | // Number of registers per "core" (SM/CU/EU) 33 | static constexpr int64_t kNumRegisters = 65536; 34 | 35 | // Cache sizes, in bytes 36 | static constexpr int64_t kL1CacheSize = 24576; // 24KB 37 | static constexpr int64_t kL2CacheSize = 2097152; // 2MB 38 | static constexpr int64_t kL3CacheSize = 0; 39 | 40 | // Shared memory size, in bytes 41 | static constexpr int64_t kSharedMemorySizePerMultiprocessor = 98304; // 96KB 42 | 43 | // Non-shared dedicated video memory, in bytes 44 | static constexpr int64_t kMemorySize = 8ll << 30; // 8GB 45 | 46 | // Memory bandwidth, in bytes/s 47 | static constexpr int64_t kBandwidth = 256ll << 30; // 256GB/s 48 | 49 | // Total compute, in 32-bit floating point operations per second 50 | static constexpr int64_t kComputeFlops = 6.5e12; // 6.5 TFLOPS 51 | }; 52 | 53 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_bfc_allocator.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "dml_bfc_allocator.h" 17 | 18 | #include "dml_common.h" 19 | #include "dml_heap_allocator.h" 20 | #include "tfdml/runtime_adapter/env_var.h" 21 | 22 | namespace tfdml 23 | { 24 | 25 | DmlAllocator::DmlAllocator( 26 | D3D12HeapAllocator* heap_allocator, 27 | const std::string& name) 28 | : heap_allocator_(heap_allocator) 29 | { 30 | } 31 | 32 | D3D12BufferRegion DmlAllocator::CreateBufferRegion( 33 | const void* ptr, 34 | uint64_t size_in_bytes) 35 | { 36 | return heap_allocator_->CreateBufferRegion(ptr, size_in_bytes); 37 | } 38 | 39 | void* DmlAllocator::Alloc(uint32_t device_id, size_t num_bytes) 40 | { 41 | void* p = heap_allocator_->Alloc(device_id, num_bytes); 42 | return p; 43 | } 44 | 45 | void DmlAllocator::Free(void* ptr, size_t num_bytes) 46 | { 47 | heap_allocator_->Free(ptr, num_bytes); 48 | } 49 | 50 | } // namespace tfdml 51 | -------------------------------------------------------------------------------- /tfdml/core/dml_bfc_allocator.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "dml_buffer_region.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | class D3D12HeapAllocator; 24 | 25 | // The framework "wraps" this allocator inside a BFC allocator and calls Alloc 26 | // when it determines that it needs to grow the allocated memory. Here, 27 | // DmlAllocator is basically a SubAllocator with additional functionalities like 28 | // CreateBufferRegion(). 29 | class DmlAllocator 30 | { 31 | public: 32 | DmlAllocator(D3D12HeapAllocator* heap_allocator, const std::string& name); 33 | D3D12BufferRegion CreateBufferRegion( 34 | const void* ptr, 35 | uint64_t size_in_bytes); 36 | void* Alloc(uint32_t device_id, size_t num_bytes); 37 | void Free(void* ptr, size_t num_bytes); 38 | 39 | private: 40 | D3D12HeapAllocator* heap_allocator_; 41 | }; 42 | 43 | } // namespace tfdml 44 | -------------------------------------------------------------------------------- /tfdml/core/dml_buffer.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "dml_buffer_region.h" 19 | #include "dml_common.h" 20 | #include "tfdml/runtime_adapter/tensor.h" 21 | 22 | struct TF_OpKernelContext; 23 | 24 | namespace tfdml 25 | { 26 | 27 | class DmlAllocator; 28 | class OpKernelContext; 29 | 30 | // Owns a D3D12 default heap buffer allocated using the DML device's 31 | // allocator. This is essentially a convenience wrapper over a device memory 32 | // allocation as well as the buffer region that spans it. When this object is 33 | // destructed, the device memory is freed to the allocator. 34 | class DmlBuffer 35 | { 36 | public: 37 | explicit DmlBuffer( 38 | TF_OpKernelContext* op_kernel_context, 39 | DmlAllocator* allocator, 40 | uint64_t size_in_bytes); 41 | 42 | // Move-only 43 | DmlBuffer(const DmlBuffer&) = delete; 44 | DmlBuffer& operator=(const DmlBuffer&) = delete; 45 | DmlBuffer(DmlBuffer&&) = default; 46 | DmlBuffer& operator=(DmlBuffer&&) = default; 47 | 48 | ID3D12Resource* ResourceInUavState() const; 49 | ID3D12Resource* ResourceInCopySrcState() const; 50 | ID3D12Resource* ResourceInCopyDstState() const; 51 | uint64_t Offset() const; 52 | uint64_t SizeInBytes() const; 53 | const D3D12BufferRegion& Region() const { return buffer_region_; } 54 | 55 | DML_BUFFER_BINDING GetBufferBinding() const; 56 | 57 | explicit operator bool() const { return !!buffer_region_; } 58 | 59 | private: 60 | DmlAllocator* allocator_; // weak; owned by the device state 61 | D3D12BufferRegion buffer_region_; 62 | 63 | // Dummy tensor that holds the memory allocated by the BFC Allocator 64 | Tensor tensor_; 65 | }; 66 | 67 | } // namespace tfdml 68 | -------------------------------------------------------------------------------- /tfdml/core/dml_command_queue.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "dml_command_queue.h" 17 | 18 | namespace tfdml 19 | { 20 | 21 | DmlCommandQueue::DmlCommandQueue(ID3D12CommandQueue* existing_queue) 22 | : queue_(existing_queue), 23 | type_(existing_queue->GetDesc().Type) 24 | { 25 | Microsoft::WRL::ComPtr device; 26 | DML_CHECK_SUCCEEDED(queue_->GetDevice(IID_PPV_ARGS(&device))); 27 | 28 | DML_CHECK_SUCCEEDED( 29 | device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence_))); 30 | } 31 | 32 | void DmlCommandQueue::ExecuteCommandLists( 33 | absl::Span command_lists) 34 | { 35 | queue_->ExecuteCommandLists( 36 | static_cast(command_lists.size()), 37 | command_lists.data()); 38 | 39 | ++last_fence_value_; 40 | DML_CHECK_SUCCEEDED(queue_->Signal(fence_.Get(), last_fence_value_)); 41 | } 42 | 43 | DmlGpuEvent DmlCommandQueue::GetCurrentCompletionEvent() 44 | { 45 | return DmlGpuEvent{last_fence_value_, fence_}; 46 | } 47 | 48 | DmlGpuEvent DmlCommandQueue::GetNextCompletionEvent() 49 | { 50 | return DmlGpuEvent{last_fence_value_ + 1, fence_}; 51 | } 52 | 53 | } // namespace tfdml 54 | -------------------------------------------------------------------------------- /tfdml/core/dml_command_queue.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | #include "dml_common.h" 21 | #include "dml_gpu_event.h" 22 | 23 | namespace tfdml 24 | { 25 | 26 | // Manages a D3D12 command queue and provides a waitable fence which is signaled 27 | // with a monotonically increasing value once each execute completes on the GPU. 28 | class DmlCommandQueue 29 | { 30 | public: 31 | // Creates a CommandQueue object that wraps an existing D3D12 queue. 32 | DmlCommandQueue(ID3D12CommandQueue* existing_queue); 33 | 34 | D3D12_COMMAND_LIST_TYPE GetType() const { return type_; } 35 | Microsoft::WRL::ComPtr GetFence() const { return fence_; } 36 | uint64_t GetLastFenceValue() const { return last_fence_value_; } 37 | 38 | void ExecuteCommandLists(absl::Span command_lists); 39 | 40 | // Returns an event that will become signaled when everything submitted to 41 | // the queue thus far has completed execution on the GPU. 42 | DmlGpuEvent GetCurrentCompletionEvent(); 43 | 44 | // Returns an event that will become signaled after the next 45 | // ExecuteCommandLists call. 46 | DmlGpuEvent GetNextCompletionEvent(); 47 | 48 | private: 49 | Microsoft::WRL::ComPtr queue_; 50 | D3D12_COMMAND_LIST_TYPE type_; 51 | 52 | Microsoft::WRL::ComPtr fence_; 53 | uint64_t last_fence_value_ = 0; 54 | }; 55 | 56 | } // namespace tfdml 57 | -------------------------------------------------------------------------------- /tfdml/core/dml_device.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/device.h" 19 | 20 | class IDMLDevice; 21 | class ID3D12Device; 22 | 23 | namespace tfdml 24 | { 25 | 26 | class Tensor; 27 | class DmlAdapter; 28 | class DmlAllocator; 29 | class DmlDescriptorAllocator; 30 | class DmlKernelManager; 31 | class DmlExecutionContext; 32 | class DmlUploadHeap; 33 | class DmlReadbackHeap; 34 | class DmlEventQueue; 35 | class DMLDeviceContext; 36 | struct DmlDeviceState; 37 | 38 | class DmlDevice : public Device 39 | { 40 | public: 41 | DmlDevice(const DmlDeviceState* state, uint32_t device_ordinal); 42 | 43 | ID3D12Device* GetD3D12Device() const; 44 | IDMLDevice* GetDmlDevice() const; 45 | DmlAllocator* GetAllocator() const; 46 | DmlDescriptorAllocator* GetDescriptorAllocator() const; 47 | DmlKernelManager* GetKernelManager() const; 48 | DmlExecutionContext* GetExecutionContext() const; 49 | DmlUploadHeap* GetUploadHeap() const; 50 | DmlReadbackHeap* GetReadbackHeap() const; 51 | DmlEventQueue* GetEventQueue() const; 52 | DMLDeviceContext* GetDeviceContext() const; 53 | Status Sync(); 54 | inline uint32_t GetDeviceOrdinal() const { return device_ordinal_; } 55 | 56 | absl::optional TryLogKernelComputeStart( 57 | const absl::string_view type, 58 | const absl::string_view name) const final; 59 | 60 | void LogKernelComputeEnd(uint32_t event_id) const final; 61 | 62 | void CopyTensorInSameDevice( 63 | const Tensor* input_tensor, 64 | Tensor* output_tensor) final; 65 | 66 | Status CopyCPUTensorToDevice( 67 | const Tensor* cpu_tensor, 68 | Tensor* device_tensor) final; 69 | 70 | Status CopyDeviceTensorToCPU( 71 | const Tensor* device_tensor, 72 | Tensor* cpu_tensor) final; 73 | 74 | Status CopyDeviceTensorsToCPU( 75 | absl::Span device_tensors, 76 | absl::Span cpu_tensors) final; 77 | 78 | private: 79 | const DmlDeviceState* state_; // Weak; owned by the device factory 80 | std::unique_ptr device_context_; 81 | uint32_t device_ordinal_; 82 | }; 83 | 84 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_device_cache.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #include "tfdml/core/dml_adapter.h" 22 | #include "tfdml/core/dml_device_state.h" 23 | #include "tfdml/runtime_adapter/status.h" 24 | 25 | namespace tfdml 26 | { 27 | 28 | // Maintains a static cache of device singletons, one per adapter. This class is 29 | // thread-safe. 30 | class DmlDeviceCache 31 | { 32 | public: 33 | static DmlDeviceCache& Instance(); 34 | uint32_t GetAdapterCount() const; 35 | 36 | // It is a little odd that we require GPUOptions and memory_limit here, as 37 | // those can vary per TF device instance - they're not process-global. We 38 | // handle this by using the options and memory limit that are provided to 39 | // the first device created on this adapter. If subsequent devices are 40 | // created on the same adapter but with different options/memory_limit, they 41 | // are ignored. This is unusual, but matches the behavior of the CUDA 42 | // device. 43 | const DmlDeviceState* GetOrCreateDeviceState(uint32_t adapter_index); 44 | const DmlAdapter& GetAdapter(uint32_t adapter_index) const; 45 | 46 | private: 47 | DmlDeviceCache(); 48 | 49 | mutable std::mutex mutex_; 50 | std::vector adapters_; 51 | std::vector> device_states_; 52 | }; 53 | 54 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_device_manager.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/core/dml_device_manager.h" 17 | 18 | namespace tfdml 19 | { 20 | /*static*/ DmlDeviceManager& DmlDeviceManager::Instance() 21 | { 22 | static DmlDeviceManager device_manager; 23 | return device_manager; 24 | } 25 | 26 | Status DmlDeviceManager::InsertDevice(uint32_t device_id, DmlDevice* dml_device) 27 | { 28 | if (device_id >= kNumMaxDevices) 29 | { 30 | return errors::InvalidArgument( 31 | "DML doesn't support more than ", 32 | kNumMaxDevices, 33 | " devices at the moment. Use the DML_VISIBLE_DEVICES environment " 34 | "variable to reduce the number of visible devices (e.g. " 35 | "DML_VISIBLE_DEVICES=\"0,1,2\" to show only the first 3 devices)."); 36 | } 37 | 38 | if (device_id >= devices_.size()) 39 | { 40 | devices_.resize(device_id + 1, nullptr); 41 | } 42 | 43 | devices_[device_id] = dml_device; 44 | return Status::OK(); 45 | } 46 | 47 | DmlDevice* DmlDeviceManager::GetDevice(uint32_t device_id) const 48 | { 49 | return devices_[device_id]; 50 | } 51 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_device_manager.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/container/inlined_vector.h" 19 | #include "tfdml/core/dml_device.h" 20 | #include "tfdml/core/dml_tagged_pointer.h" 21 | 22 | namespace tfdml 23 | { 24 | class DmlDeviceManager 25 | { 26 | public: 27 | static DmlDeviceManager& Instance(); 28 | Status InsertDevice(uint32_t device_id, DmlDevice* dml_device); 29 | DmlDevice* GetDevice(uint32_t device_id) const; 30 | 31 | private: 32 | static constexpr uint64_t kNumMaxDevices = TaggedPointer::kDeviceIDBits 33 | << 1; 34 | 35 | DmlDeviceManager() = default; 36 | absl::InlinedVector devices_; 37 | }; 38 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_device_state.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "dml_adapter.h" 19 | #include "dml_adapter_impl.h" 20 | #include "dml_common.h" 21 | 22 | namespace tfdml 23 | { 24 | 25 | class DmlAdapter; 26 | class DmlExecutionContext; 27 | class DmlEventQueue; 28 | class D3D12HeapAllocator; 29 | class DmlAllocator; 30 | class D3D12DescriptorHeapAllocator; 31 | class DmlDescriptorAllocator; 32 | class DmlUploadHeap; 33 | class DmlReadbackHeap; 34 | class DmlKernelManager; 35 | class GPUOptions; 36 | 37 | // Holds device state that is shared across one or more DmlDevice instances. 38 | // Instances of these state objects are owned by the DML device factory. 39 | // Typically one of these state objects exists for each physical D3D adapter, 40 | // but multiple TF DmlDevice instances can share this state. All objects owned 41 | // by this state object are thread-safe. 42 | struct DmlDeviceState 43 | { 44 | public: 45 | static std::unique_ptr Create( 46 | const DmlAdapter& adapter, 47 | uint32_t adapter_index); 48 | 49 | DmlDeviceState(); 50 | ~DmlDeviceState(); 51 | 52 | std::unique_ptr adapter; 53 | Microsoft::WRL::ComPtr d3d_device; 54 | Microsoft::WRL::ComPtr command_queue; 55 | Microsoft::WRL::ComPtr sharing_contract; 56 | Microsoft::WRL::ComPtr dml_device; 57 | std::unique_ptr execution_context; 58 | std::unique_ptr event_queue; 59 | std::unique_ptr heap_allocator; 60 | std::unique_ptr dml_allocator; 61 | std::unique_ptr descriptor_heap_allocator; 62 | std::unique_ptr descriptor_allocator; 63 | std::unique_ptr upload_heap; 64 | std::unique_ptr readback_heap; 65 | std::unique_ptr kernel_manager; 66 | }; 67 | 68 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/core/dml_dso_loader.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Common DSO loading functionality: exposes callables that dlopen DSOs 17 | // in either the runfiles directories 18 | 19 | #pragma once 20 | 21 | #include "tfdml/runtime_adapter/statusor.h" 22 | 23 | namespace tfdml 24 | { 25 | namespace DmlDsoLoader 26 | { 27 | // The following methods either load the DSO of interest and return a dlopen 28 | // handle or error status. 29 | StatusOr GetDirectMLDsoHandle(); 30 | StatusOr GetDirectMLDebugDsoHandle(); 31 | StatusOr GetD3d12DsoHandle(); 32 | StatusOr GetDxgiDsoHandle(); 33 | StatusOr GetDxCoreDsoHandle(); 34 | StatusOr GetPixDsoHandle(); 35 | StatusOr GetKernel32DsoHandle(); 36 | } // namespace DmlDsoLoader 37 | 38 | // Wrapper around the DmlDsoLoader that prevents us from dlopen'ing any of the 39 | // DSOs more than once. 40 | namespace DmlCachedDsoLoader 41 | { 42 | // Cached versions of the corresponding DmlDsoLoader methods above. 43 | StatusOr GetDirectMLDsoHandle(); 44 | StatusOr GetDirectMLDebugDsoHandle(); 45 | StatusOr GetD3d12DsoHandle(); 46 | StatusOr GetDxgiDsoHandle(); 47 | StatusOr GetDxCoreDsoHandle(); 48 | StatusOr GetPixDsoHandle(); 49 | StatusOr GetKernel32DsoHandle(); 50 | } // namespace DmlCachedDsoLoader 51 | } // namespace tfdml 52 | -------------------------------------------------------------------------------- /tfdml/core/dml_error_handling.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/strings/string_view.h" 19 | 20 | #ifdef _WIN32 21 | #include 22 | #else 23 | #include 24 | #endif 25 | 26 | namespace tfdml 27 | { 28 | namespace dml_util 29 | { 30 | [[noreturn]] void HandleFailedHr( 31 | HRESULT hr, 32 | const char* expression, 33 | const char* file, 34 | int line); 35 | 36 | bool HrIsOutOfMemory(HRESULT hr); 37 | absl::string_view StringifyDeviceRemovedReason(HRESULT reason); 38 | 39 | } // namespace dml_util 40 | } // namespace tfdml 41 | 42 | #define DML_CHECK_SUCCEEDED(x) \ 43 | do \ 44 | { \ 45 | HRESULT _hr = (x); \ 46 | if (FAILED(_hr)) \ 47 | { \ 48 | tfdml::dml_util::HandleFailedHr(_hr, #x, __FILE__, __LINE__); \ 49 | } \ 50 | } while (0) 51 | -------------------------------------------------------------------------------- /tfdml/core/dml_event_queue.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "dml_common.h" 24 | #include "dml_gpu_event.h" 25 | 26 | namespace tfdml 27 | { 28 | 29 | // Allows for queueing CPU work in response to a signaled GPU event. Each 30 | // instance of this queue can only be used with a single fence, and the fence's 31 | // signaled values are assumed to only ever increase in a monotonic fashion. 32 | // This class is thread-safe. 33 | class DmlEventQueue 34 | { 35 | public: 36 | using DoneCallback = std::function; 37 | 38 | explicit DmlEventQueue(ID3D12Fence* fence); 39 | ~DmlEventQueue(); 40 | 41 | // Enqueues an arbitrary callback to fire once the given GPU event becomes 42 | // signaled. The callback is invoked asynchronously, on an arbitrary thread. 43 | // If there are multiple callbacks enqueued for a single fence value, those 44 | // callbacks are executed in the order they were queued. This method is 45 | // thread-safe. 46 | void Enqueue(DmlGpuEvent gpu_event, DoneCallback done_callback); 47 | 48 | private: 49 | struct Event 50 | { 51 | DoneCallback done_callback; 52 | }; 53 | 54 | // State shared with the background thread. Protected by `mutex`. 55 | struct SharedState 56 | { 57 | // The fence associated with this queue. 58 | Microsoft::WRL::ComPtr fence; 59 | std::mutex mutex; 60 | std::condition_variable 61 | new_event_enqueued; // An event that fires whenever 62 | // a new event is added. 63 | std::multimap events_by_fence_value; 64 | 65 | // The current fence value that the thread is waiting to be signaled. 66 | // This value is guaranteed to be <= fence->GetCompletedValue()+1. 67 | uint64_t current_awaited_fence_value = 0; 68 | 69 | bool exit_requested = false; 70 | }; 71 | 72 | static void ThreadProc(std::shared_ptr state); 73 | 74 | std::shared_ptr shared_state_; 75 | std::thread thread_; 76 | }; 77 | 78 | } // namespace tfdml 79 | -------------------------------------------------------------------------------- /tfdml/core/dml_guids.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // This file's sole purpose is to initialize the GUIDs declared using the 17 | // DEFINE_GUID macro. This file is used instead of dxguids.cpp in the 18 | // DirectX-Headers repository for two reasons: 19 | // 1. DXGI IIDs aren't defined in DirectX-Headers 20 | // 2. DirectML IIDs aren't defined in DirectX-Headers 21 | 22 | #define INITGUID 23 | 24 | // clang-format off 25 | #ifndef _WIN32 26 | #include 27 | #include 28 | #include 29 | #include "DirectML.h" 30 | #include "dxguids/dxguids.h" 31 | #include "dml_guids.h" 32 | #else 33 | #include 34 | #include 35 | #endif 36 | // clang-format on 37 | -------------------------------------------------------------------------------- /tfdml/core/dml_guids.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | // clang-format off 19 | 20 | WINADAPTER_IID(IDMLObject, 0xc8263aac, 0x9e0c, 0x4a2d, 0x9b, 0x8e, 0x00, 0x75, 0x21, 0xa3, 0x31, 0x7c); 21 | WINADAPTER_IID(IDMLDevice, 0x6dbd6437, 0x96fd, 0x423f, 0xa9, 0x8c, 0xae, 0x5e, 0x7c, 0x2a, 0x57, 0x3f); 22 | WINADAPTER_IID(IDMLDeviceChild, 0x27e83142, 0x8165, 0x49e3, 0x97, 0x4e, 0x2f, 0xd6, 0x6e, 0x4c, 0xb6, 0x9d); 23 | WINADAPTER_IID(IDMLPageable, 0xb1ab0825, 0x4542, 0x4a4b, 0x86, 0x17, 0x6d, 0xde, 0x6e, 0x8f, 0x62, 0x01); 24 | WINADAPTER_IID(IDMLOperator, 0x26caae7a, 0x3081, 0x4633, 0x95, 0x81, 0x22, 0x6f, 0xbe, 0x57, 0x69, 0x5d); 25 | WINADAPTER_IID(IDMLDispatchable, 0xdcb821a8, 0x1039, 0x441e, 0x9f, 0x1c, 0xb1, 0x75, 0x9c, 0x2f, 0x3c, 0xec); 26 | WINADAPTER_IID(IDMLCompiledOperator, 0x6b15e56a, 0xbf5c, 0x4902, 0x92, 0xd8, 0xda, 0x3a, 0x65, 0x0a, 0xfe, 0xa4); 27 | WINADAPTER_IID(IDMLOperatorInitializer, 0x427c1113, 0x435c, 0x469c, 0x86, 0x76, 0x4d, 0x5d, 0xd0, 0x72, 0xf8, 0x13); 28 | WINADAPTER_IID(IDMLBindingTable, 0x29c687dc, 0xde74, 0x4e3b, 0xab, 0x00, 0x11, 0x68, 0xf2, 0xfc, 0x3c, 0xfc); 29 | WINADAPTER_IID(IDMLCommandRecorder, 0xe6857a76, 0x2e3e, 0x4fdd, 0xbf, 0xf4, 0x5d, 0x2b, 0xa1, 0x0f, 0xb4, 0x53); 30 | WINADAPTER_IID(IDMLDebugDevice, 0x7d6f3ac9, 0x394a, 0x4ac3, 0x92, 0xa7, 0x39, 0x0c, 0xc5, 0x7a, 0x82, 0x17); 31 | WINADAPTER_IID(IDMLDevice1, 0xa0884f9a, 0xd2be, 0x4355, 0xaa, 0x5d, 0x59, 0x01, 0x28, 0x1a, 0xd1, 0xd2); 32 | 33 | // clang-format on 34 | -------------------------------------------------------------------------------- /tfdml/core/dml_readback_heap.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "dml_common.h" 19 | #include "dml_event_queue.h" 20 | #include "dml_execution_context.h" 21 | #include "dml_pooled_heap.h" 22 | 23 | namespace tfdml 24 | { 25 | class DmlExecutionContext; 26 | 27 | // Performs non-blocking readback from GPU resources. This class is thread-safe. 28 | class DmlReadbackHeap : public DmlPooledHeap 29 | { 30 | public: 31 | DmlReadbackHeap( 32 | ID3D12Device* device, 33 | DmlExecutionContext* execution_context, 34 | DmlEventQueue* event_queue); 35 | 36 | // Copies data from the specified GPU resource into CPU memory pointed-to by 37 | // the span. This is non-blocking; the copy is not complete until the 38 | // returned event becomes signaled. Both the dst buffer and src resource 39 | // must stay alive until the copy is complete. 40 | StatusOr ReadbackFromGpu( 41 | absl::Span dst, 42 | const D3D12BufferRegion& src); 43 | 44 | private: 45 | std::mutex mutex_; 46 | DmlExecutionContext* execution_context_; // weak; owned by DmlDeviceState 47 | DmlEventQueue* event_queue_; // weak; owned by DmlDeviceState 48 | 49 | // We maintain a completion event independent of the execution context, 50 | // because the execution context's completion event only tells you when the 51 | // copy to the readback heap has completed, whereas what the caller cares 52 | // about is whether the copy to the `dst` buffer is complete. 53 | DmlGpuEvent current_completion_event_; 54 | }; 55 | 56 | } // namespace tfdml 57 | -------------------------------------------------------------------------------- /tfdml/core/dml_tagged_pointer.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/core/dml_tagged_pointer.h" 17 | #include 18 | 19 | namespace tfdml 20 | { 21 | /*static*/ TaggedPointer TaggedPointer::Unpack(const void* ptr) 22 | { 23 | uint64_t ptr_val = reinterpret_cast(ptr); 24 | 25 | static constexpr uint64_t kAllocationIDMask = 26 | (1ull << kAllocationIDBits) - 1; 27 | static constexpr uint64_t kOffsetMask = (1ull << kOffsetBits) - 1; 28 | 29 | TaggedPointer tagged_ptr; 30 | tagged_ptr.device_id = (ptr_val >> (kAllocationIDBits + kOffsetBits)); 31 | tagged_ptr.allocation_id = (ptr_val >> kOffsetBits) & kAllocationIDMask; 32 | tagged_ptr.offset = (ptr_val & kOffsetMask); 33 | 34 | return tagged_ptr; 35 | } 36 | 37 | /*static*/ void* TaggedPointer::Pack( 38 | uint32_t device_id, 39 | uint32_t allocation_id, 40 | uint64_t offset) 41 | { 42 | assert(device_id < (1ull << kDeviceIDBits)); 43 | assert(allocation_id < (1ull << kAllocationIDBits)); 44 | assert(offset < (1ull << kOffsetBits)); 45 | 46 | // Store the device ID in the upper bits of the pointer, followed by the 47 | // allocation id and the offset in the lower bits 48 | uint64_t ptr = ((uint64_t)device_id << (kAllocationIDBits + kOffsetBits)) | 49 | ((uint64_t)allocation_id << kOffsetBits) | offset; 50 | 51 | return reinterpret_cast(ptr); 52 | } 53 | } // namespace tfdml 54 | -------------------------------------------------------------------------------- /tfdml/core/dml_tagged_pointer.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | namespace tfdml 22 | { 23 | 24 | // D3D12HeapAllocator and D3D12DescriptorHeapAllocator encode the allocation ID 25 | // into the high bits of the pointers it returns, while the low bits are used as 26 | // an offset into the allocation. Note that since the layout of bitfields is 27 | // implementation-defined, you can't just cast a void* into a TaggedPointer: it 28 | // must be done using masks and shifts. 29 | struct TaggedPointer 30 | { 31 | static constexpr uint64_t kDeviceIDBits = 4; 32 | static constexpr uint64_t kAllocationIDBits = 20; 33 | static constexpr uint64_t kOffsetBits = 40; 34 | 35 | uint64_t device_id : kDeviceIDBits; 36 | uint64_t allocation_id : kAllocationIDBits; 37 | uint64_t offset : kOffsetBits; 38 | 39 | static void* Pack( 40 | uint32_t device_id, 41 | uint32_t allocation_id, 42 | uint64_t offset); 43 | static TaggedPointer Unpack(const void* ptr); 44 | }; 45 | 46 | static_assert( 47 | sizeof(TaggedPointer) == sizeof(void*), 48 | "DML requires a 64-bit architecture"); 49 | static_assert( 50 | TaggedPointer::kDeviceIDBits + TaggedPointer::kAllocationIDBits + 51 | TaggedPointer::kOffsetBits == 52 | sizeof(void*) * CHAR_BIT, 53 | "DML requires a 64-bit architecture"); 54 | 55 | } // namespace tfdml 56 | -------------------------------------------------------------------------------- /tfdml/core/dml_upload_heap.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "dml_common.h" 19 | #include "dml_execution_context.h" 20 | #include "dml_pooled_heap.h" 21 | 22 | namespace tfdml 23 | { 24 | 25 | class DmlExecutionContext; 26 | 27 | // Implements a non-blocking, ring-buffer style upload heap for copying CPU data 28 | // to GPU resources. This class is thread-safe. 29 | class DmlUploadHeap : public DmlPooledHeap 30 | { 31 | public: 32 | DmlUploadHeap(ID3D12Device* device, DmlExecutionContext* execution_context); 33 | 34 | // Makes a copy of the source data and begins copying it into the 35 | // destination resource, and returns a DmlGpuEvent which will become 36 | // signaled when the copy is complete. The destination resource must be a 37 | // default or readback buffer. 38 | StatusOr BeginUploadToGpu( 39 | const D3D12BufferRegion& dst, 40 | absl::Span src); 41 | 42 | private: 43 | std::mutex mutex_; 44 | DmlExecutionContext* execution_context_; // weak; owned by DmlDeviceState 45 | }; 46 | 47 | } // namespace tfdml 48 | -------------------------------------------------------------------------------- /tfdml/kernels/dml_deepcopy_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Portions Copyright (c) Microsoft Corporation. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "tfdml/kernels/pch.h" 18 | 19 | namespace tfdml 20 | { 21 | 22 | class DmlDeepCopyKernel : public OpKernel 23 | { 24 | public: 25 | explicit DmlDeepCopyKernel( 26 | OpKernelConstruction* ctx, 27 | std::shared_ptr node_def) 28 | : OpKernel(std::move(node_def)) 29 | { 30 | } 31 | 32 | private: 33 | void ComputeImpl(OpKernelContext* ctx) final 34 | { 35 | const Tensor& input = ctx->input(0); 36 | const TensorShape& input_shape = input.shape(); 37 | 38 | StatusOr status_or_output_tensor = 39 | ctx->allocate_output(0, input_shape); 40 | OP_REQUIRES_OK(ctx, status_or_output_tensor.status()); 41 | 42 | DmlDevice* device = static_cast(ctx->device()); 43 | auto* device_context = device->GetDeviceContext(); 44 | 45 | if (input.NumElements() == 0) 46 | { 47 | return; 48 | } 49 | 50 | D3D12BufferRegion input_buffer = 51 | device_context->GetBufferForTensor(input); 52 | 53 | D3D12BufferRegion output_buffer = device_context->GetBufferForTensor( 54 | status_or_output_tensor.ValueOrDie()); 55 | 56 | uint64_t copy_size = 57 | std::min(output_buffer.SizeInBytes(), input_buffer.SizeInBytes()); 58 | 59 | device_context->CopyBufferToBuffer( 60 | output_buffer, 61 | input_buffer.Subregion(0, copy_size)); 62 | } 63 | }; 64 | 65 | void RegisterKernels_DeepCopy() 66 | { 67 | using K = KernelDefinition; 68 | 69 | RegisterWithTypes< 70 | K, 71 | ops::DeepCopy::Attribute::T, 72 | TF_FLOAT, 73 | TF_HALF, 74 | TF_INT64>(); 75 | } 76 | 77 | } // namespace tfdml 78 | -------------------------------------------------------------------------------- /tfdml/kernels/dml_extract_patches_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Portions Copyright (c) Microsoft Corporation. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "tfdml/core/dml_common.h" 18 | 19 | namespace dml 20 | { 21 | dml::Expression ExtractPatches( 22 | dml::Graph& scope, 23 | dml::Expression input, 24 | absl::Span window_sizes, 25 | absl::Span window_strides, 26 | absl::Span window_rates, 27 | absl::Span start_padding, 28 | absl::Span end_padding, 29 | absl::Span output_sizes); 30 | } // namespace dml 31 | -------------------------------------------------------------------------------- /tfdml/kernels/dml_matrix_diag_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Portions Copyright (c) Microsoft Corporation. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing perMatrixDiagPartmissions 14 | and limitations under the License. 15 | ==============================================================================*/ 16 | 17 | namespace dml 18 | { 19 | dml::Expression MatrixDiag( 20 | dml::Graph& scope, 21 | dml::Expression diag, 22 | int32_t k_min, 23 | int32_t k_max, 24 | float padding_value, 25 | int64_t out_height, 26 | int64_t out_width, 27 | bool align_sup_left, 28 | bool align_sub_left); 29 | 30 | dml::Expression MatrixDiagPart( 31 | dml::Graph& scope, 32 | dml::Expression m, 33 | int32_t k0, 34 | int32_t k1, 35 | float padding_value, 36 | uint32_t out_height, 37 | uint32_t out_width, 38 | bool align_sup_left, 39 | bool align_sub_left); 40 | } // namespace dml 41 | -------------------------------------------------------------------------------- /tfdml/kernels/dml_ones_like_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Portions Copyright (c) Microsoft Corporation. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "tfdml/kernels/pch.h" 18 | 19 | namespace tfdml 20 | { 21 | 22 | class DmlOnesLikeKernel : public DmlKernel 23 | { 24 | public: 25 | using InitHelper = NoOpInitializationHelper; 26 | 27 | explicit DmlOnesLikeKernel( 28 | DmlKernelConstruction* ctx, 29 | const InitHelper* init_helper) 30 | { 31 | TensorShape tensor_shape({ctx->GetOutputTensorShape(0).num_elements()}); 32 | DmlTensorInfo output; 33 | output.kernel_index = 0; 34 | output.desc = DmlTensorDesc::Create( 35 | ctx->GetOutputDataType(0), 36 | tensor_shape, 37 | tensor_shape); 38 | 39 | DmlKernelTensors tensors; 40 | tensors.outputs = {output}; 41 | 42 | auto dml_dtype = GetDmlDataTypeFromTfDataType(ctx->GetInputDataType(0)); 43 | DML_SCALAR_UNION one_value = dml::ScalarUnion(1, dml_dtype); 44 | 45 | auto scope = dml::Graph(ctx->GetDmlDevice()); 46 | auto result = dml::FillValueConstant( 47 | scope, 48 | {static_cast(tensor_shape.num_elements())}, 49 | dml_dtype, 50 | one_value); 51 | 52 | Microsoft::WRL::ComPtr compiled_op = 53 | scope.Compile(DML_EXECUTION_FLAG_NONE, {result}); 54 | 55 | Initialize(ctx, std::move(tensors), compiled_op.Get()); 56 | } 57 | }; 58 | 59 | void RegisterKernels_OnesLike() 60 | { 61 | using K = KernelDefinition< 62 | ops::OnesLike, 63 | DmlKernelWrapper>; 64 | 65 | RegisterWithTypes< 66 | K, 67 | ops::OnesLike::Attribute::T, 68 | TF_FLOAT, 69 | TF_HALF, 70 | TF_INT64, 71 | TF_BOOL>(); 72 | } 73 | 74 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/kernels/dml_snapshot_op.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Portions Copyright (c) Microsoft Corporation. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "tfdml/kernels/pch.h" 18 | 19 | namespace tfdml 20 | { 21 | class DmlSnapshotOp : public OpKernel 22 | { 23 | public: 24 | explicit DmlSnapshotOp( 25 | OpKernelConstruction* c, 26 | std::shared_ptr node_def) 27 | : OpKernel(std::move(node_def)) 28 | { 29 | } 30 | 31 | private: 32 | void ComputeImpl(OpKernelContext* context) final 33 | { 34 | const Tensor& input = context->input(0); 35 | // Try to use buffer forwarding to avoid an explicit copy. 36 | int candidate_input_indices[] = {0}; 37 | StatusOr status_or_output = 38 | context->forward_input_or_allocate_output( 39 | candidate_input_indices, 40 | 0, 41 | input.shape()); 42 | 43 | OP_REQUIRES_OK(context, status_or_output.status()); 44 | if (!status_or_output.ValueOrDie().SharesBufferWith(input)) 45 | { 46 | context->device()->CopyTensorInSameDevice( 47 | &input, 48 | &status_or_output.ValueOrDie()); 49 | } 50 | } 51 | }; 52 | 53 | void RegisterKernels_Snapshot() 54 | { 55 | using K = KernelDefinition; 56 | 57 | RegisterWithTypes< 58 | K, 59 | ops::Snapshot::Attribute::T, 60 | TF_FLOAT, 61 | TF_HALF, 62 | TF_BOOL, 63 | TF_INT64, 64 | TF_INT32, 65 | TF_UINT16, 66 | TF_INT16, 67 | TF_UINT8, 68 | TF_INT8>(); 69 | } 70 | 71 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/kernels/pch.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/core/dml_kernel_definition.h" 19 | #include "tfdml/core/dml_kernel_wrapper.h" 20 | #include "tfdml/core/dml_ops_common.h" -------------------------------------------------------------------------------- /tfdml/optimizer/byte_order.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | // Byte order defines provided by gcc. MSVC doesn't define those so 19 | // we define them here. 20 | // We assume that all windows platform out there are little endian. 21 | #if defined(_MSC_VER) && !defined(__clang__) 22 | #define __ORDER_LITTLE_ENDIAN__ 0x4d2 23 | #define __ORDER_BIG_ENDIAN__ 0x10e1 24 | #define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__ 25 | #endif 26 | 27 | namespace tfdml 28 | { 29 | 30 | constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; 31 | 32 | } // namespace tfdml 33 | -------------------------------------------------------------------------------- /tfdml/optimizer/device_type.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/optimizer/device_type.h" 17 | 18 | DeviceType::DeviceType(const char* type) // NOLINT(runtime/explicit) 19 | : type_(type) 20 | { 21 | } 22 | 23 | DeviceType::DeviceType(absl::string_view type) : type_(type.data(), type.size()) 24 | { 25 | } 26 | 27 | const char* DeviceType::type() const { return type_.c_str(); } 28 | const std::string& DeviceType::type_string() const { return type_; } 29 | 30 | bool DeviceType::operator<(const DeviceType& other) const 31 | { 32 | return type_ < other.type_; 33 | } 34 | bool DeviceType::operator==(const DeviceType& other) const 35 | { 36 | return type_ == other.type_; 37 | }; 38 | bool DeviceType::operator!=(const DeviceType& other) const 39 | { 40 | return !(*this == other); 41 | } 42 | -------------------------------------------------------------------------------- /tfdml/optimizer/device_type.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/strings/string_view.h" 19 | #include 20 | 21 | // A DeviceType is just a string, but we wrap it up in a class to give 22 | // some type checking as we're passing these around 23 | class DeviceType 24 | { 25 | public: 26 | DeviceType(const char* type); 27 | explicit DeviceType(absl::string_view type); 28 | const char* type() const; 29 | const std::string& type_string() const; 30 | bool operator<(const DeviceType& other) const; 31 | bool operator==(const DeviceType& other) const; 32 | bool operator!=(const DeviceType& other) const; 33 | 34 | private: 35 | std::string type_; 36 | }; -------------------------------------------------------------------------------- /tfdml/optimizer/graph.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | namespace tfdml 19 | { 20 | const int kControlSlot = -1; 21 | } 22 | -------------------------------------------------------------------------------- /tfdml/optimizer/graph_optimizer.cc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/tfdml/optimizer/graph_optimizer.cc -------------------------------------------------------------------------------- /tfdml/optimizer/graph_optimizer.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/status.h" 19 | 20 | namespace tensorflow 21 | { 22 | class GraphDef; 23 | } 24 | 25 | namespace tfdml 26 | { 27 | class GrapplerItem; 28 | 29 | // An abstract interface for an algorithm for generating a candidate 30 | // optimization of a GrapplerItem for running on a cluster. 31 | class GraphOptimizer 32 | { 33 | public: 34 | GraphOptimizer() {} 35 | virtual ~GraphOptimizer() {} 36 | 37 | // Routine called to allow an algorithm to propose a rewritten graph 38 | // for the graph, feeds and fetches in "item" to run more efficiently. If 39 | // the returned status is Status::OK() then *optimized_graph contains the 40 | // rewritten graph. Returns an error status if it failed to generate a 41 | // solution. 42 | // 43 | // A return value of error::Aborted() can be used signal early termination 44 | // of the optimizer, e.g. if the optimization turned out to be a no-op. In 45 | // this case the content of *optimized_graph is undefined. 46 | virtual Status Optimize( 47 | const GrapplerItem& item, 48 | tensorflow::GraphDef* optimized_graph) = 0; 49 | 50 | // Subclasses may define a version of Optimize that consumes item. 51 | virtual Status Optimize( 52 | GrapplerItem&& item, 53 | tensorflow::GraphDef* optimized_graph) 54 | { 55 | return Optimize(item, optimized_graph); 56 | } 57 | }; 58 | } // namespace tfdml 59 | -------------------------------------------------------------------------------- /tfdml/optimizer/graph_properties.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "absl/container/flat_hash_map.h" 17 | #include "tensorflow/core/framework/graph.pb.h" 18 | #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" 19 | #include "tfdml/runtime_adapter/status.h" 20 | 21 | struct TF_GraphProperties; 22 | 23 | namespace tfdml 24 | { 25 | class GrapplerItem; 26 | 27 | class GraphProperties 28 | { 29 | public: 30 | GraphProperties(const GrapplerItem& item); 31 | ~GraphProperties(); 32 | 33 | Status InferStatically( 34 | bool assume_valid_feeds, 35 | bool aggressive_shape_inference, 36 | bool include_input_tensor_values, 37 | bool include_output_tensor_values); 38 | Status InferStatically( 39 | bool assume_valid_feeds, 40 | bool aggressive_shape_inference, 41 | bool include_tensor_values); 42 | Status InferStatically(bool assume_valid_feeds); 43 | 44 | const std::vector& GetInputProperties( 45 | const std::string& node_name) const 46 | { 47 | auto it = input_properties_.find(node_name); 48 | if (it != input_properties_.end()) 49 | { 50 | return it->second; 51 | } 52 | return missing_properties_; 53 | } 54 | 55 | const std::vector& 56 | GetOutputProperties(const std::string& node_name) const 57 | { 58 | auto it = output_properties_.find(node_name); 59 | if (it != output_properties_.end()) 60 | { 61 | return it->second; 62 | } 63 | return missing_properties_; 64 | } 65 | 66 | private: 67 | TF_GraphProperties* graph_props_; 68 | 69 | const GrapplerItem& item_; 70 | 71 | absl::flat_hash_map< 72 | std::string, 73 | std::vector> 74 | input_properties_; 75 | 76 | absl::flat_hash_map< 77 | std::string, 78 | std::vector> 79 | output_properties_; 80 | 81 | const std::vector missing_properties_; 82 | }; 83 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/grappler_item.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #include "tfdml/optimizer/grappler_item.h" 15 | #include "absl/container/flat_hash_set.h" 16 | #include "tensorflow/c/experimental/grappler/grappler.h" 17 | #include "tensorflow/core/framework/graph.pb.h" 18 | #include "tfdml/runtime_adapter/macros.h" 19 | #include "tfdml/runtime_adapter/status.h" 20 | 21 | namespace tfdml 22 | { 23 | GrapplerItem::GrapplerItem( 24 | const TF_GrapplerItem* grappler_item, 25 | OptimizationOptions optimization_options, 26 | tensorflow::GraphDef graph) 27 | : grappler_item_(grappler_item), 28 | optimization_options_(optimization_options), 29 | graph(std::move(graph)) 30 | { 31 | } 32 | 33 | absl::flat_hash_set GrapplerItem::NodesToPreserve() const 34 | { 35 | int num_preserved_nodes; 36 | size_t preserved_nodes_size; 37 | 38 | Status status; 39 | TF_GetNodesToPreserveListSize( 40 | grappler_item_, 41 | &num_preserved_nodes, 42 | &preserved_nodes_size, 43 | status.raw()); 44 | CHECK(status.ok()); 45 | 46 | std::vector preserved_node_names(num_preserved_nodes); 47 | std::vector preserved_node_name_lengths(num_preserved_nodes); 48 | std::vector preserved_node_name_storage(preserved_nodes_size); 49 | 50 | TF_GetNodesToPreserveList( 51 | grappler_item_, 52 | preserved_node_names.data(), 53 | preserved_node_name_lengths.data(), 54 | num_preserved_nodes, 55 | preserved_node_name_storage.data(), 56 | preserved_node_name_storage.size(), 57 | status.raw()); 58 | CHECK(status.ok()); 59 | 60 | absl::flat_hash_set preserved_nodes; 61 | for (int i = 0; i < num_preserved_nodes; ++i) 62 | { 63 | preserved_nodes.insert(std::string( 64 | preserved_node_names[i], 65 | preserved_node_name_lengths[i])); 66 | } 67 | return preserved_nodes; 68 | } 69 | 70 | OptimizationOptions& GrapplerItem::optimization_options() 71 | { 72 | return optimization_options_; 73 | } 74 | 75 | const TF_GrapplerItem* GrapplerItem::raw() const { return grappler_item_; } 76 | } // namespace tfdml 77 | -------------------------------------------------------------------------------- /tfdml/optimizer/grappler_item.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "absl/container/flat_hash_set.h" 17 | #include "tensorflow/core/framework/graph.pb.h" 18 | 19 | struct TF_GrapplerItem; 20 | 21 | namespace tfdml 22 | { 23 | struct OptimizationOptions 24 | { 25 | // Is it allowed to add nodes to the graph that do not have registered 26 | // gradient function. 27 | bool allow_non_differentiable_rewrites = true; 28 | }; 29 | 30 | struct GrapplerItem 31 | { 32 | GrapplerItem( 33 | const TF_GrapplerItem* grappler_item, 34 | OptimizationOptions optimization_options, 35 | tensorflow::GraphDef graph); 36 | 37 | absl::flat_hash_set NodesToPreserve() const; 38 | const TF_GrapplerItem* raw() const; 39 | tensorflow::GraphDef graph; 40 | OptimizationOptions& optimization_options(); 41 | OptimizationOptions optimization_options_; 42 | 43 | private: 44 | const TF_GrapplerItem* const grappler_item_; 45 | }; 46 | } // namespace tfdml 47 | -------------------------------------------------------------------------------- /tfdml/optimizer/hash.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | // Simple hash functions used for internal data structures 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | 23 | namespace tfdml 24 | { 25 | uint32_t Hash32(const char* data, size_t n, uint32_t seed); 26 | } // namespace tfdml 27 | -------------------------------------------------------------------------------- /tfdml/optimizer/map_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | namespace tfdml 17 | { 18 | // Returns a pointer to the const value associated with the given key if it 19 | // exists, or NULL otherwise. 20 | template 21 | const typename Collection::value_type::second_type* FindOrNull( 22 | const Collection& collection, 23 | const typename Collection::value_type::first_type& key) 24 | { 25 | typename Collection::const_iterator it = collection.find(key); 26 | if (it == collection.end()) 27 | { 28 | return 0; 29 | } 30 | return &it->second; 31 | } 32 | 33 | } // end namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/op_registry.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #include "tfdml/optimizer/op_registry.h" 15 | #include "absl/cleanup/cleanup.h" 16 | #include "tensorflow/c/experimental/grappler/grappler.h" 17 | #include "tensorflow/core/framework/op_def.pb.h" 18 | #include "tfdml/optimizer/proto_buffer_helpers.h" 19 | #include "tfdml/runtime_adapter/macros.h" 20 | #include "tfdml/runtime_adapter/status.h" 21 | 22 | namespace tfdml 23 | { 24 | OpRegistry::OpRegistry() {} 25 | 26 | void OpRegistry::Initialize(TF_FunctionLibraryDefinition* function_lib_def) 27 | { 28 | function_lib_def_ = function_lib_def; 29 | } 30 | 31 | Status OpRegistry::LookUpOpDef(const char* op_name, tensorflow::OpDef* op_def) 32 | { 33 | if (function_lib_def_ == nullptr) 34 | { 35 | return errors::InvalidArgument("OpRegistry::Initialize must be called " 36 | "once before OpRegistry::LookUpOpDef."); 37 | } 38 | 39 | Status status; 40 | TF_Buffer* op_buf = TF_NewBuffer(); 41 | auto buf_cleanup = absl::MakeCleanup([op_buf] { TF_DeleteBuffer(op_buf); }); 42 | 43 | TF_LookUpOpDef(function_lib_def_, op_name, op_buf, status.raw()); 44 | TF_RETURN_IF_ERROR(status); 45 | 46 | return ParseBuffer(op_buf, op_def); 47 | } 48 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/op_registry.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "tfdml/runtime_adapter/status.h" 17 | 18 | struct TF_FunctionLibraryDefinition; 19 | 20 | namespace tensorflow 21 | { 22 | class OpDef; 23 | } 24 | 25 | namespace tfdml 26 | { 27 | class OpRegistry 28 | { 29 | public: 30 | OpRegistry(); 31 | void Initialize(TF_FunctionLibraryDefinition* function_lib_def); 32 | Status LookUpOpDef(const char* op_name, tensorflow::OpDef* op_def); 33 | 34 | private: 35 | TF_FunctionLibraryDefinition* function_lib_def_ = nullptr; 36 | }; 37 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/op_types.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tensorflow/core/framework/types.pb.h" 19 | 20 | namespace tensorflow 21 | { 22 | class NodeDef; 23 | } // namespace tensorflow 24 | 25 | namespace tfdml 26 | { 27 | class MutableNodeView; 28 | 29 | constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute"; 30 | constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap"; 31 | 32 | bool IsBiasAdd(const tensorflow::NodeDef& node); 33 | bool IsConstant(const tensorflow::NodeDef& node); 34 | bool IsConv2D(const tensorflow::NodeDef& node); 35 | bool IsElu(const tensorflow::NodeDef& node); 36 | bool IsFusedBatchNormGrad(const tensorflow::NodeDef& node); 37 | bool IsLeakyRelu(const tensorflow::NodeDef& node); 38 | bool IsMerge(const tensorflow::NodeDef& node); 39 | bool IsNextIteration(const tensorflow::NodeDef& node); 40 | bool IsPad(const tensorflow::NodeDef& node); 41 | bool IsPlaceholder(const tensorflow::NodeDef& node); 42 | bool IsRelu(const tensorflow::NodeDef& node); 43 | bool IsRelu6(const tensorflow::NodeDef& node); 44 | bool IsSymbolicGradient(const tensorflow::NodeDef& node); 45 | bool IsTranspose(const tensorflow::NodeDef& node); 46 | 47 | } // namespace tfdml 48 | -------------------------------------------------------------------------------- /tfdml/optimizer/optimizer_runner.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "tfdml/runtime_adapter/status.h" 17 | 18 | struct TF_GrapplerItem; 19 | 20 | namespace tensorflow 21 | { 22 | class GraphDef; 23 | } 24 | 25 | namespace tfdml 26 | { 27 | Status RunOptimizer( 28 | void* optimizer, 29 | const tensorflow::GraphDef& input_graph_def, 30 | const TF_GrapplerItem* grappler_item, 31 | tensorflow::GraphDef& output_graph_def); 32 | } // namespace tfdml 33 | -------------------------------------------------------------------------------- /tfdml/optimizer/perm_utils.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #include "tfdml/optimizer/perm_utils.h" 15 | #include "absl/container/flat_hash_map.h" 16 | #include 17 | 18 | namespace tfdml 19 | { 20 | absl::flat_hash_map GetDimensionIndices( 21 | absl::string_view data_format) 22 | { 23 | const int size = data_format.size(); 24 | absl::flat_hash_map index; 25 | index.reserve(size); 26 | for (int i = 0; i < size; i++) 27 | { 28 | index[data_format[i]] = i; 29 | } 30 | return index; 31 | } 32 | 33 | std::vector GetPermutation( 34 | const absl::flat_hash_map& src_dim_indices, 35 | absl::string_view dst_format) 36 | { 37 | // Generate permutation for transformation between src and dst format. 38 | // Example: 39 | // src = NWHC, dst = NCWH 40 | // index = { N:0 W:1 H:2 C:3 } 41 | // permutation = [0, 3, 1, 2] 42 | assert(src_dim_indices.size() == dst_format.size()); 43 | std::vector permutation; 44 | const int size = dst_format.size(); 45 | permutation.reserve(size); 46 | for (int i = 0; i < size; i++) 47 | { 48 | permutation.push_back(src_dim_indices.at(dst_format[i])); 49 | } 50 | return permutation; 51 | } 52 | } // namespace tfdml 53 | -------------------------------------------------------------------------------- /tfdml/optimizer/perm_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "absl/container/flat_hash_map.h" 17 | #include 18 | 19 | namespace tfdml 20 | { 21 | absl::flat_hash_map GetDimensionIndices( 22 | absl::string_view data_format); 23 | 24 | std::vector GetPermutation( 25 | const absl::flat_hash_map& src_dim_indices, 26 | absl::string_view dst_format); 27 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/proto_buffer_helpers.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #include "tfdml/optimizer/proto_buffer_helpers.h" 15 | #include "tensorflow/c/c_api.h" 16 | #include "tensorflow/core/framework/graph.pb.h" 17 | #include "tfdml/runtime_adapter/status.h" 18 | 19 | namespace tfdml 20 | { 21 | Status GraphDefToBuffer(const tensorflow::GraphDef& in, TF_Buffer* out) 22 | { 23 | if (out->data != nullptr) 24 | { 25 | return errors::InvalidArgument( 26 | "Passing non-empty TF_Buffer is invalid."); 27 | } 28 | const size_t proto_size = in.ByteSizeLong(); 29 | void* buf = malloc(proto_size); 30 | if (buf == nullptr) 31 | { 32 | return errors::ResourceExhausted( 33 | "Failed to allocate memory to serialize message of type '", 34 | in.GetTypeName(), 35 | "' and size ", 36 | proto_size); 37 | } 38 | if (!in.SerializeWithCachedSizesToArray(static_cast(buf))) 39 | { 40 | free(buf); 41 | return errors::InvalidArgument( 42 | "Unable to serialize ", 43 | in.GetTypeName(), 44 | " protocol buffer, perhaps the serialized size (", 45 | proto_size, 46 | " bytes) is too large?"); 47 | } 48 | out->data = buf; 49 | out->length = proto_size; 50 | out->data_deallocator = [](void* data, size_t length) { free(data); }; 51 | return Status::OK(); 52 | } 53 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/proto_buffer_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "tensorflow/c/c_api.h" 17 | #include "tfdml/runtime_adapter/status.h" 18 | 19 | namespace tensorflow 20 | { 21 | class GraphDef; 22 | } 23 | 24 | namespace tfdml 25 | { 26 | template 27 | Status ParseBuffer(const TF_Buffer* in, T* out) 28 | { 29 | if (in == nullptr || !out->ParseFromArray(in->data, in->length)) 30 | { 31 | return errors::InvalidArgument( 32 | "Unparseable ", 33 | out->GetTypeName(), 34 | " proto"); 35 | } 36 | return Status::OK(); 37 | } 38 | 39 | Status GraphDefToBuffer(const tensorflow::GraphDef& in, TF_Buffer* out); 40 | 41 | } // namespace tfdml 42 | -------------------------------------------------------------------------------- /tfdml/optimizer/remapper.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #include "tfdml/optimizer/graph_optimizer.h" 17 | #include "tfdml/runtime_adapter/status.h" 18 | 19 | namespace tensorflow 20 | { 21 | class GraphDef; 22 | } 23 | 24 | namespace tfdml 25 | { 26 | // Optimize TF computations by remapping subgraphs/nodes onto other subgraphs or 27 | // nodes to decrease the amount of operations needed to perform a computation. 28 | class Remapper : public GraphOptimizer 29 | { 30 | public: 31 | ~Remapper() override = default; 32 | Status Optimize( 33 | const GrapplerItem& item, 34 | tensorflow::GraphDef* optimized_graph) override; 35 | }; 36 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/tensor_id.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/optimizer/tensor_id.h" 17 | #include "absl/strings/match.h" 18 | 19 | namespace tfdml 20 | { 21 | 22 | TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {} 23 | 24 | SafeTensorId::SafeTensorId(const TensorId& id) 25 | : SafeTensorId(std::string(id.first), id.second) 26 | { 27 | } 28 | 29 | TensorId ParseTensorName(const std::string& name) 30 | { 31 | return ParseTensorName(absl::string_view(name.data(), name.size())); 32 | } 33 | 34 | TensorId ParseTensorName(absl::string_view name) 35 | { 36 | // Parse either a name, ^name, or name:digits. To do so, we go backwards 37 | // from the end of the string, skipping over a run of digits. If we hit a 38 | // ':' character, then we know we are in the 'name:digits' regime. 39 | // Otherwise, we see if the name starts with '^', indicating a control edge. 40 | // If we find neither ':' nor '^' characters, the output index is implicitly 41 | // 0, and the whole name string forms the first part of the tensor name. 42 | const char* base = name.data(); 43 | const char* p = base + name.size() - 1; 44 | unsigned int index = 0; 45 | unsigned int mul = 1; 46 | while (p > base && (*p >= '0' && *p <= '9')) 47 | { 48 | index += ((*p - '0') * mul); 49 | mul *= 10; 50 | p--; 51 | } 52 | TensorId id; 53 | if (p > base && *p == ':' && mul > 1) 54 | { 55 | id.first = absl::string_view(base, p - base); 56 | id.second = index; 57 | } 58 | else if (absl::StartsWith(name, "^")) 59 | { 60 | // Control edge 61 | id.first = absl::string_view(base + 1); 62 | id.second = kControlSlot; 63 | } 64 | else 65 | { 66 | id.first = name; 67 | id.second = 0; 68 | } 69 | return id; 70 | } 71 | 72 | bool IsTensorIdControl(const TensorId& tensor_id) 73 | { 74 | return tensor_id.index() == kControlSlot; 75 | } 76 | 77 | } // namespace tfdml 78 | -------------------------------------------------------------------------------- /tfdml/optimizer/tensor_proto_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/optimizer/tensor_proto_util.h" 17 | #include "tensorflow/core/framework/tensor.pb.h" 18 | 19 | namespace tfdml 20 | { 21 | int GetNumElements(const tensorflow::TensorProto& tensor) 22 | { 23 | assert(tensor.has_tensor_shape()); 24 | const tensorflow::TensorShapeProto& shape = tensor.tensor_shape(); 25 | 26 | if (shape.dim_size() == 0) 27 | { 28 | return 0; 29 | } 30 | 31 | int64_t num_elements = 1; 32 | for (int i = 0; i < shape.dim_size(); ++i) 33 | { 34 | num_elements *= shape.dim(i).size(); 35 | } 36 | 37 | return num_elements; 38 | } 39 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/optimizer/transpose_remover.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/optimizer/graph_optimizer.h" 19 | #include "tfdml/runtime_adapter/status.h" 20 | 21 | namespace tensorflow 22 | { 23 | class GraphDef; 24 | } 25 | 26 | namespace tfdml 27 | { 28 | class TransposeRemover : public GraphOptimizer 29 | { 30 | public: 31 | ~TransposeRemover() override = default; 32 | Status Optimize(const GrapplerItem& item, tensorflow::GraphDef* output) 33 | override; 34 | }; 35 | } // namespace tfdml 36 | -------------------------------------------------------------------------------- /tfdml/optimizer/utils.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/optimizer/utils.h" 17 | 18 | namespace tfdml 19 | { 20 | 21 | // Returns the data type in attribute `attr_name` of `node`. If that attribute 22 | // doesn't exist, returns DT_INVALID. 23 | tensorflow::DataType GetDataTypeFromAttr( 24 | const tensorflow::NodeDef& node, 25 | const std::string& type_attr) 26 | { 27 | if (!node.attr().count(type_attr)) 28 | { 29 | return tensorflow::DT_INVALID; 30 | } 31 | const auto& attr = node.attr().at(type_attr); 32 | if (attr.value_case() != tensorflow::AttrValue::kType) 33 | { 34 | return tensorflow::DT_INVALID; 35 | } 36 | return attr.type(); 37 | } 38 | 39 | } // end namespace tfdml 40 | -------------------------------------------------------------------------------- /tfdml/optimizer/utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow/core/framework/graph.pb.h" 17 | 18 | namespace tensorflow 19 | { 20 | class NodeDef; 21 | } 22 | 23 | namespace tfdml 24 | { 25 | 26 | // // Utilities for manipulating node name and input strings. 27 | 28 | // Returns the data type in attribute `attr_name` of `node`. If that attribute 29 | // doesn't exist, returns DT_INVALID. 30 | tensorflow::DataType GetDataTypeFromAttr( 31 | const tensorflow::NodeDef& node, 32 | const std::string& type_attr); 33 | 34 | } // end namespace tfdml -------------------------------------------------------------------------------- /tfdml/plugin/plugin_version.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Use of this source code is governed by an MIT-style 4 | license that can be found in the LICENSE file or at 5 | https://opensource.org/licenses/MIT. 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | ==============================================================================*/ 13 | 14 | #pragma once 15 | 16 | #define DML_MAJOR_VERSION 0 17 | #define DML_MINOR_VERSION 0 18 | #define DML_PATCH_VERSION 1 19 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/allocator.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "allocator.h" 17 | 18 | #include 19 | 20 | #include "absl/strings/str_cat.h" 21 | #include "absl/strings/str_format.h" 22 | 23 | namespace tfdml 24 | { 25 | 26 | std::string AllocatorStats::DebugString() const 27 | { 28 | return absl::StrFormat( 29 | "Limit: %20lld\n" 30 | "InUse: %20lld\n" 31 | "MaxInUse: %20lld\n" 32 | "NumAllocs: %20lld\n" 33 | "MaxAllocSize: %20lld\n", 34 | this->bytes_limit ? *this->bytes_limit : 0, 35 | this->bytes_in_use, 36 | this->peak_bytes_in_use, 37 | this->num_allocs, 38 | this->largest_alloc_size); 39 | } 40 | 41 | constexpr size_t Allocator::kAllocatorAlignment; 42 | 43 | Allocator::~Allocator() {} 44 | 45 | std::string AllocatorAttributes::DebugString() const 46 | { 47 | return absl::StrCat( 48 | "AllocatorAttributes(on_host=", 49 | on_host(), 50 | " nic_compatible=", 51 | nic_compatible(), 52 | " gpu_compatible=", 53 | gpu_compatible(), 54 | ")"); 55 | } 56 | 57 | SubAllocator::SubAllocator( 58 | const std::vector& alloc_visitors, 59 | const std::vector& free_visitors) 60 | : alloc_visitors_(alloc_visitors), 61 | free_visitors_(free_visitors) 62 | { 63 | } 64 | 65 | void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) 66 | { 67 | for (const auto& v : alloc_visitors_) 68 | { 69 | v(ptr, index, num_bytes); 70 | } 71 | } 72 | 73 | void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) 74 | { 75 | // Although we don't guarantee any order of visitor application, strive 76 | // to apply free visitors in reverse order of alloc visitors. 77 | for (int i = free_visitors_.size() - 1; i >= 0; --i) 78 | { 79 | free_visitors_[i](ptr, index, num_bytes); 80 | } 81 | } 82 | } // namespace tfdml 83 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/allocator_retry.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "allocator_retry.h" 17 | 18 | namespace tfdml 19 | { 20 | 21 | static void WaitForMilliseconds( 22 | std::unique_lock* mu, 23 | std::condition_variable* cv, 24 | int64_t ms) 25 | { 26 | cv->wait_for(*mu, std::chrono::milliseconds(ms)); 27 | } 28 | 29 | AllocatorRetry::AllocatorRetry() {} 30 | 31 | void* AllocatorRetry::AllocateRaw( 32 | std::function< 33 | void*(size_t alignment, size_t num_bytes, bool verbose_failure)> 34 | alloc_func, 35 | int max_millis_to_wait, 36 | size_t alignment, 37 | size_t num_bytes) 38 | { 39 | if (num_bytes == 0) 40 | { 41 | return nullptr; 42 | } 43 | uint64_t deadline_micros = 0; 44 | bool first = true; 45 | void* ptr = nullptr; 46 | while (ptr == nullptr) 47 | { 48 | ptr = alloc_func(alignment, num_bytes, false); 49 | if (ptr == nullptr) 50 | { 51 | uint64_t now = 52 | std::chrono::duration_cast( 53 | std::chrono::system_clock::now().time_since_epoch()) 54 | .count(); 55 | if (first) 56 | { 57 | deadline_micros = now + max_millis_to_wait * 1000; 58 | first = false; 59 | } 60 | if (now < deadline_micros) 61 | { 62 | std::unique_lock l(mu_); 63 | WaitForMilliseconds( 64 | &l, 65 | &memory_returned_, 66 | (deadline_micros - now) / 1000); 67 | } 68 | else 69 | { 70 | return alloc_func(alignment, num_bytes, true); 71 | } 72 | } 73 | } 74 | return ptr; 75 | } 76 | 77 | } // namespace tfdml 78 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/allocator_retry.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | namespace tfdml 23 | { 24 | 25 | // A retrying wrapper for a memory allocator. 26 | class AllocatorRetry 27 | { 28 | public: 29 | AllocatorRetry(); 30 | 31 | // Call 'alloc_func' to obtain memory. On first call, 32 | // 'verbose_failure' will be false. If return value is nullptr, 33 | // then wait up to 'max_millis_to_wait' milliseconds, retrying each 34 | // time a call to DeallocateRaw() is detected, until either a good 35 | // pointer is returned or the deadline is exhausted. If the 36 | // deadline is exhausted, try one more time with 'verbose_failure' 37 | // set to true. The value returned is either the first good pointer 38 | // obtained from 'alloc_func' or nullptr. 39 | void* AllocateRaw( 40 | std::function< 41 | void*(size_t alignment, size_t num_bytes, bool verbose_failure)> 42 | alloc_func, 43 | int max_millis_to_wait, 44 | size_t alignment, 45 | size_t bytes); 46 | 47 | // Called to notify clients that some memory was returned. 48 | void NotifyDealloc(); 49 | 50 | private: 51 | std::mutex mu_; 52 | std::condition_variable memory_returned_; 53 | }; 54 | 55 | // Implementation details below 56 | inline void AllocatorRetry::NotifyDealloc() 57 | { 58 | std::unique_lock l(mu_); 59 | memory_returned_.notify_all(); 60 | } 61 | 62 | } // namespace tfdml 63 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/attribute.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/container/inlined_vector.h" 19 | #include "absl/types/optional.h" 20 | #include "absl/types/span.h" 21 | #include "absl/types/variant.h" 22 | #include "tfdml/runtime_adapter/types.h" 23 | 24 | namespace tfdml 25 | { 26 | using AttributeValue = absl::optional, 33 | std::vector, 34 | std::vector, 35 | std::vector, 36 | std::vector>>; 37 | 38 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/determinism.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "absl/strings/string_view.h" 17 | #include "tfdml/runtime_adapter/env_var.h" 18 | #include "tfdml/runtime_adapter/macros.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | namespace 24 | { 25 | 26 | class DeterminismState 27 | { 28 | public: 29 | explicit DeterminismState(absl::string_view env_var) : env_var_(env_var) {} 30 | bool Required() 31 | { 32 | if (state_ == Value::NOT_SET) 33 | { 34 | bool env_var_set = false; 35 | TF_CHECK_OK(ReadBoolFromEnvVar(env_var_, false, &env_var_set)); 36 | state_ = env_var_set ? Value::ENABLED : Value::DISABLED; 37 | } 38 | 39 | return state_ == Value::ENABLED; 40 | } 41 | 42 | private: 43 | absl::string_view env_var_; 44 | enum class Value 45 | { 46 | DISABLED, 47 | ENABLED, 48 | NOT_SET 49 | }; 50 | Value state_ = Value::NOT_SET; 51 | }; 52 | 53 | } // namespace 54 | 55 | DeterminismState OpDeterminismState = DeterminismState("TF_DETERMINISTIC_OPS"); 56 | 57 | bool OpDeterminismRequired() { return OpDeterminismState.Required(); } 58 | 59 | } // namespace tfdml 60 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/determinism.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | namespace tfdml 19 | { 20 | 21 | bool OpDeterminismRequired(); 22 | 23 | } // namespace tfdml 24 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/device.cc: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/device.h" 17 | 18 | namespace tfdml 19 | { 20 | 21 | } // namespace tfdml 22 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/device.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/types/optional.h" 19 | #include "absl/types/span.h" 20 | #include "tfdml/runtime_adapter/status.h" 21 | 22 | namespace tfdml 23 | { 24 | class Tensor; 25 | 26 | class Device 27 | { 28 | public: 29 | virtual ~Device() = default; 30 | 31 | virtual Status CopyCPUTensorToDevice( 32 | const Tensor* cpu_tensor, 33 | Tensor* device_tensor) = 0; 34 | 35 | virtual Status CopyDeviceTensorToCPU( 36 | const Tensor* device_tensor, 37 | Tensor* cpu_tensor) = 0; 38 | 39 | virtual Status CopyDeviceTensorsToCPU( 40 | absl::Span device_tensors, 41 | absl::Span cpu_tensors) = 0; 42 | 43 | virtual void CopyTensorInSameDevice( 44 | const Tensor* input_tensor, 45 | Tensor* output_tensor) = 0; 46 | 47 | virtual absl::optional TryLogKernelComputeStart( 48 | const absl::string_view type, 49 | const absl::string_view name) const = 0; 50 | 51 | virtual void LogKernelComputeEnd(uint32_t event_id) const = 0; 52 | }; 53 | } // namespace tfdml 54 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/env.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | #include "tfdml/runtime_adapter/status.h" 21 | 22 | namespace tfdml 23 | { 24 | namespace env 25 | { 26 | std::string FormatLibraryFileName( 27 | const std::string& name, 28 | const std::string& version); 29 | 30 | Status GetSymbolFromLibrary( 31 | void* handle, 32 | const char* symbol_name, 33 | void** symbol); 34 | 35 | Status LoadDynamicLibrary(const char* library_filename, void** handle); 36 | } // namespace env 37 | } // namespace tfdml 38 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/env_var.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/status.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | // Returns a boolean into "value" from the environmental variable 24 | // "env_var_name". If it is unset, the default value is used. A string "0" or a 25 | // case insensitive "false" is interpreted as false. A string "1" or a case 26 | // insensitive "true" is interpreted as true. Otherwise, an error status is 27 | // returned. 28 | Status ReadBoolFromEnvVar( 29 | absl::string_view env_var_name, 30 | bool default_val, 31 | bool* value); 32 | 33 | // Returns an int64 into "value" from the environmental variable "env_var_name". 34 | // If it is unset, the default value is used. 35 | // If the string cannot be parsed into int64, an error status is returned. 36 | Status ReadInt64FromEnvVar( 37 | absl::string_view env_var_name, 38 | int64_t default_val, 39 | int64_t* value); 40 | 41 | // Returns a string into "value" from the environmental variable "env_var_name". 42 | // If it is unset, the default value is used. 43 | Status ReadStringFromEnvVar( 44 | absl::string_view env_var_name, 45 | absl::string_view default_val, 46 | std::string* value); 47 | 48 | } // namespace tfdml 49 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/guarded_philox_random.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/guarded_philox_random.h" 17 | #include "tfdml/runtime_adapter/determinism.h" 18 | #include "tfdml/runtime_adapter/macros.h" 19 | #include "tfdml/runtime_adapter/op_kernel_construction.h" 20 | #include "tfdml/runtime_adapter/status.h" 21 | 22 | namespace tfdml 23 | { 24 | 25 | static std::mt19937_64* InitRngWithRandomSeed() 26 | { 27 | std::random_device device("/dev/urandom"); 28 | return new std::mt19937_64(device()); 29 | } 30 | 31 | static uint64_t New64() 32 | { 33 | static std::mt19937_64* rng = InitRngWithRandomSeed(); 34 | static std::mutex mu; 35 | std::unique_lock l(mu); 36 | return (*rng)(); 37 | } 38 | 39 | Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) 40 | { 41 | CHECK(!initialized_); 42 | // Grab seed Attrs. 43 | int64_t seed, seed2; 44 | auto status = context->GetAttr("seed", &seed); 45 | if (!status.ok()) return status; 46 | status = context->GetAttr("seed2", &seed2); 47 | if (!status.ok()) return status; 48 | if (seed == 0 && seed2 == 0 && OpDeterminismRequired()) 49 | { 50 | return errors::InvalidArgument("When determinism is enabled, random " 51 | "ops must have a seed specified."); 52 | } 53 | 54 | // Initialize with the given seeds 55 | if (seed == 0 && seed2 == 0) 56 | { 57 | // If both seeds are unspecified, use completely random seeds. 58 | seed = New64(); 59 | seed2 = New64(); 60 | } 61 | std::unique_lock lock(mu_); 62 | generator_ = random::PhiloxRandom(seed, seed2); 63 | initialized_ = true; 64 | return Status::OK(); 65 | } 66 | 67 | random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64_t samples) 68 | { 69 | CHECK(initialized_); 70 | std::unique_lock lock(mu_); 71 | auto local = generator_; 72 | generator_.Skip(samples); 73 | return local; 74 | } 75 | 76 | } // namespace tfdml 77 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/guarded_philox_random.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/philox_random.h" 19 | #include "tfdml/runtime_adapter/status.h" 20 | #include 21 | 22 | namespace tfdml 23 | { 24 | 25 | struct OpKernelConstruction; 26 | 27 | // A thread safe wrapper around a Philox generator. Example usage: 28 | // 29 | // GuardedRandomPhilox generator; 30 | // generator.Init(context); 31 | // 32 | // // In thread safe code 33 | // const int samples = ...; 34 | // auto local_generator = generator.ReserveSamples128(samples); 35 | // for (int i = 0; i < samples; i++) 36 | // Array sample = local_generator(); 37 | // // Use sample 38 | // } 39 | // 40 | class GuardedPhiloxRandom 41 | { 42 | public: 43 | // Must call Init to finish initialization 44 | GuardedPhiloxRandom() : initialized_(false) {} 45 | 46 | // Initialize the generator from attributes "seed" and "seed2". 47 | // If both seeds are unspecified, use random seeds. 48 | // Must be called exactly once. 49 | Status Init(OpKernelConstruction* context); 50 | 51 | // Reserve a certain number of 128-bit samples. 52 | // This function is thread safe. The returned generator is valid for the 53 | // given number of samples, and can be used without a lock. 54 | random::PhiloxRandom ReserveSamples128(int64_t samples); 55 | 56 | // Reserve enough random samples in the generator for the given output 57 | // count. 58 | random::PhiloxRandom ReserveRandomOutputs( 59 | int64_t output_count, 60 | int multiplier) 61 | { 62 | int64_t conservative_sample_count = output_count * multiplier; 63 | return ReserveSamples128(conservative_sample_count); 64 | } 65 | 66 | private: 67 | std::mutex mu_; 68 | random::PhiloxRandom generator_; 69 | bool initialized_; 70 | 71 | GuardedPhiloxRandom(const GuardedPhiloxRandom&) = delete; 72 | void operator=(const GuardedPhiloxRandom&) = delete; 73 | }; 74 | 75 | } // namespace tfdml 76 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/mirror_pad_mode.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/mirror_pad_mode.h" 17 | 18 | namespace tfdml 19 | { 20 | 21 | Status GetMirrorPaddingFromString( 22 | absl::string_view str_value, 23 | MirrorPadMode* value) 24 | { 25 | if (str_value == "REFLECT") 26 | { 27 | *value = REFLECT; 28 | } 29 | else if (str_value == "SYMMETRIC") 30 | { 31 | *value = SYMMETRIC; 32 | } 33 | else 34 | { 35 | return errors::NotFound(str_value, " is not an allowed padding type"); 36 | } 37 | return Status::OK(); 38 | } 39 | 40 | std::string GetMirrorPadModeAttrString() 41 | { 42 | return "mode: {'REFLECT', 'SYMMETRIC'}"; 43 | } 44 | 45 | } // end namespace tfdml 46 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/mirror_pad_mode.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/status.h" 19 | #include 20 | 21 | namespace tfdml 22 | { 23 | 24 | // REFLECT: Border elements are not mirrored to padded regions. 25 | // SYMMETRIC: Border elements are mirrored to padded regions. 26 | // 27 | // For example, if two elements are padded to the right of an array [1, 2, 3], 28 | // then the result is [1, 2, 3, 2, 1] for REFLECT mode, and is [1, 2, 3, 3, 2] 29 | // for SYMMETRIC mode. 30 | enum MirrorPadMode 31 | { 32 | REFLECT = 1, 33 | SYMMETRIC = 2, 34 | }; 35 | 36 | // Return the string containing the list of valid padding modes, that can be 37 | // used as an Attr() in REGISTER_OP. 38 | std::string GetMirrorPadModeAttrString(); 39 | 40 | // Sets mirror pad mode based on the given string padding value. 41 | Status GetMirrorPaddingFromString( 42 | absl::string_view str_value, 43 | MirrorPadMode* value); 44 | 45 | } // end namespace tfdml 46 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/numbers.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "numbers.h" 17 | 18 | #include 19 | 20 | #include "tfdml/runtime_adapter/macros.h" 21 | 22 | namespace tfdml 23 | { 24 | namespace strings 25 | { 26 | std::string HumanReadableNumBytes(int64_t num_bytes) 27 | { 28 | if (num_bytes == std::numeric_limits::min()) 29 | { 30 | // Special case for number with not representable negation. 31 | return "-8E"; 32 | } 33 | 34 | const char* neg_str = (num_bytes < 0) ? "-" : ""; 35 | if (num_bytes < 0) 36 | { 37 | num_bytes = -num_bytes; 38 | } 39 | 40 | // Special case for bytes. 41 | if (num_bytes < 1024) 42 | { 43 | // No fractions for bytes. 44 | char buf[8]; // Longest possible string is '-XXXXB' 45 | snprintf( 46 | buf, 47 | sizeof(buf), 48 | "%s%lldB", 49 | neg_str, 50 | static_cast(num_bytes)); 51 | return std::string(buf); 52 | } 53 | 54 | static const char units[] = "KMGTPE"; // int64 only goes up to E. 55 | const char* unit = units; 56 | while (num_bytes >= static_cast(1024) * 1024) 57 | { 58 | num_bytes /= 1024; 59 | ++unit; 60 | CHECK(unit < units + TF_ARRAYSIZE(units)); 61 | } 62 | 63 | // We use SI prefixes. 64 | char buf[16]; 65 | snprintf( 66 | buf, 67 | sizeof(buf), 68 | ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"), 69 | neg_str, 70 | num_bytes / 1024.0, 71 | *unit); 72 | return std::string(buf); 73 | } 74 | } // namespace strings 75 | } // namespace tfdml 76 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/numbers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | namespace tfdml 21 | { 22 | namespace strings 23 | { 24 | std::string HumanReadableNumBytes(int64_t num_bytes); 25 | 26 | } // namespace strings 27 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/op_defs_dml.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | namespace tfdml 19 | { 20 | namespace ops 21 | { 22 | // Placeholder for DirectML-specific op definitions. 23 | } // namespace ops 24 | } // namespace tfdml 25 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/op_kernel.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | #include "absl/cleanup/cleanup.h" 21 | #include "absl/strings/string_view.h" 22 | #include "absl/types/span.h" 23 | #include "attribute.h" 24 | #include "node_def.h" 25 | #include "tfdml/runtime_adapter/device.h" 26 | #include "tfdml/runtime_adapter/op_kernel_context.h" 27 | #include "types.h" 28 | 29 | namespace tfdml 30 | { 31 | class OpKernel 32 | { 33 | public: 34 | OpKernel(std::shared_ptr node_def) 35 | : node_def_(std::move(node_def)) 36 | { 37 | } 38 | 39 | virtual ~OpKernel() = default; 40 | 41 | std::shared_ptr node_def() const { return node_def_; } 42 | 43 | const absl::string_view type_string() const 44 | { 45 | return node_def_->GetOpTypeName(); 46 | } 47 | const absl::string_view name() const { return node_def_->GetOpName(); } 48 | 49 | MemoryType input_memory_type(int index) const 50 | { 51 | return node_def_->GetInputTensorMemoryType(index); 52 | } 53 | 54 | MemoryType output_memory_type(int index) const 55 | { 56 | return node_def_->GetOutputTensorMemoryType(index); 57 | } 58 | 59 | void Compute(OpKernelContext* ctx) 60 | { 61 | auto profiler_event_id = ctx->device()->TryLogKernelComputeStart( 62 | ctx->op_kernel().type_string(), 63 | ctx->op_kernel().name()); 64 | 65 | auto profiler_cleanup = absl::MakeCleanup( 66 | [ctx, &profiler_event_id] 67 | { 68 | if (profiler_event_id) 69 | { 70 | ctx->device()->LogKernelComputeEnd(*profiler_event_id); 71 | } 72 | }); 73 | 74 | ComputeImpl(ctx); 75 | } 76 | 77 | private: 78 | virtual void ComputeImpl(OpKernelContext* raw_ctx) = 0; 79 | 80 | std::shared_ptr node_def_; 81 | }; 82 | } // namespace tfdml 83 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/padding.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | // This file contains helper routines to deal with padding in various ops and 19 | // kernels. 20 | 21 | #include "tfdml/runtime_adapter/status.h" 22 | #include "tfdml/runtime_adapter/tensor_format.h" 23 | #include 24 | #include 25 | 26 | namespace tfdml 27 | { 28 | 29 | // Padding: the padding we apply to the input tensor along the rows and columns 30 | // dimensions. This is usually used to make sure that the spatial dimensions do 31 | // not shrink when we progress with convolutions. Three types of padding are 32 | // supported: 33 | // VALID: No padding is carried out. 34 | // SAME: The pad value is computed so that the output will have the same 35 | // dimensions as the input. 36 | // EXPLICIT: The user specifies the pad values in the explicit_paddings 37 | // attribute. 38 | // The padded area is typically zero-filled. For pooling ops, the padded area is 39 | // instead ignored. For max pool, this is equivalent to padding with -infinity. 40 | enum Padding 41 | { 42 | VALID = 1, // No padding. 43 | SAME = 2, // Input and output layers have the same size. 44 | EXPLICIT = 3, // Padding is explicitly specified 45 | }; 46 | 47 | // Returns an error if the padding attributes are invalid. 48 | Status CheckValidPadding( 49 | Padding padding_type, 50 | absl::Span explicit_paddings, 51 | int num_dims, 52 | TensorFormat data_format); 53 | 54 | // Return the string containing the list of valid padding types, that can be 55 | // used as an Attr() in REGISTER_OP. 56 | std::string GetPaddingAttrString(); 57 | 58 | // Like GetPaddingAttrString(), but also includes EXPLICIT. 59 | std::string GetPaddingAttrStringWithExplicit(); 60 | 61 | std::string GetExplicitPaddingsAttrString(); 62 | 63 | // Sets padding value based on the given string padding value. 64 | Status GetPaddingFromString(absl::string_view str_value, Padding* value); 65 | 66 | } // namespace tfdml 67 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/path.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/path.h" 17 | #include "absl/strings/str_cat.h" 18 | #include "absl/strings/string_view.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | static bool IsAbsolutePath(absl::string_view path) 24 | { 25 | return !path.empty() && path[0] == '/'; 26 | } 27 | 28 | // For an array of paths of length count, append them all together, 29 | // ensuring that the proper path separators are inserted between them. 30 | std::string JoinPathImpl(std::initializer_list paths) 31 | { 32 | std::string result; 33 | 34 | for (absl::string_view path : paths) 35 | { 36 | if (path.empty()) continue; 37 | 38 | if (result.empty()) 39 | { 40 | result = std::string(path); 41 | continue; 42 | } 43 | 44 | if (result[result.size() - 1] == '/') 45 | { 46 | if (IsAbsolutePath(path)) 47 | { 48 | absl::StrAppend(&result, path.substr(1)); 49 | } 50 | else 51 | { 52 | absl::StrAppend(&result, path); 53 | } 54 | } 55 | else 56 | { 57 | if (IsAbsolutePath(path)) 58 | { 59 | absl::StrAppend(&result, path); 60 | } 61 | else 62 | { 63 | absl::StrAppend(&result, "/", path); 64 | } 65 | } 66 | } 67 | 68 | return result; 69 | } 70 | 71 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/path.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/strings/string_view.h" 19 | 20 | namespace tfdml 21 | { 22 | std::string JoinPathImpl(std::initializer_list paths); 23 | 24 | // Join multiple paths together. 25 | // JoinPath unconditionally joins all paths together. For example: 26 | // 27 | // Arguments | JoinPath 28 | // ---------------------------+--------------------- 29 | // '/foo', 'bar' | /foo/bar 30 | // '/foo/', 'bar' | /foo/bar 31 | // '/foo', '/bar' | /foo/bar 32 | // '/foo', '/bar', '/baz' | /foo/bar/baz 33 | // 34 | // All paths will be treated as relative paths, regardless of whether or not 35 | // they start with a leading '/'. That is, all paths will be concatenated 36 | // together, with the appropriate path separator inserted in between. 37 | // Arguments must be convertible to absl::string_view. 38 | // 39 | // Usage: 40 | // string path = file::JoinPath("/var/log", dirname, filename); 41 | // string path = file::JoinPath(FLAGS_test_srcdir, filename); 42 | template 43 | inline std::string JoinPath(const T&... args) 44 | { 45 | return JoinPathImpl({args...}); 46 | } 47 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/random_ops_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/philox_random.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | using random::PhiloxRandom; 24 | 25 | // The following 2 functions use the contract "lower 32 bits for the first 26 | // uint32_t, higher 32 bits for the second". Note that this is endian-neutral, 27 | // unlike a direct memory copy `memcpy(output, &input, 8)`. 28 | void Uint64ToUint32s(uint64_t input, uint32_t* output1, uint32_t* output2) 29 | { 30 | *output1 = static_cast(input); 31 | *output2 = static_cast(input >> 32); 32 | } 33 | 34 | uint64_t Uint32sToUint64(uint32_t input1, uint32_t input2) 35 | { 36 | auto u64_1 = static_cast(input1); 37 | auto u64_2 = static_cast(input2); 38 | return u64_1 | (u64_2 << 32); 39 | } 40 | 41 | PhiloxRandom::ResultType GetCounterFromMem(uint64_t const* ptr) 42 | { 43 | PhiloxRandom::ResultType counter; 44 | Uint64ToUint32s(ptr[0], &counter[0], &counter[1]); 45 | Uint64ToUint32s(ptr[1], &counter[2], &counter[3]); 46 | return counter; 47 | } 48 | 49 | void WriteCounterToMem(PhiloxRandom::ResultType const& counter, uint64_t* ptr) 50 | { 51 | ptr[0] = Uint32sToUint64(counter[0], counter[1]); 52 | ptr[1] = Uint32sToUint64(counter[2], counter[3]); 53 | } 54 | 55 | PhiloxRandom::Key GetKeyFromMem(uint64_t const* ptr) 56 | { 57 | PhiloxRandom::Key key; 58 | Uint64ToUint32s(ptr[0], &key[0], &key[1]); 59 | return key; 60 | } 61 | 62 | void WriteKeyToMem(PhiloxRandom::Key const& key, uint64_t* ptr) 63 | { 64 | *ptr = Uint32sToUint64(key[0], key[1]); 65 | } 66 | 67 | PhiloxRandom GetPhiloxRandomFromCounterKeyMem( 68 | uint64_t const* counter_ptr, 69 | uint64_t const* key_ptr) 70 | { 71 | return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr)); 72 | } 73 | 74 | } // end namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/rng_alg.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | namespace tfdml 19 | { 20 | 21 | enum Algorithm 22 | { 23 | // The Philox algorithm, as described in paper 24 | // ['Parallel Random Numbers: As Easy as 1, 2, 3'] 25 | // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 26 | RNG_ALG_PHILOX = 1, 27 | // The ThreeFry algorithm, as described in paper 28 | // ['Parallel Random Numbers: As Easy as 1, 2, 3'] 29 | // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 30 | RNG_ALG_THREEFRY = 2, 31 | // An algorithm auto-selected by the system according to device type. 32 | RNG_ALG_AUTO_SELECT = 3 33 | }; 34 | 35 | static constexpr int RNG_KEY_SIZE = 1; 36 | static constexpr int RNG_MAX_COUNTER_SIZE = 2; 37 | // Gets the counter size (in unit of uint64) for a counter-based RNG 38 | // algorithm `alg`. In the case of RNG_ALG_AUTO_SELECT, gets the minimal 39 | // counter size among all algorithms. 40 | inline int GetCounterSize(Algorithm alg) 41 | { 42 | if (alg == RNG_ALG_PHILOX) 43 | { 44 | return 2; 45 | } 46 | else if (alg == RNG_ALG_THREEFRY) 47 | { 48 | return 1; 49 | } 50 | return 1; 51 | } 52 | 53 | } // end namespace tfdml 54 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/stateless_random_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/philox_random.h" 17 | #include "tfdml/runtime_adapter/status.h" 18 | #include "tfdml/runtime_adapter/tensor.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | Status GenerateKey( 24 | const Tensor& seed, 25 | random::PhiloxRandom::Key* out_key, 26 | random::PhiloxRandom::ResultType* out_counter) 27 | { 28 | // Grab the two seeds 29 | uint64_t seed0; 30 | uint64_t seed1; 31 | if (seed.dtype() == TF_INT32) 32 | { 33 | const auto seed_vals = seed.base(); 34 | seed0 = seed_vals[0]; 35 | seed1 = seed_vals[1]; 36 | } 37 | else if (seed.dtype() == TF_INT64) 38 | { 39 | const auto seed_vals = seed.base(); 40 | seed0 = seed_vals[0]; 41 | seed1 = seed_vals[1]; 42 | } 43 | else 44 | { 45 | return errors::InvalidArgument( 46 | "Invalid seed type: ", 47 | DataTypeString(seed.dtype())); 48 | } 49 | 50 | // Scramble the seeds so that the user doesn't need to worry about which 51 | // part of the seed needs to be strong. 52 | (*out_key)[0] = 0x3ec8f720; 53 | (*out_key)[1] = 0x02461e29; 54 | (*out_counter)[0] = static_cast(seed0); 55 | (*out_counter)[1] = static_cast(seed0 >> 32); 56 | (*out_counter)[2] = static_cast(seed1); 57 | (*out_counter)[3] = static_cast(seed1 >> 32); 58 | const auto mix = random::PhiloxRandom(*out_counter, *out_key)(); 59 | (*out_key)[0] = mix[0]; 60 | (*out_key)[1] = mix[1]; 61 | (*out_counter)[0] = (*out_counter)[1] = 0; 62 | (*out_counter)[2] = mix[2]; 63 | (*out_counter)[3] = mix[3]; 64 | return Status::OK(); 65 | } 66 | 67 | } // namespace tfdml 68 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/stateless_random_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/philox_random.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | class Tensor; 24 | 25 | // Generates a key and counter that can be used to seed a PhiloxRandom, 26 | // generator, based on the seed value in `seed_t`. 27 | // 28 | // REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}. 29 | // `out_key` and `out_counter` must be non-null. 30 | Status GenerateKey( 31 | const Tensor& seed_t, 32 | random::PhiloxRandom::Key* out_key, 33 | random::PhiloxRandom::ResultType* out_counter); 34 | 35 | } // namespace tfdml 36 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/status.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "status.h" 17 | 18 | #include "tensorflow/c/tf_status.h" 19 | 20 | namespace tfdml 21 | { 22 | 23 | void delete_status(TF_Status* s) 24 | { 25 | if (s != nullptr) 26 | { 27 | TF_DeleteStatus(s); 28 | } 29 | } 30 | 31 | Status::Status() : Status(TF_OK, "") {} 32 | 33 | Status::Status(TF_Code code, const char* message) 34 | : safe_status_(TF_NewStatus(), delete_status) 35 | { 36 | TF_SetStatus(safe_status_.get(), code, message); 37 | } 38 | 39 | Status::Status(TF_Code code, const std::string& message) 40 | : Status(code, message.c_str()) 41 | { 42 | } 43 | 44 | Status::Status(TF_Code code, std::string&& message) 45 | : Status(code, message.c_str()) 46 | { 47 | } 48 | 49 | TF_Code Status::code() const { return TF_GetCode(safe_status_.get()); } 50 | bool Status::ok() const { return TF_GetCode(safe_status_.get()) == TF_OK; } 51 | 52 | TF_Status* Status::raw() const { return safe_status_.get(); } 53 | 54 | const char* Status::error_message() const 55 | { 56 | return TF_Message(safe_status_.get()); 57 | } 58 | 59 | Status Status::OK() { return Status(); } 60 | 61 | void Status::Update(const Status& new_status) 62 | { 63 | if (ok()) 64 | { 65 | *this = new_status; 66 | } 67 | } 68 | 69 | } // namespace tfdml 70 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/stream.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | struct SP_Stream_st 19 | { 20 | explicit SP_Stream_st(void* stream_h) : stream_handle(stream_h) {} 21 | void* stream_handle; 22 | }; 23 | 24 | struct SP_Event_st 25 | { 26 | explicit SP_Event_st(void* event_h) : event_handle(event_h) {} 27 | void* event_handle; 28 | }; 29 | 30 | struct SP_Timer_st 31 | { 32 | explicit SP_Timer_st(int id) : timer_handle(id) {} 33 | int timer_handle; 34 | }; 35 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/tensor_shape_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/status.h" 19 | #include "tfdml/runtime_adapter/tensor_shape.h" 20 | 21 | namespace tfdml 22 | { 23 | class Tensor; 24 | 25 | class TensorShapeUtils 26 | { 27 | public: 28 | // Makes a shape from a tensor. The datatype of the tensor must be int32 or 29 | // int64 30 | static TensorShape MakeShape(const Tensor& tensor); 31 | static Status MakeShape(const Tensor& tensor, TensorShape* out); 32 | 33 | static bool IsScalar(const TensorShape& shape); 34 | static bool IsVector(const TensorShape& shape); 35 | static bool IsVectorOrHigher(const TensorShape& shape); 36 | static bool IsMatrix(const TensorShape& shape); 37 | static bool IsMatrixOrHigher(const TensorShape& shape); 38 | static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); 39 | }; 40 | 41 | int64_t MultiplyWithoutOverflow(int64_t x, int64_t y); 42 | 43 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/tensor_types.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 17 | 18 | namespace tfdml 19 | { 20 | 21 | // Helper to define Tensor types given that the scalar is of type T. 22 | template 23 | struct TTypes 24 | { 25 | // Rank- tensor of scalar type T. 26 | typedef Eigen::TensorMap< 27 | Eigen::Tensor, 28 | Eigen::Aligned> 29 | Tensor; 30 | typedef Eigen::TensorMap< 31 | Eigen::Tensor, 32 | Eigen::Aligned> 33 | ConstTensor; 34 | 35 | // Rank-2 tensor (matrix) of scalar type T. 36 | typedef Eigen::TensorMap< 37 | Eigen::Tensor, 38 | Eigen::Aligned> 39 | Matrix; 40 | typedef Eigen::TensorMap< 41 | Eigen::Tensor, 42 | Eigen::Aligned> 43 | ConstMatrix; 44 | }; 45 | 46 | } // namespace tfdml 47 | -------------------------------------------------------------------------------- /tfdml/runtime_adapter/training_op_helpers.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tfdml/runtime_adapter/training_op_helpers.h" 17 | #include "tfdml/runtime_adapter/device.h" 18 | #include "tfdml/runtime_adapter/op_kernel_context.h" 19 | 20 | namespace tfdml 21 | { 22 | // This is for use with ResourceVariables to ensure *tensor has a 23 | // reference count of 1 before you update it. 24 | // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. 25 | Status PrepareToUpdateVariable( 26 | OpKernelContext* ctx, 27 | Tensor* tensor, 28 | bool copy_on_read_mode) 29 | { 30 | if (copy_on_read_mode) 31 | { 32 | // Tensor's buffer is in use by some read, so we need to copy before 33 | // updating. 34 | Tensor tmp; 35 | if (tensor->dtype() == TF_VARIANT) 36 | { 37 | LogFatal("TF_VARIANT is not supported yet."); 38 | } 39 | else 40 | { 41 | constexpr bool on_host = false; 42 | TF_RETURN_IF_ERROR(ctx->allocate_temp( 43 | tensor->dtype(), 44 | tensor->shape(), 45 | &tmp, 46 | on_host)); 47 | 48 | ctx->device()->CopyTensorInSameDevice(tensor, &tmp); 49 | } 50 | *tensor = tmp; 51 | } 52 | return Status::OK(); 53 | } 54 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/training_op_helpers.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "tfdml/runtime_adapter/status.h" 19 | 20 | namespace tfdml 21 | { 22 | class OpKernelContext; 23 | class Tensor; 24 | 25 | // This is for use with ResourceVariables to ensure *tensor has a 26 | // reference count of 1 before you update it. 27 | // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. 28 | Status PrepareToUpdateVariable( 29 | OpKernelContext* ctx, 30 | Tensor* tensor, 31 | bool copy_on_read_mode); 32 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/variable_lock.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #include "absl/types/span.h" 19 | 20 | struct TF_VariableInputLockHolder; 21 | 22 | namespace tfdml 23 | { 24 | 25 | class OpKernelContext; 26 | 27 | class VariableLock 28 | { 29 | public: 30 | VariableLock(OpKernelContext* ctx); 31 | VariableLock( 32 | OpKernelContext* ctx, 33 | bool exclusive_lock, 34 | absl::Span input_indices); 35 | VariableLock(VariableLock&& other); 36 | ~VariableLock(); 37 | void LockShared(absl::Span input_indices); 38 | void LockUnique(absl::Span input_indices); 39 | void Unlock(); 40 | 41 | private: 42 | TF_VariableInputLockHolder* lock_holder_ = nullptr; 43 | OpKernelContext* ctx_; 44 | }; 45 | 46 | } // namespace tfdml -------------------------------------------------------------------------------- /tfdml/runtime_adapter/wide_char.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 Google Inc. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #pragma once 17 | 18 | #if _WIN32 19 | #include 20 | 21 | #include 22 | 23 | namespace tfdml 24 | { 25 | 26 | inline std::wstring Utf8ToWideChar(const std::string& utf8str) 27 | { 28 | int size_required = MultiByteToWideChar( 29 | CP_UTF8, 30 | 0, 31 | utf8str.c_str(), 32 | (int)utf8str.size(), 33 | NULL, 34 | 0); 35 | std::wstring ws_translated_str(size_required, 0); 36 | MultiByteToWideChar( 37 | CP_UTF8, 38 | 0, 39 | utf8str.c_str(), 40 | (int)utf8str.size(), 41 | &ws_translated_str[0], 42 | size_required); 43 | return ws_translated_str; 44 | } 45 | 46 | inline std::string WideCharToUtf8(const std::wstring& wstr) 47 | { 48 | if (wstr.empty()) return std::string(); 49 | int size_required = WideCharToMultiByte( 50 | CP_UTF8, 51 | 0, 52 | wstr.c_str(), 53 | (int)wstr.size(), 54 | NULL, 55 | 0, 56 | NULL, 57 | NULL); 58 | std::string utf8_translated_str(size_required, 0); 59 | WideCharToMultiByte( 60 | CP_UTF8, 61 | 0, 62 | wstr.c_str(), 63 | (int)wstr.size(), 64 | &utf8_translated_str[0], 65 | size_required, 66 | NULL, 67 | NULL); 68 | return utf8_translated_str; 69 | } 70 | 71 | } // namespace tfdml 72 | #endif 73 | -------------------------------------------------------------------------------- /tfdml/tfdml.natvis: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {{ Count={storage_.metadata_.value >> 1} }} 5 | 6 | storage_.metadata_.value >> 1 7 | 8 | storage_.metadata_.value >> 1 9 | (($T1 *)(& storage_.data_.inlined.inlined_data[0])) 10 | 11 | 12 | storage_.metadata_.value >> 1 13 | & storage_.data_.allocated.allocated_data[0] 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /tfdml/wheel/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README 2 | recursive-include * *.py 3 | recursive-include * *.pyd 4 | recursive-include * *.pd 5 | recursive-include * *.so 6 | recursive-include * *.so.[0-9] 7 | recursive-include * *.dylib 8 | recursive-include * *.dll 9 | recursive-include * *.lib 10 | recursive-include * *.csv 11 | recursive-include * *.h 12 | recursive-include * *.hpp 13 | recursive-include * *.txt -------------------------------------------------------------------------------- /tfdml/wheel/README: -------------------------------------------------------------------------------- 1 | tfdml 2 | -------------------------------------------------------------------------------- /tfdml/wheel/template_init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorflow-directml-plugin/6c3e918359b70fb4f0de9d703eb8aa56a8cd3592/tfdml/wheel/template_init.py --------------------------------------------------------------------------------