├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── README.md ├── cmake ├── absl.cmake ├── glog.cmake ├── oneflow.cmake ├── openvino.cmake ├── protobuf.cmake ├── pybind11.cmake ├── python.cmake ├── tensorflow-xla.cmake ├── tensorrt.cmake ├── third_party.cmake └── util.cmake ├── dev-requirements.txt ├── oneflow_xrt ├── CMakeLists.txt ├── api │ ├── api_internal.cpp │ ├── api_internal.h │ ├── api_serving.cpp │ └── api_serving.h ├── common │ ├── device.cpp │ ├── device.h │ ├── env.h │ ├── registry.h │ ├── shape_util.h │ └── typedef.h ├── compiler │ ├── compilation_cache.cpp │ ├── compilation_cache.h │ ├── executable.h │ ├── graph_compiler.h │ ├── kernel │ │ ├── op_context.h │ │ ├── op_kernel.h │ │ ├── op_kernel_registry.h │ │ └── op_kernel_registry_id.h │ ├── openvino │ │ ├── inference_engine_data_desc.h │ │ ├── ngraph_shape.h │ │ ├── openvino_executable.cpp │ │ ├── openvino_executable.h │ │ ├── openvino_graph_compiler.cpp │ │ ├── openvino_graph_compiler.h │ │ └── ops │ │ │ ├── add_op.cpp │ │ │ ├── argument_op.cpp │ │ │ ├── batch_norm_op.cpp │ │ │ ├── concat_op.cpp │ │ │ ├── convolution_op.cpp │ │ │ ├── leaky_relu_op.cpp │ │ │ ├── matmul_op.cpp │ │ │ ├── op_context.cpp │ │ │ ├── op_context.h │ │ │ ├── op_kernel.h │ │ │ ├── relu_op.cpp │ │ │ ├── reshape_op.cpp │ │ │ └── upsample_op.cpp │ ├── parameter.h │ ├── passes │ │ ├── build_subgraph_pass.cpp │ │ ├── build_subgraph_pass.h │ │ ├── cluster.cpp │ │ ├── cluster.h │ │ ├── mark_cluster_id_pass.cpp │ │ ├── mark_cluster_id_pass.h │ │ ├── options.h │ │ ├── rebuild_job_pass.cpp │ │ ├── shape_inference_context.h │ │ ├── shape_inference_pass.cpp │ │ ├── trainable_propagation_pass.cpp │ │ └── trainable_propagation_pass.h │ ├── tensorrt │ │ ├── README.md │ │ ├── common.h │ │ ├── ops │ │ │ ├── activation_op.cpp │ │ │ ├── argument_op.cpp │ │ │ ├── batch_norm_op.cpp │ │ │ ├── bias_add_op.cpp │ │ │ ├── broadcast_binary_grad_ops.cpp │ │ │ ├── broadcast_binary_ops.cpp │ │ │ ├── broadcast_like_op.cpp │ │ │ ├── cast_op.cpp │ │ │ ├── concat_op.cpp │ │ │ ├── convolution_grad_op.cpp │ │ │ ├── convolution_op.cpp │ │ │ ├── deconvolution_op.cpp │ │ │ ├── element_wise_op.cpp │ │ │ ├── expand_dims.cpp │ │ │ ├── identity_op.cpp │ │ │ ├── matmul_op.cpp │ │ │ ├── multiply_op.cpp │ │ │ ├── narrow.cpp │ │ │ ├── op_context.cpp │ │ │ ├── op_context.h │ │ │ ├── op_kernel.h │ │ │ ├── pad_grad_op.cpp │ │ │ ├── pad_op.cpp │ │ │ ├── pooling_op.cpp │ │ │ ├── prelu_grad_op.cpp │ │ │ ├── reduce_op.cpp │ │ │ ├── reshape_op.cpp │ │ │ ├── scalar_binary_ops.cpp │ │ │ ├── scalar_pow_grad_op.cpp │ │ │ ├── softmax_op.cpp │ │ │ ├── topk_op.cpp │ │ │ ├── transpose_op.cpp │ │ │ ├── unary_grad_ops.cpp │ │ │ ├── unary_ops.cpp │ │ │ └── upsample_op.cpp │ │ ├── plugin │ │ │ ├── README.md │ │ │ ├── broadcast_like_plugin.cpp │ │ │ └── broadcast_like_plugin.h │ │ ├── trt_builder.cpp │ │ ├── trt_builder.h │ │ ├── trt_executable.cpp │ │ ├── trt_executable.h │ │ ├── trt_graph_compiler.cpp │ │ ├── trt_graph_compiler.h │ │ ├── trt_helpers.cpp │ │ ├── trt_helpers.h │ │ ├── trt_int8_calibrator.cpp │ │ ├── trt_int8_calibrator.h │ │ ├── trt_logger.cpp │ │ ├── trt_logger.h │ │ ├── trt_plugin.h │ │ ├── trt_shape.h │ │ ├── trt_unique_ptr.h │ │ └── trt_value.h │ └── xla │ │ ├── README.md │ │ ├── memory │ │ ├── device_buffer_allocator.h │ │ ├── device_memory_pool.cpp │ │ └── device_memory_pool.h │ │ ├── ops │ │ ├── activation_grad_op.cpp │ │ ├── adam_optimizer_op.cpp │ │ ├── add_op.cpp │ │ ├── argument_op.cpp │ │ ├── batch_matmul_op.cpp │ │ ├── bias_add_op.cpp │ │ ├── binary_op.h │ │ ├── broadcast_binary_op.cpp │ │ ├── cast_op.cpp │ │ ├── fc_op.cpp │ │ ├── gather.cpp │ │ ├── layer_norm_op.cpp │ │ ├── matmul_op.cpp │ │ ├── op_context.cpp │ │ ├── op_context.h │ │ ├── op_kernel.h │ │ ├── optimizer_op.h │ │ ├── reduce_op.cpp │ │ ├── reshape_op.cpp │ │ ├── scalar_binary_op.cpp │ │ ├── softmax_op.cpp │ │ ├── square_sum_op.cpp │ │ ├── transpose_op.cpp │ │ ├── unary_op.cpp │ │ └── unary_op.h │ │ ├── xla_allocator.cpp │ │ ├── xla_allocator.h │ │ ├── xla_data_type.cpp │ │ ├── xla_data_type.h │ │ ├── xla_executable.cpp │ │ ├── xla_executable.h │ │ ├── xla_executable_context.cpp │ │ ├── xla_executable_context.h │ │ ├── xla_executable_scope.h │ │ ├── xla_graph_compiler.cpp │ │ ├── xla_graph_compiler.h │ │ ├── xla_helpers.cpp │ │ ├── xla_helpers.h │ │ ├── xla_macro.h │ │ ├── xla_resource_manager.cpp │ │ ├── xla_resource_manager.h │ │ ├── xla_shape.cpp │ │ └── xla_shape.h ├── graph │ ├── algorithm.h │ ├── argument.h │ ├── graph.cpp │ ├── graph.h │ ├── graph_util.cpp │ ├── node.cpp │ ├── node.h │ ├── node_util.cpp │ └── node_util.h ├── int8_calibration │ ├── calibration.cpp │ ├── calibration.h │ ├── calibration_mode.cpp │ └── calibration_mode.h ├── python │ ├── CMakeLists.txt │ ├── graph.cpp │ ├── int8_calibration.cpp │ ├── openvino_stub.cpp │ ├── options.cpp │ ├── stub.cpp │ ├── tensorrt_stub.cpp │ └── xla_stub.cpp ├── version_script.lds ├── xrt.proto ├── xrt_launch_kernel.cpp └── xrt_launch_op.cpp ├── patches └── xla.patch ├── python ├── .gitignore └── oneflow_xrt │ ├── __init__.py │ ├── calibration_mode.py │ ├── graph.py │ ├── import_engine.py │ └── module.py ├── setup.py └── tools ├── create_python_module.py ├── env.py ├── run_clang_format.py ├── run_cmake_format.py ├── run_license_format.py └── run_py_format.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | DerivePointerAlignment: false 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /build-* 3 | *.pyc 4 | *.ipynb 5 | core.* 6 | *.egg-info 7 | /dist 8 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18.0) 2 | 3 | option(BUILD_XLA "Option to build with XLA" OFF) 4 | option(BUILD_TENSORRT "Option to build with TensorRT" OFF) 5 | option(BUILD_OPENVINO "Option to build with OpenVINO" OFF) 6 | option(BUILD_PYTHON "Option to build python module" ON) 7 | option(AUTO_INSTALL_ONEFLOW "Option to install oneflow automatically with pip" OFF) 8 | 9 | project(oneflow-xrt CXX) 10 | 11 | set(CMAKE_CXX_STANDARD 17) 12 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 13 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 14 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}/cmake) 15 | 16 | set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") 17 | 18 | include(third_party) 19 | 20 | if(BUILD_XLA) 21 | include(tensorflow-xla) 22 | endif() 23 | if(BUILD_TENSORRT) 24 | include(tensorrt) 25 | endif() 26 | if(BUILD_OPENVINO) 27 | include(openvino) 28 | endif() 29 | 30 | set(INSTALL_DIR 31 | "${PROJECT_BINARY_DIR}/install" 32 | CACHE STRING "") 33 | add_subdirectory(oneflow_xrt) 34 | -------------------------------------------------------------------------------- /cmake/absl.cmake: -------------------------------------------------------------------------------- 1 | set(ABSL_URL https://github.com/Oneflow-Inc/abseil-cpp/archive/d0.tar.gz) 2 | use_mirror(VARIABLE ABSL_URL URL ${ABSL_URL}) 3 | 4 | include(FetchContent) 5 | 6 | FetchContent_Declare(absl URL ${ABSL_URL}) 7 | 8 | set(BUILD_TESTING OFF) 9 | FetchContent_MakeAvailable(absl) 10 | -------------------------------------------------------------------------------- /cmake/glog.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | set(GLOG_URL https://github.com/google/glog/archive/refs/tags/v0.5.0.tar.gz) 4 | use_mirror(VARIABLE GLOG_URL URL ${GLOG_URL}) 5 | set(GLOG_URL_HASH 2368e3e0a95cce8b5b35a133271b480f) 6 | 7 | include(FetchContent) 8 | 9 | FetchContent_Declare( 10 | glog 11 | URL ${GLOG_URL} 12 | URL_HASH MD5=${GLOG_URL_HASH}) 13 | 14 | set(WITH_GFLAGS 15 | OFF 16 | CACHE BOOL "") 17 | set(BUILD_SHARED_LIBS 18 | OFF 19 | CACHE BOOL "") 20 | set(WITH_GTEST 21 | OFF 22 | CACHE BOOL "") 23 | FetchContent_MakeAvailable(glog) 24 | 25 | # just for tensorflow, DO NOT USE IN OTHER PLACE 26 | FetchContent_GetProperties(glog) 27 | set(GLOG_INCLUDE_DIR ${glog_BINARY_DIR}) 28 | -------------------------------------------------------------------------------- /cmake/openvino.cmake: -------------------------------------------------------------------------------- 1 | if(OPENVINO_ROOT) 2 | set(InferenceEngine_DIR ${OPENVINO_ROOT}/cmake) 3 | set(ngraph_DIR ${OPENVINO_ROOT}/cmake) 4 | elseif($ENV{OPENVINO_ROOT}) 5 | set(InferenceEngine_DIR $ENV{OPENVINO_ROOT}/cmake) 6 | set(ngraph_DIR $ENV{OPENVINO_ROOT}/cmake) 7 | endif() 8 | 9 | find_package(InferenceEngine REQUIRED) 10 | find_package(ngraph REQUIRED) 11 | 12 | list(APPEND XRT_OPENVINO_THIRD_PARTY_LIBRARIES IE::inference_engine 13 | ${NGRAPH_LIBRARIES}) 14 | -------------------------------------------------------------------------------- /cmake/pybind11.cmake: -------------------------------------------------------------------------------- 1 | include(FetchContent) 2 | 3 | set(PYBIND11_URL https://github.com/pybind/pybind11/archive/v2.7.0.zip) 4 | use_mirror(VARIABLE PYBIND11_URL URL ${PYBIND11_URL}) 5 | set(PYBIND11_URL_HASH 267807f790ef598ef912a79aceefdc10) 6 | 7 | FetchContent_Declare( 8 | pybind11 9 | URL ${PYBIND11_URL} 10 | URL_HASH MD5=${PYBIND11_URL_HASH}) 11 | 12 | FetchContent_MakeAvailable(pybind11) 13 | -------------------------------------------------------------------------------- /cmake/tensorrt.cmake: -------------------------------------------------------------------------------- 1 | if(NOT WITH_CUDA) 2 | message(FATAL_ERROR "Should recompile OneFlow with BUILD_CUDA=ON") 3 | endif() 4 | 5 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h 6 | PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include $ENV{TENSORRT_ROOT} 7 | $ENV{TENSORRT_ROOT}/include) 8 | 9 | find_library( 10 | TENSORRT_LIBRARIES 11 | NAMES nvinfer 12 | PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib $ENV{TENSORRT_ROOT} 13 | $ENV{TENSORRT_ROOT}/lib) 14 | 15 | if(NOT TENSORRT_INCLUDE_DIR OR NOT TENSORRT_LIBRARIES) 16 | message( 17 | FATAL_ERROR 18 | "TensorRT was not found. You can set TENSORRT_ROOT to specify the search path." 19 | ) 20 | endif() 21 | 22 | message(STATUS "TensorRT Include: ${TENSORRT_INCLUDE_DIR}") 23 | message(STATUS "TensorRT Lib: ${TENSORRT_LIBRARIES}") 24 | 25 | list(APPEND XRT_TENSORRT_THIRD_PARTY_LIBRARIES ${TENSORRT_LIBRARIES}) 26 | -------------------------------------------------------------------------------- /cmake/third_party.cmake: -------------------------------------------------------------------------------- 1 | include(util) 2 | include(protobuf) 3 | include(python) 4 | include(oneflow) 5 | include(glog) 6 | include(absl) 7 | 8 | set(XRT_COMMON_THIRD_PARTY_LIBRARIES 9 | glog::glog 10 | absl::algorithm 11 | absl::base 12 | absl::debugging 13 | absl::flat_hash_map 14 | absl::flags 15 | absl::memory 16 | absl::meta 17 | absl::numeric 18 | absl::strings 19 | absl::synchronization 20 | absl::time 21 | absl::utility 22 | absl::span) 23 | set(XRT_THIRD_PARTY_DEPENDICES protobuf) 24 | 25 | set(XRT_THIRD_PARTY_LIBRARIES ${XRT_COMMON_THIRD_PARTY_LIBRARIES} oneflow) 26 | if(WITH_CUDA) 27 | find_package(CUDAToolkit REQUIRED) 28 | list(APPEND XRT_THIRD_PARTY_LIBRARIES CUDA::cudart_static) 29 | include_directories(${CUDAToolkit_INCLUDE_DIRS}) 30 | endif() 31 | -------------------------------------------------------------------------------- /cmake/util.cmake: -------------------------------------------------------------------------------- 1 | function(use_mirror) 2 | set(ALIYUN_URL_PREFIX 3 | "https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/https/" 4 | CACHE STRING "URL prefix of Aliyun OSS mirror") 5 | cmake_parse_arguments(PARSED_ARGS "" "VARIABLE;URL" "" ${ARGN}) 6 | 7 | if((NOT PARSED_ARGS_VARIABLE) OR (NOT PARSED_ARGS_URL)) 8 | message(FATAL_ERROR "VARIABLE or URL required") 9 | endif() 10 | 11 | if(PARSED_ARGS_URL MATCHES "file://") 12 | set(${PARSED_ARGS_VARIABLE} ${PARSED_ARGS_URL} PARENT_SCOPE) 13 | return() 14 | endif() 15 | if(DEFINED THIRD_PARTY_MIRROR) 16 | if(THIRD_PARTY_MIRROR STREQUAL "aliyun") 17 | if(NOT PARSED_ARGS_URL MATCHES "^https://") 18 | message(FATAL_ERROR "URL should start with 'https://'") 19 | endif() 20 | string(REPLACE "https://" ${ALIYUN_URL_PREFIX} MIRRORED_URL ${PARSED_ARGS_URL}) 21 | set(${PARSED_ARGS_VARIABLE} ${MIRRORED_URL} PARENT_SCOPE) 22 | message(NOTICE "-- fetch ${PARSED_ARGS_VARIABLE} using aliyun mirror ${MIRRORED_URL}") 23 | elseif(NOT THIRD_PARTY_MIRROR STREQUAL "") 24 | message(FATAL_ERROR "invalid key for third party mirror") 25 | endif() 26 | endif() 27 | endfunction() -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | oneflow -------------------------------------------------------------------------------- /oneflow_xrt/api/api_internal.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/api/api_internal.h" 17 | 18 | #include "oneflow_xrt/compiler/passes/build_subgraph_pass.h" 19 | #include "oneflow_xrt/compiler/passes/mark_cluster_id_pass.h" 20 | #include "oneflow_xrt/compiler/passes/trainable_propagation_pass.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | std::shared_ptr RunClusterSubGraphPass( 26 | const XrtGraph* graph, const ClusteringOptions& options) { 27 | std::shared_ptr new_graph; 28 | new_graph = TrainablePropagationPass(graph); 29 | new_graph = RunMarkClusterIdPass(new_graph.get(), options); 30 | return RunBuildSubGraphPass(new_graph.get(), options); 31 | } 32 | 33 | } // namespace xrt 34 | } // namespace oneflow 35 | -------------------------------------------------------------------------------- /oneflow_xrt/api/api_internal.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_API_INTERNAL_H_ 17 | #define ONEFLOW_XRT_API_INTERNAL_H_ 18 | 19 | #include "oneflow/core/framework/framework.h" 20 | #include "oneflow/core/job/job.pb.h" 21 | #include "oneflow_xrt/compiler/passes/options.h" 22 | #include "oneflow_xrt/compiler/passes/shape_inference_context.h" 23 | #include "oneflow_xrt/graph/graph.h" 24 | #include "oneflow_xrt/int8_calibration/calibration_mode.h" 25 | #include "oneflow_xrt/xrt.pb.h" 26 | 27 | namespace oneflow { 28 | namespace xrt { 29 | 30 | extern std::shared_ptr BuildGraph(const FunctionProto& function); 31 | extern std::shared_ptr BuildGraph(const Job& job); 32 | 33 | std::shared_ptr RunClusterSubGraphPass( 34 | const XrtGraph* graph, const ClusteringOptions& options); 35 | 36 | extern std::shared_ptr RunRebuildJobPass(const XrtGraph* graph, 37 | const Job& origin, 38 | const ReBuildJobOptions& options); 39 | 40 | extern void RunShapeInferencePass(const XrtGraph* graph, 41 | ShapeInferenceContext& context); 42 | 43 | } // namespace xrt 44 | } // namespace oneflow 45 | 46 | #endif // ONEFLOW_XRT_API_INTERNAL_H_ 47 | -------------------------------------------------------------------------------- /oneflow_xrt/api/api_serving.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/api/api_serving.h" 17 | 18 | #include "oneflow_xrt/api/api_internal.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | 23 | std::string CompileJob(const std::string& job, 24 | const std::vector& engine, bool use_fp16, 25 | bool use_int8, size_t max_batch_size, 26 | size_t max_workspace_size, bool strict_types, 27 | bool force_precision_constraints, bool force_compile, 28 | size_t cluster_minimum_nodes, 29 | size_t cluster_maximum_nodes, 30 | bool cluster_ignore_pipeline, 31 | size_t cluster_max_iteration, 32 | bool cluster_strict_sbp_policy, 33 | const std::string& dump_subgraph_dir) { 34 | Job job_proto; 35 | if (!job_proto.ParseFromString(job)) { 36 | LOG(FATAL) << "invalid serialized job"; 37 | } 38 | auto graph = BuildGraph(job_proto); 39 | ClusteringOptions cluster_options; 40 | cluster_options.minimum_nodes = cluster_minimum_nodes; 41 | cluster_options.maximum_nodes = cluster_maximum_nodes; 42 | cluster_options.ignore_pipeline = cluster_ignore_pipeline; 43 | cluster_options.max_iteration = cluster_max_iteration; 44 | cluster_options.strict_sbp_policy = cluster_strict_sbp_policy; 45 | cluster_options.dump_subgraph_dir = dump_subgraph_dir; 46 | for (const auto& e : engine) { 47 | XrtEngine xrt_engine; 48 | XrtEngine_Parse(e, &xrt_engine); 49 | cluster_options.engine = xrt_engine; 50 | graph = RunClusterSubGraphPass(graph.get(), cluster_options); 51 | } 52 | 53 | ReBuildJobOptions options; 54 | options.use_fp16 = use_fp16; 55 | options.use_int8 = use_int8; 56 | options.max_batch_size = max_batch_size; 57 | options.max_workspace_size = max_workspace_size; 58 | options.strict_types = strict_types; 59 | options.force_precision_constraints = force_precision_constraints; 60 | options.force_compile = force_compile; 61 | options.dump_subgraph_dir = dump_subgraph_dir; 62 | 63 | auto new_job = RunRebuildJobPass(graph.get(), job_proto, options); 64 | return new_job->SerializeAsString(); 65 | } 66 | 67 | } // namespace xrt 68 | } // namespace oneflow 69 | -------------------------------------------------------------------------------- /oneflow_xrt/api/api_serving.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_API_SERVING_H_ 17 | #define ONEFLOW_XRT_API_SERVING_H_ 18 | 19 | #include 20 | #include 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | std::string CompileJob( 26 | const std::string& job, const std::vector& engine, 27 | bool use_fp16 = false, bool use_int8 = false, size_t max_batch_size = 1, 28 | size_t max_workspace_size = -1, bool strict_types = false, 29 | bool force_precision_constraints = true, bool force_compile = false, 30 | size_t cluster_minimum_nodes = 1, size_t cluster_maximum_nodes = 0x7fffffff, 31 | bool cluster_ignore_pipeline = true, size_t cluster_max_iteration = 20, 32 | bool cluster_strict_sbp_policy = true, 33 | const std::string& dump_subgraph_dir = ""); 34 | 35 | } // namespace xrt 36 | } // namespace oneflow 37 | 38 | #endif // ONEFLOW_XRT_API_SERVING_H_ 39 | -------------------------------------------------------------------------------- /oneflow_xrt/common/device.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/common/device.h" 17 | 18 | #include "glog/logging.h" 19 | #include "oneflow/core/common/device_type.pb.h" 20 | #ifdef WITH_CUDA 21 | #include "cuda_runtime.h" 22 | #endif 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | 27 | XrtDevice OfDeviceToXrtDevice(const std::string& device) { 28 | DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(device)); 29 | return OfDeviceToXrtDevice(device_type); 30 | } 31 | 32 | XrtDevice OfDeviceToXrtDevice(const DeviceType& device) { 33 | switch (device) { 34 | case DeviceType::kCUDA: 35 | return XrtDevice::GPU_CUDA; 36 | case DeviceType::kCPU: 37 | return XrtDevice::CPU_X86; 38 | default: 39 | LOG(WARNING) << "unsupported oneflow device type (" << device 40 | << ") is encountered, so use the default CPU device instead"; 41 | return XrtDevice::CPU_X86; 42 | } 43 | } 44 | 45 | DeviceType XrtDeviceToOfDevice(const XrtDevice& device) { 46 | if (device == XrtDevice::GPU_CUDA) { 47 | return DeviceType::kCUDA; 48 | } else if (device == XrtDevice::CPU_X86) { 49 | return DeviceType::kCPU; 50 | } else { 51 | LOG(FATAL) << "unsupported xrt device " << device; 52 | return DeviceType::kCPU; 53 | } 54 | } 55 | 56 | int GetDeviceId(const XrtDevice& device) { 57 | switch (device) { 58 | case XrtDevice::CPU_X86: 59 | return 0; 60 | case XrtDevice::GPU_CUDA: { 61 | #ifdef WITH_CUDA 62 | int device_id = 0; 63 | CHECK_EQ(cudaSuccess, cudaGetDevice(&device_id)); 64 | return device_id; 65 | #endif 66 | } 67 | case XrtDevice::GPU_CL: 68 | case XrtDevice::CPU_ARM: 69 | return 0; 70 | } 71 | return 0; // let compiler warning free 72 | } 73 | 74 | void SetDeviceId(const XrtDevice& device, const int device_id) { 75 | switch (device) { 76 | case XrtDevice::CPU_X86: 77 | return; 78 | case XrtDevice::GPU_CUDA: { 79 | #ifdef WITH_CUDA 80 | CHECK_EQ(cudaSuccess, cudaSetDevice(device_id)); 81 | return; 82 | #endif 83 | } 84 | case XrtDevice::GPU_CL: 85 | case XrtDevice::CPU_ARM: 86 | return; 87 | } 88 | } 89 | 90 | } // namespace xrt 91 | } // namespace oneflow 92 | -------------------------------------------------------------------------------- /oneflow_xrt/common/device.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMMON_DEVICE_H_ 17 | #define ONEFLOW_XRT_COMMON_DEVICE_H_ 18 | 19 | #include "oneflow/core/framework/framework.h" 20 | #include "oneflow_xrt/xrt.pb.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | XrtDevice OfDeviceToXrtDevice(const std::string& device); 26 | XrtDevice OfDeviceToXrtDevice(const DeviceType& device); 27 | 28 | DeviceType XrtDeviceToOfDevice(const XrtDevice& device); 29 | 30 | int GetDeviceId(const XrtDevice& device); 31 | 32 | void SetDeviceId(const XrtDevice& device, const int device_id); 33 | 34 | } // namespace xrt 35 | } // namespace oneflow 36 | 37 | #endif // ONEFLOW_XRT_COMMON_DEVICE_H_ 38 | -------------------------------------------------------------------------------- /oneflow_xrt/common/env.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMMON_ENV_H_ 17 | #define ONEFLOW_XRT_COMMON_ENV_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | #define EnvToString(envname, dflt) \ 25 | (!getenv(#envname) ? (dflt) : getenv(#envname)) 26 | 27 | #define EnvToBool(envname, dflt) \ 28 | (!getenv(#envname) ? (dflt) \ 29 | : memchr("tTyY1\0", getenv(#envname)[0], 6) != NULL) 30 | 31 | #define EnvToInt(envname, dflt) \ 32 | (!getenv(#envname) ? (dflt) : strtol(getenv(#envname), NULL, 10)) 33 | 34 | #define EnvToInt64(envname, dflt) \ 35 | (!getenv(#envname) ? (dflt) : strtoll(getenv(#envname), NULL, 10)) 36 | 37 | #endif // ONEFLOW_XRT_COMMON_ENV_H_ 38 | -------------------------------------------------------------------------------- /oneflow_xrt/common/registry.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMMON_REGISTER_H_ 17 | #define ONEFLOW_XRT_COMMON_REGISTER_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "glog/logging.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace common { 27 | 28 | template 29 | class Registry { 30 | public: 31 | using Factory = typename ID::Factory; 32 | 33 | static Registry* Global() { 34 | static Registry registry; 35 | return ®istry; 36 | } 37 | 38 | bool Has(const K& key) const { return factories_.count(key); } 39 | 40 | bool Register(const K& key, const Factory& value) { 41 | return factories_.emplace(key, value).second; 42 | } 43 | 44 | const Factory& Lookup(const K& key) const { 45 | const auto& it = factories_.find(key); 46 | if (it == factories_.end()) { 47 | LOG(FATAL) << "key " << key << " has not been registered"; 48 | } 49 | return it->second; 50 | } 51 | 52 | private: 53 | Registry() = default; 54 | virtual ~Registry() = default; 55 | 56 | private: 57 | std::unordered_map factories_; 58 | }; 59 | 60 | #define XRT_REGISTER(ID, key, value) \ 61 | common::Registry::type>::Global()->Register( \ 62 | key, value) 63 | 64 | #define XRT_REGISTER_HAS(ID, key) \ 65 | common::Registry::type>::Global()->Has(key) 66 | 67 | #define XRT_REGISTER_LOOKUP(ID, key) \ 68 | common::Registry::type>::Global()->Lookup(key) 69 | 70 | } // namespace common 71 | } // namespace xrt 72 | } // namespace oneflow 73 | 74 | #endif // ONEFLOW_XRT_COMMON_REGISTER_H_ 75 | -------------------------------------------------------------------------------- /oneflow_xrt/common/shape_util.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMMON_SHAPE_UTIL_H_ 17 | #define ONEFLOW_XRT_COMMON_SHAPE_UTIL_H_ 18 | 19 | #include "oneflow/core/common/shape.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | template 25 | inline Shape AsShape(const std::vector& dim_vec) { 26 | return Shape(DimVector(dim_vec.begin(), dim_vec.end())); 27 | } 28 | 29 | } // namespace xrt 30 | } // namespace oneflow 31 | 32 | #endif // ONEFLOW_XRT_COMMON_SHAPE_UTIL_H_ 33 | -------------------------------------------------------------------------------- /oneflow_xrt/common/typedef.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMMON_TYPEDEF_H_ 17 | #define ONEFLOW_XRT_COMMON_TYPEDEF_H_ 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | 22 | constexpr char const _XrtLaunchOpType[] = "XrtLaunch"; 23 | constexpr char const _XrtEntryOpType[] = "XrtEntry"; 24 | constexpr char const _XrtReturnOpType[] = "XrtReturn"; 25 | constexpr char const _XrtNoOpType[] = "XrtNoOp"; 26 | constexpr char const _XrtUnsupportedOpType[] = "XrtUnsupported"; 27 | 28 | constexpr char const _XrtLaunchPrefix[] = "_xrt_launch_"; 29 | constexpr char const _XrtEntryName[] = "_xrt_entry"; 30 | constexpr char const _XrtReturnName[] = "_xrt_return"; 31 | 32 | } // namespace xrt 33 | } // namespace oneflow 34 | 35 | #endif // ONEFLOW_XRT_COMMON_TYPEDEF_H_ 36 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/compilation_cache.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/compilation_cache.h" 17 | 18 | namespace oneflow { 19 | namespace xrt { 20 | 21 | bool operator==(const Signature& lhs, const Signature& rhs) { 22 | return lhs.builder_name == rhs.builder_name && 23 | lhs.device_ordinal == rhs.device_ordinal && 24 | lhs.entry_shapes == rhs.entry_shapes; 25 | } 26 | 27 | size_t SignatureHash::operator()(const Signature& signature) const { 28 | size_t hash_val = std::hash()(signature.builder_name) ^ 29 | std::hash()(signature.device_ordinal); 30 | for (const auto& shape : signature.entry_shapes) { 31 | hash_val ^= std::hash()(shape); 32 | } 33 | return hash_val; 34 | } 35 | 36 | Executable* CompilationCache::GetRecord(const Signature& signature) const { 37 | Executable* record = nullptr; 38 | std::lock_guard lock(mutex_); 39 | const auto& it = records_.find(signature); 40 | if (it != records_.end()) { 41 | record = it->second.get(); 42 | } 43 | return record; 44 | } 45 | 46 | void CompilationCache::Record(const Signature& signature, 47 | const std::shared_ptr& result) { 48 | std::lock_guard lock(mutex_); 49 | records_.emplace(signature, result); 50 | } 51 | 52 | void CompilationCache::Release() { 53 | std::unordered_map, SignatureHash> 54 | empty_records; 55 | records_.swap(empty_records); 56 | } 57 | 58 | Signature ComputeSignature(const std::string& name, const int device_ordinal, 59 | const std::vector& entry_params) { 60 | Signature signature; 61 | signature.builder_name = name; 62 | signature.device_ordinal = device_ordinal; 63 | signature.entry_shapes.resize(entry_params.size()); 64 | for (int i = 0; i < entry_params.size(); ++i) { 65 | signature.entry_shapes[i] = entry_params[i].shape(); 66 | } 67 | return signature; 68 | } 69 | 70 | } // namespace xrt 71 | } // namespace oneflow 72 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/compilation_cache.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_COMPILATION_CACHE_H_ 17 | #define ONEFLOW_XRT_COMPILER_COMPILATION_CACHE_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "oneflow/core/common/shape.h" 26 | #include "oneflow_xrt/compiler/executable.h" 27 | #include "oneflow_xrt/compiler/parameter.h" 28 | 29 | namespace oneflow { 30 | namespace xrt { 31 | 32 | struct Signature { 33 | // builder name 34 | std::string builder_name; 35 | // device ordinal 36 | int device_ordinal; 37 | 38 | // the signature should be recompute if the entry shape has been changed 39 | std::vector entry_shapes; 40 | }; 41 | 42 | bool operator==(const Signature& lhs, const Signature& rhs); 43 | 44 | struct SignatureHash { 45 | size_t operator()(const Signature& signature) const; 46 | }; 47 | 48 | class CompilationCache { 49 | public: 50 | Executable* GetRecord(const Signature& signature) const; 51 | 52 | void Record(const Signature& signature, 53 | const std::shared_ptr& result); 54 | 55 | void Release(); 56 | 57 | private: 58 | mutable std::mutex mutex_; 59 | std::unordered_map, SignatureHash> 60 | records_; 61 | }; 62 | 63 | Signature ComputeSignature(const std::string& name, const int device_ordinal, 64 | const std::vector& entry_params); 65 | 66 | } // namespace xrt 67 | } // namespace oneflow 68 | 69 | #endif // ONEFLOW_XRT_COMPILER_COMPILATION_CACHE_H_ 70 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/executable.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_EXECUTABLE_H_ 17 | #define ONEFLOW_XRT_COMPILER_EXECUTABLE_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "oneflow_xrt/compiler/parameter.h" 23 | #include "oneflow_xrt/xrt.pb.h" 24 | 25 | namespace oneflow { 26 | namespace xrt { 27 | 28 | struct ExecutableRunOptions { 29 | // user defined common options 30 | ExecuteOptionsProto common; 31 | // specify the compute stream if the engine supports multiple streams 32 | // the default compute stream will be used if `stream` is nullptr 33 | void* stream = nullptr; 34 | 35 | int32_t device_ordinal = -1; 36 | 37 | // populate the return parameters to reuse their storages while running 38 | // the executable 39 | std::vector return_params; 40 | }; 41 | 42 | class Executable { 43 | public: 44 | Executable(const std::string& name, const XrtEngine& engine) 45 | : name_(name), engine_(engine) {} 46 | virtual ~Executable() = default; 47 | 48 | const XrtEngine& engine() const { return engine_; } 49 | 50 | const std::string& name() const { return name_; } 51 | 52 | virtual bool Run(const std::vector& inputs, 53 | const ExecutableRunOptions& run_options, 54 | bool block_until_done = true) = 0; 55 | 56 | bool RunAsync(const std::vector inputs, 57 | const ExecutableRunOptions& run_options) { 58 | return Run(inputs, run_options, false); 59 | } 60 | 61 | const std::vector& Results() const { return results_; } 62 | 63 | protected: 64 | std::string name_; 65 | XrtEngine engine_; 66 | std::vector results_; 67 | }; 68 | 69 | } // namespace xrt 70 | } // namespace oneflow 71 | 72 | #endif // ONEFLOW_XRT_COMPILER_EXECUTABLE_H_ 73 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/kernel/op_context.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_KERNEL_OP_CONTEXT_H_ 17 | #define ONEFLOW_XRT_COMPILER_KERNEL_OP_CONTEXT_H_ 18 | 19 | #include "oneflow/core/framework/attr_map.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | class OpContext { 25 | public: 26 | explicit OpContext(const AttrMap& attrs) : attrs_(attrs) {} 27 | 28 | template 29 | T Attr(const std::string& name) { 30 | return CHECK_JUST(attrs_.GetAttr(name)); 31 | } 32 | 33 | bool HasAttr(const std::string& name) const { return attrs_.Has(name); } 34 | 35 | private: 36 | AttrMap attrs_; 37 | }; 38 | 39 | } // namespace xrt 40 | } // namespace oneflow 41 | 42 | #endif // ONEFLOW_XRT_COMPILER_KERNEL_OP_CONTEXT_H_ 43 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/kernel/op_kernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_KERNEL_OP_KERNEL_H_ 17 | #define ONEFLOW_XRT_COMPILER_KERNEL_OP_KERNEL_H_ 18 | 19 | #include "oneflow_xrt/common/registry.h" 20 | #include "oneflow_xrt/compiler/kernel/op_context.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | class OpKernelBase { 26 | public: 27 | OpKernelBase() = default; 28 | virtual ~OpKernelBase() = default; 29 | }; 30 | 31 | template 32 | class OpKernel : public OpKernelBase { 33 | public: 34 | virtual void Compile(ContextT* ctx) = 0; 35 | 36 | OpKernel() = default; 37 | virtual ~OpKernel() = default; 38 | }; 39 | 40 | } // namespace xrt 41 | } // namespace oneflow 42 | 43 | #endif // ONEFLOW_XRT_COMPILER_KERNEL_OP_KERNEL_H_ 44 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ngraph_shape.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_OPENVINO_NGRAPH_SHAPE_H_ 17 | #define ONEFLOW_XRT_COMPILER_OPENVINO_NGRAPH_SHAPE_H_ 18 | 19 | #include "glog/logging.h" 20 | #include "ngraph/ngraph.hpp" 21 | #include "oneflow/core/common/data_type.pb.h" 22 | #include "oneflow/core/common/shape.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace openvino { 27 | 28 | inline ngraph::element::Type DataTypeToNgraphDataType( 29 | const DataType& data_type) { 30 | switch (data_type) { 31 | case oneflow::kDouble: 32 | return ngraph::element::f64; 33 | case oneflow::kFloat: 34 | return ngraph::element::f32; 35 | case oneflow::kFloat16: 36 | return ngraph::element::f16; 37 | case oneflow::kInt8: 38 | return ngraph::element::i8; 39 | case oneflow::kInt32: 40 | return ngraph::element::i32; 41 | case oneflow::kInt64: 42 | return ngraph::element::i64; 43 | default: { 44 | LOG(FATAL) << "Unsupported data type " << data_type << " for Ngraph"; 45 | return ngraph::element::f32; 46 | } 47 | } 48 | } 49 | 50 | inline ngraph::Shape ShapeToNgraphShape(const Shape& shape) { 51 | CHECK_LE(shape.NumAxes(), 8) 52 | << "The maximum dimensions is 8 supported by Ngraph"; 53 | std::vector dim_vec; 54 | for (int i = 0; i < shape.NumAxes(); ++i) { 55 | dim_vec.push_back(shape.At(i)); 56 | } 57 | return ngraph::Shape(dim_vec); 58 | } 59 | 60 | class NgraphShape { 61 | public: 62 | NgraphShape() = default; 63 | 64 | NgraphShape(const Shape& shape, const DataType& data_type) 65 | : shape_(ShapeToNgraphShape(shape)), 66 | data_type_(DataTypeToNgraphDataType(data_type)) {} 67 | 68 | const ngraph::element::Type& data_type() const { return data_type_; } 69 | 70 | const ngraph::Shape& shape() const { return shape_; } 71 | 72 | private: 73 | ngraph::Shape shape_; 74 | ngraph::element::Type data_type_; 75 | }; 76 | 77 | } // namespace openvino 78 | } // namespace xrt 79 | } // namespace oneflow 80 | 81 | #endif // ONEFLOW_XRT_COMPILER_OPENVINO_NGRAPH_SHAPE_H_ 82 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/openvino_executable.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_EXECUTABLE_H_ 17 | #define ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_EXECUTABLE_H_ 18 | 19 | #include 20 | 21 | #include "oneflow_xrt/compiler/executable.h" 22 | #include "oneflow_xrt/compiler/openvino/inference_engine_data_desc.h" 23 | #include "oneflow_xrt/compiler/parameter.h" 24 | 25 | namespace oneflow { 26 | namespace xrt { 27 | namespace openvino { 28 | 29 | class OpenvinoExecutable : public Executable { 30 | public: 31 | OpenvinoExecutable( 32 | std::unique_ptr network, 33 | const std::unordered_map& in_out_to_param_idx) 34 | : Executable("", XrtEngine::OPENVINO), 35 | executable_network_(std::move(network)), 36 | in_out_to_param_idx_(in_out_to_param_idx) {} 37 | virtual ~OpenvinoExecutable() = default; 38 | 39 | bool Run(const std::vector& inputs, 40 | const ExecutableRunOptions& run_options, 41 | bool block_until_done = true) override; 42 | 43 | InferenceEngine::Blob::Ptr ParameterToBlobPtr( 44 | const Parameter& input, const InferenceEngine::TensorDesc& in_desc); 45 | 46 | private: 47 | std::unique_ptr executable_network_; 48 | std::unordered_map in_out_to_param_idx_; 49 | }; 50 | 51 | } // namespace openvino 52 | } // namespace xrt 53 | } // namespace oneflow 54 | 55 | #endif // ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_EXECUTABLE_H_ 56 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/openvino_graph_compiler.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_GRAPH_COMPILER_H_ 17 | #define ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_GRAPH_COMPILER_H_ 18 | 19 | #include "oneflow_xrt/compiler/graph_compiler.h" 20 | #include "oneflow_xrt/compiler/openvino/ngraph_shape.h" 21 | #include "oneflow_xrt/compiler/openvino/openvino_executable.h" 22 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace openvino { 27 | 28 | class OpenvinoGraphCompiler : public GraphCompiler::Impl { 29 | public: 30 | explicit OpenvinoGraphCompiler(const std::string& name) 31 | : GraphCompiler::Impl(name) {} 32 | 33 | virtual ~OpenvinoGraphCompiler() = default; 34 | 35 | std::shared_ptr Compile( 36 | const XrtGraph* graph, const std::vector& entry_params, 37 | const std::vector& return_params, 38 | const std::vector& aliases) override; 39 | 40 | private: 41 | void SetupKernelContextParam(const XrtNode* node, 42 | OpenvinoOpContext::Param* context_param); 43 | 44 | void PopulateEntryParams( 45 | const std::vector& entry_params, 46 | std::unordered_map& entry_params_map, 47 | std::unordered_map& entry_params_index_map); 48 | 49 | Argument ArgFromParameter(const Parameter& param); 50 | 51 | private: 52 | std::unordered_map arguments_; 53 | std::unordered_map> operands_; 54 | }; 55 | 56 | } // namespace openvino 57 | } // namespace xrt 58 | } // namespace oneflow 59 | 60 | #endif // ONEFLOW_XRT_COMPILER_OPENVINO_OPENVINO_GRAPH_COMPILER_H_ 61 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/add_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "absl/strings/str_cat.h" 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class AddOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | int num_inputs = ctx->num_inputs(); 30 | CHECK_GE(num_inputs, 2) << "ElementWiseOp needs 2 inputs at least."; 31 | Shape in_shape = ctx->InputShape("in_0"); 32 | std::shared_ptr result = ctx->Input("in_0"); 33 | for (int i = 1; i < num_inputs; ++i) { 34 | std::string name = absl::StrCat("in_", i); 35 | CHECK_EQ(in_shape, ctx->InputShape(name)); 36 | result = std::make_shared(ctx->Input(name), result); 37 | result->set_friendly_name(absl::StrCat(ctx->op_name().c_str(), i)); 38 | } 39 | ctx->SetOutput("out_0", result); 40 | } 41 | }; 42 | 43 | REGISTER_OPENVINO_OP_KERNEL(add_n, AddOp).EnableTrainPhase().Finalize(); 44 | 45 | } // namespace openvino 46 | } // namespace xrt 47 | } // namespace oneflow 48 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/argument_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace openvino { 22 | 23 | class XrtEntryOp : public OpenvinoOpKernel { 24 | public: 25 | void Compile(OpenvinoOpContext* ctx) override { 26 | ctx->SetOutput("value", ctx->Variable()); 27 | } 28 | }; 29 | 30 | class XrtReturnOp : public OpenvinoOpKernel { 31 | public: 32 | void Compile(OpenvinoOpContext* ctx) override { 33 | ctx->SetVariable(ctx->Input("value")); 34 | } 35 | }; 36 | 37 | REGISTER_OPENVINO_OP_KERNEL(XrtEntry, XrtEntryOp).Finalize(); 38 | REGISTER_OPENVINO_OP_KERNEL(XrtReturn, XrtReturnOp).Finalize(); 39 | 40 | } // namespace openvino 41 | } // namespace xrt 42 | } // namespace oneflow 43 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/batch_norm_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "absl/strings/str_cat.h" 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class NormalizationOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | std::shared_ptr input = ctx->Input("x_0"); 30 | std::shared_ptr gamma = ctx->Weight("gamma_0"); 31 | std::shared_ptr beta = ctx->Weight("beta_0"); 32 | std::shared_ptr moving_mean = ctx->Weight("moving_mean_0"); 33 | std::shared_ptr moving_variance = 34 | ctx->Weight("moving_variance_0"); 35 | float epsilon = ctx->Attr("epsilon"); 36 | std::shared_ptr ngraph_node = 37 | std::make_shared( 38 | input, gamma, beta, moving_mean, moving_variance, epsilon); 39 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 40 | ctx->SetOutput("y_0", ngraph_node); 41 | } 42 | }; 43 | 44 | REGISTER_OPENVINO_OP_KERNEL(normalization, NormalizationOp).Finalize(); 45 | 46 | } // namespace openvino 47 | } // namespace xrt 48 | } // namespace oneflow 49 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/concat_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "absl/strings/str_cat.h" 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class ConcatOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | int num_inputs = ctx->num_inputs(); 30 | CHECK_GE(num_inputs, 2) << "Concat needs 2 inputs at least."; 31 | Shape in_shape = ctx->InputShape("in_0"); 32 | int64_t axis = ctx->Attr("axis"); 33 | if (axis < 0) { 34 | axis += in_shape.NumAxes(); 35 | } 36 | CHECK_GE(axis, 0); 37 | CHECK_LT(axis, in_shape.NumAxes()); 38 | 39 | std::vector> in(num_inputs); 40 | for (int i = 0; i < num_inputs; ++i) { 41 | in[i] = ctx->Input(absl::StrCat("in_", i)); 42 | } 43 | std::shared_ptr ngraph_node = 44 | std::make_shared(in, axis); 45 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 46 | ctx->SetOutput("out_0", ngraph_node); 47 | } 48 | }; 49 | 50 | REGISTER_OPENVINO_OP_KERNEL(concat, ConcatOp).EnableTrainPhase().Finalize(); 51 | 52 | } // namespace openvino 53 | } // namespace xrt 54 | } // namespace oneflow 55 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/leaky_relu_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class LeakyReluOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | float alpha = ctx->Attr("alpha"); 30 | std::shared_ptr input = ctx->Input("x_0"); 31 | std::shared_ptr alpha_node = 32 | std::make_shared(ngraph::element::f32, 33 | ngraph::Shape({1}), &alpha); 34 | std::shared_ptr ngraph_node = 35 | std::make_shared(input, alpha_node); 36 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 37 | ctx->SetOutput("y_0", ngraph_node); 38 | } 39 | }; 40 | 41 | REGISTER_OPENVINO_OP_KERNEL(leaky_relu, LeakyReluOp) 42 | .EnableTrainPhase() 43 | .Finalize(); 44 | 45 | } // namespace openvino 46 | } // namespace xrt 47 | } // namespace oneflow 48 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/matmul_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace openvino { 24 | 25 | class MatMulOp : public OpenvinoOpKernel { 26 | public: 27 | void Compile(OpenvinoOpContext* ctx) override { 28 | Shape a_shape = ctx->InputShape("a_0"); 29 | Shape b_shape = ctx->InputShape("b_0"); 30 | CHECK_GE(a_shape.NumAxes(), 2); 31 | CHECK_EQ(a_shape.NumAxes(), b_shape.NumAxes()); 32 | 33 | bool transpose_a = ctx->Attr("transpose_a"); 34 | bool transpose_b = ctx->Attr("transpose_b"); 35 | auto a = ctx->Input("a_0"); 36 | auto b = ctx->Input("b_0"); 37 | 38 | std::shared_ptr ngraph_node = 39 | std::make_shared(a, b, transpose_a, 40 | transpose_b); 41 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 42 | ctx->SetOutput("out_0", ngraph_node); 43 | } 44 | }; 45 | 46 | REGISTER_OPENVINO_OP_KERNEL(matmul, MatMulOp).EnableTrainPhase().Finalize(); 47 | 48 | } // namespace openvino 49 | } // namespace xrt 50 | } // namespace oneflow 51 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/op_kernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_OPENVINO_OPS_OP_KERNEL_H_ 17 | #define ONEFLOW_XRT_COMPILER_OPENVINO_OPS_OP_KERNEL_H_ 18 | 19 | #include "oneflow_xrt/common/registry.h" 20 | #include "oneflow_xrt/compiler/kernel/op_kernel.h" 21 | #include "oneflow_xrt/compiler/kernel/op_kernel_registry.h" 22 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace openvino { 27 | 28 | class OpenvinoOpKernel : public OpKernel { 29 | public: 30 | virtual void Compile(OpenvinoOpContext* ctx) = 0; 31 | 32 | OpenvinoOpKernel() = default; 33 | virtual ~OpenvinoOpKernel() = default; 34 | }; 35 | 36 | #define REGISTER_OPENVINO_OP_KERNEL(OpName, KernelType) \ 37 | static OpKernelRegistrar _openvino_op_kernel_##OpName##_ \ 38 | __attribute__((unused)) = \ 39 | OpKernelRegistrar(#OpName) \ 40 | .SetEngine(XrtEngine::OPENVINO) \ 41 | .SetDevice({XrtDevice::CPU_X86}) \ 42 | .SetFactory([]() -> OpKernelBase* { return new KernelType; }) 43 | 44 | inline std::shared_ptr BuildOpKernel( 45 | const std::string& op_name) { 46 | OpKernelRegKey reg_key{op_name, XrtEngine::OPENVINO, XrtDevice::CPU_X86}; 47 | const auto& f = XRT_REGISTER_LOOKUP(OpKernelRegId, reg_key); 48 | auto* openvino_kernel = dynamic_cast(f()); 49 | CHECK(openvino_kernel) << "failed to build openvino op kernel for " 50 | << reg_key; 51 | return std::shared_ptr(openvino_kernel); 52 | } 53 | 54 | } // namespace openvino 55 | } // namespace xrt 56 | } // namespace oneflow 57 | 58 | #endif // ONEFLOW_XRT_COMPILER_OPENVINO_OPS_OP_KERNEL_H_ 59 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/relu_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | 18 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace openvino { 24 | 25 | class ReluOp : public OpenvinoOpKernel { 26 | public: 27 | void Compile(OpenvinoOpContext* ctx) override { 28 | std::shared_ptr ngraph_node = 29 | std::make_shared(ctx->Input("x_0")); 30 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 31 | ctx->SetOutput("y_0", ngraph_node); 32 | } 33 | }; 34 | 35 | REGISTER_OPENVINO_OP_KERNEL(relu, ReluOp).EnableTrainPhase().Finalize(); 36 | 37 | } // namespace openvino 38 | } // namespace xrt 39 | } // namespace oneflow 40 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/reshape_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class ReshapeOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | Shape in_shape = ctx->InputShape("in_0"); 30 | Shape shape = ctx->OutputShape("out_0"); 31 | CHECK_EQ(shape.Count(0), in_shape.Count(0)); 32 | std::vector dim_vec; 33 | for (int i = 0; i < shape.NumAxes(); ++i) { 34 | dim_vec.push_back(shape.At(i)); 35 | } 36 | 37 | std::shared_ptr alpha_node = 38 | std::make_shared( 39 | ngraph::element::i32, ngraph::Shape({dim_vec.size()}), dim_vec); 40 | std::shared_ptr ngraph_node = 41 | std::make_shared(ctx->Input("in_0"), 42 | alpha_node, false); 43 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 44 | ctx->SetOutput("out_0", ngraph_node); 45 | } 46 | }; 47 | 48 | REGISTER_OPENVINO_OP_KERNEL(reshape, ReshapeOp).EnableTrainPhase().Finalize(); 49 | 50 | } // namespace openvino 51 | } // namespace xrt 52 | } // namespace oneflow 53 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/openvino/ops/upsample_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/compiler/openvino/ops/op_context.h" 20 | #include "oneflow_xrt/compiler/openvino/ops/op_kernel.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace openvino { 25 | 26 | class UpsampleNearestOp : public OpenvinoOpKernel { 27 | public: 28 | void Compile(OpenvinoOpContext* ctx) override { 29 | // upsample_nearest only supports NCHW 30 | const int32_t scale = ctx->Attr("scale"); 31 | if (ctx->Attr("data_format") != "channels_first") { 32 | LOG(FATAL) << "upsample_nearest only supports NCHW"; 33 | } 34 | std::vector out_shape = {ctx->InputShape("in_0").At(2) * scale, 35 | ctx->InputShape("in_0").At(3) * scale}; 36 | ngraph::op::InterpolateAttrs attrs; 37 | attrs.axes.insert(2); 38 | attrs.axes.insert(3); 39 | attrs.mode = "nearest"; 40 | attrs.align_corners = 0; 41 | attrs.antialias = 0; 42 | attrs.pads_begin.push_back(0); 43 | attrs.pads_end.push_back(0); 44 | std::shared_ptr input = ctx->Input("in_0"); 45 | std::shared_ptr out_shape_node = 46 | std::make_shared(ngraph::element::i64, 47 | ngraph::Shape({2}), out_shape); 48 | std::shared_ptr ngraph_node = 49 | std::make_shared(input, out_shape_node, attrs); 50 | ngraph_node->set_friendly_name(ctx->op_name().c_str()); 51 | ctx->SetOutput("out_0", ngraph_node); 52 | } 53 | }; 54 | 55 | REGISTER_OPENVINO_OP_KERNEL(upsample_nearest, UpsampleNearestOp) 56 | .EnableTrainPhase() 57 | .Finalize(); 58 | 59 | } // namespace openvino 60 | } // namespace xrt 61 | } // namespace oneflow 62 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/passes/build_subgraph_pass.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_PASSES_BUILD_SUBGRAPH_PASS_H_ 17 | #define ONEFLOW_XRT_COMPILER_PASSES_BUILD_SUBGRAPH_PASS_H_ 18 | 19 | #include "oneflow_xrt/compiler/passes/options.h" 20 | #include "oneflow_xrt/graph/graph.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | std::shared_ptr RunBuildSubGraphPass( 26 | const XrtGraph* graph, const ClusteringOptions& options); 27 | 28 | } // namespace xrt 29 | } // namespace oneflow 30 | 31 | #endif // ONEFLOW_XRT_COMPILER_PASSES_BUILD_SUBGRAPH_PASS_H_ 32 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/passes/mark_cluster_id_pass.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_PASSES_MARK_CLUSTER_ID_PASS_H_ 17 | #define ONEFLOW_XRT_COMPILER_PASSES_MARK_CLUSTER_ID_PASS_H_ 18 | 19 | #include "oneflow_xrt/compiler/passes/options.h" 20 | #include "oneflow_xrt/graph/graph.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | std::shared_ptr RunMarkClusterIdPass( 26 | const XrtGraph* graph, const ClusteringOptions& options); 27 | 28 | } // namespace xrt 29 | } // namespace oneflow 30 | 31 | #endif // ONEFLOW_XRT_COMPILER_PASSES_MARK_CLUSTER_ID_PASS_H_ 32 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/passes/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_PASSES_OPTIONS_H_ 17 | #define ONEFLOW_XRT_COMPILER_PASSES_OPTIONS_H_ 18 | 19 | #include 20 | 21 | #include "oneflow_xrt/xrt.pb.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | 26 | struct ClusteringOptions { 27 | XrtEngine engine = XrtEngine::DEFAULT; 28 | XrtDevice device = XrtDevice::CPU_X86; 29 | 30 | // minimum node number in each cluster after clustering. If the number of 31 | // nodes contained by a cluster is less than `minimum_nodes` or grater than 32 | // `maximum_nodes`, then this cluster will be discard and not compiled 33 | int32_t minimum_nodes = 0x1; 34 | int32_t maximum_nodes = 0x7fffffff; 35 | 36 | // ignore strict dependencies analysis 37 | bool ignore_pipeline = true; 38 | // check is satisfy strict sbp policy 39 | bool strict_sbp_policy = true; 40 | 41 | // maximum iteration count for iteratively clustering. -1 means 42 | // that it will always iteratively merge as much as possible until no 43 | // node can be merged 44 | int32_t max_iteration = 20; 45 | 46 | std::string dump_subgraph_dir = ""; 47 | }; 48 | 49 | struct ReBuildJobOptions { 50 | bool use_fp16 = false; 51 | bool use_int8 = false; 52 | 53 | std::string int8_calibration = ""; 54 | 55 | bool force_compile = false; 56 | bool strict_types = false; 57 | bool force_precision_constraints = true; 58 | 59 | int64_t max_batch_size = 1; 60 | int64_t max_workspace_size = -1; 61 | 62 | std::string dump_subgraph_dir = ""; 63 | }; 64 | 65 | } // namespace xrt 66 | } // namespace oneflow 67 | 68 | #endif // ONEFLOW_XRT_COMPILER_PASSES_OPTIONS_H_ 69 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/passes/trainable_propagation_pass.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/passes/trainable_propagation_pass.h" 17 | 18 | namespace oneflow { 19 | namespace xrt { 20 | 21 | std::shared_ptr TrainablePropagationPass(const XrtGraph* graph) { 22 | auto new_graph = graph->clone(); 23 | algorithm::TopologyVisit(*new_graph, [&](XrtNode* node) { 24 | if (node->trainable()) { 25 | return; 26 | } 27 | bool trainable = false; 28 | for (const auto& edge : node->in_edges()) { 29 | if (edge->start()->trainable()) { 30 | trainable = true; 31 | break; 32 | } 33 | } 34 | node->set_trainable(trainable); 35 | }); 36 | return new_graph; 37 | } 38 | 39 | } // namespace xrt 40 | } // namespace oneflow 41 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/passes/trainable_propagation_pass.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_PASSES_TRAINABLE_PROPAGATION_PASS_H_ 17 | #define ONEFLOW_XRT_COMPILER_PASSES_TRAINABLE_PROPAGATION_PASS_H_ 18 | 19 | #include "oneflow_xrt/graph/graph.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | std::shared_ptr TrainablePropagationPass(const XrtGraph* graph); 25 | 26 | } // namespace xrt 27 | } // namespace oneflow 28 | 29 | #endif // ONEFLOW_XRT_COMPILER_PASSES_TRAINABLE_PROPAGATION_PASS_H_ 30 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/README.md: -------------------------------------------------------------------------------- 1 | ## TensorRT 2 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_COMMON_H_ 18 | #define ONEFLOW_XRT_COMPILER_TENSORRT_COMMON_H_ 19 | 20 | #include "NvInferVersion.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace tensorrt { 25 | 26 | #define TRT_VERSION \ 27 | ((NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH) 28 | 29 | #if NV_TENSORRT_MAJOR > 7 30 | #define TRT_NOEXCEPT noexcept 31 | #else 32 | #define TRT_NOEXCEPT 33 | #endif 34 | 35 | } // namespace tensorrt 36 | } // namespace xrt 37 | } // namespace oneflow 38 | 39 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_COMMON_H_ 40 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/argument_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class XrtEntryOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | ctx->SetVariable("value", ctx->Variable("variable")); 27 | } 28 | }; 29 | 30 | class XrtReturnOp : public TrtOpKernel { 31 | public: 32 | void Compile(TrtOpContext* ctx) override { 33 | ctx->SetVariable("variable", ctx->Variable("value")); 34 | } 35 | }; 36 | 37 | REGISTER_TRT_OP_KERNEL(XrtEntry, XrtEntryOp).Finalize(); 38 | REGISTER_TRT_OP_KERNEL(XrtReturn, XrtReturnOp).Finalize(); 39 | 40 | } // namespace tensorrt 41 | } // namespace xrt 42 | } // namespace oneflow 43 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/bias_add_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/common/shape_util.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 19 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace tensorrt { 24 | 25 | class BiasAddOp : public TrtOpKernel { 26 | public: 27 | void Compile(TrtOpContext* ctx) override { 28 | CHECK_EQ(ctx->InputType("a_0"), ctx->InputType("b_0")); 29 | 30 | Shape in_shape = ctx->InputShape("a_0"); 31 | Shape bias_shape = ctx->InputShape("b_0"); 32 | CHECK_GE(in_shape.NumAxes(), 2); 33 | CHECK_EQ(bias_shape.NumAxes(), 1); 34 | 35 | std::vector dims(in_shape.NumAxes(), 1); 36 | int32_t axis = ctx->Attr("axis"); 37 | dims[axis] = bias_shape.At(0); 38 | 39 | nvinfer1::ITensor* in = ctx->Input("a_0"); 40 | ; 41 | nvinfer1::Weights bias = ctx->Weight("b_0"); 42 | nvinfer1::ITensor* reshaped_bias = 43 | helpers::Reshape(ctx, bias, AsShape(dims)); 44 | // Add bias to input by ElementWise layer. 45 | auto* layer = ctx->builder()->addElementWise( 46 | *in, *reshaped_bias, nvinfer1::ElementWiseOperation::kSUM); 47 | layer->setName(ctx->op_name().c_str()); 48 | 49 | ctx->SetOutput("out_0", layer->getOutput(0)); 50 | } 51 | }; 52 | 53 | REGISTER_TRT_OP_KERNEL(bias_add, BiasAddOp).Finalize(); 54 | 55 | } // namespace tensorrt 56 | } // namespace xrt 57 | } // namespace oneflow 58 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/broadcast_binary_ops.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "NvInfer.h" 17 | #include "oneflow/core/common/shape_view.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 20 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace tensorrt { 25 | 26 | template 27 | class BcastBinaryOp : public TrtOpKernel { 28 | public: 29 | void Compile(TrtOpContext* ctx) override { 30 | Shape shape_a = ctx->InputShape("x_0"); 31 | Shape shape_b = ctx->InputShape("y_0"); 32 | 33 | int axes = std::max(shape_a.NumAxes(), shape_b.NumAxes()); 34 | shape_a = CreateLeftExtendedShape(ShapeView(shape_a), axes); 35 | shape_b = CreateLeftExtendedShape(ShapeView(shape_b), axes); 36 | 37 | nvinfer1::ITensor* x = helpers::Reshape(ctx, ctx->Input("x_0"), shape_a); 38 | nvinfer1::ITensor* y = helpers::Reshape(ctx, ctx->Input("y_0"), shape_b); 39 | auto* layer = ctx->builder()->addElementWise(*x, *y, element_wise_op); 40 | layer->setName(ctx->op_name().c_str()); 41 | ctx->SetSoleOutput(layer->getOutput(0)); 42 | } 43 | }; 44 | 45 | REGISTER_TRT_OP_KERNEL(broadcast_add, 46 | BcastBinaryOp) 47 | .EnableTrainPhase() 48 | .Finalize(); 49 | REGISTER_TRT_OP_KERNEL(broadcast_sub, 50 | BcastBinaryOp) 51 | .EnableTrainPhase() 52 | .Finalize(); 53 | REGISTER_TRT_OP_KERNEL(broadcast_mul, 54 | BcastBinaryOp) 55 | .EnableTrainPhase() 56 | .Finalize(); 57 | REGISTER_TRT_OP_KERNEL(broadcast_div, 58 | BcastBinaryOp) 59 | .EnableTrainPhase() 60 | .Finalize(); 61 | 62 | } // namespace tensorrt 63 | } // namespace xrt 64 | } // namespace oneflow 65 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/broadcast_like_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/plugin/broadcast_like_plugin.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class BroadcastLikeOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | const auto& broadcast_axes = 28 | ctx->Attr>("broadcast_axes"); 29 | std::vector inputs(2); 30 | inputs[0] = ctx->Input("x_0"); 31 | inputs[1] = ctx->Input("like_0"); 32 | BroadcastLikePlugin plugin(ctx->op_name(), broadcast_axes); 33 | auto* layer = ctx->builder()->addPluginV2(inputs.data(), 2, plugin); 34 | layer->setName(ctx->op_name().c_str()); 35 | ctx->SetOutput("y_0", layer->getOutput(0)); 36 | } 37 | }; 38 | 39 | // REGISTER_TRT_OP_KERNEL(broadcast_like, BroadcastLikeOp) 40 | // .EnableTrainPhase() 41 | // .Finalize(); 42 | 43 | } // namespace tensorrt 44 | } // namespace xrt 45 | } // namespace oneflow 46 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/cast_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_shape.h" // DataTypeToTrtDataType 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class CastOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | DataType dest_dtype = ctx->Attr("dtype"); 28 | DataType src_dtype = ctx->SoleInputType(); 29 | nvinfer1::ITensor* in = ctx->SoleInput(); 30 | if (src_dtype == dest_dtype) { 31 | ctx->SetSoleOutput(in); 32 | } else { 33 | auto* layer = ctx->builder()->addIdentity(*in); 34 | layer->setOutputType(0, DataTypeToTrtDataType(dest_dtype)); 35 | layer->setName(ctx->op_name().c_str()); 36 | ctx->SetSoleOutput(layer->getOutput(0)); 37 | } 38 | } 39 | }; 40 | 41 | REGISTER_TRT_OP_KERNEL(cast, CastOp).EnableTrainPhase().Finalize(); 42 | 43 | } // namespace tensorrt 44 | } // namespace xrt 45 | } // namespace oneflow 46 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/concat_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "absl/strings/str_cat.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class ConcatOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | int num_inputs = ctx->num_inputs(); 28 | CHECK_GE(num_inputs, 2) << "Concat needs 2 inputs at least."; 29 | Shape in_shape = ctx->InputShape("in_0"); 30 | int64_t axis = ctx->Attr("axis"); 31 | if (axis < 0) { 32 | axis += in_shape.NumAxes(); 33 | } 34 | CHECK_GE(axis, 0); 35 | CHECK_LT(axis, in_shape.NumAxes()); 36 | 37 | std::vector in(num_inputs); 38 | for (int i = 0; i < num_inputs; ++i) { 39 | in[i] = ctx->Input(absl::StrCat("in_", i)); 40 | } 41 | auto* layer = ctx->builder()->addConcatenation(in.data(), num_inputs); 42 | layer->setAxis(axis); 43 | layer->setName(ctx->op_name().c_str()); 44 | ctx->SetSoleOutput(layer->getOutput(0)); 45 | } 46 | }; 47 | 48 | REGISTER_TRT_OP_KERNEL(concat, ConcatOp).EnableTrainPhase().Finalize(); 49 | 50 | } // namespace tensorrt 51 | } // namespace xrt 52 | } // namespace oneflow 53 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/element_wise_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "NvInfer.h" 17 | #include "absl/strings/str_cat.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace tensorrt { 24 | 25 | template 26 | class ElementWiseOp : public TrtOpKernel { 27 | public: 28 | void Compile(TrtOpContext* ctx) override { 29 | int num_inputs = ctx->num_inputs(); 30 | CHECK_GE(num_inputs, 2) << "ElementWiseOp needs 2 inputs at least."; 31 | 32 | Shape in_shape = ctx->InputShape("in_0"); 33 | nvinfer1::ITensor* result = ctx->Input("in_0"); 34 | for (int i = 1; i < num_inputs; ++i) { 35 | std::string name = absl::StrCat("in_", i); 36 | CHECK_EQ(in_shape, ctx->InputShape(name)); 37 | auto* layer = ctx->builder()->addElementWise(*ctx->Input(name), *result, 38 | element_wise_op); 39 | result = layer->getOutput(0); 40 | } 41 | ctx->SetSoleOutput(result); 42 | } 43 | }; 44 | 45 | REGISTER_TRT_OP_KERNEL(add_n, 46 | ElementWiseOp) 47 | .EnableTrainPhase() 48 | .Finalize(); 49 | 50 | } // namespace tensorrt 51 | } // namespace xrt 52 | } // namespace oneflow 53 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/expand_dims.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class ExpandDimsOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | Shape in_shape = ctx->InputShape("in_0"); 28 | int32_t axis = ctx->Attr("axis"); 29 | if (axis < 0) { 30 | axis = axis + in_shape.NumAxes() + 1; 31 | } 32 | 33 | auto dim_vec = in_shape.dim_vec(); 34 | dim_vec.insert(dim_vec.begin() + axis, 1); 35 | ctx->SetSoleOutput( 36 | helpers::Reshape(ctx, ctx->Input("in_0"), Shape(dim_vec))); 37 | } 38 | }; 39 | 40 | REGISTER_TRT_OP_KERNEL(expand_dims, ExpandDimsOp).EnableTrainPhase().Finalize(); 41 | 42 | } // namespace tensorrt 43 | } // namespace xrt 44 | } // namespace oneflow 45 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/identity_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class IdentityOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | nvinfer1::ITensor* in = ctx->SoleInput(); 27 | auto* layer = ctx->builder()->addIdentity(*in); 28 | layer->setName(ctx->op_name().c_str()); 29 | ctx->SetSoleOutput(layer->getOutput(0)); 30 | } 31 | }; 32 | 33 | REGISTER_TRT_OP_KERNEL(identity, IdentityOp).EnableTrainPhase().Finalize(); 34 | 35 | } // namespace tensorrt 36 | } // namespace xrt 37 | } // namespace oneflow 38 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/multiply_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "NvInfer.h" 17 | #include "absl/strings/str_cat.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace tensorrt { 24 | 25 | class MultiplyOp : public TrtOpKernel { 26 | public: 27 | void Compile(TrtOpContext* ctx) override { 28 | Shape x_shape = ctx->InputShape("x_0"); 29 | Shape y_shape = ctx->InputShape("y_0"); 30 | nvinfer1::ITensor* x = ctx->Input("x_0"); 31 | nvinfer1::ITensor* y = ctx->Input("y_0"); 32 | CHECK_EQ(x_shape, y_shape); 33 | auto* layer = ctx->builder()->addElementWise( 34 | *x, *y, nvinfer1::ElementWiseOperation::kPROD); 35 | ctx->SetSoleOutput(layer->getOutput(0)); 36 | } 37 | }; 38 | 39 | REGISTER_TRT_OP_KERNEL(multiply, MultiplyOp).EnableTrainPhase().Finalize(); 40 | 41 | } // namespace tensorrt 42 | } // namespace xrt 43 | } // namespace oneflow 44 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/op_kernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_OPS_OP_KERNEL_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_OPS_OP_KERNEL_H_ 18 | 19 | #include "oneflow_xrt/common/registry.h" 20 | #include "oneflow_xrt/compiler/kernel/op_kernel.h" 21 | #include "oneflow_xrt/compiler/kernel/op_kernel_registry.h" 22 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace tensorrt { 27 | 28 | class TrtOpKernel : public OpKernel { 29 | public: 30 | virtual void Compile(TrtOpContext* ctx) = 0; 31 | 32 | TrtOpKernel() = default; 33 | virtual ~TrtOpKernel() = default; 34 | }; 35 | 36 | #define REGISTER_TRT_OP_KERNEL(OpName, KernelType) \ 37 | static OpKernelRegistrar _trt_op_kernel_##OpName##_ \ 38 | __attribute__((unused)) = \ 39 | OpKernelRegistrar(#OpName) \ 40 | .SetEngine(XrtEngine::TENSORRT) \ 41 | .SetDevice({XrtDevice::GPU_CUDA}) \ 42 | .SetFactory([]() -> OpKernelBase* { return new KernelType; }) 43 | 44 | inline std::shared_ptr BuildOpKernel(const std::string& op_name) { 45 | OpKernelRegKey reg_key{op_name, XrtEngine::TENSORRT, XrtDevice::GPU_CUDA}; 46 | const auto& f = XRT_REGISTER_LOOKUP(OpKernelRegId, reg_key); 47 | auto* trt_kernel = dynamic_cast(f()); 48 | CHECK(trt_kernel) << "failed to build tensorrt op kernel for " << reg_key; 49 | return std::shared_ptr(trt_kernel); 50 | } 51 | 52 | } // namespace tensorrt 53 | } // namespace xrt 54 | } // namespace oneflow 55 | 56 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_OPS_OP_KERNEL_H_ 57 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/pad_grad_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class PaddingGradOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | const auto& padding_before = 27 | ctx->Attr>("padding_before"); 28 | const auto& padding_after = 29 | ctx->Attr>("padding_after"); 30 | const auto& shape = ctx->InputShape("dy_0"); 31 | CHECK_EQ(shape.NumAxes(), padding_before.size()); 32 | CHECK_EQ(shape.NumAxes(), padding_after.size()); 33 | nvinfer1::Dims start, size, stride; 34 | start.nbDims = padding_before.size(); 35 | size.nbDims = start.nbDims; 36 | stride.nbDims = start.nbDims; 37 | for (int i = 0; i < start.nbDims; ++i) { 38 | start.d[i] = padding_before[i]; 39 | size.d[i] = shape.At(i) - padding_before[i] - padding_after[i]; 40 | stride.d[i] = 1; 41 | } 42 | auto* layer = 43 | ctx->builder()->addSlice(*(ctx->Input("dy_0")), start, size, stride); 44 | layer->setName(ctx->op_name().c_str()); 45 | 46 | // add identity layer after slice to bypass some internal error, 47 | // refer to https://github.com/NVIDIA/TensorRT/issues/1821 48 | auto* identity_layer = ctx->builder()->addIdentity(*(layer->getOutput(0))); 49 | std::string name = ctx->op_name() + ".identity"; 50 | identity_layer->setName(name.c_str()); 51 | ctx->SetOutput("dx_0", identity_layer->getOutput(0)); 52 | } 53 | }; 54 | 55 | REGISTER_TRT_OP_KERNEL(pad_grad, PaddingGradOp).EnableTrainPhase().Finalize(); 56 | 57 | } // namespace tensorrt 58 | } // namespace xrt 59 | } // namespace oneflow 60 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/pad_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class PaddingOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | // TODO(hjchen2): Support not constant 0 padding 27 | if (ctx->HasAttr("floating_constant_value")) { 28 | double value = ctx->Attr("floating_constant_value"); 29 | CHECK_EQ(value, 0) << "Only support constant 0 padding"; 30 | } 31 | if (ctx->HasAttr("integral_constant_value")) { 32 | int64_t value = ctx->Attr("integral_constant_value"); 33 | CHECK_EQ(value, 0) << "Only support constant 0 padding"; 34 | } 35 | 36 | const auto& padding_before = 37 | ctx->Attr>("padding_before"); 38 | const auto& padding_after = 39 | ctx->Attr>("padding_after"); 40 | CHECK_EQ(padding_before.size(), padding_after.size()); 41 | CHECK_EQ(padding_before.size(), 4); 42 | if (padding_before[0] != 0 || padding_before[1] != 0 || 43 | padding_after[0] != 0 || padding_after[1] != 0) { 44 | UNIMPLEMENTED() 45 | << "TensorRT does not support padding batch and channel dimension"; 46 | } 47 | 48 | nvinfer1::ITensor* x = ctx->SoleInput(); 49 | nvinfer1::DimsHW prePadding{static_cast(padding_before[2]), 50 | static_cast(padding_before[3])}; 51 | nvinfer1::DimsHW postPadding{static_cast(padding_after[2]), 52 | static_cast(padding_after[3])}; 53 | auto* layer = ctx->builder()->addPaddingNd(*x, prePadding, postPadding); 54 | layer->setName(ctx->op_name().c_str()); 55 | ctx->SetSoleOutput(layer->getOutput(0)); 56 | } 57 | }; 58 | 59 | REGISTER_TRT_OP_KERNEL(pad, PaddingOp).EnableTrainPhase().Finalize(); 60 | 61 | } // namespace tensorrt 62 | } // namespace xrt 63 | } // namespace oneflow 64 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/reduce_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | template 24 | class ReduceOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | const auto& axis = ctx->Attr>("axis"); 28 | 29 | int32_t reduce_axis = 0; 30 | for (int i = 0; i < axis.size(); ++i) { 31 | reduce_axis = reduce_axis | (1U << axis[i]); 32 | } 33 | bool keepDimensions = ctx->Attr("keepdims"); 34 | // // TensorRT does not support full reduce without keepDimensions. 35 | // Shape in_shape = ctx->SoleInputShape(); 36 | // if (!keepDimensions) { 37 | // CHECK_NE(reduce_axis, (1U << in_shape.NumAxes()) - 1) 38 | // << "TensorRT does not support full reduce without keepDimensions."; 39 | // } 40 | 41 | nvinfer1::ITensor* in = ctx->SoleInput(); 42 | auto* layer = 43 | ctx->builder()->addReduce(*in, reduce_op, reduce_axis, keepDimensions); 44 | layer->setName(ctx->op_name().c_str()); 45 | ctx->SetSoleOutput(layer->getOutput(0)); 46 | } 47 | }; 48 | 49 | REGISTER_TRT_OP_KERNEL(reduce_sum, ReduceOp) 50 | .EnableTrainPhase() 51 | .Finalize(); 52 | REGISTER_TRT_OP_KERNEL(reduce_mean, ReduceOp) 53 | .EnableTrainPhase() 54 | .Finalize(); 55 | 56 | } // namespace tensorrt 57 | } // namespace xrt 58 | } // namespace oneflow 59 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/reshape_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class ReshapeOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | Shape in_shape = ctx->SoleInputShape(); 28 | Shape shape = ctx->SoleOutputShape(); 29 | CHECK_EQ(shape.Count(0), in_shape.Count(0)); 30 | 31 | nvinfer1::ITensor* input = ctx->SoleInput(); 32 | ctx->SetSoleOutput(helpers::Reshape(ctx, input, shape)); 33 | } 34 | }; 35 | 36 | REGISTER_TRT_OP_KERNEL(reshape, ReshapeOp).EnableTrainPhase().Finalize(); 37 | 38 | class ReshapeLikeOp : public TrtOpKernel { 39 | public: 40 | void Compile(TrtOpContext* ctx) override { 41 | Shape x_shape = ctx->InputShape("in_0"); 42 | Shape like_shape = ctx->InputShape("like_0"); 43 | CHECK_EQ(x_shape.Count(0), like_shape.Count(0)); 44 | 45 | nvinfer1::ITensor* input = ctx->Input("in_0"); 46 | ctx->SetSoleOutput(helpers::Reshape(ctx, input, like_shape)); 47 | } 48 | }; 49 | 50 | REGISTER_TRT_OP_KERNEL(reshape_like, ReshapeLikeOp) 51 | .EnableTrainPhase() 52 | .Finalize(); 53 | 54 | } // namespace tensorrt 55 | } // namespace xrt 56 | } // namespace oneflow 57 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/softmax_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class SoftmaxOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | Shape in_shape = ctx->SoleInputShape(); 27 | CHECK_GE(in_shape.NumAxes(), 2); 28 | int32_t axis = in_shape.NumAxes() - 1; 29 | nvinfer1::ITensor* in = ctx->SoleInput(); 30 | auto* layer = ctx->builder()->addSoftMax(*in); 31 | layer->setAxes((1U << axis)); 32 | layer->setName(ctx->op_name().c_str()); 33 | ctx->SetSoleOutput(layer->getOutput(0)); 34 | } 35 | }; 36 | 37 | REGISTER_TRT_OP_KERNEL(softmax, SoftmaxOp).EnableTrainPhase().Finalize(); 38 | 39 | } // namespace tensorrt 40 | } // namespace xrt 41 | } // namespace oneflow 42 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/topk_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | 19 | namespace oneflow { 20 | namespace xrt { 21 | namespace tensorrt { 22 | 23 | class TopKOp : public TrtOpKernel { 24 | public: 25 | void Compile(TrtOpContext* ctx) override { 26 | Shape in_shape = ctx->SoleInputShape(); 27 | // CHECK_GE(in_shape.NumAxes(), 2); 28 | 29 | int32_t k = ctx->Attr("k"); 30 | // We only compute the top-k at the last dimension. 31 | // CHECK_GE(axis, 1); 32 | // CHECK_LT(axis, in_shape.NumAxes()); 33 | uint32_t reduce_axis = (1U << (in_shape.NumAxes() - 1)); 34 | nvinfer1::ITensor* in = ctx->SoleInput(); 35 | auto* layer = ctx->builder()->addTopK(*in, nvinfer1::TopKOperation::kMAX, k, 36 | reduce_axis); 37 | layer->setName(ctx->op_name().c_str()); 38 | ctx->SetSoleOutput(layer->getOutput(0)); 39 | } 40 | }; 41 | 42 | REGISTER_TRT_OP_KERNEL(top_k, TopKOp).EnableTrainPhase().Finalize(); 43 | 44 | } // namespace tensorrt 45 | } // namespace xrt 46 | } // namespace oneflow 47 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/transpose_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class TransposeOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | const auto& perm = ctx->Attr>("perm"); 28 | Shape in_shape = ctx->SoleInputShape(); 29 | CHECK_EQ(perm.size(), in_shape.NumAxes()); 30 | 31 | nvinfer1::ITensor* input = ctx->SoleInput(); 32 | if (IsIdentity(perm)) { 33 | ctx->SetSoleOutput(input); 34 | } else { 35 | ctx->SetSoleOutput(helpers::Transpose(ctx, input, perm)); 36 | } 37 | } 38 | 39 | bool IsIdentity(const std::vector& perm) const { 40 | bool is_identity = true; 41 | for (int i = 0; i < perm.size(); ++i) { 42 | if (i != perm[i]) { 43 | is_identity = false; 44 | break; 45 | } 46 | } 47 | return is_identity || (perm.size() <= 1); 48 | } 49 | }; 50 | 51 | REGISTER_TRT_OP_KERNEL(transpose, TransposeOp).EnableTrainPhase().Finalize(); 52 | 53 | } // namespace tensorrt 54 | } // namespace xrt 55 | } // namespace oneflow 56 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/unary_grad_ops.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | class SqrtGradOp : public TrtOpKernel { 25 | public: 26 | void Compile(TrtOpContext* ctx) override { 27 | nvinfer1::ITensor* x = ctx->Input("x_0"); 28 | nvinfer1::ITensor* dy = ctx->Input("dy_0"); 29 | // x^0.5 30 | auto* sqrt_layer = 31 | ctx->builder()->addUnary(*x, nvinfer1::UnaryOperation::kSQRT); 32 | std::string sqrt_name = ctx->op_name() + ".sqrt"; 33 | sqrt_layer->setName(sqrt_name.c_str()); 34 | 35 | // 2 * x^0.5 36 | const Shape& x_shape = ctx->InputShape("x_0"); 37 | Shape shape(DimVector(x_shape.NumAxes(), 1)); 38 | 39 | DataType data_type = ctx->InputType("x_0"); 40 | std::string name = ctx->op_name() + ".2"; 41 | nvinfer1::Weights constant = 42 | helpers::Constant(ctx, 2, shape, data_type, name); 43 | auto* constant_layer = 44 | ctx->builder()->addConstant(ShapeToXrtDims(shape), constant); 45 | constant_layer->setName(name.c_str()); 46 | auto* mul_layer = ctx->builder()->addElementWise( 47 | *(sqrt_layer->getOutput(0)), *(constant_layer->getOutput(0)), 48 | nvinfer1::ElementWiseOperation::kPROD); 49 | std::string mul_name = ctx->op_name() + ".mul"; 50 | mul_layer->setName(mul_name.c_str()); 51 | 52 | // 1 / (2 * x^0.5) 53 | auto* layer = ctx->builder()->addUnary(*(mul_layer->getOutput(0)), 54 | nvinfer1::UnaryOperation::kRECIP); 55 | std::string recip_name = ctx->op_name() + ".recip"; 56 | layer->setName(recip_name.c_str()); 57 | ctx->SetSoleOutput(layer->getOutput(0)); 58 | } 59 | }; 60 | 61 | REGISTER_TRT_OP_KERNEL(sqrt_grad, SqrtGradOp).EnableTrainPhase().Finalize(); 62 | 63 | } // namespace tensorrt 64 | } // namespace xrt 65 | } // namespace oneflow 66 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/unary_ops.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | template 25 | class UnaryOp : public TrtOpKernel { 26 | public: 27 | void Compile(TrtOpContext* ctx) override { 28 | nvinfer1::ITensor* x = ctx->SoleInput(); 29 | auto* layer = ctx->builder()->addUnary(*x, unary_op); 30 | layer->setName(ctx->op_name().c_str()); 31 | ctx->SetSoleOutput(layer->getOutput(0)); 32 | } 33 | }; 34 | 35 | REGISTER_TRT_OP_KERNEL(sqrt, UnaryOp) 36 | .EnableTrainPhase() 37 | .Finalize(); 38 | 39 | class RsqrtOp : public TrtOpKernel { 40 | public: 41 | void Compile(TrtOpContext* ctx) override { 42 | nvinfer1::ITensor* x = ctx->SoleInput(); 43 | auto* sqrt_layer = 44 | ctx->builder()->addUnary(*x, nvinfer1::UnaryOperation::kSQRT); 45 | std::string sqrt_name = ctx->op_name() + ".sqrt"; 46 | sqrt_layer->setName(sqrt_name.c_str()); 47 | auto* layer = ctx->builder()->addUnary(*(sqrt_layer->getOutput(0)), 48 | nvinfer1::UnaryOperation::kRECIP); 49 | std::string name = ctx->op_name() + ".recip"; 50 | layer->setName(name.c_str()); 51 | ctx->SetSoleOutput(layer->getOutput(0)); 52 | } 53 | }; 54 | 55 | REGISTER_TRT_OP_KERNEL(rsqrt, RsqrtOp).EnableTrainPhase().Finalize(); 56 | 57 | } // namespace tensorrt 58 | } // namespace xrt 59 | } // namespace oneflow 60 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/ops/upsample_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/common/shape_util.h" 17 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 18 | #include "oneflow_xrt/compiler/tensorrt/ops/op_kernel.h" 19 | #include "oneflow_xrt/compiler/tensorrt/trt_helpers.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace tensorrt { 24 | 25 | template 26 | class Upsample2dOp : public TrtOpKernel { 27 | public: 28 | void Compile(TrtOpContext* ctx) override { 29 | Shape in_shape = ctx->SoleInputShape(); 30 | CHECK_EQ(in_shape.NumAxes(), 4); 31 | 32 | const double height_scale = ctx->Attr("height_scale"); 33 | const double width_scale = ctx->Attr("width_scale"); 34 | std::vector output_size = 35 | ctx->Attr>("output_size"); 36 | CHECK(output_size.empty()) 37 | << "Upsample output_size is not supported in TensorRT"; 38 | 39 | nvinfer1::ITensor* in = ctx->SoleInput(); 40 | nvinfer1::IResizeLayer* layer = ctx->builder()->addResize(*in); 41 | layer->setName(ctx->op_name().c_str()); 42 | 43 | std::vector scales{1.0, 1.0, height_scale, width_scale}; 44 | layer->setScales(scales.data(), 4); 45 | layer->setResizeMode(resize_mode); 46 | layer->setSelectorForSinglePixel(nvinfer1::ResizeSelector::kFORMULA); 47 | layer->setNearestRounding(nvinfer1::ResizeRoundMode::kFLOOR); 48 | layer->setCoordinateTransformation( 49 | nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC); 50 | ctx->SetSoleOutput(layer->getOutput(0)); 51 | } 52 | }; 53 | 54 | REGISTER_TRT_OP_KERNEL(upsample_nearest_2d, 55 | Upsample2dOp) 56 | .EnableTrainPhase() 57 | .Finalize(); 58 | 59 | } // namespace tensorrt 60 | } // namespace xrt 61 | } // namespace oneflow 62 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/plugin/README.md: -------------------------------------------------------------------------------- 1 | ## TensorRT Plugin 2 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/plugin/broadcast_like_plugin.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/plugin/broadcast_like_plugin.h" 17 | 18 | #include "oneflow_xrt/compiler/tensorrt/trt_shape.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | nvinfer1::DimsExprs BroadcastLikePlugin::getOutputDimensions( 25 | int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, 26 | nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { 27 | // inputs are x, like 28 | CHECK_EQ(nb_inputs, 2); 29 | CHECK_EQ(output_index, 0); 30 | return inputs[1]; 31 | } 32 | 33 | nvinfer1::DataType BroadcastLikePlugin::getOutputDataType( 34 | int index, const nvinfer1::DataType* input_types, 35 | int nb_inputs) const TRT_NOEXCEPT { 36 | CHECK_EQ(nb_inputs, 2); 37 | return input_types[0]; 38 | } 39 | 40 | bool BroadcastLikePlugin::supportsFormatCombination( 41 | int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, 42 | int nb_outputs) TRT_NOEXCEPT { 43 | const auto& desc = in_out[pos]; 44 | return desc.type == in_out[0].type && 45 | desc.format == nvinfer1::TensorFormat::kLINEAR; 46 | } 47 | 48 | int BroadcastLikePlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc, 49 | const nvinfer1::PluginTensorDesc* output_desc, 50 | const void* const* inputs, 51 | void* const* outputs, void* workspace, 52 | cudaStream_t stream) TRT_NOEXCEPT { 53 | return 0; 54 | } 55 | 56 | nvinfer1::IPluginV2DynamicExt* BroadcastLikePlugin::clone() const TRT_NOEXCEPT { 57 | auto* plugin = new BroadcastLikePlugin(*this); 58 | plugin->setPluginNamespace(this->getPluginNamespace()); 59 | return plugin; 60 | } 61 | 62 | } // namespace tensorrt 63 | } // namespace xrt 64 | } // namespace oneflow 65 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/plugin/broadcast_like_plugin.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_PLUGIN_BROADCAST_LIKE_PLUGIN_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_PLUGIN_BROADCAST_LIKE_PLUGIN_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "oneflow_xrt/compiler/tensorrt/trt_plugin.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace tensorrt { 27 | 28 | class BroadcastLikePlugin : public TrtPlugin { 29 | public: 30 | BroadcastLikePlugin(const std::string& name, 31 | const std::vector& broadcast_axes) 32 | : name_(name), broadcast_axes_(broadcast_axes) {} 33 | 34 | const char* getPluginType() const TRT_NOEXCEPT override { 35 | return "BroadcastLike"; 36 | } 37 | 38 | int getNbOutputs() const TRT_NOEXCEPT { return 1; } 39 | 40 | nvinfer1::DimsExprs getOutputDimensions( 41 | int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, 42 | nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override; 43 | 44 | bool supportsFormatCombination(int pos, 45 | const nvinfer1::PluginTensorDesc* in_out, 46 | int nb_inputs, 47 | int nb_outputs) TRT_NOEXCEPT override; 48 | 49 | int enqueue(const nvinfer1::PluginTensorDesc* input_desc, 50 | const nvinfer1::PluginTensorDesc* output_desc, 51 | const void* const* inputs, void* const* outputs, void* workspace, 52 | cudaStream_t stream) TRT_NOEXCEPT override; 53 | 54 | nvinfer1::DataType getOutputDataType( 55 | int index, const nvinfer1::DataType* input_types, 56 | int nb_inputs) const TRT_NOEXCEPT override; 57 | 58 | void destroy() TRT_NOEXCEPT override { delete this; } 59 | 60 | nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; 61 | 62 | private: 63 | std::string name_; 64 | std::vector broadcast_axes_; 65 | }; 66 | 67 | } // namespace tensorrt 68 | } // namespace xrt 69 | } // namespace oneflow 70 | 71 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_PLUGIN_BROADCAST_LIKE_PLUGIN_H_ 72 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_graph_compiler.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_TRT_GRAPH_COMPILER_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_TRT_GRAPH_COMPILER_H_ 18 | 19 | #include "NvInfer.h" 20 | #include "oneflow_xrt/compiler/graph_compiler.h" 21 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 22 | #include "oneflow_xrt/compiler/tensorrt/trt_builder.h" 23 | #include "oneflow_xrt/compiler/tensorrt/trt_executable.h" 24 | #include "oneflow_xrt/compiler/tensorrt/trt_value.h" 25 | 26 | namespace oneflow { 27 | namespace xrt { 28 | namespace tensorrt { 29 | 30 | class TrtGraphCompiler : public GraphCompiler::Impl { 31 | public: 32 | explicit TrtGraphCompiler(const std::string& name) 33 | : GraphCompiler::Impl(name) { 34 | builder_ = std::make_shared(name); 35 | } 36 | 37 | virtual ~TrtGraphCompiler() = default; 38 | 39 | std::shared_ptr Compile( 40 | const XrtGraph* graph, const std::vector& entry_params, 41 | const std::vector& return_params, 42 | const std::vector& aliases) override; 43 | 44 | private: 45 | void SetupKernelContextParam(const XrtNode* node, 46 | TrtOpContext::Param* context_param); 47 | 48 | void PopulateEntryParams(const std::vector& entry_params); 49 | 50 | Argument ArgFromParameter(const Parameter& param); 51 | 52 | private: 53 | std::shared_ptr builder_; 54 | 55 | std::unordered_map arguments_; 56 | std::unordered_map operands_; 57 | }; 58 | 59 | } // namespace tensorrt 60 | } // namespace xrt 61 | } // namespace oneflow 62 | 63 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_TRT_GRAPH_COMPILER_H_ 64 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_helpers.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_TRT_HELPERS_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_TRT_HELPERS_H_ 18 | 19 | #include "oneflow/core/common/scalar.h" 20 | #include "oneflow_xrt/compiler/tensorrt/ops/op_context.h" 21 | #include "oneflow_xrt/compiler/tensorrt/trt_shape.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | namespace tensorrt { 26 | 27 | namespace helpers { 28 | 29 | bool DimsEqual(const nvinfer1::Dims& dim1, const nvinfer1::Dims& dim2); 30 | 31 | nvinfer1::Weights Constant(TrtOpContext* ctx, const Scalar& value, 32 | const Shape& shape, DataType data_type, 33 | const std::string& name); 34 | 35 | nvinfer1::ITensor* Reshape(TrtOpContext* ctx, nvinfer1::ITensor* in, 36 | const Shape& shape); 37 | 38 | nvinfer1::ITensor* Reshape(TrtOpContext* ctx, nvinfer1::Weights in, 39 | const Shape& shape); 40 | 41 | nvinfer1::ITensor* Transpose(TrtOpContext* ctx, nvinfer1::ITensor* in, 42 | const std::vector& permute); 43 | 44 | nvinfer1::ITensor* Transpose(TrtOpContext* ctx, nvinfer1::Weights in, 45 | const Shape& shape, 46 | const std::vector& permute); 47 | 48 | } // namespace helpers 49 | 50 | } // namespace tensorrt 51 | } // namespace xrt 52 | } // namespace oneflow 53 | 54 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_TRT_HELPERS_H_ 55 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_logger.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/tensorrt/trt_logger.h" 17 | 18 | #include "glog/logging.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace tensorrt { 23 | 24 | namespace nv { 25 | 26 | using ILogger = ::nvinfer1::ILogger; 27 | 28 | void Logger::log(ILogger::Severity severity, const char* msg) TRT_NOEXCEPT { 29 | switch (severity) { 30 | case ILogger::Severity::kVERBOSE: 31 | case ILogger::Severity::kINFO: { 32 | VLOG(2) << name_ << ": " << msg; 33 | break; 34 | } 35 | case ILogger::Severity::kWARNING: { 36 | VLOG(2) << name_ << ": " << msg; 37 | break; 38 | } 39 | case ILogger::Severity::kERROR: { 40 | LOG(FATAL) << name_ << ": " << msg; 41 | break; 42 | } 43 | case ILogger::Severity::kINTERNAL_ERROR: { 44 | LOG(FATAL) << name_ << ": " << msg; 45 | break; 46 | } 47 | default: { 48 | LOG(FATAL) << name_ << ": Unknow severity level " << int(severity) 49 | << " with message: " << msg; 50 | break; 51 | } 52 | } 53 | } 54 | 55 | } // namespace nv 56 | 57 | } // namespace tensorrt 58 | } // namespace xrt 59 | } // namespace oneflow 60 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_logger.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_TRT_LOGGER_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_TRT_LOGGER_H_ 18 | 19 | #include 20 | 21 | #include "NvInfer.h" 22 | #include "oneflow_xrt/compiler/tensorrt/common.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace tensorrt { 27 | 28 | namespace nv { 29 | 30 | class Logger : public nvinfer1::ILogger { 31 | public: 32 | Logger() = default; 33 | 34 | Logger(const std::string& name) : name_(name) {} 35 | 36 | void log(nvinfer1::ILogger::Severity severity, 37 | const char* msg) TRT_NOEXCEPT override; 38 | 39 | private: 40 | std::string name_ = "TensorRT Logging"; 41 | }; 42 | 43 | } // namespace nv 44 | 45 | } // namespace tensorrt 46 | } // namespace xrt 47 | } // namespace oneflow 48 | 49 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_TRT_LOGGER_H_ 50 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_unique_ptr.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_TRT_UNIQUE_PTR_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_TRT_UNIQUE_PTR_H_ 18 | 19 | #include "NvInferVersion.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace tensorrt { 24 | 25 | namespace nv { 26 | 27 | struct PtrDeleter { 28 | template 29 | inline void operator()(T* obj) { 30 | if (obj) { 31 | #if NV_TENSORRT_MAJOR >= 8 32 | delete obj; 33 | #else 34 | obj->destroy(); 35 | #endif 36 | } 37 | } 38 | }; 39 | 40 | template 41 | using unique_ptr = std::unique_ptr; 42 | 43 | } // namespace nv 44 | 45 | } // namespace tensorrt 46 | } // namespace xrt 47 | } // namespace oneflow 48 | 49 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_TRT_UNIQUE_PTR_H_ 50 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/tensorrt/trt_value.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_TENSORRT_TRT_VALUE_H_ 17 | #define ONEFLOW_XRT_COMPILER_TENSORRT_TRT_VALUE_H_ 18 | 19 | #include "NvInfer.h" 20 | #include "oneflow_xrt/compiler/parameter.h" 21 | #include "oneflow_xrt/compiler/tensorrt/trt_builder.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | namespace tensorrt { 26 | 27 | // TensorRT ITensor or Weights. 28 | class TrtValue { 29 | public: 30 | TrtValue() = default; 31 | 32 | int handle() const { return handle_; } 33 | 34 | TrtValueKind ValueKind(TrtBuilder* builder) const { 35 | CHECK_EQ(builder_, builder); 36 | return builder_->ValueKind(handle_); 37 | } 38 | 39 | nvinfer1::ITensor* AsTensor(TrtBuilder* builder) { 40 | CHECK_EQ(builder_, builder); 41 | return builder_->GetTensor(handle_); 42 | } 43 | 44 | nvinfer1::Weights& AsWeight(TrtBuilder* builder) { 45 | CHECK_EQ(builder_, builder); 46 | return builder_->GetWeight(handle_); 47 | } 48 | 49 | inline static TrtValue Parameter(TrtBuilder* builder, 50 | const xrt::Parameter& param); 51 | 52 | inline static TrtValue Tensor(TrtBuilder* builder, nvinfer1::ITensor* tensor); 53 | 54 | inline static TrtValue Weight(TrtBuilder* builder, nvinfer1::Weights& weight); 55 | 56 | private: 57 | // Unique id for the `TrtValue`. 58 | int64_t handle_ = -1; 59 | TrtBuilder* builder_ = nullptr; 60 | }; 61 | 62 | TrtValue TrtValue::Parameter(TrtBuilder* builder, const xrt::Parameter& param) { 63 | TrtValue trt_value; 64 | trt_value.handle_ = builder->AddParameter(param); 65 | trt_value.builder_ = builder; 66 | return trt_value; 67 | } 68 | 69 | TrtValue TrtValue::Tensor(TrtBuilder* builder, nvinfer1::ITensor* tensor) { 70 | TrtValue trt_value; 71 | trt_value.handle_ = builder->AddTensor(tensor); 72 | trt_value.builder_ = builder; 73 | return trt_value; 74 | } 75 | 76 | TrtValue TrtValue::Weight(TrtBuilder* builder, nvinfer1::Weights& weight) { 77 | TrtValue trt_value; 78 | trt_value.handle_ = builder->AddWeight(weight); 79 | trt_value.builder_ = builder; 80 | return trt_value; 81 | } 82 | 83 | } // namespace tensorrt 84 | } // namespace xrt 85 | } // namespace oneflow 86 | 87 | #endif // ONEFLOW_XRT_COMPILER_TENSORRT_TRT_VALUE_H_ 88 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/README.md: -------------------------------------------------------------------------------- 1 | ## XLA 2 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/memory/device_buffer_allocator.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_MEMORY_DEVICE_BUFFER_ALLOCATOR_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_MEMORY_DEVICE_BUFFER_ALLOCATOR_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "oneflow_xrt/compiler/xla/memory/device_memory_pool.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace mola { 27 | 28 | class DeviceBufferAllocator { 29 | public: 30 | explicit DeviceBufferAllocator(std::shared_ptr mem_pool) 31 | : mem_pool_(mem_pool) { 32 | // mem_pool_->Reserve(256 * 1024 * 1024/*256MiB*/); 33 | } 34 | 35 | virtual ~DeviceBufferAllocator() {} 36 | 37 | void* AllocateRaw(size_t offset, size_t size) { 38 | return mem_pool_->AllocateRaw(offset, size); 39 | } 40 | 41 | void Reserve(size_t size) { 42 | while (size > mem_pool_->capacity()) { 43 | std::unique_lock lock(mutex_); 44 | cond_.wait(lock, [&]() { return lock_count_ == 0; }); 45 | 46 | mem_pool_->Reserve(size); 47 | } 48 | } 49 | 50 | void Lock() { 51 | std::unique_lock lock(mutex_); 52 | ++lock_count_; 53 | } 54 | 55 | void Unlock() { 56 | std::unique_lock lock(mutex_); 57 | --lock_count_; 58 | cond_.notify_all(); 59 | } 60 | 61 | private: 62 | volatile uint64_t lock_count_ = 0; 63 | 64 | std::mutex mutex_; 65 | 66 | std::condition_variable cond_; 67 | 68 | std::shared_ptr mem_pool_; 69 | }; 70 | 71 | } // namespace mola 72 | } // namespace xrt 73 | } // namespace oneflow 74 | 75 | #endif // ONEFLOW_XRT_COMPILER_XLA_MEMORY_DEVICE_BUFFER_ALLOCATOR_H_ 76 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/activation_grad_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 19 | #include "tensorflow/compiler/xla/client/lib/constants.h" 20 | #include "tensorflow/compiler/xla/client/lib/math.h" 21 | #include "tensorflow/compiler/xla/client/xla_builder.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | namespace mola { 26 | 27 | class TanhGradOp : public XlaOpKernel { 28 | public: 29 | void Compile(XlaOpContext* ctx) override { 30 | xla::XlaOp y = ctx->Input("y_0"); 31 | xla::XlaOp dy = ctx->Input("dy_0"); 32 | xla::XlaOp one = xla::ScalarLike(y, 1.f); 33 | // dx = dy * (1 - y * y) 34 | xla::XlaOp dx = dy * (one - (y * y)); 35 | ctx->SetOutput("dx_0", dx); 36 | } 37 | }; 38 | REGISTER_XLA_OP_KERNEL(tanh_grad, TanhGradOp).Finalize(); 39 | 40 | class GeluGradOp : public XlaOpKernel { 41 | public: 42 | void Compile(XlaOpContext* ctx) override { 43 | xla::XlaOp x = ctx->Input("x_0"); 44 | xla::XlaOp dy = ctx->Input("dy_0"); 45 | xla::XlaOp dot_5 = xla::ScalarLike(x, 0.5f); 46 | xla::XlaOp inv_sqrt2 = xla::ScalarLike(x, std::sqrt(0.5f)); 47 | xla::XlaOp one = xla::ScalarLike(x, 1.f); 48 | 49 | xla::XlaOp coef = xla::ScalarLike(x, std::sqrt(2.f / std::acos(-1.f))); 50 | // coef = 1 + erf(sqrt(0.5) * x) + x * coef * exp(-0.5 * x * x) 51 | coef = one + xla::Erf(inv_sqrt2 * x) + 52 | (x * coef * xla::Exp(xla::Neg(dot_5) * x * x)); 53 | 54 | ctx->SetOutput("dx_0", dot_5 * coef * dy); 55 | } 56 | }; 57 | REGISTER_XLA_OP_KERNEL(gelu_grad, GeluGradOp).Finalize(); 58 | 59 | } // namespace mola 60 | } // namespace xrt 61 | } // namespace oneflow 62 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/add_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class AddOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | int num_inputs = ctx->num_inputs(); 28 | CHECK_GT(num_inputs, 0); 29 | Shape shape = ctx->InputShape("in_0"); 30 | xla::XlaOp sum = ctx->Input("in_0"); 31 | 32 | for (int i = 1; i < num_inputs; ++i) { 33 | std::string name = absl::StrCat("in_", i); 34 | CHECK_EQ(shape, ctx->InputShape(name)); 35 | sum = xla::Add(sum, ctx->Input(name)); 36 | } 37 | 38 | ctx->SetSoleOutput(sum); 39 | } 40 | }; 41 | 42 | REGISTER_XLA_OP_KERNEL(add_n, AddOp).Finalize(); 43 | 44 | } // namespace mola 45 | } // namespace xrt 46 | } // namespace oneflow 47 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/argument_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class XrtEntryOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | xla::XlaOp value = ctx->Variable(); 28 | ctx->SetOutput("value", value); 29 | } 30 | }; 31 | 32 | class XrtReturnOp : public XlaOpKernel { 33 | public: 34 | void Compile(XlaOpContext* ctx) override { 35 | xla::XlaOp value = ctx->Input("value"); 36 | ctx->SetVariable(value); 37 | } 38 | }; 39 | 40 | REGISTER_XLA_OP_KERNEL(XrtEntry, XrtEntryOp).Finalize(); 41 | REGISTER_XLA_OP_KERNEL(XrtReturn, XrtReturnOp).Finalize(); 42 | 43 | } // namespace mola 44 | } // namespace xrt 45 | } // namespace oneflow 46 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/batch_matmul_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/lib/matrix.h" 19 | #include "tensorflow/compiler/xla/client/xla_builder.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace mola { 24 | 25 | class BatchMatMulOp : public XlaOpKernel { 26 | public: 27 | void Compile(XlaOpContext* ctx) override { 28 | Shape shape_a = ctx->InputShape("a_0"); 29 | Shape shape_b = ctx->InputShape("b_0"); 30 | CHECK_EQ(shape_a.NumAxes(), shape_b.NumAxes()); 31 | CHECK_GT(shape_a.NumAxes(), 2); 32 | 33 | bool transpose_a = ctx->Attr("transpose_a"); 34 | bool transpose_b = ctx->Attr("transpose_b"); 35 | 36 | xla::XlaOp a = ctx->Input("a_0"); 37 | xla::XlaOp b = ctx->Input("b_0"); 38 | 39 | xla::XlaOp out = xla::BatchDot(a, transpose_a, b, transpose_b); 40 | if (ctx->HasInput("_add_to_output_0")) { 41 | out = xla::Add(out, ctx->Input("_add_to_output_0")); 42 | } 43 | ctx->SetOutput("out_0", out); 44 | } 45 | }; 46 | 47 | REGISTER_XLA_OP_KERNEL(batch_matmul, BatchMatMulOp).Finalize(); 48 | 49 | } // namespace mola 50 | } // namespace xrt 51 | } // namespace oneflow 52 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/bias_add_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class BiasAddOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | Shape in_shape = ctx->InputShape("a_0"); 28 | Shape bias_shape = ctx->InputShape("b_0"); 29 | CHECK_GE(in_shape.NumAxes(), 2); 30 | CHECK_EQ(bias_shape.NumAxes(), 1); 31 | 32 | CHECK_EQ(ctx->InputType("a_0"), ctx->InputType("b_0")); 33 | 34 | xla::XlaOp in = ctx->Input("a_0"); 35 | xla::XlaOp bias = ctx->Input("b_0"); 36 | 37 | // Channel dim for NCHW data formart 38 | int channel_dim = 1; 39 | ctx->SetOutput("out_0", xla::Add(in, bias, {channel_dim})); 40 | } 41 | }; 42 | 43 | REGISTER_XLA_OP_KERNEL(bias_add, BiasAddOp).Finalize(); 44 | 45 | } // namespace mola 46 | } // namespace xrt 47 | } // namespace oneflow 48 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/binary_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_OPS_BINARY_OP_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_OPS_BINARY_OP_H_ 18 | 19 | #include "tensorflow/compiler/xla/client/xla_builder.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace mola { 24 | namespace op { 25 | 26 | #define OFXLA_DECLARE_BINARY_OP(op) \ 27 | struct op { \ 28 | xla::XlaOp operator()(xla::XlaOp a, xla::XlaOp b) { \ 29 | return xla::op(a, b); \ 30 | } \ 31 | }; 32 | 33 | OFXLA_DECLARE_BINARY_OP(Add); 34 | OFXLA_DECLARE_BINARY_OP(Sub); 35 | OFXLA_DECLARE_BINARY_OP(Mul); 36 | OFXLA_DECLARE_BINARY_OP(Div); 37 | OFXLA_DECLARE_BINARY_OP(Min); 38 | OFXLA_DECLARE_BINARY_OP(Pow); 39 | 40 | #undef OFXLA_DECLARE_BINARY_OP 41 | 42 | } // namespace op 43 | } // namespace mola 44 | } // namespace xrt 45 | } // namespace oneflow 46 | 47 | #endif // ONEFLOW_XRT_COMPILER_XLA_OPS_BINARY_OP_H_ 48 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/broadcast_binary_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow/core/common/shape_view.h" 17 | #include "oneflow_xrt/compiler/xla/ops/binary_op.h" 18 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 20 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | template 27 | class BcastBinaryOp : public XlaOpKernel { 28 | public: 29 | void Compile(XlaOpContext* ctx) override { 30 | Shape shape_a = ctx->InputShape("x_0"); 31 | Shape shape_b = ctx->InputShape("y_0"); 32 | 33 | int axes = std::max(shape_a.NumAxes(), shape_b.NumAxes()); 34 | shape_a = CreateLeftExtendedShape(ShapeView(shape_a), axes); 35 | shape_b = CreateLeftExtendedShape(ShapeView(shape_b), axes); 36 | 37 | xla::XlaOp a = Reshape(ctx->Input("x_0"), shape_a); 38 | xla::XlaOp b = Reshape(ctx->Input("y_0"), shape_b); 39 | ctx->SetOutput("z_0", BinaryOp()(a, b)); 40 | } 41 | }; 42 | 43 | REGISTER_XLA_OP_KERNEL(broadcast_add, BcastBinaryOp).Finalize(); 44 | REGISTER_XLA_OP_KERNEL(broadcast_sub, BcastBinaryOp).Finalize(); 45 | REGISTER_XLA_OP_KERNEL(broadcast_mul, BcastBinaryOp).Finalize(); 46 | REGISTER_XLA_OP_KERNEL(broadcast_div, BcastBinaryOp).Finalize(); 47 | REGISTER_XLA_OP_KERNEL(broadcast_min, BcastBinaryOp).Finalize(); 48 | 49 | } // namespace mola 50 | } // namespace xrt 51 | } // namespace oneflow 52 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/cast_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/xla/xla_data_type.h" 19 | #include "tensorflow/compiler/xla/client/xla_builder.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace mola { 24 | 25 | class CastOp : public XlaOpKernel { 26 | public: 27 | void Compile(XlaOpContext* ctx) override { 28 | DataType dest_dtype = ctx->Attr("dtype"); 29 | DataType src_dtype = ctx->SoleInputType(); 30 | xla::XlaOp in = ctx->SoleInput(); 31 | if (src_dtype == dest_dtype) { 32 | ctx->SetSoleOutput(in); 33 | } else { 34 | xla::PrimitiveType data_type = DataTypeToPrimitiveType(dest_dtype); 35 | ctx->SetSoleOutput(xla::ConvertElementType(in, data_type)); 36 | } 37 | } 38 | }; 39 | 40 | REGISTER_XLA_OP_KERNEL(cast, CastOp).Finalize(); 41 | 42 | } // namespace mola 43 | } // namespace xrt 44 | } // namespace oneflow 45 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/fc_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class FullyConnectedOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | xla::XlaOp in = ctx->Input("in"); 28 | xla::XlaOp weight = xla::Transpose(ctx->Input("weight"), {1, 0}); 29 | xla::XlaOp result = xla::Dot(in, weight); 30 | 31 | if (ctx->Attr("use_bias")) { 32 | result = xla::Add(result, ctx->Input("bias")); 33 | } 34 | ctx->SetOutput("out", result); 35 | } 36 | }; 37 | 38 | REGISTER_XLA_OP_KERNEL(fully_connected, FullyConnectedOp).Finalize(); 39 | 40 | } // namespace mola 41 | } // namespace xrt 42 | } // namespace oneflow 43 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/matmul_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class MatMulOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | Shape a_shape = ctx->InputShape("a_0"); 28 | Shape b_shape = ctx->InputShape("b_0"); 29 | CHECK_GE(a_shape.NumAxes(), 2); 30 | CHECK_EQ(a_shape.NumAxes(), b_shape.NumAxes()); 31 | 32 | if (a_shape.NumAxes() > 2) { 33 | auto batch_matmul_kernel = BuildOpKernel(ctx->device(), "batch_matmul"); 34 | batch_matmul_kernel->Compile(ctx); 35 | return; 36 | } 37 | 38 | bool transpose_a = ctx->Attr("transpose_a"); 39 | bool transpose_b = ctx->Attr("transpose_b"); 40 | 41 | xla::XlaOp a = ctx->Input("a_0"); 42 | xla::XlaOp b = ctx->Input("b_0"); 43 | 44 | auto lhs = transpose_a ? xla::Transpose(a, {1, 0}) : a; 45 | auto rhs = transpose_b ? xla::Transpose(b, {1, 0}) : b; 46 | xla::XlaOp out = xla::Dot(lhs, rhs); 47 | if (ctx->HasInput("_add_to_output_0")) { 48 | out = xla::Add(out, ctx->Input("_add_to_output_0")); 49 | } 50 | ctx->SetOutput("out_0", out); 51 | } 52 | }; 53 | 54 | REGISTER_XLA_OP_KERNEL(matmul, MatMulOp).Finalize(); 55 | 56 | } // namespace mola 57 | } // namespace xrt 58 | } // namespace oneflow 59 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/op_kernel.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_OPS_OP_KERNEL_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_OPS_OP_KERNEL_H_ 18 | 19 | #include "oneflow_xrt/common/device.h" 20 | #include "oneflow_xrt/common/registry.h" 21 | #include "oneflow_xrt/compiler/kernel/op_kernel.h" 22 | #include "oneflow_xrt/compiler/kernel/op_kernel_registry.h" 23 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 24 | #include "oneflow_xrt/compiler/xla/xla_macro.h" 25 | 26 | namespace oneflow { 27 | namespace xrt { 28 | namespace mola { 29 | 30 | class XlaOpKernel : public OpKernel { 31 | public: 32 | virtual void Compile(XlaOpContext* ctx) = 0; 33 | 34 | XlaOpKernel() = default; 35 | virtual ~XlaOpKernel() = default; 36 | }; 37 | 38 | #define REGISTER_XLA_OP_KERNEL(OpName, KernelType) \ 39 | static OpKernelRegistrar _xla_op_kernel_##OpName##_ \ 40 | __attribute__((unused)) = \ 41 | OpKernelRegistrar(#OpName) \ 42 | .SetEngine(XrtEngine::XLA) \ 43 | .EnableTrainPhase() \ 44 | .SetFactory([]() -> OpKernelBase* { return new KernelType; }) 45 | 46 | inline std::shared_ptr BuildOpKernel(const XrtDevice& device, 47 | const std::string& op_name) { 48 | OpKernelRegKey reg_key{op_name, XrtEngine::XLA, device}; 49 | const auto& f = XRT_REGISTER_LOOKUP(OpKernelRegId, reg_key); 50 | auto* xla_kernel = dynamic_cast(f()); 51 | CHECK(xla_kernel) << "failed to build xla op kernel for " << reg_key; 52 | return std::shared_ptr(xla_kernel); 53 | } 54 | 55 | } // namespace mola 56 | } // namespace xrt 57 | } // namespace oneflow 58 | 59 | #endif // ONEFLOW_XRT_COMPILER_XLA_OPS_OP_KERNEL_H_ 60 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/optimizer_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_OPS_OPTIMIZER_OP_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_OPS_OPTIMIZER_OP_H_ 18 | 19 | #include "oneflow/core/operator/op_conf.pb.h" 20 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 21 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 22 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 23 | #include "tensorflow/compiler/xla/client/lib/constants.h" 24 | #include "tensorflow/compiler/xla/client/lib/math.h" 25 | #include "tensorflow/compiler/xla/client/xla_builder.h" 26 | 27 | namespace oneflow { 28 | namespace xrt { 29 | namespace mola { 30 | 31 | class OptimizerOp : public XlaOpKernel { 32 | public: 33 | void Compile(XlaOpContext* ctx) override { 34 | xla::XlaOp gradient = ctx->Input("model_diff_0"); 35 | xla::XlaOp learning_rate = ctx->Input("learning_rate_0"); 36 | ApplyUpdate(ctx, gradient, learning_rate); 37 | } 38 | 39 | private: 40 | virtual void ApplyUpdate(XlaOpContext* ctx, xla::XlaOp gradient, 41 | xla::XlaOp learning_rate) = 0; 42 | }; 43 | 44 | } // namespace mola 45 | } // namespace xrt 46 | } // namespace oneflow 47 | 48 | #endif // ONEFLOW_XRT_COMPILER_XLA_OPS_OPTIMIZER_OP_H_ 49 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/reshape_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 19 | #include "tensorflow/compiler/xla/client/xla_builder.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | namespace mola { 24 | 25 | class ReshapeOp : public XlaOpKernel { 26 | public: 27 | void Compile(XlaOpContext* ctx) override { 28 | Shape in_shape = ctx->SoleInputShape(); 29 | Shape shape = ctx->SoleOutputShape(); 30 | CHECK_EQ(shape.Count(0), in_shape.Count(0)); 31 | 32 | ctx->SetSoleOutput(Reshape(ctx->SoleInput(), shape)); 33 | } 34 | }; 35 | 36 | REGISTER_XLA_OP_KERNEL(reshape, ReshapeOp).Finalize(); 37 | 38 | class ReshapeLikeOp : public XlaOpKernel { 39 | public: 40 | void Compile(XlaOpContext* ctx) override { 41 | Shape x_shape = ctx->InputShape("in_0"); 42 | Shape like_shape = ctx->InputShape("like_0"); 43 | CHECK_EQ(x_shape.Count(0), like_shape.Count(0)); 44 | 45 | ctx->SetOutput("out_0", Reshape(ctx->Input("in_0"), like_shape)); 46 | } 47 | }; 48 | 49 | REGISTER_XLA_OP_KERNEL(reshape_like, ReshapeLikeOp).Finalize(); 50 | 51 | } // namespace mola 52 | } // namespace xrt 53 | } // namespace oneflow 54 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/scalar_binary_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/binary_op.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 18 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 19 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 20 | #include "tensorflow/compiler/xla/client/xla_builder.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | template 27 | class ScalarBinaryOp : public XlaOpKernel { 28 | public: 29 | void Compile(XlaOpContext* ctx) override { 30 | xla::XlaOp scalar = Scalar(ctx); 31 | xla::XlaOp in = ctx->SoleInput(); 32 | 33 | ctx->SetSoleOutput(BinaryOp()(in, scalar)); 34 | } 35 | 36 | xla::XlaOp Scalar(XlaOpContext* ctx) const { 37 | xla::XlaBuilder* builder = ctx->builder(); 38 | DataType data_type = ctx->SoleInputType(); 39 | if (ctx->Attr("has_int_operand")) { 40 | int64_t value = ctx->Attr("int_operand"); 41 | return IntegerLiteral(builder, data_type, value); 42 | } else if (ctx->Attr("has_float_operand")) { 43 | double value = ctx->Attr("float_operand"); 44 | return FloatLiteral(builder, data_type, value); 45 | } 46 | UNIMPLEMENTED(); 47 | return xla::XlaOp(); 48 | } 49 | }; 50 | 51 | REGISTER_XLA_OP_KERNEL(scalar_add, ScalarBinaryOp).Finalize(); 52 | REGISTER_XLA_OP_KERNEL(scalar_mul, ScalarBinaryOp).Finalize(); 53 | REGISTER_XLA_OP_KERNEL(scalar_div, ScalarBinaryOp).Finalize(); 54 | REGISTER_XLA_OP_KERNEL(scalar_pow, ScalarBinaryOp).Finalize(); 55 | 56 | } // namespace mola 57 | } // namespace xrt 58 | } // namespace oneflow 59 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/square_sum_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "oneflow_xrt/compiler/xla/xla_helpers.h" 19 | #include "tensorflow/compiler/xla/client/lib/math.h" 20 | #include "tensorflow/compiler/xla/client/xla_builder.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | class SquareSumOp : public XlaOpKernel { 27 | public: 28 | void Compile(XlaOpContext* ctx) override { 29 | xla::XlaOp x = ctx->Input("x_0"); 30 | Shape x_shape = ctx->InputShape("x_0"); 31 | xla::XlaBuilder* builder = ctx->builder(); 32 | DataType data_type = ctx->SoleInputType(); 33 | xla::XlaOp sum; 34 | std::vector x_dims(x_shape.NumAxes()); 35 | std::iota(x_dims.begin(), x_dims.end(), 0); 36 | xla::XlaComputation add_func = CreateAddFunc(data_type); 37 | sum = 38 | xla::Reduce(xla::Square(x), Zero(builder, data_type), add_func, x_dims); 39 | ctx->SetSoleOutput(Reshape(sum, Shape({1}))); 40 | } 41 | }; 42 | 43 | REGISTER_XLA_OP_KERNEL(square_sum, SquareSumOp).Finalize(); 44 | 45 | } // namespace mola 46 | } // namespace xrt 47 | } // namespace oneflow 48 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/transpose_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 17 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 18 | #include "tensorflow/compiler/xla/client/xla_builder.h" 19 | 20 | namespace oneflow { 21 | namespace xrt { 22 | namespace mola { 23 | 24 | class TransposeOp : public XlaOpKernel { 25 | public: 26 | void Compile(XlaOpContext* ctx) override { 27 | const auto& perm = ctx->Attr>("perm"); 28 | Shape x_shape = ctx->SoleInputShape(); 29 | CHECK_EQ(perm.size(), x_shape.NumAxes()); 30 | 31 | xla::XlaOp x = ctx->SoleInput(); 32 | if (IsIdentity(perm)) { 33 | ctx->SetSoleOutput(x); 34 | } else { 35 | std::vector transposed_order(x_shape.NumAxes()); 36 | for (int i = 0; i < x_shape.NumAxes(); ++i) { 37 | transposed_order[i] = perm[i]; 38 | } 39 | ctx->SetSoleOutput(xla::Transpose(x, transposed_order)); 40 | } 41 | } 42 | 43 | bool IsIdentity(const std::vector& perm) const { 44 | bool is_identity = true; 45 | for (int i = 0; i < perm.size(); ++i) { 46 | if (i != perm[i]) { 47 | is_identity = false; 48 | break; 49 | } 50 | } 51 | return is_identity || (perm.size() <= 1); 52 | } 53 | }; 54 | 55 | REGISTER_XLA_OP_KERNEL(transpose, TransposeOp).Finalize(); 56 | 57 | } // namespace mola 58 | } // namespace xrt 59 | } // namespace oneflow 60 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/unary_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/ops/unary_op.h" 17 | 18 | #include "oneflow_xrt/compiler/xla/ops/op_context.h" 19 | #include "oneflow_xrt/compiler/xla/ops/op_kernel.h" 20 | #include "tensorflow/compiler/xla/client/lib/constants.h" 21 | #include "tensorflow/compiler/xla/client/lib/math.h" 22 | #include "tensorflow/compiler/xla/client/xla_builder.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace mola { 27 | 28 | struct Gelu { 29 | xla::XlaOp operator()(const xla::XlaOp& x) { 30 | xla::XlaOp dot_5 = xla::ScalarLike(x, 0.5f); 31 | xla::XlaOp inv_sqrt2 = xla::ScalarLike(x, std::sqrt(0.5f)); 32 | xla::XlaOp one = xla::ScalarLike(x, 1.f); 33 | // cdf = erf(sqrt(0.5) * x) 34 | xla::XlaOp cdf = xla::Erf(xla::Mul(inv_sqrt2, x)); 35 | // return 0.5 * x * (1.0 + cdf) 36 | return xla::Mul(xla::Mul(dot_5, x), xla::Add(one, cdf)); 37 | } 38 | }; 39 | 40 | template 41 | class ApplyUnaryOp : public XlaOpKernel { 42 | public: 43 | void Compile(XlaOpContext* ctx) override { 44 | ctx->SetSoleOutput(UnaryOp()(ctx->SoleInput())); 45 | } 46 | }; 47 | 48 | REGISTER_XLA_OP_KERNEL(sigmoid, ApplyUnaryOp).Finalize(); 49 | REGISTER_XLA_OP_KERNEL(sigmoid_v2, ApplyUnaryOp).Finalize(); 50 | REGISTER_XLA_OP_KERNEL(tanh, ApplyUnaryOp).Finalize(); 51 | REGISTER_XLA_OP_KERNEL(gelu, ApplyUnaryOp).Finalize(); 52 | REGISTER_XLA_OP_KERNEL(rsqrt, ApplyUnaryOp).Finalize(); 53 | REGISTER_XLA_OP_KERNEL(sqrt, ApplyUnaryOp).Finalize(); 54 | 55 | struct Identity { 56 | xla::XlaOp operator()(const xla::XlaOp& x) { return x; } 57 | }; 58 | 59 | REGISTER_XLA_OP_KERNEL(identity, ApplyUnaryOp).Finalize(); 60 | 61 | } // namespace mola 62 | } // namespace xrt 63 | } // namespace oneflow 64 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/ops/unary_op.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_OPS_UNARY_OP_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_OPS_UNARY_OP_H_ 18 | 19 | #include "oneflow_xrt/compiler/xla/xla_data_type.h" 20 | #include "tensorflow/compiler/xla/client/lib/math.h" 21 | #include "tensorflow/compiler/xla/client/xla_builder.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | namespace mola { 26 | namespace op { 27 | 28 | #define OFXLA_DECLARE_UNARY_OP(op) \ 29 | struct op { \ 30 | xla::XlaOp operator()(const xla::XlaOp& x) { return xla::op(x); } \ 31 | }; 32 | 33 | OFXLA_DECLARE_UNARY_OP(Abs); 34 | OFXLA_DECLARE_UNARY_OP(Logistic); 35 | OFXLA_DECLARE_UNARY_OP(Tanh); 36 | OFXLA_DECLARE_UNARY_OP(Rsqrt); 37 | OFXLA_DECLARE_UNARY_OP(Sqrt); 38 | 39 | #undef OFXLA_DECLARE_UNARY_OP 40 | 41 | } // namespace op 42 | } // namespace mola 43 | } // namespace xrt 44 | } // namespace oneflow 45 | 46 | #endif // ONEFLOW_XRT_COMPILER_XLA_OPS_UNARY_OP_H_ 47 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_allocator.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_ALLOCATOR_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_ALLOCATOR_H_ 18 | 19 | #include "oneflow/core/common/util.h" 20 | // #include "oneflow/xrt/fix_ostream_nullptr.h" 21 | #include "oneflow_xrt/compiler/xla/memory/device_buffer_allocator.h" 22 | #include "tensorflow/compiler/xla/statusor.h" 23 | #include "tensorflow/core/framework/allocator.h" 24 | #include "tensorflow/stream_executor/device_memory_allocator.h" 25 | 26 | namespace oneflow { 27 | namespace xrt { 28 | namespace mola { 29 | 30 | namespace se = tensorflow::se; 31 | using uint64 = tensorflow::uint64; 32 | using int64 = tensorflow::int64; 33 | 34 | class XlaAllocator : public se::DeviceMemoryAllocator { 35 | public: 36 | explicit XlaAllocator(const se::Platform* platform, 37 | DeviceBufferAllocator* allocator); 38 | virtual ~XlaAllocator(); 39 | using se::DeviceMemoryAllocator::Allocate; 40 | xla::StatusOr Allocate( 41 | int device_ordinal, uint64 size, bool retry_on_failure, 42 | int64 /*memory_space*/) override; 43 | tensorflow::Status Deallocate(int device_ordinal, 44 | se::DeviceMemoryBase mem) override; 45 | 46 | bool AllowsAsynchronousDeallocation() const override { return true; } 47 | 48 | void ResetState(); 49 | void ReserveWorkspace(size_t workspace_bytes); 50 | void LockWorkspace(); 51 | void UnlockWorkspace(); 52 | 53 | void PopulateDeviceMemory( 54 | const std::vector& device_buffers, 55 | const std::vector& allocation_indices); 56 | stream_executor::port::StatusOr GetStream( 57 | int device_ordinal) override { 58 | UNIMPLEMENTED(); 59 | }; 60 | 61 | private: 62 | DeviceBufferAllocator* allocator_; 63 | int64_t allocate_offset_; 64 | int64_t allocate_index_; 65 | 66 | struct AllocationBuffer { 67 | bool populated = false; 68 | int64_t index = -1; 69 | se::DeviceMemoryBase memory; 70 | }; 71 | std::vector populated_buffers_; 72 | }; 73 | 74 | } // namespace mola 75 | } // namespace xrt 76 | } // namespace oneflow 77 | 78 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_ALLOCATOR_H_ 79 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_data_type.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/xla_data_type.h" 17 | 18 | #include "glog/logging.h" 19 | #include "oneflow/core/common/data_type.pb.h" 20 | #include "tensorflow/compiler/xla/xla_data.pb.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | xla::PrimitiveType DataTypeToPrimitiveType(DataType data_type) { 27 | switch (data_type) { 28 | case oneflow::kFloat: 29 | return xla::F32; 30 | case oneflow::kDouble: 31 | return xla::F64; 32 | case oneflow::kInt8: 33 | return xla::S8; 34 | case oneflow::kInt32: 35 | return xla::S32; 36 | case oneflow::kInt64: 37 | return xla::S64; 38 | case oneflow::kChar: 39 | case oneflow::kUInt8: 40 | return xla::U8; 41 | case oneflow::kFloat16: 42 | return xla::F16; 43 | default: { 44 | LOG(FATAL) << "Unsupported data type (" << data_type 45 | << ") in DataTypeToPrimitiveType"; 46 | return xla::PRIMITIVE_TYPE_INVALID; 47 | } 48 | } 49 | } 50 | 51 | } // namespace mola 52 | } // namespace xrt 53 | } // namespace oneflow 54 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_data_type.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_DATA_TYPE_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_DATA_TYPE_H_ 18 | 19 | #include "oneflow/core/common/data_type.pb.h" 20 | #include "tensorflow/compiler/xla/xla_data.pb.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | // Convert oneflow `DataType` to xla `PrimitiveType` 27 | xla::PrimitiveType DataTypeToPrimitiveType(DataType data_type); 28 | 29 | } // namespace mola 30 | } // namespace xrt 31 | } // namespace oneflow 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_executable.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_H_ 18 | 19 | #include "oneflow_xrt/compiler/executable.h" 20 | #include "tensorflow/compiler/xla/client/local_client.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | class XlaExecutable : public Executable { 27 | public: 28 | XlaExecutable(const std::string& name, const XrtDevice& device, 29 | const std::vector& input_shapes, 30 | const xla::Shape& output_shape, 31 | std::unique_ptr&& executable) 32 | : Executable(name, XrtEngine::XLA), 33 | device_(device), 34 | input_shapes_(input_shapes), 35 | output_shape_(output_shape), 36 | executable_(std::move(executable)) {} 37 | 38 | virtual ~XlaExecutable() = default; 39 | 40 | bool Run(const std::vector& inputs, 41 | const ExecutableRunOptions& run_options, 42 | bool block_until_done = true) override; 43 | 44 | private: 45 | XrtDevice device_; 46 | 47 | std::vector input_shapes_; 48 | // The output shape is always a tuple. 49 | xla::Shape output_shape_; 50 | 51 | std::unique_ptr executable_; 52 | }; 53 | 54 | } // namespace mola 55 | } // namespace xrt 56 | } // namespace oneflow 57 | 58 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_H_ 59 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_executable_scope.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_SCOPE_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_SCOPE_H_ 18 | 19 | #include "oneflow_xrt/compiler/xla/xla_executable_context.h" 20 | #include "tensorflow/compiler/jit/xla_lib/xla_runtime_util.h" 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | namespace mola { 25 | 26 | inline bool SupportMultiStream(const XrtDevice& device) { 27 | switch (device) { 28 | case XrtDevice::GPU_CUDA: 29 | return true; 30 | default: { 31 | return false; 32 | } 33 | } 34 | } 35 | 36 | class XlaExecutableRunScope { 37 | public: 38 | inline XlaExecutableRunScope(xla::LocalExecutable* executable, 39 | XlaExecutableRunContext& run_context); 40 | 41 | inline virtual ~XlaExecutableRunScope(); 42 | 43 | private: 44 | void* launch_stream_ = nullptr; 45 | XlaExecutableRunContext& run_context_; 46 | }; 47 | 48 | XlaExecutableRunScope::XlaExecutableRunScope( 49 | xla::LocalExecutable* executable, XlaExecutableRunContext& run_context) 50 | : run_context_(run_context) { 51 | // Swap cuda stream between the backend stream and context, so XLA could 52 | // launch kernel on the specified cuda stream of the context. Note that it 53 | // should do nothing for single stream device such as CPU. 54 | launch_stream_ = run_context_.run_options().stream; 55 | #ifdef WITH_CUDA 56 | if (SupportMultiStream(run_context_.device())) { 57 | xla::SwapGpuStreamHandle(run_context_.stream(), &launch_stream_); 58 | } 59 | #endif // WITH_CUDA 60 | 61 | size_t workspace_size = xla::CalcWorkspaceByteSize(executable); 62 | run_context_.ReserveWorkspace(workspace_size); 63 | run_context_.LockWorkspace(); 64 | } 65 | 66 | XlaExecutableRunScope::~XlaExecutableRunScope() { 67 | #ifdef WITH_CUDA 68 | if (SupportMultiStream(run_context_.device())) { 69 | xla::SwapGpuStreamHandle(run_context_.stream(), &launch_stream_); 70 | } 71 | #endif // WITH_CUDA 72 | run_context_.UnlockWorkspace(); 73 | } 74 | 75 | } // namespace mola 76 | } // namespace xrt 77 | } // namespace oneflow 78 | 79 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_EXECUTABLE_SCOPE_H_ 80 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_helpers.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_HELPERS_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_HELPERS_H_ 18 | 19 | #include "oneflow_xrt/compiler/xla/xla_data_type.h" 20 | #include "oneflow_xrt/compiler/xla/xla_shape.h" 21 | #include "tensorflow/compiler/xla/client/xla_builder.h" 22 | 23 | namespace oneflow { 24 | namespace xrt { 25 | namespace mola { 26 | 27 | xla::XlaOp One(xla::XlaBuilder* builder, DataType data_type); 28 | 29 | xla::XlaOp Zero(xla::XlaBuilder* builder, DataType data_type); 30 | 31 | xla::XlaOp Ones(xla::XlaBuilder* builder, const Shape& shape, 32 | DataType data_type); 33 | 34 | xla::XlaOp Zeros(xla::XlaBuilder* builder, const Shape& shape, 35 | DataType data_type); 36 | 37 | xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, DataType data_type, 38 | int32_t value); 39 | 40 | xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, DataType data_type, 41 | float value); 42 | 43 | xla::XlaOp Reshape(xla::XlaOp input, Shape dest_shape); 44 | 45 | xla::XlaOp MinValue(xla::XlaBuilder* builder, DataType data_type); 46 | xla::XlaOp MaxValue(xla::XlaBuilder* builder, DataType data_type); 47 | 48 | // Create computation of max func with data_type 49 | xla::XlaComputation CreateMaxFunc(DataType data_type); 50 | 51 | // Create computation of min func with data_type 52 | xla::XlaComputation CreateMinFunc(DataType data_type); 53 | 54 | xla::XlaComputation CreateAddFunc(DataType data_type); 55 | 56 | xla::XlaComputation CreateSubFunc(DataType data_type); 57 | 58 | xla::XlaComputation CreateMulFunc(DataType data_type); 59 | 60 | xla::XlaComputation CreateDivFunc(DataType data_type); 61 | 62 | } // namespace mola 63 | } // namespace xrt 64 | } // namespace oneflow 65 | 66 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_HELPERS_H_ 67 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_macro.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_MACRO_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_MACRO_H_ 18 | 19 | #define TF_CPP_VLOG_LEVEL_REQUARED(level) \ 20 | "Set env TF_CPP_MIN_VLOG_LEVEL=" #level " to see the details." 21 | 22 | #define MOLA_STATUS_MACROS_CONCAT_NAME(x, y) \ 23 | MOLA_STATUS_MACROS_CONCAT_NAME_IMPL(x, y) 24 | #define MOLA_STATUS_MACROS_CONCAT_NAME_IMPL(x, y) x##y 25 | 26 | #define MOLA_CHECK_AND_ASSIGN(lhs, rexpr) \ 27 | MOLA_CHECK_AND_ASSIGN_IMPL( \ 28 | MOLA_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ 29 | rexpr) 30 | 31 | #define MOLA_CHECK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ 32 | auto&& statusor = (rexpr); \ 33 | CHECK(statusor.ok()) << xla::WithLogBacktrace(statusor.status()) << ". " \ 34 | << TF_CPP_VLOG_LEVEL_REQUARED(2); \ 35 | lhs = std::move(statusor.ValueOrDie()); 36 | 37 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_MACRO_H_ 38 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_resource_manager.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_RESOURCE_MANAGER_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_RESOURCE_MANAGER_H_ 18 | 19 | #define EIGEN_USE_THREADS 20 | 21 | #include "oneflow_xrt/common/device.h" 22 | #include "oneflow_xrt/compiler/xla/xla_allocator.h" 23 | #include "tensorflow/compiler/xla/client/local_client.h" 24 | #include "tensorflow/compiler/xla/client/xla_builder.h" 25 | #include "tensorflow/stream_executor/stream.h" 26 | #include "unsupported/Eigen/CXX11/Tensor" 27 | #include "unsupported/Eigen/CXX11/ThreadPool" 28 | 29 | namespace oneflow { 30 | namespace xrt { 31 | namespace mola { 32 | 33 | namespace resource_mgr { 34 | 35 | se::Platform::Id GetPlatformId(const XrtDevice& device); 36 | 37 | const se::Platform* GetPlatform(const XrtDevice& device); 38 | 39 | Eigen::ThreadPoolDevice* GetOrCreateEigenHostDevice(); 40 | 41 | typedef void* StreamId; 42 | 43 | DeviceBufferAllocator* GetOrCreateBufferAllocator(const XrtDevice& device, 44 | const StreamId& stream_id, 45 | se::Stream* stream, 46 | int device_ordinal); 47 | 48 | xla::LocalClient* GetOrCreateLocalClient(const XrtDevice& device); 49 | 50 | } // namespace resource_mgr 51 | 52 | } // namespace mola 53 | } // namespace xrt 54 | } // namespace oneflow 55 | 56 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_RESOURCE_MANAGER_H_ 57 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_shape.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/compiler/xla/xla_shape.h" 17 | 18 | #include 19 | 20 | #include "oneflow/core/common/data_type.h" 21 | #include "oneflow/core/common/shape.h" 22 | #include "oneflow_xrt/common/shape_util.h" 23 | #include "oneflow_xrt/compiler/xla/xla_data_type.h" 24 | #include "tensorflow/compiler/xla/layout_util.h" 25 | #include "tensorflow/compiler/xla/shape.h" 26 | #include "tensorflow/compiler/xla/shape_util.h" 27 | 28 | namespace oneflow { 29 | namespace xrt { 30 | namespace mola { 31 | 32 | Shape XlaShapeToOfShape(const xla::Shape& xla_shape) { 33 | CHECK(!xla_shape.IsTuple()); 34 | int rank = xla_shape.rank(); 35 | std::vector dimensions(rank); 36 | for (int i = 0; i < rank; ++i) { 37 | dimensions[i] = xla_shape.dimensions(i); 38 | } 39 | return AsShape(dimensions); 40 | } 41 | 42 | xla::Shape OfShapeToXlaShape(const Shape& shape, DataType dtype) { 43 | xla::PrimitiveType type = DataTypeToPrimitiveType(dtype); 44 | return OfShapeToXlaShape(shape, type); 45 | } 46 | 47 | xla::Shape OfShapeToXlaShape(const Shape& shape, xla::PrimitiveType type) { 48 | int rank = shape.NumAxes(); 49 | std::vector layout(rank); 50 | std::vector dimensions(rank); 51 | for (int i = 0; i < rank; ++i) { 52 | dimensions[i] = shape.At(i); 53 | } 54 | 55 | std::iota(layout.rbegin(), layout.rend(), 0); 56 | return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); 57 | } 58 | 59 | Shape SliceShape(const Shape& shape, size_t start_dim, size_t end_dim) { 60 | CHECK_LE(start_dim, end_dim); 61 | CHECK_LE(end_dim, shape.NumAxes()); 62 | 63 | std::vector slice_shape(end_dim - start_dim); 64 | for (size_t i = start_dim; i < end_dim; ++i) { 65 | slice_shape[i] = shape.At(i); 66 | } 67 | return AsShape(slice_shape); 68 | } 69 | 70 | } // namespace mola 71 | } // namespace xrt 72 | } // namespace oneflow 73 | -------------------------------------------------------------------------------- /oneflow_xrt/compiler/xla/xla_shape.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_COMPILER_XLA_XLA_SHAPE_H_ 17 | #define ONEFLOW_XRT_COMPILER_XLA_XLA_SHAPE_H_ 18 | 19 | #include "oneflow/core/common/data_type.pb.h" 20 | #include "oneflow/core/common/shape.h" 21 | #include "tensorflow/compiler/xla/shape.h" 22 | #include "tensorflow/compiler/xla/xla_data.pb.h" 23 | 24 | namespace oneflow { 25 | namespace xrt { 26 | namespace mola { 27 | 28 | Shape XlaShapeToOfShape(const xla::Shape& xla_shape); 29 | 30 | xla::Shape OfShapeToXlaShape(const Shape& shape, DataType dtype); 31 | 32 | xla::Shape OfShapeToXlaShape(const Shape& shape, xla::PrimitiveType type); 33 | 34 | // Returns shape[start_dim:end_dim] 35 | Shape SliceShape(const Shape& shape, size_t start_dim, size_t end_dim); 36 | 37 | } // namespace mola 38 | } // namespace xrt 39 | } // namespace oneflow 40 | 41 | #endif // ONEFLOW_XRT_COMPILER_XLA_XLA_SHAPE_H_ 42 | -------------------------------------------------------------------------------- /oneflow_xrt/graph/node_util.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_GRAPH_NODE_UTIL_H_ 17 | #define ONEFLOW_XRT_GRAPH_NODE_UTIL_H_ 18 | 19 | #include "oneflow_xrt/graph/node.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | bool IsCanbeCompiledNode(const XrtNode* node, const XrtEngine& engine, 25 | const XrtDevice& device); 26 | bool IsModelUpdateNode(const XrtNode* node); 27 | 28 | bool IsMutableVariable(const Argument& argument, const std::string& op_type, 29 | const XrtEngine& engine); 30 | 31 | bool IsNodeInput(const XrtNode* node, const Argument& argument); 32 | bool IsNodeOutput(const XrtNode* node, const Argument& argument); 33 | 34 | } // namespace xrt 35 | } // namespace oneflow 36 | 37 | #endif // ONEFLOW_XRT_GRAPH_NODE_UTIL_H_ 38 | -------------------------------------------------------------------------------- /oneflow_xrt/int8_calibration/calibration.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/int8_calibration/calibration.h" 17 | 18 | #include 19 | #include 20 | 21 | #include "absl/strings/str_cat.h" 22 | #include "glog/logging.h" 23 | #include "oneflow/core/vm/vm_sync.h" 24 | 25 | namespace oneflow { 26 | namespace xrt { 27 | 28 | static std::unordered_map resources; 29 | 30 | /*static*/ bool Int8CalibratorResource::Record(const std::string& name, 31 | Int8CalibratorResource* res) { 32 | return resources.emplace(name, res).second; 33 | } 34 | 35 | /*static*/ Int8CalibratorResource* Int8CalibratorResource::Lookup( 36 | const std::string& name) { 37 | const auto& it = resources.find(name); 38 | return it == resources.end() ? NULL : it->second; 39 | } 40 | 41 | /*static*/ const std::unordered_map& 42 | Int8CalibratorResource::All() { 43 | return resources; 44 | } 45 | 46 | void CacheInt8Calibration() { 47 | // synchronize oneflow virtual machine to ensure that the kernel has been 48 | // complete executed 49 | vm::CurrentRankSync(); 50 | const auto& calib_resources = Int8CalibratorResource::All(); 51 | for (const auto& res : calib_resources) { 52 | res.second->WaitAndSetDone(); 53 | } 54 | } 55 | 56 | void CacheAndWriteInt8Calibration(const std::string& path) { 57 | // synchronize oneflow virtual machine to ensure that the kernel has been 58 | // complete executed 59 | vm::CurrentRankSync(); 60 | const auto& calib_resources = Int8CalibratorResource::All(); 61 | for (const auto& res : calib_resources) { 62 | if (!res.second->IsDone()) { 63 | res.second->WaitAndSetDone(); 64 | } 65 | const std::string& calibration_table_data = 66 | res.second->GetCalibrationTableAsString(); 67 | std::string calib_store_path = 68 | absl::StrCat(path, "/", res.first /*calibrator name*/); 69 | std::ofstream ofile(calib_store_path, std::ios::out); 70 | CHECK(ofile.good()) << "Could not open calibration file: " 71 | << calib_store_path; 72 | ofile << calibration_table_data; 73 | ofile.close(); 74 | } 75 | } 76 | 77 | } // namespace xrt 78 | } // namespace oneflow 79 | -------------------------------------------------------------------------------- /oneflow_xrt/int8_calibration/calibration.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_H_ 17 | #define ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_H_ 18 | 19 | #include 20 | #include 21 | 22 | namespace oneflow { 23 | namespace xrt { 24 | 25 | class Int8CalibratorResource { 26 | public: 27 | static bool Record(const std::string& name, Int8CalibratorResource* res); 28 | static Int8CalibratorResource* Lookup(const std::string& name); 29 | 30 | static const std::unordered_map& All(); 31 | 32 | virtual void WaitAndSetDone() = 0; 33 | virtual bool IsDone() const = 0; 34 | virtual std::string GetCalibrationTableAsString() const = 0; 35 | }; 36 | 37 | void CacheInt8Calibration(); 38 | void CacheAndWriteInt8Calibration(const std::string& path); 39 | 40 | } // namespace xrt 41 | } // namespace oneflow 42 | 43 | #endif // ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_H_ 44 | -------------------------------------------------------------------------------- /oneflow_xrt/int8_calibration/calibration_mode.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include "oneflow_xrt/int8_calibration/calibration_mode.h" 17 | 18 | #include "glog/logging.h" 19 | #include "oneflow_xrt/int8_calibration/calibration.h" 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | // not thread safe 25 | static bool calibration_mode_enabled = false; 26 | 27 | PTQCalibrationMode::PTQCalibrationMode(const std::string& cache_path) 28 | : cache_path_(cache_path) { 29 | CHECK(!calibration_mode_enabled) << "calibration mode cannot be nested"; 30 | calibration_mode_enabled = true; 31 | } 32 | 33 | PTQCalibrationMode::~PTQCalibrationMode() { 34 | if (cache_path_.empty()) { 35 | CacheInt8Calibration(); 36 | } else { 37 | CacheAndWriteInt8Calibration(cache_path_); 38 | } 39 | calibration_mode_enabled = false; 40 | } 41 | 42 | /*static*/ bool PTQCalibrationMode::Enabled() { 43 | return calibration_mode_enabled; 44 | } 45 | 46 | } // namespace xrt 47 | } // namespace oneflow 48 | -------------------------------------------------------------------------------- /oneflow_xrt/int8_calibration/calibration_mode.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #ifndef ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_MODE_H_ 17 | #define ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_MODE_H_ 18 | 19 | #include 20 | 21 | namespace oneflow { 22 | namespace xrt { 23 | 24 | class PTQCalibrationMode { 25 | public: 26 | explicit PTQCalibrationMode(const std::string& cache_path); 27 | virtual ~PTQCalibrationMode(); 28 | 29 | static bool Enabled(); 30 | 31 | protected: 32 | std::string cache_path_; 33 | }; 34 | 35 | } // namespace xrt 36 | } // namespace oneflow 37 | 38 | #endif // ONEFLOW_XRT_INT8_CALIBRATION_CALIBRATION_MODE_H_ 39 | -------------------------------------------------------------------------------- /oneflow_xrt/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(ONEFLOW_XRT_PYTHON_DIR ${PROJECT_SOURCE_DIR}/python) 2 | include(pybind11) 3 | 4 | function(ONEFLOW_XRT_ADD_MODULE target_name) 5 | pybind11_add_module(${target_name} SHARED ${ARGN}) 6 | set_target_properties(${target_name} PROPERTIES CXX_VISIBILITY_PRESET "default") 7 | set_target_properties(${target_name} PROPERTIES PREFIX "_") 8 | target_include_directories(${target_name} PRIVATE ${Python_INCLUDE_DIRS}) 9 | endfunction() 10 | 11 | set(XRT_PYTHON_SRCS 12 | stub.cpp 13 | graph.cpp 14 | options.cpp 15 | int8_calibration.cpp 16 | ) 17 | oneflow_xrt_add_module(oneflow_xrt_internal ${XRT_PYTHON_SRCS}) 18 | set_target_properties(oneflow_xrt_internal 19 | PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_XRT_PYTHON_DIR}/oneflow_xrt") 20 | 21 | target_include_directories(oneflow_xrt_internal PRIVATE ${ONEFLOW_INCLUDE_DIR}) 22 | target_link_libraries(oneflow_xrt_internal PUBLIC oneflow_xrt glog::glog) 23 | 24 | function(ONEFLOW_XRT_ADD_STUB target_name) 25 | oneflow_xrt_add_module(${target_name}_internal ${ARGN}) 26 | set_target_properties(${target_name}_internal 27 | PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_XRT_PYTHON_DIR}/${target_name}") 28 | 29 | target_link_libraries(${target_name}_internal PUBLIC 30 | -Wl,--no-as-needed 31 | ${target_name} 32 | -Wl,--as-needed 33 | glog::glog) 34 | 35 | add_custom_target( 36 | ${target_name}_create_python_module 37 | COMMAND ${Python_EXECUTABLE} 38 | ${PROJECT_SOURCE_DIR}/tools/create_python_module.py 39 | ${ONEFLOW_XRT_PYTHON_DIR} ${target_name} 40 | DEPENDS ${Python_EXECUTABLE} 41 | ) 42 | add_dependencies(${target_name}_internal ${target_name}_create_python_module) 43 | endfunction() 44 | 45 | if(BUILD_XLA) 46 | oneflow_xrt_add_stub(oneflow_xrt_xla xla_stub.cpp) 47 | endif() 48 | 49 | if(BUILD_TENSORRT) 50 | oneflow_xrt_add_stub(oneflow_xrt_tensorrt tensorrt_stub.cpp) 51 | endif() 52 | 53 | if(BUILD_OPENVINO) 54 | oneflow_xrt_add_stub(oneflow_xrt_openvino openvino_stub.cpp) 55 | endif() 56 | -------------------------------------------------------------------------------- /oneflow_xrt/python/graph.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/api/api_internal.h" 20 | 21 | namespace py = pybind11; 22 | 23 | using namespace oneflow::xrt; 24 | 25 | void InitXrtGraphApis(py::module_& m) { 26 | py::class_>(m, "Graph") 27 | .def(py::init([](const std::string& serialized_job) { 28 | oneflow::Job job; 29 | if (!job.ParseFromString(serialized_job)) { 30 | PyErr_SetString(PyExc_RuntimeError, 31 | "the first argument is not a valid job"); 32 | } 33 | return BuildGraph(job); 34 | })); 35 | } 36 | -------------------------------------------------------------------------------- /oneflow_xrt/python/int8_calibration.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/api/api_internal.h" 20 | 21 | namespace py = pybind11; 22 | 23 | using namespace oneflow::xrt; 24 | 25 | void InitInt8CalibrationApis(py::module_& m) { 26 | py::class_>( 27 | m, "PTQCalibrationMode") 28 | .def(py::init([](const std::string& cache_path) { 29 | return std::make_shared(cache_path); 30 | })) 31 | .def("__enter__", [](const PTQCalibrationMode&) {}) 32 | .def("__exit__", [](const PTQCalibrationMode&, const py::object&, 33 | const py::object&, const py::object&) {}); 34 | } 35 | -------------------------------------------------------------------------------- /oneflow_xrt/python/openvino_stub.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | PYBIND11_MODULE(_oneflow_xrt_openvino_internal, m) {} 20 | -------------------------------------------------------------------------------- /oneflow_xrt/python/stub.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | #include "oneflow_xrt/api/api_internal.h" 20 | 21 | namespace py = pybind11; 22 | 23 | using namespace oneflow; 24 | using namespace oneflow::xrt; 25 | 26 | extern void InitXrtGraphApis(py::module_& m); 27 | extern void InitClusteringOptionsApis(py::module_& m); 28 | extern void InitReBuildJobOptionsApis(py::module_& m); 29 | extern void InitInt8CalibrationApis(py::module_& m); 30 | 31 | PYBIND11_MODULE(_oneflow_xrt_internal, m) { 32 | m.def("rebuild_job", 33 | [](XrtGraph* graph, const std::string& serialized_origin_job, 34 | const ReBuildJobOptions& options) { 35 | Job origin_job; 36 | if (!origin_job.ParseFromString(serialized_origin_job)) { 37 | PyErr_SetString(PyExc_TypeError, 38 | "the second argument is not a valid job"); 39 | } 40 | auto job = RunRebuildJobPass(graph, origin_job, options); 41 | return py::bytes(job->SerializeAsString()); 42 | }); 43 | m.def("cluster_subgraph", &RunClusterSubGraphPass); 44 | 45 | InitXrtGraphApis(m); 46 | InitClusteringOptionsApis(m); 47 | InitReBuildJobOptionsApis(m); 48 | InitInt8CalibrationApis(m); 49 | } 50 | -------------------------------------------------------------------------------- /oneflow_xrt/python/tensorrt_stub.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | PYBIND11_MODULE(_oneflow_xrt_tensorrt_internal, m) {} 20 | -------------------------------------------------------------------------------- /oneflow_xrt/python/xla_stub.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | #include 17 | #include 18 | 19 | PYBIND11_MODULE(_oneflow_xrt_xla_internal, m) {} 20 | -------------------------------------------------------------------------------- /oneflow_xrt/version_script.lds: -------------------------------------------------------------------------------- 1 | { 2 | global: 3 | *xrt*; 4 | local: 5 | *; 6 | }; 7 | -------------------------------------------------------------------------------- /oneflow_xrt/xrt.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package oneflow.xrt; 3 | 4 | import "oneflow/core/operator/op_conf.proto"; 5 | import "oneflow/core/job/sbp_parallel.proto"; 6 | import "oneflow/core/register/blob_desc.proto"; 7 | 8 | enum XrtDevice { 9 | CPU_X86 = 1; 10 | GPU_CUDA = 2; 11 | GPU_CL = 3; 12 | CPU_ARM = 4; 13 | } 14 | 15 | enum XrtEngine { 16 | DEFAULT = 1; 17 | XLA = 2; 18 | TENSORRT = 3; 19 | TVM = 4; 20 | OPENVINO = 5; 21 | } 22 | 23 | message ExecuteOptionsProto { 24 | required XrtEngine engine = 1; 25 | required XrtDevice device = 2; 26 | 27 | optional bool use_fp16 = 3 [default = false]; 28 | optional bool use_int8 = 4 [default = false]; 29 | optional string int8_calibration = 5 [default = ""]; 30 | optional int64 max_batch_size = 6 [default = 1]; 31 | optional int64 max_workspace_size = 7 [default = 0]; 32 | 33 | optional int64 host_num_threads = 8 [default = -1]; 34 | optional int64 random_seed = 9 [default = -1]; 35 | 36 | optional bool force_compile = 10 [default = false]; 37 | 38 | // It does not guarantee to use low precision if just set use_int8 or 39 | // use_fp16, but you can set strict_types to enforce the engine to use 40 | // low precision 41 | optional bool strict_types = 11 [default = false]; 42 | 43 | // In order to reduce the computation precision loss, some ops specify 44 | // a precision constraint, but this constraint is not mandatory, and the 45 | // engine may still choose an appropriate precision based on it's tuning 46 | // result. This option will make the constraint to be mandatory 47 | optional bool force_precision_constraints = 12 [default = true]; 48 | } 49 | 50 | message FunctionArgumentProto { 51 | required string name = 1; 52 | required string value = 2; 53 | } 54 | 55 | message FunctionProto { 56 | repeated FunctionArgumentProto input = 1; 57 | repeated FunctionArgumentProto output = 2; 58 | repeated OperatorConf node = 3; 59 | } 60 | 61 | message XrtLaunchProto { 62 | required ExecuteOptionsProto options = 1; 63 | 64 | required FunctionProto function = 2; 65 | repeated string liveout_entries = 3; 66 | 67 | // nd sbp signature for each folded node 68 | map nd_sbp_signatures = 5; 69 | map logical_blob_descs = 6; 70 | }; 71 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | /dist 3 | /build 4 | /oneflow_xrt_* 5 | -------------------------------------------------------------------------------- /python/oneflow_xrt/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import oneflow_xrt._oneflow_xrt_internal 17 | from oneflow_xrt._oneflow_xrt_internal import cluster_subgraph 18 | from oneflow_xrt._oneflow_xrt_internal import ClusteringOptions, ReBuildJobOptions 19 | from .graph import Graph 20 | from .module import XRTModule 21 | from .calibration_mode import ptq_calibration_mode 22 | 23 | 24 | def rebuild_job(graph, origin_job, options): 25 | serialized_origin_job = origin_job.SerializeToString() 26 | serialized_job = oneflow_xrt._oneflow_xrt_internal.rebuild_job( 27 | graph, serialized_origin_job, options 28 | ) 29 | 30 | import oneflow.core.job.job_pb2 as job_pb 31 | 32 | new_job = job_pb.Job() 33 | new_job.ParseFromString(serialized_job) 34 | return new_job 35 | -------------------------------------------------------------------------------- /python/oneflow_xrt/calibration_mode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from oneflow_xrt._oneflow_xrt_internal import PTQCalibrationMode 18 | 19 | 20 | class ptq_calibration_mode: 21 | def __init__(self, cache_path=None): 22 | self.cache_path = "" if cache_path is None else cache_path 23 | 24 | def __call__(self, func): 25 | def wrapper(*args, **kwargs): 26 | with PTQCalibrationMode(self.cache_path): 27 | return func(*args, **kwargs) 28 | 29 | return wrapper 30 | 31 | def __enter__(self): 32 | self.calibration_mode = PTQCalibrationMode(self.cache_path) 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_val, exc_tb): 36 | pass 37 | -------------------------------------------------------------------------------- /python/oneflow_xrt/graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | import oneflow_xrt._oneflow_xrt_internal 17 | 18 | 19 | class Graph(oneflow_xrt._oneflow_xrt_internal.Graph): 20 | def __init__(self, job): 21 | serialized_job = job.SerializeToString() 22 | super().__init__(serialized_job) 23 | -------------------------------------------------------------------------------- /python/oneflow_xrt/import_engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 The OneFlow Authors. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | def try_import_engine(engine): 19 | try: 20 | if engine == "XLA": 21 | import oneflow_xrt_xla 22 | elif engine == "TENSORRT": 23 | import oneflow_xrt_tensorrt 24 | elif engine == "OPENVINO": 25 | import oneflow_xrt_openvino 26 | else: 27 | pass 28 | except: 29 | raise RuntimeError( 30 | f"the engine {engine} has not been installed, please install it via the command `pip3 install oneflow_xrt_{engine.lower()}`" 31 | ) 32 | -------------------------------------------------------------------------------- /tools/create_python_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("module_dir", type=str, help="python module directory") 6 | parser.add_argument("module_name", type=str, help="module name") 7 | 8 | args = parser.parse_args() 9 | 10 | 11 | if __name__ == "__main__": 12 | module_path = os.path.join(args.module_dir, args.module_name) 13 | 14 | if not os.path.exists(module_path): 15 | os.makedirs(module_path) 16 | 17 | module_init_filename = os.path.join(module_path, "__init__.py") 18 | 19 | content = f"import {args.module_name}._{args.module_name}_internal" 20 | write_file = True 21 | if os.path.exists(module_init_filename): 22 | with open(module_init_filename, "r") as f: 23 | if f.read() == content: 24 | write_file = False 25 | 26 | if write_file: 27 | with open(module_init_filename, "w") as f: 28 | f.write(content) 29 | -------------------------------------------------------------------------------- /tools/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_env(name, default=""): 5 | return os.getenv(name, default) 6 | 7 | 8 | class Env(object): 9 | def __init__(self): 10 | pass 11 | 12 | @property 13 | def cmake_build_type(self): 14 | if get_env("DEBUG") in ["ON", "1"]: 15 | return "Debug" 16 | elif get_env("CMAKE_BUILD_TYPE") != "": 17 | return get_env("CMAKE_BUILD_TYPE") 18 | else: 19 | return "Release" 20 | 21 | @property 22 | def build_xla(self): 23 | return get_env("BUILD_XLA") in ["ON", "1"] 24 | 25 | @property 26 | def build_tensorrt(self): 27 | return get_env("BUILD_TENSORRT") in ["ON", "1"] 28 | 29 | @property 30 | def build_openvino(self): 31 | return get_env("BUILD_OPENVINO") in ["ON", "1"] 32 | 33 | @property 34 | def tensorrt_root(self): 35 | return get_env("TENSORRT_ROOT") 36 | 37 | @property 38 | def openvino_root(self): 39 | return get_env("OPENVINO_ROOT") 40 | 41 | 42 | env = Env() 43 | -------------------------------------------------------------------------------- /tools/run_cmake_format.py: -------------------------------------------------------------------------------- 1 | from subprocess import call 2 | from argparse import ArgumentParser 3 | from glob import glob 4 | from pathlib import Path 5 | from multiprocessing.pool import ThreadPool 6 | from multiprocessing import cpu_count 7 | 8 | if __name__ == "__main__": 9 | parser = ArgumentParser( 10 | description="Runs cmake-format on all of the cmake source files." 11 | ) 12 | 13 | parser.add_argument( 14 | "--bin", default="cmake-format", help="Path of cmake-format binary" 15 | ) 16 | parser.add_argument( 17 | "--fix", default=False, action="store_true", help="Format all sources in place" 18 | ) 19 | parser.add_argument( 20 | "--source_dir", default=".", help="Root directory of the source code" 21 | ) 22 | parser.add_argument( 23 | "-j", 24 | "--jobs", 25 | type=int, 26 | default=cpu_count(), 27 | help="Specifies the number of jobs (commands) to run simultaneously", 28 | ) 29 | 30 | args = parser.parse_args() 31 | 32 | patterns = [ 33 | "cmake/**/*.cmake", 34 | "oneflow/**/*.cmake", 35 | "oneflow/**/CMakeLists.txt", 36 | "tools/**/*.cmake", 37 | "tools/**/CMakeLists.txt", 38 | "CMakeLists.txt", 39 | ] 40 | 41 | files = [] 42 | for pattern in patterns: 43 | files.extend(glob(str(Path(args.source_dir) / pattern), recursive=True)) 44 | 45 | def gen_cmd(file): 46 | cmd = [args.bin, file] 47 | cmd.append("-i" if args.fix else "--check") 48 | return cmd 49 | 50 | tp = ThreadPool(args.jobs) 51 | res = tp.map_async(call, [gen_cmd(file) for file in files]) 52 | 53 | tp.close() 54 | tp.join() 55 | 56 | count = sum(map(lambda x: 0 if x == 0 else 1, res.get())) 57 | total = len(files) 58 | if args.fix: 59 | print(f"cmake-format -i done. {total} total") 60 | else: 61 | print(f"cmake-format --check done. {count} failed / {total} total") 62 | 63 | exit(0 if count == 0 else 1) 64 | -------------------------------------------------------------------------------- /tools/run_py_format.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import platform 4 | from subprocess import Popen 5 | import os 6 | 7 | if __name__ == "__main__": 8 | 9 | major = platform.sys.version_info.major 10 | minor = platform.sys.version_info.minor 11 | if major == 3 and minor < 6: 12 | print("WARNING: python >= 3.6 required, python source format won't run") 13 | exit(0) 14 | parser = argparse.ArgumentParser( 15 | description="Runs py-format on all of the source files." 16 | "If --fix is specified enforce format by modifying in place." 17 | ) 18 | parser.add_argument( 19 | "--source_dir", required=True, help="Root directory of the source code" 20 | ) 21 | parser.add_argument( 22 | "--fix", 23 | default=False, 24 | action="store_true", 25 | help="If specified, will re-format the source", 26 | ) 27 | 28 | arguments = parser.parse_args() 29 | os.chdir(arguments.source_dir) 30 | 31 | version_cmd = sys.executable + " -m {} --version | grep {} > /dev/null" 32 | BLACK_VER = "19.10b0" 33 | if os.system(version_cmd.format("black", BLACK_VER)): 34 | print( 35 | f"Please install black {BLACK_VER}. For instance, run 'python3 -m pip install black=={BLACK_VER} --user'" 36 | ) 37 | sys.exit(1) 38 | 39 | cmd_line = sys.executable + " -m black " + "." 40 | if arguments.fix == False: 41 | cmd_line += " --check" 42 | if os.system(cmd_line): 43 | sys.exit(1) 44 | --------------------------------------------------------------------------------