├── test ├── __init__.py ├── cat.png ├── util.py ├── benchmarks.py ├── test_core.py ├── test_models.py └── test_operators.py ├── VERSION_NUMBER ├── torch_tvm ├── custom_tvm_ops │ ├── relay │ │ ├── __init__.py │ │ └── custom_fp32_dense.py │ ├── topi │ │ ├── __init__.py │ │ └── custom_fp32_dense.py │ ├── __init__.py │ ├── cpp │ │ ├── relay │ │ │ ├── utils.h │ │ │ ├── weight_pack_attrs.h │ │ │ ├── utils.cc │ │ │ ├── custom_layer_norm_attrs.h │ │ │ ├── custom_dense.h │ │ │ ├── custom_layer_norm.h │ │ │ ├── custom_layer_norm_init.cc │ │ │ ├── custom_dense_init.cc │ │ │ ├── quantize_attrs.h │ │ │ ├── quantize.h │ │ │ ├── custom_dense.cc │ │ │ ├── custom_layer_norm.cc │ │ │ ├── quantize_init.cc │ │ │ └── quantize.cc │ │ └── topi │ │ │ ├── custom_layer_norm_generic_sched.h │ │ │ ├── custom_layer_norm.h │ │ │ ├── quantize.h │ │ │ ├── contrib │ │ │ ├── quantize.h │ │ │ └── quantize.cc │ │ │ ├── generic │ │ │ └── quantize_generic_sched.h │ │ │ ├── custom_layer_norm_generic_sched.cc │ │ │ ├── quantize.cc │ │ │ ├── custom_layer_norm.cc │ │ │ ├── x86 │ │ │ └── quantize_data_mm_dequantize.h │ │ │ └── custom_topi_ops.cc │ └── test │ │ └── test_custom_layer_norm.py ├── remove_dropout.h ├── fuse_concat.h ├── fuse_linear.h ├── fusion_pass.h ├── debug_utils.h ├── __init__.py ├── memory_utils.h ├── remove_dropout.cpp ├── operators.h ├── fuse_concat.cpp ├── fuse_linear.cpp ├── debug_utils.cpp ├── memory_utils.cpp ├── register.h ├── register.cpp ├── compiler.h ├── fusion_pass.cpp └── compiler.cpp ├── pt_execution.png ├── .gitmodules ├── .gitignore ├── setup.cfg ├── CMakeLists.txt ├── .circleci └── config.yml ├── .clang-format ├── README.md └── setup.py /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /VERSION_NUMBER: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/relay/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/topi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tvm/HEAD/test/cat.png -------------------------------------------------------------------------------- /pt_execution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tvm/HEAD/pt_execution.png -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .relay import custom_fp32_dense 2 | from .topi import custom_fp32_dense 3 | -------------------------------------------------------------------------------- /torch_tvm/remove_dropout.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | TORCH_API void RemoveDropout(std::shared_ptr& graph); 6 | -------------------------------------------------------------------------------- /torch_tvm/fuse_concat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | using namespace torch::jit; 6 | 7 | TORCH_API void FuseConcat(std::shared_ptr& graph); 8 | -------------------------------------------------------------------------------- /torch_tvm/fuse_linear.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | using namespace torch::jit; 6 | 7 | TORCH_API void FuseLinear(std::shared_ptr& graph); 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pybind11"] 2 | path = pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | [submodule "tvm"] 5 | path = tvm 6 | url = https://github.com/facebookexperimental/tvm.git 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | compile_commands.json 3 | *.egg-info 4 | *.pyc 5 | __pycache__ 6 | *.so 7 | *.a 8 | .eggs/ 9 | 10 | 11 | # ignore version file generated by build 12 | torch_tvm/version.py 13 | -------------------------------------------------------------------------------- /torch_tvm/fusion_pass.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | void FuseSupportedOps(std::shared_ptr graph); 6 | 7 | const torch::jit::Symbol& getTVMSymbol(); 8 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace tvm { 6 | namespace relay { 7 | 8 | namespace helper { 9 | int32_t get_pack_width(int32_t dim_size, int32_t pack_factor=16); 10 | } // namespace helper 11 | 12 | } // namespace relay 13 | } // namespace tvm 14 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/custom_layer_norm_generic_sched.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace topi { 8 | namespace generic { 9 | tvm::Schedule schedule_custom_layer_norm(const tvm::Array& outs); 10 | } // namespace generic 11 | } // namespace topi 12 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/custom_layer_norm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace topi { 7 | Tensor custom_layer_norm( 8 | const Tensor& data, 9 | const Tensor& gamma, 10 | const Tensor& beta, 11 | const int num_axis_to_normalize, 12 | const bool affine, 13 | const float eps); 14 | } // namespace topi 15 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/weight_pack_attrs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tvm { 7 | namespace relay { 8 | 9 | struct WeightPackAttrs : public tvm::AttrsNode { 10 | int32_t pack_width; 11 | TVM_DECLARE_ATTRS(WeightPackAttrs, "relay.attrs.WeightPackAttrs") { 12 | TVM_ATTR_FIELD(pack_width).set_default(1); 13 | } 14 | }; 15 | } // namespace relay 16 | } // namespace tvm 17 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/utils.cc: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | namespace tvm { 4 | namespace relay { 5 | 6 | namespace helper { 7 | int32_t get_pack_width(int32_t dim_size, int32_t pack_factor) { 8 | int32_t vec_width = pack_factor; 9 | while (vec_width > 1) { 10 | if (dim_size % vec_width == 0) { 11 | return vec_width; 12 | } 13 | vec_width /= 2; 14 | } 15 | return 1; 16 | } 17 | } // namespace helper 18 | } // namespace relay 19 | } // namespace tvm 20 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_layer_norm_attrs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tvm { 7 | namespace relay { 8 | struct CustomLayerNormAttrs : public tvm::AttrsNode { 9 | int num_axis_to_normalize; 10 | bool affine; 11 | double eps; 12 | TVM_DECLARE_ATTRS(CustomLayerNormAttrs, "relay.attrs.CustomLayerNormAttrs") { 13 | TVM_ATTR_FIELD(num_axis_to_normalize).set_default(-1); 14 | TVM_ATTR_FIELD(affine).set_default(false); 15 | TVM_ATTR_FIELD(eps).set_default(1e-5); 16 | } 17 | }; 18 | } // namespace relay 19 | } // namespace tvm 20 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_dense.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tvm { 9 | namespace relay { 10 | 11 | bool CustomDenseRel( 12 | const Array& types, 13 | int num_inputs, /* unused */ 14 | const Attrs& attrs, 15 | const TypeReporter& reporter); 16 | 17 | bool DenseWeightPackRel( 18 | const Array& types, 19 | int num_inputs, /* unused */ 20 | const Attrs& attrs, 21 | const TypeReporter& reporter); 22 | } // namespace relay 23 | } // namespace tvm 24 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_layer_norm.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tvm { 9 | namespace relay { 10 | 11 | Expr MakeCustomLayerNorm( 12 | Expr data, 13 | Expr gamma, 14 | Expr beta, 15 | const int num_axis_to_normalize, 16 | const bool affine, 17 | const double eps); 18 | 19 | bool CustomLayerNormRel( 20 | const Array& types, 21 | int num_inputs, /* unused */ 22 | const Attrs& attrs, 23 | const TypeReporter& reporter); 24 | } // namespace relay 25 | } // namespace tvm 26 | -------------------------------------------------------------------------------- /torch_tvm/debug_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | class DebugLogger { 13 | public: 14 | // Delete copy constructor and assign. 15 | DebugLogger(const DebugLogger&) = delete; 16 | DebugLogger operator=(const DebugLogger&) = delete; 17 | DebugLogger(); 18 | 19 | void printGraph(const std::shared_ptr& subgraph); 20 | 21 | void printLoweredFuncs(tvm::runtime::Module& build_mod); 22 | 23 | void printASM(tvm::runtime::Module& mod); 24 | 25 | private: 26 | std::ofstream debug_file_; 27 | }; 28 | 29 | DebugLogger& getDebugLogger(); 30 | -------------------------------------------------------------------------------- /torch_tvm/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | from tvm import relay # This registers all the schedules 8 | 9 | from ._torch_tvm import * 10 | from ._torch_tvm import _push_relay_expr 11 | from tvm._ffi.function import _init_api # This lets us use PackedFunc with torch_tvm 12 | from torch_tvm import custom_tvm_ops 13 | _init_api("torch_tvm") 14 | 15 | def to_relay(pt_func, inputs): 16 | if type(pt_func) is not torch._C.ScriptFunction: 17 | pt_func = torch.jit.trace(pt_func, inputs) 18 | handle = _push_relay_expr(pt_func.graph_for(*inputs), inputs) 19 | return _pop_relay_expr(handle) 20 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/quantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace topi { 9 | 10 | Array data_int8_quantize( 11 | const Tensor& data, 12 | const Tensor& zero_point, 13 | const Tensor& scale, 14 | bool is_signed, 15 | int precision); 16 | 17 | Array data_int8_row_offset(const Tensor& quantized_data); 18 | 19 | Array data_int8_mm_dequantize( 20 | const Tensor& data, 21 | const Tensor& weight, 22 | const Tensor& weight_acc, 23 | const Tensor& data_acc, 24 | const Tensor& data_scale, 25 | const Tensor& data_zero_point, 26 | const double weight_scale, 27 | const int weight_zero_point, 28 | const int N); 29 | 30 | } // namespace topi 31 | -------------------------------------------------------------------------------- /torch_tvm/memory_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | 11 | namespace torch_tvm { 12 | namespace utils { 13 | 14 | struct DLManagedTensorDeleter { 15 | void operator()(DLManagedTensor* manager_ctx) { 16 | if (manager_ctx == nullptr) { 17 | return; 18 | } 19 | 20 | auto dl_tensor = manager_ctx->dl_tensor; 21 | TORCH_CHECK(dl_tensor.ctx.device_type == kDLCPU); 22 | if (dl_tensor.data) { 23 | TORCH_CHECK((dl_tensor.shape && dl_tensor.strides), "If DLTensor's data" 24 | " pointer is valid then shape and strides must be as well.") 25 | std::free(dl_tensor.data); 26 | delete[] dl_tensor.shape; 27 | delete[] dl_tensor.strides; 28 | } 29 | delete manager_ctx; 30 | } 31 | }; 32 | 33 | bool isAligned(void* data_ptr, std::uintptr_t alignment_in_bytes); 34 | 35 | DLManagedTensor* allocAndCopyData(const at::Tensor& tensor); 36 | using DLManagedTensorPtr = std::unique_ptr; 38 | 39 | } // utils 40 | } // torch_tvm 41 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | testpaths = test 6 | 7 | [flake8] 8 | select = B,C,E,F,P,T4,W,B9 9 | max-line-length = 80 10 | ### DEFAULT IGNORES FOR 4-space INDENTED PROJECTS ### 11 | # E127, E128 are hard to silence in certain nested formatting situations. 12 | # E265, E266 talk about comment formatting which is too opinionated. 13 | # E402 warns on imports coming after statements. There are important use cases 14 | # like demandimport (https://fburl.com/demandimport) that require statements 15 | # before imports. 16 | # E501 is not flexible enough, we're using B950 instead. 17 | # E722 is a duplicate of B001. 18 | # F405 is hard to silence since we indeed do star import 19 | # P207 is a duplicate of B003. 20 | # P208 is a duplicate of C403. 21 | # W503 talks about operator formatting which is too opinionated. 22 | # F401 clashes with PEP484 requiring us to import types that are only used in 23 | # type comments. 24 | ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401 25 | exclude = 26 | .git, 27 | __pycache__, 28 | build/*, 29 | third_party/* 30 | pybind11/* 31 | tvm/* 32 | *_pb2.py, 33 | .cache/* 34 | .eggs 35 | build/* 36 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_layer_norm_init.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "custom_layer_norm.h" 8 | #include "custom_layer_norm_attrs.h" 9 | 10 | namespace tvm { 11 | namespace relay { 12 | 13 | TVM_REGISTER_API("relay.op.nn._make.custom_layer_norm") 14 | .set_body_typed(MakeCustomLayerNorm); 15 | 16 | RELAY_REGISTER_OP("nn.custom_layer_norm") 17 | .describe(R"code(Applies the layer norm transformation with per element 18 | affine transform applied after normalization. 19 | 20 | - **data**: `Tensor with N dims` 21 | - **out**: `Tensor with N dims` 22 | 23 | )code" TVM_ADD_FILELINE) 24 | .set_attrs_type_key("relay.attrs.CustomLayerNormAttrs") 25 | .set_num_inputs(3) 26 | .add_argument("data", "ND Tensor", "Input data.") 27 | .add_argument("gamma", "ND Tensor", "Input data.") 28 | .add_argument("beta", "ND Tensor", "Input data.") 29 | .set_support_level(1) 30 | .add_type_rel("CustomLayerNorm", CustomLayerNormRel); 31 | 32 | } // namespace relay 33 | } // namespace tvm 34 | -------------------------------------------------------------------------------- /torch_tvm/remove_dropout.cpp: -------------------------------------------------------------------------------- 1 | #include "remove_dropout.h" 2 | 3 | using namespace torch::jit; 4 | 5 | bool isDropoutRemovable(const Node* node) { 6 | const auto inputs = node->inputs(); 7 | TORCH_INTERNAL_ASSERT(inputs.size() == 3); 8 | const Value* training_input = inputs[2]; 9 | auto optional_ivalue = toIValue(training_input); 10 | TORCH_INTERNAL_ASSERT(optional_ivalue.has_value()); 11 | const IValue& val = optional_ivalue.value(); 12 | TORCH_INTERNAL_ASSERT(val.isBool()); 13 | const bool is_training = val.toBool(); 14 | return !is_training; 15 | } 16 | 17 | void RemoveDropout(std::shared_ptr& graph) { 18 | auto block = graph->block(); 19 | std::vector deleted_nodes; 20 | 21 | for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) { 22 | Node* node = *it; 23 | if (node->kind() == aten::dropout && isDropoutRemovable(*it)) { 24 | // Input tensor of dropout. 25 | Value* input_value = node->inputs()[0]; 26 | // Output tensor. 27 | Value* output_value = node->outputs()[0]; 28 | output_value->replaceAllUsesWith(input_value); 29 | deleted_nodes.push_back(node); 30 | } 31 | } 32 | for(auto del_node : deleted_nodes) { 33 | del_node->destroy(); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_dense_init.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "custom_dense.h" 8 | 9 | namespace tvm { 10 | namespace relay { 11 | 12 | RELAY_REGISTER_OP("nn.custom_dense") 13 | .describe(R"code(Applies GEMM op on data with weights which are 14 | prepacked for cache friendly vectorization. 15 | - **data**: `Tensor with 2 dims` 16 | - **weight**: `Tensor with 3 dims` 17 | - **out**: `Tensor with 2 dims` 18 | 19 | )code" TVM_ADD_FILELINE) 20 | .set_num_inputs(2) 21 | .add_argument("data", "ND Tensor", "Input data.") 22 | .add_argument("weight", "ND Tensor", "Input data.") 23 | .set_support_level(1) 24 | .add_type_rel("CustomDense", CustomDenseRel); 25 | 26 | RELAY_REGISTER_OP("nn.dense_weight_pack") 27 | .describe(R"code(Packs weight data for cache friendly vectorization. 28 | - **weight**: `Tensor with 2 dims` 29 | - **out**: `Tensor with 3 dims` 30 | 31 | )code" TVM_ADD_FILELINE) 32 | .set_num_inputs(1) 33 | .add_argument("weight", "ND Tensor", "Input data.") 34 | .set_support_level(1) 35 | .add_type_rel("DenseWeightPack", DenseWeightPackRel); 36 | } // namespace relay 37 | } // namespace tvm 38 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/relay/custom_fp32_dense.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import topi 4 | from tvm.relay.op import op as reg 5 | from tvm.relay.op.op import OpPattern, schedule_injective 6 | 7 | from torch_tvm.custom_tvm_ops.topi import custom_fp32_dense 8 | 9 | # dense 10 | @reg.register_compute("nn.custom_dense") 11 | def compute_custom_dense(attrs, inputs, out_type, target): 12 | out_dtype = inputs[0].dtype 13 | return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)] 14 | 15 | 16 | @reg.register_schedule("nn.custom_dense") 17 | def schedule_dense(attrs, outputs, target): 18 | with target: 19 | return topi.generic.schedule_dense(outputs) 20 | 21 | 22 | reg.register_pattern("nn.custom_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 23 | 24 | # weight pack 25 | @reg.register_compute("nn.dense_weight_pack") 26 | def compute_dense_weight_pack(attrs, inputs, out_type, target): 27 | out_dtype = inputs[0].dtype 28 | return [custom_fp32_dense.dense_weight_pack(inputs[0], attrs.pack_width)] 29 | 30 | 31 | @reg.register_schedule("nn.dense_weight_pack") 32 | def schedule_dense_weight_pack(attrs, outputs, target): 33 | with target: 34 | return custom_fp32_dense.schedule_dense_weight_pack(outputs) 35 | 36 | 37 | reg.register_pattern("nn.dense_weight_pack", reg.OpPattern.OPAQUE) 38 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/quantize_attrs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace tvm { 7 | namespace relay { 8 | struct QuantizeSchemeAttrs : public tvm::AttrsNode { 9 | int precision; 10 | bool is_signed; 11 | 12 | TVM_DECLARE_ATTRS(QuantizeSchemeAttrs, "relay.attrs.QuantizedParamsAttrs") { 13 | TVM_ATTR_FIELD(precision).set_default(8) 14 | .describe("The integer precision we want to quantize to."); 15 | TVM_ATTR_FIELD(is_signed).set_default(false) 16 | .describe("Signed or unsigned integer we want to quantize to."); 17 | } 18 | }; 19 | 20 | struct QuantizedParamsAttrs : public tvm::AttrsNode { 21 | double w_scale; 22 | int w_zp; 23 | /* 24 | * This param appears here because, input shape of weight change due to 25 | * packing thus need this to convey what is the true output shape. 26 | */ 27 | int N; 28 | 29 | TVM_DECLARE_ATTRS(QuantizedParamsAttrs, "relay.attrs.QuantizedParamsAttrs") { 30 | TVM_ATTR_FIELD(w_scale).set_default(1.0) 31 | .describe("weight scale."); 32 | TVM_ATTR_FIELD(w_zp).set_default(0) 33 | .describe("weight zero point."); 34 | TVM_ATTR_FIELD(N).set_default(-1) 35 | .describe("N dim of output matrix in MxN."); 36 | } 37 | }; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | 3 | file(GLOB TORCH_TVM_SRCS 4 | ${CMAKE_CURRENT_SOURCE_DIR}/torch_tvm/*.cpp 5 | ${CMAKE_CURRENT_SOURCE_DIR}/torch_tvm/custom_tvm_ops/cpp/relay/*.cc 6 | ${CMAKE_CURRENT_SOURCE_DIR}/torch_tvm/custom_tvm_ops/cpp/topi/*.cc 7 | ${CMAKE_CURRENT_SOURCE_DIR}/torch_tvm/custom_tvm_ops/cpp/topi/contrib/*.cc 8 | ) 9 | 10 | set(CMAKE_CXX_STANDARD 14) 11 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 12 | 13 | # PYTORCH_DIR 14 | IF(DEFINED ENV{PYTORCH_DIR}) 15 | SET(PYTORCH_DIR $ENV{PYTORCH_DIR}) 16 | ENDIF() 17 | message("Using pytorch dir ${PYTORCH_DIR}") 18 | 19 | # TVM_DIR 20 | IF(DEFINED ENV{TVM_DIR}) 21 | SET(TVM_DIR $ENV{TVM_DIR}) 22 | ENDIF() 23 | message("Using tvm dir ${TVM_DIR}") 24 | 25 | include(${TVM_DIR}/cmake/util/FindLLVM.cmake) 26 | find_llvm(${USE_LLVM}) 27 | 28 | link_directories(${PYTORCH_DIR}/lib) 29 | 30 | add_subdirectory(pybind11) 31 | add_subdirectory(${TVM_DIR}) 32 | 33 | pybind11_add_module(_torch_tvm SHARED ${TORCH_TVM_SRCS}) 34 | target_link_libraries(_torch_tvm PUBLIC 35 | torch pybind11 tvm tvm_topi) 36 | 37 | target_include_directories(_torch_tvm PUBLIC 38 | ${CMAKE_CURRENT_SOURCE_DIR} 39 | ${CMAKE_CURRENT_SOURCE_DIR}/torch_tvm 40 | ${TVM_DIR}/include 41 | ${TVM_DIR}/src 42 | ${TVM_DIR}/include/HalideIR 43 | ${TVM_DIR}/3rdparty/dmlc-core/include/ 44 | ${TVM_DIR}/3rdparty/dlpack/include/ 45 | ${PYTORCH_DIR}/include 46 | ${PYBIND11_INCLUDE_DIR} 47 | ${LLVM_INCLUDE_DIRS} 48 | ) 49 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: ubuntu:latest 6 | resource_class: 2xlarge+ 7 | steps: 8 | - run: 9 | name: Install dependencies 10 | command: | 11 | apt update 12 | apt-get install -y python3 python3-venv git cmake ninja-build g++ python3-dev llvm 13 | - checkout # Important to run checkout step after git has been 14 | # installed, otherwise submodules will not be setup 15 | # properly 16 | - run: 17 | name: Git Submodules 18 | command: | 19 | git submodule sync --recursive 20 | git submodule update --init --recursive 21 | - run: 22 | name: Setup Python Virtualenv 23 | command: | 24 | python3 -m venv env 25 | - run: 26 | name: Checkout PyTorch 27 | command: | 28 | cd .. 29 | git clone https://github.com/pytorch/pytorch.git --recursive 30 | - run: 31 | name: Build PyTorch 32 | command: | 33 | source env/bin/activate 34 | cd ../pytorch 35 | pip install -r requirements.txt 36 | BUILD_BINARY=OFF BUILD_TEST=0 BUILD_CAFFE2_OPS=0 python setup.py install 37 | - run: 38 | name: Build and test pytorch/tvm 39 | command: | 40 | source env/bin/activate 41 | 42 | python setup.py install --cmake 43 | mkdir -p tvm/build/ 44 | 45 | OMP_NUM_THREADS=1 python setup.py test 46 | -------------------------------------------------------------------------------- /torch_tvm/operators.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #define PARAM_INDICES_convolution {1, 2} 7 | #define PARAM_INDICES_layer_norm {2, 3} 8 | #define PARAM_INDICES_linear {1, 2} 9 | #define PARAM_INDICES_quantized_linear {1, 3, 6} 10 | 11 | #define PARAM_INDICES(op_name) PARAM_INDICES_##op_name 12 | 13 | bool isSupported(torch::jit::Node* node); 14 | tvm::relay::Expr getOperator( 15 | torch::jit::Node* node, 16 | tvm::Array inputs); 17 | 18 | bool relayIsNone(tvm::relay::Expr e); 19 | uint64_t getNoneSentinel(); 20 | 21 | const std::vector& getParamIndices(torch::jit::Node* node); 22 | 23 | using TVMOpFunctor = std::function inputs)>; 26 | using TVMScheduleFunctor = std::function; 27 | 28 | struct TVMOpMap { 29 | TVMOpMap(torch::jit::Symbol sym_, TVMOpFunctor fn_, std::string name_ = "" 30 | ,std::vector param_indices_={}) 31 | : sym(sym_), fn(fn_), param_indices(param_indices_), name(name_) {} 32 | 33 | torch::jit::Symbol sym; 34 | TVMOpFunctor fn; 35 | std::vector param_indices; 36 | std::string name; 37 | }; 38 | 39 | struct RegisterTVMOperator { 40 | RegisterTVMOperator(std::vector ops); 41 | }; 42 | 43 | struct RegisterTVMOperatorSchedule { 44 | RegisterTVMOperatorSchedule( 45 | std::vector> scheds); 46 | }; 47 | 48 | TVMContext cpuContext(); 49 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/contrib/quantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "topi/detail/extern.h" 4 | #include "topi/tags.h" 5 | #include "tvm/operation.h" 6 | 7 | namespace topi { 8 | namespace contrib { 9 | 10 | using namespace tvm; 11 | using namespace topi::detail; 12 | 13 | inline Array quantize_findminmax( 14 | const Tensor& data) { 15 | 16 | return make_extern( 17 | {{1}, {1}}, 18 | {data->dtype, data->dtype}, 19 | {data}, 20 | [&](Array ins, Array outs) { 21 | return call_packed({ 22 | Expr("tvm.contrib.find_minmax"), 23 | pack_buffer(ins[0]), 24 | pack_buffer(outs[0]), 25 | pack_buffer(outs[1])}); 26 | }, 27 | "C", 28 | "findminmax", 29 | {}); 30 | } 31 | 32 | inline Array choose_quantize_params( 33 | const Tensor& data_min, 34 | const Tensor& data_max, 35 | bool is_signed, 36 | int precision) { 37 | auto q_min = is_signed? -(1 << (precision - 1)) : 0; 38 | auto q_max = is_signed? ((1 << (precision - 1)) - 1) : (1 << precision) - 1; 39 | 40 | return make_extern( 41 | {{1}, {1}}, 42 | {Int(32), Float(32)}, 43 | {data_min, data_max}, 44 | [&](Array ins, Array outs) { 45 | return call_packed({ 46 | Expr("tvm.contrib.choose_quantize_params"), 47 | pack_buffer(ins[0]), 48 | pack_buffer(ins[1]), 49 | pack_buffer(outs[0]), 50 | pack_buffer(outs[1]), 51 | q_min, 52 | q_max}); 53 | }, 54 | "C", 55 | "chooseQuantizeParams", 56 | {}); 57 | } 58 | 59 | } // namespace contrib 60 | } // namespace topi 61 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/quantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace tvm { 11 | namespace relay { 12 | 13 | Expr MakeDataInt8Quantization( 14 | Expr data, 15 | Expr zero_point, 16 | Expr scale, 17 | bool is_signed, 18 | int precision); 19 | 20 | bool DataInt8QuantizationRel( 21 | const Array& types, 22 | int num_inputs, 23 | const Attrs& attrs, 24 | const TypeReporter& reporter); 25 | 26 | Expr MakeDataInt8RowOffset(Expr data); 27 | 28 | bool DataInt8RowOffsetRel( 29 | const Array& types, 30 | int num_inputs, 31 | const Attrs& attrs, 32 | const TypeReporter& reporter); 33 | 34 | Expr MakeFindMinMax(Expr data); 35 | 36 | bool FindMinMaxRel( 37 | const Array& types, 38 | int num_inputs, 39 | const Attrs& attrs, 40 | const TypeReporter& reporter); 41 | 42 | Expr MakeDataMMDequantize( 43 | Expr data, 44 | Expr weight, 45 | Expr weight_acc, 46 | Expr data_acc, 47 | Expr data_scale, 48 | Expr data_zero_point, 49 | const double w_scale, 50 | const int w_zp, 51 | const int N); 52 | 53 | bool DataMMDequantizeRel( 54 | const Array& types, 55 | int num_inputs, 56 | const Attrs& attrs, 57 | const TypeReporter& reporter); 58 | 59 | Expr MakeChooseQuantizeParams( 60 | Expr data_min, 61 | Expr data_max, 62 | bool is_signed, 63 | int precision); 64 | 65 | bool ChooseQuantizeParamsRel( 66 | const Array& types, 67 | int num_inputs, 68 | const Attrs& attrs, 69 | const TypeReporter& reporter); 70 | 71 | } // namespace relay 72 | } // namespace tvm 73 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/generic/quantize_generic_sched.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "tvm/operation.h" 4 | 5 | namespace topi { 6 | using namespace tvm; 7 | 8 | namespace generic { 9 | /*! 10 | * \brief Create an x86 schedule for the given quantize ops. 11 | * 12 | * \param target The target to generate a schedule for. 13 | * \param outs The output tensors. 14 | * 15 | * \return A schedule for the given ops. 16 | */ 17 | inline Schedule schedule_choose_quantize_params( 18 | const Array& outs) { 19 | Array out_ops; 20 | for (auto out : outs) { 21 | out_ops.push_back(out->op); 22 | } 23 | auto s = create_schedule(out_ops); 24 | return s; 25 | } 26 | 27 | inline Schedule schedule_quantize_findminmax( 28 | const Array& outs) { 29 | Array out_ops; 30 | for (auto out : outs) { 31 | out_ops.push_back(out->op); 32 | } 33 | auto s = create_schedule(out_ops); 34 | return s; 35 | } 36 | 37 | inline Schedule schedule_quantize_data_int8_quantize( 38 | const Array& outs) { 39 | Array out_ops; 40 | for (auto out : outs) { 41 | out_ops.push_back(out->op); 42 | } 43 | auto s = create_schedule(out_ops); 44 | return s; 45 | } 46 | 47 | inline Schedule schedule_quantize_data_int8_row_offset( 48 | const Array& outs) { 49 | Array out_ops; 50 | for (auto out : outs) { 51 | out_ops.push_back(out->op); 52 | } 53 | auto s = create_schedule(out_ops); 54 | return s; 55 | } 56 | 57 | inline Schedule schedule_quantized_mm_dequantize( 58 | const Target& target, 59 | const Array& outs) { 60 | Array out_ops; 61 | for (auto out : outs) { 62 | out_ops.push_back(out->op); 63 | } 64 | auto s = create_schedule(out_ops); 65 | return s; 66 | } 67 | 68 | } // namespace generic 69 | } // namespace topi 70 | -------------------------------------------------------------------------------- /torch_tvm/fuse_concat.cpp: -------------------------------------------------------------------------------- 1 | #include "fuse_concat.h" 2 | #include "operators.h" 3 | 4 | using namespace torch::jit; 5 | 6 | const size_t subgraph_arg_limit_ = 128; 7 | 8 | bool isFusableCatNode(Node* node) { 9 | if (node->kind() != aten::cat) { 10 | return false; 11 | } 12 | if (!node->is_constant(attr::dim)) { 13 | return false; 14 | } 15 | 16 | auto tensors_node = node->namedInput(attr::tensors)->node(); 17 | if ((tensors_node->inputs().size() + node->outputs().size()) > 18 | subgraph_arg_limit_) { 19 | return false; 20 | } 21 | if (tensors_node->kind() != prim::ListConstruct) { 22 | return false; 23 | } 24 | 25 | if (tensors_node->output()->uses().size() > 1) { 26 | return false; 27 | } 28 | 29 | return true; 30 | } 31 | 32 | Node* createFusedConcat(Node* node) { 33 | AT_ASSERT(node->kind() == aten::cat); 34 | Graph* graph = node->owningGraph(); 35 | Node* list_construct = node->namedInput(attr::tensors)->node(); 36 | int64_t dim = node->get(attr::dim).value(); 37 | 38 | Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs()) 39 | ->i_(attr::dim, dim); 40 | fused_cat->insertBefore(list_construct); 41 | fused_cat->output()->copyMetadata(node->output()); 42 | 43 | node->output()->replaceAllUsesWith(fused_cat->output()); 44 | if (list_construct->output()->uses().empty()) { 45 | list_construct->destroy(); 46 | } 47 | return fused_cat; 48 | } 49 | 50 | void fuseConcats(Block* block_) { 51 | std::vector deleted_nodes; 52 | for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend(); 53 | ++it) { 54 | if (!isFusableCatNode(*it)) { 55 | continue; 56 | } 57 | createFusedConcat(*it); 58 | deleted_nodes.push_back(*it); 59 | } 60 | for(auto del_node : deleted_nodes) { 61 | del_node->destroy(); 62 | } 63 | } 64 | 65 | void FuseConcat(std::shared_ptr& graph) { 66 | auto block = graph->block(); 67 | fuseConcats(block); 68 | } 69 | -------------------------------------------------------------------------------- /torch_tvm/fuse_linear.cpp: -------------------------------------------------------------------------------- 1 | #include "fuse_linear.h" 2 | #include 3 | 4 | // This pass fuse the addmm or matmul + add generated by JIT back to linear 5 | // to allow direct support with tvm integration with Relay IR 6 | // This pass can be deleted once the JIT can emit the aten::linear in the future 7 | void FuseLinear(std::shared_ptr& graph) { 8 | std::string addmm_pattern = R"IR( 9 | graph(%input, %weight, %bias, %4): 10 | %weight_t = aten::t(%weight) 11 | %res = aten::addmm(%bias, %input, %weight_t, %4, %4) 12 | return (%res))IR"; 13 | std::string matmul_add_pattern = R"IR( 14 | graph(%input, %weight, %bias, %4): 15 | %weight_t = aten::t(%weight) 16 | %output = aten::matmul(%input, %weight_t) 17 | %res = aten::add_(%output, %bias, %4) 18 | return (%res))IR"; 19 | std::string fused_linear = R"IR( 20 | graph(%input, %weight, %bias, %4): 21 | %res = aten::linear(%input, %weight, %bias) 22 | return (%res))IR"; 23 | 24 | std::string matmul_pattern = R"IR( 25 | graph(%input, %weight): 26 | %weight_t = aten::t(%weight) 27 | %output = aten::matmul(%input, %weight_t) 28 | return (%output))IR"; 29 | std::string fused_linear_bias_none = R"IR( 30 | graph(%input, %weight): 31 | %bias: Tensor? = prim::Constant() 32 | %res = aten::linear(%input, %weight, %bias) 33 | return (%res))IR"; 34 | 35 | 36 | // replace addmm pattern to linear 37 | SubgraphRewriter addmm_to_linear; 38 | addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear); 39 | addmm_to_linear.runOnGraph(graph); 40 | 41 | // replace matmul + add pattern to linear 42 | SubgraphRewriter matmuladd_to_linear; 43 | matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear); 44 | matmuladd_to_linear.runOnGraph(graph); 45 | 46 | // replace matmul with bias=None pattern to linear 47 | SubgraphRewriter matmul_to_linear; 48 | matmul_to_linear.RegisterRewritePattern(matmul_pattern, fused_linear_bias_none); 49 | matmul_to_linear.runOnGraph(graph); 50 | } -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_dense.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "custom_dense.h" 9 | #include "weight_pack_attrs.h" 10 | 11 | namespace tvm { 12 | namespace relay { 13 | 14 | bool DenseWeightPackRel( 15 | const Array& types, 16 | int num_inputs, /* unused */ 17 | const Attrs& attrs, 18 | const TypeReporter& reporter) { 19 | CHECK_EQ(types.size(), 2); 20 | 21 | const auto* weight = types[0].as(); 22 | if (weight == nullptr) { 23 | return false; 24 | } 25 | CHECK(weight->dtype == Float(32)); 26 | CHECK_EQ(weight->shape.size(), 2); 27 | 28 | auto weight_pack_attrs = attrs.as(); 29 | int32_t pack_width = weight_pack_attrs->pack_width; 30 | int32_t out_dim = *as_const_int(weight->shape[0]); 31 | CHECK_EQ((out_dim % pack_width), 0); 32 | out_dim = out_dim / pack_width; 33 | 34 | Array oshape = weight->shape; 35 | oshape.Set(0, out_dim); 36 | oshape.push_back(pack_width); 37 | reporter->Assign(types[1], TensorTypeNode::make(oshape, weight->dtype)); 38 | return true; 39 | } 40 | 41 | bool CustomDenseRel( 42 | const Array& types, 43 | int num_inputs, /* unused */ 44 | const Attrs& attrs, 45 | const TypeReporter& reporter) { 46 | CHECK_EQ(types.size(), 3); 47 | 48 | const auto* data = types[0].as(); 49 | const auto* weight = types[1].as(); 50 | if (data == nullptr || weight == nullptr) { 51 | return false; 52 | } 53 | CHECK_EQ(data->shape.size(), 2); 54 | CHECK_EQ(weight->shape.size(), 3); 55 | CHECK_EQ(*as_const_int(data->shape[1]), 56 | *as_const_int(weight->shape[1])); 57 | int32_t out_dim = (*as_const_int(weight->shape[0])) * 58 | (*(as_const_int(weight->shape[2]))); 59 | CHECK_GT(out_dim, 0); 60 | 61 | CHECK(data->dtype == Float(32)); 62 | 63 | Array oshape = data->shape; 64 | oshape.Set((oshape.size() - 1), out_dim); 65 | reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); 66 | return true; 67 | } 68 | 69 | } // namespace relay 70 | } // namespace tvm 71 | -------------------------------------------------------------------------------- /torch_tvm/debug_utils.cpp: -------------------------------------------------------------------------------- 1 | #include "debug_utils.h" 2 | 3 | DebugLogger& getDebugLogger() { 4 | static thread_local DebugLogger debug_logger; 5 | return debug_logger; 6 | } 7 | 8 | DebugLogger::DebugLogger() { 9 | std::ostringstream ss; 10 | ss << std::this_thread::get_id(); 11 | std::string file_name = "/tmp/debug_output_" + 12 | ss.str() + 13 | ".txt"; 14 | debug_file_ = std::ofstream(file_name, std::ios::out); 15 | if (!debug_file_.is_open()) { 16 | LOG(WARNING) 17 | << "Could not open file:" << file_name 18 | << ", will dump debug info to stdout\n"; 19 | } 20 | } 21 | 22 | void DebugLogger::printGraph( 23 | const std::shared_ptr& subgraph) { 24 | if (debug_file_.is_open()) { 25 | debug_file_ <<"subgraph \n"; 26 | debug_file_ << *subgraph << std::endl; 27 | debug_file_ <<"END OF Input subgraph\n"; 28 | } else { 29 | std::cout <<"subgraph \n"; 30 | std::cout << *subgraph << std::endl; 31 | std::cout <<"END OF Input subgraph\n"; 32 | } 33 | } 34 | 35 | void DebugLogger::printLoweredFuncs(tvm::runtime::Module& build_mod) { 36 | tvm::runtime::PackedFunc lowered_f; 37 | try { 38 | lowered_f = build_mod.GetFunction("get_lowered_funcs", false); 39 | } catch (const std::exception& e) { 40 | LOG(WARNING) << "TVM runtime is not exposed lowered_funcs:" 41 | << e.what() << std::endl; 42 | return; 43 | } 44 | tvm::Map > lowered_funcs = lowered_f(); 45 | for (auto funcs : lowered_funcs) { 46 | for (auto f : funcs.second) { 47 | if (debug_file_.is_open()) { 48 | debug_file_ << "===== lowered func=====\n"; 49 | debug_file_ << f->body << std::endl; 50 | debug_file_ << "===== end of lowered func=====\n"; 51 | } else { 52 | std::cout << "===== lowered func=====\n"; 53 | std::cout << f->body << std::endl; 54 | std::cout << "===== end of lowered func=====\n"; 55 | } 56 | } 57 | } 58 | } 59 | 60 | void DebugLogger::printASM(tvm::runtime::Module& mod) { 61 | if (debug_file_.is_open()) { 62 | debug_file_ << "======= ASM ========\n"; 63 | debug_file_ << mod->GetSource("asm") << std::endl; 64 | debug_file_ << "======= END OF ASM========\n"; 65 | } else { 66 | std::cout << "======= ASM ========\n"; 67 | std::cout << mod->GetSource("asm") << std::endl; 68 | std::cout << "======= END OF ASM========\n"; 69 | } 70 | } 71 | 72 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/custom_layer_norm_generic_sched.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "custom_layer_norm_generic_sched.h" 7 | 8 | namespace topi { 9 | namespace generic { 10 | 11 | namespace { 12 | void layer_norm_sched(tvm::Schedule& s, const tvm::Tensor& out) { 13 | auto divide_1 = out->op.as()->InputTensors()[1]; 14 | auto divide_2 = out->op.as()->InputTensors()[2]; 15 | auto mean_var_sum = divide_1->op.as()->InputTensors()[0]; 16 | auto squared_data = mean_var_sum->op.as()->InputTensors()[1]; 17 | s[divide_1].compute_inline(); 18 | s[divide_2].compute_inline(); 19 | auto k = s[mean_var_sum]->op.as()->reduce_axis; 20 | tvm::IterVar ko, ki; 21 | s[mean_var_sum].split(k[0], 16, &ko, &ki); 22 | auto factored_tensors = s.rfactor(mean_var_sum, ki, -1); 23 | s[mean_var_sum].compute_at(s[out], out->op.as()->axis[0]); 24 | s[factored_tensors[0]].compute_at(s[out], 25 | out->op.as()->axis[0]); 26 | s[squared_data].compute_inline(); 27 | } 28 | } // namespace 29 | 30 | tvm::Schedule schedule_custom_layer_norm(const tvm::Array& outs) { 31 | tvm::Array out_ops; 32 | for (auto out : outs) { 33 | out_ops.push_back(out->op); 34 | } 35 | auto s = create_schedule(out_ops); 36 | //Copy paste traverse logic from elsewhere. This is also the traverse logic 37 | //in dense schedule in dense.py 38 | std::function traverse; 39 | traverse = [&](const tvm::Operation& op) { 40 | // Inline all one-to-one-mapping operators except the last stage (output) 41 | if (is_injective(op->tag)) { 42 | if (!detail::contains(s->outputs, op)) { 43 | s[op].compute_inline(); 44 | } 45 | for (auto tensor : op->InputTensors()) { 46 | if (tensor->op->InputTensors().size() > 0) { 47 | traverse(tensor->op); 48 | } 49 | } 50 | } 51 | else if (op->tag == "custom_layer_norm_tag") { 52 | auto layer_norm = op.output(0); 53 | layer_norm_sched(s, layer_norm); 54 | } else { 55 | LOG(ERROR) << "Unsupported operator " << op->tag; 56 | } 57 | }; 58 | traverse(outs[0]->op); 59 | return s; 60 | } 61 | } // namespace generic 62 | } // namespace topi 63 | -------------------------------------------------------------------------------- /torch_tvm/memory_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "memory_utils.h" 4 | 5 | namespace torch_tvm { 6 | namespace utils { 7 | 8 | bool isAligned(void* data_ptr, std::uintptr_t alignment_in_bytes) { 9 | auto mask = alignment_in_bytes - 1; 10 | TORCH_CHECK((alignment_in_bytes & mask) == 0); 11 | return (reinterpret_cast(data_ptr) & mask) == 0; 12 | } 13 | 14 | DLManagedTensor* allocAndCopyData(const at::Tensor& tensor) { 15 | TORCH_CHECK(tensor.device().is_cpu()); 16 | DLManagedTensor* dl_managed_tensor = new DLManagedTensor(); 17 | auto contig_tensor = tensor; 18 | if (!tensor.is_contiguous()) { 19 | auto contig_tensor = tensor.contiguous(); 20 | } 21 | // managed_tensor_deleter is supplied to unique_ptr as a deleter 22 | // of this managed memory. Thus setting deleter to nullptr; 23 | dl_managed_tensor->deleter = nullptr; 24 | dl_managed_tensor->manager_ctx = dl_managed_tensor; 25 | auto& dl_tensor = dl_managed_tensor->dl_tensor; 26 | 27 | auto num_dims = contig_tensor.dim(); 28 | dl_tensor.ndim = num_dims; 29 | dl_tensor.dtype = at::getDLDataType(contig_tensor); 30 | int64_t device_id = 0; 31 | dl_tensor.ctx = getDLContext(contig_tensor, device_id); 32 | dl_tensor.shape = dl_tensor.strides = nullptr; 33 | dl_tensor.data = nullptr; 34 | dl_tensor.shape = new int64_t[num_dims]; 35 | dl_tensor.strides = new int64_t[num_dims]; 36 | TORCH_CHECK(dl_tensor.shape != nullptr && dl_tensor.strides != nullptr, 37 | "Memory allocation failed for DLTensor shape and strides" 38 | "by ManagedTensors."); 39 | 40 | auto tensor_sizes = contig_tensor.sizes(); 41 | auto tensor_strides = contig_tensor.strides(); 42 | for (int64_t i = 0; i < num_dims; ++i) { 43 | dl_tensor.shape[i] = tensor_sizes[i]; 44 | dl_tensor.strides[i] = tensor_strides[i]; 45 | } 46 | 47 | // make sure the allocated size is a multiple of alignment 48 | auto nbytes_alloc = contig_tensor.nbytes(); 49 | auto rem = nbytes_alloc % tvm::runtime::kAllocAlignment; 50 | if (rem > 0) { 51 | nbytes_alloc += tvm::runtime::kAllocAlignment - rem; 52 | } 53 | dl_tensor.data = aligned_alloc(tvm::runtime::kAllocAlignment, nbytes_alloc); 54 | TORCH_CHECK(dl_tensor.data != nullptr, 55 | "Memory allocation failed for DLTensor data by ManagedTensors."); 56 | 57 | std::memcpy(dl_tensor.data, contig_tensor.data_ptr(), contig_tensor.nbytes()); 58 | dl_tensor.byte_offset = 0; 59 | 60 | return dl_managed_tensor; 61 | } 62 | 63 | } // utils 64 | } // torch_tvm 65 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/custom_layer_norm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "custom_layer_norm.h" 9 | #include "custom_layer_norm_attrs.h" 10 | 11 | namespace tvm { 12 | namespace relay { 13 | 14 | Expr MakeCustomLayerNorm( 15 | Expr data, 16 | Expr gamma, 17 | Expr beta, 18 | const int num_axis_to_normalize, 19 | const bool affine, 20 | const double eps) { 21 | auto attrs = make_node(); 22 | attrs->num_axis_to_normalize = num_axis_to_normalize; 23 | attrs->affine = affine; 24 | static const Op& op = Op::Get("nn.custom_layer_norm"); 25 | return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {}); 26 | } 27 | 28 | bool CustomLayerNormRel( 29 | const Array& types, 30 | int num_inputs, /* unused */ 31 | const Attrs& attrs, 32 | const TypeReporter& reporter) { 33 | CHECK_EQ(types.size(), 4); 34 | 35 | const auto* data = types[0].as(); 36 | if (data == nullptr) { 37 | return false; 38 | } 39 | CHECK_GT(data->shape.size(), 1); 40 | int64_t data_size = data->shape.size(); 41 | int64_t num_elements = 1; 42 | for (int64_t i = 0; i < data_size; ++i) { 43 | CHECK_LE(*as_const_int(data->shape[i]), std::numeric_limits::max()); 44 | num_elements *= *as_const_int(data->shape[i]); 45 | CHECK_LE(num_elements, std::numeric_limits::max()); 46 | } 47 | 48 | CHECK(data->dtype == Float(32)); 49 | 50 | auto layer_norm_attrs_ptr = attrs.as(); 51 | auto num_axis_to_normalize = layer_norm_attrs_ptr->num_axis_to_normalize; 52 | CHECK_GT(num_axis_to_normalize, 0); 53 | CHECK_LT(num_axis_to_normalize, data->shape.size()); 54 | 55 | const auto* gamma = types[1].as(); 56 | const auto* beta = types[2].as(); 57 | if (gamma && beta) { 58 | CHECK_EQ(gamma->shape.size(), num_axis_to_normalize); 59 | CHECK_EQ(beta->shape.size(), num_axis_to_normalize); 60 | for (int64_t i = 0; i < num_axis_to_normalize; ++i) { 61 | int64_t data_index = i + (data_size - num_axis_to_normalize); 62 | CHECK_EQ( 63 | *as_const_int(data->shape[data_index]), 64 | *as_const_int(gamma->shape[i])); 65 | CHECK_EQ( 66 | *as_const_int(data->shape[data_index]), 67 | *as_const_int(beta->shape[i])); 68 | } 69 | } 70 | reporter->Assign(types[3], types[0]); 71 | return true; 72 | } 73 | 74 | } // namespace relay 75 | } // namespace tvm 76 | -------------------------------------------------------------------------------- /torch_tvm/register.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "compiler.h" 9 | #include "debug_utils.h" 10 | #include "fuse_concat.h" 11 | #include "fuse_linear.h" 12 | #include "fusion_pass.h" 13 | #include "remove_dropout.h" 14 | 15 | namespace tvm { 16 | namespace { 17 | // control if the run mode is strict or not, if it's strict, we throw to 18 | // user with the relevant conversion errors, otherwise we bail out to JIT 19 | static bool strict = false; 20 | static int opt_level = 2; 21 | static bool debug = false; 22 | static bool debug_runtime = false; 23 | static std::string device_type = "cpu"; 24 | static std::string device = "llvm -mcpu=core-avx2"; 25 | static std::string host = "llvm -mcpu=core-avx2"; 26 | static int device_id = 0; 27 | static bool is_training_mode = false; 28 | 29 | void registerTVMOp() { 30 | auto options = c10::OperatorOptions(); 31 | options.setAliasAnalysis(AliasAnalysisKind::PURE_FUNCTION); 32 | torch::jit::RegisterOperators op({torch::jit::Operator( 33 | getTVMSymbol(), 34 | [](const torch::jit::Node* node) -> torch::jit::Operation { 35 | auto cc = std::make_shared( 36 | node, opt_level, strict, debug, debug_runtime, device_type, device, 37 | host, device_id); 38 | return [cc](Stack& stack) { 39 | RECORD_FUNCTION("TVM", std::vector()); 40 | cc->run(stack); 41 | return 0; 42 | }; 43 | }, 44 | options)}); 45 | } 46 | 47 | void set_build_config( 48 | int opt_level_, 49 | bool strict_, 50 | bool debug_, 51 | bool debug_runtime_, 52 | const std::string& device_type_, 53 | const std::string& device_, 54 | const std::string& host_, 55 | int device_id_, 56 | bool is_training_) { 57 | opt_level = opt_level_; 58 | strict = strict_; 59 | debug = debug_; 60 | debug_runtime = debug_runtime_; 61 | device_type = device_type_; 62 | device = device_; 63 | host = host_; 64 | device_id = device_id_; 65 | is_training_mode = is_training_; 66 | } 67 | 68 | bool is_training() { 69 | return is_training_mode; 70 | } 71 | 72 | void torch_tvm_enable(std::function enableTVMCompile) { 73 | registerTVMOp(); 74 | torch::jit::RegisterPass pass( 75 | [enableTVMCompile = 76 | std::move(enableTVMCompile)](std::shared_ptr& g) { 77 | if (enableTVMCompile()) { 78 | FuseLinear(g); 79 | FuseConcat(g); 80 | RemoveDropout(g); 81 | FuseSupportedOps(g); 82 | } 83 | }); 84 | } 85 | 86 | } // namespace 87 | } // namespace tvm 88 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: AlwaysBreak 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: true 7 | AlignOperands: false 8 | AlignTrailingComments: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: true 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BraceWrapping: 21 | AfterClass: false 22 | AfterControlStatement: false 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterObjCDeclaration: false 27 | AfterStruct: false 28 | AfterUnion: false 29 | BeforeCatch: false 30 | BeforeElse: false 31 | IndentBraces: false 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeBraces: Attach 34 | BreakBeforeTernaryOperators: true 35 | BreakConstructorInitializersBeforeComma: false 36 | BreakAfterJavaFieldAnnotations: false 37 | BreakStringLiterals: false 38 | ColumnLimit: 80 39 | CommentPragmas: '^ IWYU pragma:' 40 | CompactNamespaces: false 41 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 42 | ConstructorInitializerIndentWidth: 4 43 | ContinuationIndentWidth: 4 44 | Cpp11BracedListStyle: true 45 | DerivePointerAlignment: false 46 | DisableFormat: false 47 | ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] 48 | IncludeCategories: 49 | - Regex: '^<.*\.h(pp)?>' 50 | Priority: 1 51 | - Regex: '^<.*' 52 | Priority: 2 53 | - Regex: '.*' 54 | Priority: 3 55 | IndentCaseLabels: true 56 | IndentWidth: 2 57 | IndentWrappedFunctionNames: false 58 | KeepEmptyLinesAtTheStartOfBlocks: false 59 | MacroBlockBegin: '' 60 | MacroBlockEnd: '' 61 | MaxEmptyLinesToKeep: 1 62 | NamespaceIndentation: None 63 | ObjCBlockIndentWidth: 2 64 | ObjCSpaceAfterProperty: false 65 | ObjCSpaceBeforeProtocolList: false 66 | PenaltyBreakBeforeFirstCallParameter: 1 67 | PenaltyBreakComment: 300 68 | PenaltyBreakFirstLessLess: 120 69 | PenaltyBreakString: 1000 70 | PenaltyExcessCharacter: 1000000 71 | PenaltyReturnTypeOnItsOwnLine: 2000000 72 | PointerAlignment: Left 73 | ReflowComments: true 74 | SortIncludes: true 75 | SpaceAfterCStyleCast: false 76 | SpaceBeforeAssignmentOperators: true 77 | SpaceBeforeParens: ControlStatements 78 | SpaceInEmptyParentheses: false 79 | SpacesBeforeTrailingComments: 1 80 | SpacesInAngles: false 81 | SpacesInContainerLiterals: true 82 | SpacesInCStyleCastParentheses: false 83 | SpacesInParentheses: false 84 | SpacesInSquareBrackets: false 85 | Standard: Cpp11 86 | TabWidth: 8 87 | UseTab: Never 88 | ... 89 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/quantize.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include "topi/reduction.h" 8 | #include "topi/tags.h" 9 | 10 | #include "quantize.h" 11 | 12 | namespace topi { 13 | using namespace tvm; 14 | 15 | Array data_int8_quantize( 16 | const Tensor& data, 17 | const Tensor& zero_point, 18 | const Tensor& scale, 19 | bool is_signed, 20 | int precision) { 21 | auto q_min = is_signed ? -(1 << (precision - 1)) : 0; 22 | auto q_max = is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1; 23 | auto target_type = is_signed ? Int(8) : UInt(8); 24 | auto inverse_scale = 1 /scale(0); 25 | 26 | auto clamp_output = tvm::compute( 27 | data->shape, 28 | [&](Var i, Var j) { 29 | return tvm::cast(target_type, tvm::nearbyint( 30 | tvm::min( 31 | tvm::max(tvm::cast(Float(32), zero_point(0)) + data(i, j)*inverse_scale, q_min), 32 | q_max 33 | ) 34 | )); 35 | }, 36 | "tensor", 37 | "int8_quantize_data" 38 | ); 39 | 40 | return {clamp_output}; 41 | } 42 | 43 | Array data_int8_row_offset(const Tensor& quantized_data) { 44 | 45 | auto k = tvm::reduce_axis(Range(0, quantized_data->shape[1]), "k"); 46 | auto data_acc = tvm::compute( 47 | {quantized_data->shape[0]}, 48 | [&](Var i) { 49 | return tvm::sum(tvm::cast(Int(32), quantized_data(i, k)), {k}); 50 | }, 51 | "tensor", 52 | "int8_quantize_row_offset" 53 | ); 54 | 55 | return {data_acc}; 56 | } 57 | 58 | Array data_int8_mm_dequantize( 59 | const Tensor& data, 60 | const Tensor& weight, 61 | const Tensor& weight_acc, 62 | const Tensor& data_acc, 63 | const Tensor& data_scale, 64 | const Tensor& data_zero_point, 65 | const double weight_scale, 66 | const int weight_zero_point, 67 | const int32_t N) { 68 | // assume M, K and N, K on input shape 69 | CHECK(weight->shape.size() == 4); 70 | auto k = tvm::reduce_axis(Range(0, data->shape[1]), "k"); 71 | auto scale_mul = make_const(Float(32), weight_scale) * data_scale(0); 72 | auto out_shape = {data->shape[0], weight->shape[0] * weight->shape[2]}; 73 | 74 | auto quantized_mm = tvm::compute( 75 | out_shape, 76 | [&](Var i, Var j) { 77 | return tvm::sum(tvm::cast(Int(32), data(i, k)) * tvm::cast(Int(32), weight(j / 16, k / 4, j % 16, k % 4)), {k}); 78 | }, 79 | "tensor", 80 | "quantized_mm" 81 | ); 82 | 83 | auto result = tvm::compute( 84 | {data->shape[0], Expr(N)}, 85 | [&](Var i, Var j) { 86 | return scale_mul*(tvm::cast(Float(32), (quantized_mm(i, j)-data_acc(i)*weight_zero_point- 87 | weight_acc(j)*data_zero_point(0)))); 88 | }, 89 | "tensor", 90 | "mm_dequantize" 91 | ); 92 | 93 | return {result}; 94 | } 95 | } // namespace topi 96 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/topi/custom_fp32_dense.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | from tvm import autotvm 3 | from tvm.autotvm.task.space import SplitEntity 4 | 5 | from topi.x86.util import get_fp32_len 6 | from topi import generic, tag, nn 7 | from topi.util import traverse_inline, get_const_tuple 8 | 9 | 10 | @autotvm.register_topi_compute(nn.dense, "cpu", "direct", override=True) 11 | def _declaration_custom_dense(cfg, data, packed_weight, bias, out_dtype=None): 12 | if out_dtype is None: 13 | out_dtype = data.dtype 14 | batch, in_dim = get_const_tuple(data.shape) 15 | pack_outer, _, pack_width = get_const_tuple(packed_weight.shape) 16 | out_dim = pack_outer * pack_width 17 | 18 | k = tvm.reduce_axis((0, in_dim), name="k") 19 | C = tvm.compute((batch, out_dim), 20 | lambda y, x: tvm.sum( 21 | data[y, k].astype(out_dtype) * \ 22 | packed_weight[x // pack_width, k, x % pack_width].astype(out_dtype), 23 | axis=k), 24 | tag="custom_dense_pack") 25 | if bias is not None: 26 | C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype), 27 | tag=tag.BROADCAST) 28 | return C 29 | 30 | 31 | @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct", override=True) 32 | def _schedule_custom_dense_pack(cfg, outs): 33 | s = tvm.create_schedule([x.op for x in outs]) 34 | 35 | def _callback(op): 36 | if "dense_pack" in op.tag: 37 | _schedule_custom_dense_pack_template(cfg, s, op.output(0)) 38 | traverse_inline(s, outs[0].op, _callback) 39 | return s 40 | 41 | 42 | def _schedule_custom_dense_pack_template(cfg, s, C): 43 | CC = s.cache_write(C, "global") 44 | y, x = s[C].op.axis 45 | k, = s[CC].op.reduce_axis 46 | 47 | data = C.op.input_tensors[0] 48 | weight = C.op.input_tensors[1] 49 | _, _, pack_factor = get_const_tuple(weight.shape) 50 | 51 | xo, xi = s[C].split(x, factor=pack_factor) 52 | yo, yi = s[C].split(y, factor=8) 53 | xt, xo = s[C].split(xo, factor=16) 54 | yt, yo = s[C].split(yo, factor=16) 55 | s[C].reorder(yt, xt, yo, xo, yi, xi) 56 | s[C].unroll(yi) 57 | s[C].vectorize(xi) 58 | 59 | s[CC].compute_at(s[C], xo) 60 | y, x = s[CC].op.axis 61 | ko, ki = s[CC].split(k, factor=16) 62 | s[CC].reorder(ko, ki, y, x) 63 | s[CC].vectorize(x) 64 | s[CC].unroll(y) 65 | 66 | return s 67 | 68 | 69 | 70 | def dense_weight_pack(weight, pack_width): 71 | N, K = get_const_tuple(weight.shape) 72 | 73 | packw_shape = (N // pack_width, K, pack_width) 74 | C = tvm.compute(packw_shape, \ 75 | lambda z, y, x: weight[z * pack_width+ x, y], name="packed_weight") 76 | return C 77 | 78 | 79 | def schedule_dense_weight_pack(outs): 80 | s = tvm.create_schedule([x.op for x in outs]) 81 | packedB = outs[0].op 82 | z, y, x = s[packedB].op.axis 83 | s[packedB].reorder(z, x, y) 84 | s[packedB].parallel(z) 85 | s[packedB].vectorize(y) 86 | return s 87 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy 3 | from numbers import Number 4 | import unittest 5 | import os 6 | 7 | import torch 8 | from torch.autograd.profiler import profile 9 | 10 | import random 11 | 12 | import torch_tvm 13 | from tvm import autotvm 14 | 15 | # base TVMTest class 16 | 17 | 18 | class TVMTest(unittest.TestCase): 19 | precision = 1e-5 20 | 21 | def __init__(self, method_name="runTest"): 22 | super(TVMTest, self).__init__(method_name) 23 | 24 | @classmethod 25 | def rand_int(cls, min_=1, max_=4): 26 | def get(): 27 | return random.randint(min_, max_) 28 | 29 | return get 30 | 31 | @classmethod 32 | def rand_list(cls, elem_fn, num_elem): 33 | def get(): 34 | return [elem_fn() for _ in range(num_elem)] 35 | 36 | return get 37 | 38 | @classmethod 39 | def rand_shape( 40 | cls, min_rank=1, max_rank=4, min_dim=1, max_dim=2 ** 4, like=None, rank=None 41 | ): 42 | def get(): 43 | rank_ = rank 44 | if not rank_: 45 | rank_ = cls.rand_int(min_rank, max_rank)() 46 | return cls.rand_list(cls.rand_int(min_dim, max_dim), rank_)() 47 | 48 | return get 49 | 50 | @classmethod 51 | def given(*args_, examples=5, **kwargs_): 52 | def f(fn): 53 | def f_impl(*args, **kwargs): 54 | for _ in range(examples): 55 | for k in kwargs_: 56 | kwargs[k] = kwargs_[k]() 57 | try: 58 | fn(*args, **kwargs) 59 | except Exception as e: 60 | print("Inputs:", kwargs) 61 | raise (e) 62 | 63 | return f_impl 64 | 65 | return f 66 | 67 | def runBoth(self, func, *inputs, check_tvm=True): 68 | with torch.no_grad(): 69 | # jit the function 70 | trace_jit = torch.jit.trace(func, inputs) 71 | ref_out = trace_jit(*inputs) 72 | 73 | # jit the function and lower to TVM 74 | torch_tvm.enable() 75 | d = os.path.dirname(os.path.abspath(__file__)) 76 | fn = os.path.join(d, "autotvm_tuning.log") 77 | 78 | with autotvm.apply_history_best(fn): 79 | trace_tvm = torch.jit.trace(func, inputs) 80 | try: 81 | tvm_out = trace_tvm(*inputs) 82 | except Exception as e: 83 | print("Error with graph\n{}".format(trace_tvm.graph)) 84 | raise e 85 | 86 | if check_tvm == True: 87 | tvm_unused = "TVM was not able to optimize this trace." 88 | assert "tvm::CompilationGroup" in str( 89 | trace_tvm.graph_for(*inputs) 90 | ), tvm_unused + " Graph:\n" + str(trace_tvm.graph_for(*inputs)) 91 | # tvm compile the graph and ensure TVM is used 92 | with profile() as p: 93 | _ = trace_tvm(*inputs) 94 | assert "TVM" in [_.name for _ in p.function_events], tvm_unused 95 | 96 | torch_tvm.disable() 97 | 98 | return ref_out, tvm_out 99 | -------------------------------------------------------------------------------- /test/benchmarks.py: -------------------------------------------------------------------------------- 1 | from test.test_models import resnet18, resnext101_32x8d 2 | from skimage import io 3 | import torch 4 | from torch.autograd.profiler import profile 5 | import torch_tvm 6 | import time 7 | from tvm import autotvm 8 | import sys 9 | import os 10 | 11 | 12 | def genImage(): 13 | d = os.path.dirname(os.path.abspath(__file__)) 14 | fn = os.path.join(d, "cat.png") 15 | image = io.imread(fn)[:, :, :3].transpose(2, 0, 1) 16 | image = torch.unsqueeze(torch.Tensor(image), 0) 17 | return [image] 18 | 19 | 20 | def benchmark(model, csv_file, input_fn=genImage, iters=100, warmup=10): 21 | with torch.no_grad(): 22 | inputs = input_fn() 23 | print("Tracing model with JIT") 24 | trace_jit = torch.jit.trace(model, inputs) 25 | print("Warming JIT up with {} runs".format(warmup)) 26 | for _ in range(warmup): 27 | _ = trace_jit(*inputs) 28 | 29 | print("Running JIT {} times".format(iters)) 30 | start = time.time() 31 | for _ in range(iters): 32 | _ = trace_jit(*inputs) 33 | jit_time = time.time() - start 34 | print("Done benchmarking JIT") 35 | 36 | d = os.path.dirname(os.path.abspath(__file__)) 37 | fn = os.path.join(d, "autotvm_tuning.log") 38 | with autotvm.apply_history_best(fn): 39 | torch_tvm.enable(opt_level=3) 40 | print("Tracing model with TVM") 41 | trace_tvm = torch.jit.trace(model, inputs) 42 | print("Warming TVM up with {} iters".format(warmup)) 43 | for _ in range(warmup): 44 | _ = trace_tvm(*inputs) 45 | 46 | print("Running TVM {} times".format(iters)) 47 | start = time.time() 48 | for _ in range(iters): 49 | _ = trace_tvm(*inputs) 50 | tvm_time = time.time() - start 51 | with torch.autograd.profiler.profile() as prof: 52 | _ = trace_tvm(*inputs) 53 | tvm_profiled_time = 0 54 | total_profiled_time = 0 55 | for p in prof.key_averages(): 56 | total_profiled_time += int(p.cpu_time) 57 | if p.key == "TVM": 58 | tvm_profiled_time += int(p.cpu_time) 59 | print("Done benchmarking TVM, which compiled {:.2f}% of compute".format( 60 | 100 * tvm_profiled_time / total_profiled_time)) 61 | if csv_file: 62 | exists = os.path.isfile(csv_file) 63 | with open(csv_file, 'a' if exists else 'w') as f: 64 | if not exists: 65 | f.write("timestamp,iter_per_sec\n") 66 | f.write("{},{}\n".format( 67 | int(time.time()), iters / tvm_time)) 68 | print("JIT: {} iter/s\nTVM: {} iter/s".format(iters / 69 | jit_time, iters / tvm_time)) 70 | 71 | 72 | def run_benchmark(csv_file): 73 | model = resnet18(True) 74 | model.eval() 75 | benchmark(model, csv_file) 76 | 77 | 78 | if __name__ == "__main__": 79 | csv_file = None 80 | if len(sys.argv) == 3 and sys.argv[1] == "--csv": 81 | csv_file = sys.argv[2] 82 | run_benchmark(csv_file) 83 | -------------------------------------------------------------------------------- /torch_tvm/register.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "register.h" 9 | 10 | namespace py = pybind11; 11 | 12 | // control if we enable tvm fusion or not 13 | static bool fusion_enabled = false; 14 | 15 | static std::unordered_map relay_exprs; 16 | static size_t relay_exprs_uuid = 0; 17 | 18 | PYBIND11_MODULE(_torch_tvm, m) { 19 | std::function is_enabled = []() { return fusion_enabled; }; 20 | tvm::torch_tvm_enable(is_enabled); 21 | // python API to enable and disable tvm fusion 22 | m.def( 23 | "enable", 24 | [](int opt_level_, 25 | bool strict_, 26 | bool debug_, 27 | bool debug_runtime_, 28 | std::string device_type_, 29 | std::string device_, 30 | std::string host_, 31 | int device_id_, 32 | bool is_training_) { 33 | fusion_enabled = true; 34 | tvm::set_build_config( 35 | opt_level_, strict_, debug_, debug_runtime_, device_type_, device_, 36 | host_, device_id_, is_training_); 37 | }, 38 | py::arg("opt_level") = 2, 39 | py::arg("strict") = false, 40 | py::arg("debug") = false, 41 | py::arg("debug_runtime") = false, 42 | py::arg("device_type") = "cpu", 43 | py::arg("device") = "llvm -mcpu=core-avx2", 44 | py::arg("host") = "llvm -mcpu=core-avx2", 45 | py::arg("device_id") = 0, 46 | py::arg("is_training") = false); 47 | 48 | m.def("disable", []() { fusion_enabled = false; }); 49 | 50 | m.def( 51 | "_push_relay_expr", 52 | [](std::shared_ptr g, std::vector inputs) { 53 | size_t count = 0; 54 | for (auto node : g->nodes()) { 55 | count++; 56 | } 57 | TORCH_CHECK( 58 | count == 1, 59 | "This program cannot be exported as a single Relay expression."); 60 | for (auto node : g->nodes()) { 61 | if (node->kind() == getTVMSymbol()) { 62 | std::vector v; 63 | auto subgraph = node->g(attr::Subgraph); 64 | TORCH_CHECK( 65 | subgraph->inputs().size() == inputs.size(), 66 | "Expected ", 67 | subgraph->inputs().size(), 68 | " inputs"); 69 | for (auto i = 0; i < inputs.size(); ++i) { 70 | subgraph->inputs()[i]->inferTypeFrom(inputs[i]); 71 | } 72 | TVMContext ctx; 73 | ctx.device_type = kDLCPU; 74 | ctx.device_id = 0; 75 | auto expr = TVMCompiler::convertToRelay(subgraph, ctx); 76 | relay_exprs[++relay_exprs_uuid] = expr; 77 | return relay_exprs_uuid; 78 | } else { 79 | TORCH_CHECK( 80 | 0, 81 | "This program contains non-Relay expressions that cannot be exported."); 82 | } 83 | } 84 | return 0UL; 85 | }); 86 | 87 | m.doc() = "This module does nothing but register a TVM backend."; 88 | } 89 | 90 | TVM_REGISTER_GLOBAL("torch_tvm._pop_relay_expr") 91 | .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { 92 | size_t id = args[0]; 93 | *rv = relay_exprs[id]; //.top(); 94 | relay_exprs.erase(id); 95 | }); 96 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/quantize_init.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "quantize_attrs.h" 11 | #include "quantize.h" 12 | 13 | #include 14 | namespace tvm { 15 | namespace relay { 16 | 17 | TVM_REGISTER_NODE_TYPE(QuantizedParamsAttrs); 18 | 19 | TVM_REGISTER_NODE_TYPE(QuantizeSchemeAttrs); 20 | 21 | TVM_REGISTER_API("relay.op.nn._make.quantize_data_int8_quantize") 22 | .set_body_typed(MakeDataInt8Quantization); 23 | 24 | TVM_REGISTER_API("relay.op.nn._make.quantize_data_int8_row_offset") 25 | .set_body_typed(MakeDataInt8RowOffset); 26 | 27 | 28 | RELAY_REGISTER_OP("nn.quantize_data_int8_quantize") 29 | .describe(R"code(dynamic quantization of activation. 30 | - **data**: (M, N) 31 | )code" TVM_ADD_FILELINE) 32 | .set_num_inputs(3) 33 | .add_argument("data", "Tensor", "The input tensor.") 34 | .add_argument("zero_point", "Tensor", "The zero_point parameter for quantization") 35 | .add_argument("scale", "Tensor", "the scale parameter for quantization") 36 | .set_attrs_type_key("relay.attrs.QuantizeSchemeAttrs") 37 | .set_support_level(10) 38 | .add_type_rel("DataInt8Quantization", DataInt8QuantizationRel); 39 | 40 | 41 | RELAY_REGISTER_OP("nn.quantize_data_int8_row_offset") 42 | .describe(R"code(dynamic row offset calculation of quantized data. 43 | - **data**: (M, N) 44 | )code" TVM_ADD_FILELINE) 45 | .set_num_inputs(1) 46 | .add_argument("data", "Tensor", "Quantized input tensor.") 47 | .set_support_level(10) 48 | .add_type_rel("DataInt8RowOffset", DataInt8RowOffsetRel); 49 | 50 | 51 | TVM_REGISTER_API("relay.op.nn._make.quantize_findminmax") 52 | .set_body_typed(MakeFindMinMax); 53 | 54 | RELAY_REGISTER_OP("nn.quantize_findminmax") 55 | .describe(R"code(find min and max of the input data. 56 | - **data**: (M, N) 57 | )code" TVM_ADD_FILELINE) 58 | .set_num_inputs(1) 59 | .add_argument("data", "Tensor", "The input data tensor.") 60 | .set_support_level(5) 61 | .add_type_rel("FindMinMax", FindMinMaxRel); 62 | 63 | 64 | TVM_REGISTER_API("relay.op.nn._make.quantize_data_mm_dequantize") 65 | .set_body_typed(MakeDataMMDequantize); 66 | 67 | RELAY_REGISTER_OP("nn.quantize_data_mm_dequantize") 68 | .describe(R"code(multiply the weight and data, then dequantize the data into floating point. 69 | - **data**: (M, N) 70 | )code" TVM_ADD_FILELINE) 71 | .set_num_inputs(6) 72 | .add_argument("data", "Tensor", "The input data tensor.") 73 | .add_argument("weight", "Tensor", "The input weight tensor.") 74 | .add_argument("weight_acc", "Tensor", "The accumulation of each column") 75 | .add_argument("data_acc", "Tensor", "The accumulation of each row") 76 | .add_argument("data_scale", "Tensor", "The activation scale") 77 | .add_argument("data_zero_point", "Tensor", "The activation zero_point") 78 | .set_attrs_type_key("relay.attrs.QuantizedParamsAttrs") 79 | .set_support_level(5) 80 | .add_type_rel("DataMMDequantize", DataMMDequantizeRel); 81 | 82 | 83 | TVM_REGISTER_API("relay.op.nn._make.choose_quantize_params") 84 | .set_body_typed(MakeChooseQuantizeParams); 85 | 86 | RELAY_REGISTER_OP("nn.choose_quantize_params") 87 | .describe(R"code(calculate the zero_point and scale. 88 | )code" TVM_ADD_FILELINE) 89 | .set_num_inputs(2) 90 | .set_attrs_type_key("relay.attrs.QuantizeSchemeAttrs") 91 | .add_argument("data_min", "Tensor", "The min of input data.") 92 | .add_argument("data_max", "Tensor", "The max of input data.") 93 | .set_support_level(4) 94 | .add_type_rel("ChooseQuantizeParams", ChooseQuantizeParamsRel); 95 | 96 | } // namespace relay 97 | } // namespace tvm 98 | -------------------------------------------------------------------------------- /torch_tvm/compiler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include "memory_utils.h" 16 | #include "debug_utils.h" 17 | 18 | struct TVMGraphInputInfo { 19 | TVMGraphInputInfo(bool is_param_, std::string tvm_var_name_) { 20 | is_param = is_param_; 21 | tvm_var_name = std::move(tvm_var_name_); 22 | } 23 | TVMGraphInputInfo(bool is_param_, std::string&& tvm_var_name_) { 24 | is_param = is_param_; 25 | tvm_var_name = tvm_var_name_; 26 | } 27 | std::string tvm_var_name; 28 | bool is_param; 29 | // DLManagedTensorPtr = unique_ptr 30 | torch_tvm::utils::DLManagedTensorPtr tvm_tensor; 31 | }; 32 | 33 | struct TVMObject { 34 | tvm::PackedFunc kernel; 35 | tvm::PackedFunc set_input; 36 | tvm::PackedFunc get_output; 37 | tvm::PackedFunc setup_external_storage; 38 | // Map input indices to values in the subgraph 39 | // Plus indicates if the corresponding value is immutable, 40 | // e.g., a parameter such as weight. 41 | std::unordered_map input_values; 42 | void populateParamTVMTensors( 43 | const std::unordered_map& value_to_ivalue); 45 | tvm::Map generateParamConstantMap(); 46 | }; 47 | 48 | struct TVMCompiler { 49 | explicit TVMCompiler( 50 | const torch::jit::Node* node, 51 | int opt_level = 2, 52 | bool strict = false, 53 | bool debug = false, 54 | bool debug_runtime = false, 55 | std::string device_type = "cpu", 56 | std::string device = "llvm", 57 | std::string host = "llvm", 58 | int device_id = 0); 59 | void run(torch::jit::Stack& stack); 60 | 61 | private: 62 | std::shared_ptr subgraph_; 63 | std::unordered_map cache_; 64 | TVMContext ctx_; 65 | int opt_level_; 66 | bool strict_; 67 | bool debug_; 68 | bool debug_runtime_; 69 | std::string device_type_; 70 | std::string device_; 71 | std::string host_; 72 | int device_id_; 73 | tvm::runtime::Module build_mod_; 74 | DebugLogger debug_logger_; 75 | // Not used in OSS 76 | at::Tensor activation_buffer_; 77 | std::vector activation_buffer_shape_; 78 | int64_t max_activation_buffer_size_{0}; 79 | 80 | std::string handle_str_; 81 | std::unique_ptr fallback_interpreter_; 82 | // stores argument specs for which we couldn't compile and fall back to the 83 | // interpreter 84 | std::unordered_set bad_specs_; 85 | 86 | std::string getTVMCompilerHandle(std::shared_ptr subgraph); 87 | 88 | public: 89 | static tvm::relay::Var convertToRelay(torch::jit::Value* val, TVMContext ctx); 90 | static tvm::relay::Expr convertToRelay( 91 | const torch::jit::IValue& val, 92 | TVMContext ctx); 93 | static tvm::relay::Function convertToRelay( 94 | std::shared_ptr subgraph, 95 | TVMContext ctx, 96 | std::unordered_map* 97 | input_values = nullptr); 98 | 99 | #ifdef TVM_USE_FB_GRAPH_RUNTIME 100 | void allocateMemoryAndSetParams( 101 | TVMObject& obj, 102 | const tvm::Map& params, 103 | const std::string& json_str); 104 | #endif 105 | }; 106 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/contrib/quantize.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace tvm { 9 | namespace contrib { 10 | 11 | using namespace runtime; 12 | 13 | TVM_REGISTER_GLOBAL("tvm.contrib.find_minmax") 14 | .set_body([](TVMArgs args, TVMRetValue* ret) { 15 | DLTensor* input = args[0]; 16 | DLTensor* data_min = args[1]; 17 | DLTensor* data_max = args[2]; 18 | // calculate the data_min and data_max 19 | CHECK(input->strides == nullptr) << "find_minmax does not support the dltensor with strides"; 20 | auto data_ptr = static_cast(input->data); 21 | CHECK(input->ndim == 2) << "find_minmax only support the two dimenstion input"; 22 | int m = input->shape[0]; 23 | int n = input->shape[1]; 24 | int num_els = m * n; 25 | int num_iters = num_els / 16; 26 | int num_left_overs = num_els % 16; 27 | float min_v[16] = {std::numeric_limits::max()}; 28 | float max_v[16] = {std::numeric_limits::lowest()}; 29 | for (int i = 0; i < num_iters; i++) { 30 | for (int j = 0; j < 16; j++) { 31 | min_v[j] = std::min(data_ptr[i*16 + j], min_v[j]); 32 | max_v[j] = std::max(data_ptr[i*16 + j], max_v[j]); 33 | } 34 | } 35 | float min_value = min_v[0]; 36 | float max_value = max_v[0]; 37 | for (int i =0; i < 16; ++i) { 38 | min_value = std::min(min_v[i], min_value); 39 | max_value = std::max(max_v[i], max_value); 40 | } 41 | for (int i = (num_iters*16); i < (num_iters*16+num_left_overs); ++i) { 42 | min_value = std::min(data_ptr[i], min_value); 43 | max_value = std::max(data_ptr[i], max_value); 44 | } 45 | auto out_ptr_min = static_cast(data_min->data); 46 | auto out_ptr_max = static_cast(data_max->data); 47 | *out_ptr_min = min_value; 48 | *out_ptr_max = max_value; 49 | }); 50 | 51 | TVM_REGISTER_GLOBAL("tvm.contrib.choose_quantize_params") 52 | .set_body([](TVMArgs args, TVMRetValue* ret) { 53 | DLTensor* data_min_ptr = args[0]; 54 | DLTensor* data_max_ptr = args[1]; 55 | DLTensor* zero_point_ptr = args[2]; 56 | DLTensor* scale_ptr = args[3]; 57 | int32_t qmin = args[4]; 58 | int32_t qmax = args[5]; 59 | 60 | float data_min = *(static_cast(data_min_ptr->data)); 61 | float data_max = *(static_cast(data_max_ptr->data)); 62 | // copy from fbgemm implementation 63 | double scale = 64 | (std::max(data_max, 0.f) - std::min(data_min, 0.f)) / ((double)qmax - qmin); 65 | if (scale == 0) { 66 | scale = 0.1; 67 | } 68 | data_min = std::min(data_min, 0.f); 69 | data_max = std::max(data_max, 0.f); 70 | double zero_point_from_min = qmin - data_min / scale; 71 | double zero_point_from_max = qmax - data_max / scale; 72 | double zero_point_from_min_error = std::fabs(qmin) + std::fabs(data_min / scale); 73 | double zero_point_from_max_error = std::fabs(qmax) + std::fabs(data_max / scale); 74 | double initial_zero_point = 75 | zero_point_from_min_error < zero_point_from_max_error 76 | ? zero_point_from_min 77 | : zero_point_from_max; 78 | 79 | int32_t nudged_zero_point = 0; 80 | if (initial_zero_point < qmin) { 81 | nudged_zero_point = qmin; 82 | } else if (initial_zero_point > qmax) { 83 | nudged_zero_point = qmax; 84 | } else { 85 | nudged_zero_point = std::nearbyint(initial_zero_point); 86 | } 87 | 88 | auto zero_point_data_ptr = static_cast(zero_point_ptr->data); 89 | auto scale_data_ptr = static_cast(scale_ptr->data); 90 | *zero_point_data_ptr = nudged_zero_point; 91 | *scale_data_ptr = scale; 92 | }); 93 | 94 | } // namespace contrib 95 | } // namespace tvm 96 | -------------------------------------------------------------------------------- /torch_tvm/fusion_pass.cpp: -------------------------------------------------------------------------------- 1 | #include "fusion_pass.h" 2 | #include "operators.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace torch::jit; 9 | 10 | value_list sortReverseTopological(ArrayRef inputs, Block* block) { 11 | value_list result; 12 | for (auto i : inputs) { 13 | if (i->node()->owningBlock() == block) { 14 | result.push_back(i); 15 | } 16 | } 17 | // Sort in reverse topological order 18 | std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { 19 | return a->node()->isAfter(b->node()); 20 | }); 21 | return result; 22 | } 23 | 24 | bool canHandle(Block* block, AliasDb& aliasDb); 25 | bool canHandle(Node* node, AliasDb& aliasDb) { 26 | if (node->kind() == prim::Constant) { 27 | return true; 28 | } 29 | if (node->kind() == prim::Loop) { 30 | return false; // TODO 31 | Block* body = node->blocks().at(0); 32 | return canHandle(body, aliasDb); 33 | } 34 | return isSupported(node); 35 | } 36 | 37 | bool canHandle(Block* block, AliasDb& aliasDb) { 38 | for (Node* node : block->nodes()) { 39 | if (!canHandle(node, aliasDb)) { 40 | return false; 41 | } 42 | } 43 | return true; 44 | } 45 | 46 | #define REQ(cond) \ 47 | if (!(cond)) { \ 48 | GRAPH_DEBUG("Failed cond " #cond "\n"); \ 49 | return c10::nullopt; \ 50 | } 51 | c10::optional tryMerge( 52 | Node* consumer, 53 | Node* producer, 54 | AliasDb& aliasDb) { 55 | GRAPH_DEBUG( 56 | "Trying producer ", 57 | producer->kind().toQualString(), 58 | " and consumer ", 59 | consumer->kind().toQualString(), 60 | ":\n"); 61 | 62 | // Symbolic checks 63 | REQ(canHandle(producer, aliasDb)); 64 | REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTVMSymbol())); 65 | 66 | // Alias checks 67 | // Requirement: 68 | // - moveAfterTopologicallyValid(consumer, producer) 69 | // - One of: 70 | // 1) Both are in-place ops 71 | // 2) Consumer is in-place, producer !hasInputWriters 72 | // 3) Producer is in-place, consumer !hasOutputWriters 73 | REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer)); 74 | 75 | // 1) 76 | if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) { 77 | // 2) 78 | if (aliasDb.isMutable(consumer)) { 79 | REQ(!aliasDb.hasInputWriters(producer)); 80 | // 3) 81 | } else if (aliasDb.isMutable(producer)) { 82 | REQ(!aliasDb.hasOutputWriters(consumer)); 83 | } 84 | } 85 | 86 | if (!consumer->hasAttribute(attr::Subgraph) && 87 | consumer->kind() != getTVMSymbol()) { 88 | consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTVMSymbol()); 89 | } 90 | if (producer->kind() == prim::Constant) { 91 | auto& subgraph = consumer->g(attr::Subgraph); 92 | Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* { 93 | throw std::runtime_error("unexpected input"); 94 | }); 95 | subgraph->insertNode(in_const); 96 | } else { 97 | SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer); 98 | } 99 | 100 | return consumer; 101 | } 102 | #undef REQ 103 | 104 | std::pair scanNode( 105 | Node* consumer, 106 | AliasDb& aliasDb, 107 | Block* block) { 108 | auto inputs = sortReverseTopological(consumer->inputs(), block); 109 | for (auto input : inputs) { 110 | if (auto group = tryMerge(consumer, input->node(), aliasDb)) { 111 | // we successfully merged, so the new group's `inputs` may have 112 | // changed. So rescan the new group for more merging opportunities. 113 | return {group.value()->reverseIterator(), true}; 114 | } 115 | } 116 | return {++consumer->reverseIterator(), false}; 117 | } 118 | 119 | void FuseSupportedOps(std::shared_ptr graph) { 120 | AliasDb aliasDb(graph); 121 | auto block = graph->block(); 122 | 123 | bool any_changed{true}; 124 | while (any_changed) { 125 | any_changed = false; 126 | for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { 127 | bool changed; 128 | std::tie(it, changed) = scanNode(*it, aliasDb, block); 129 | any_changed |= changed; 130 | } 131 | } 132 | EliminateCommonSubexpression(graph); 133 | EliminateDeadCode(graph); 134 | } 135 | 136 | const torch::jit::Symbol& getTVMSymbol() { 137 | static torch::jit::Symbol tvm_sym = 138 | torch::jit::Symbol::fromQualString("tvm::CompilationGroup"); 139 | return tvm_sym; 140 | } 141 | -------------------------------------------------------------------------------- /test/test_core.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from test.util import TVMTest 3 | import torch 4 | import torch_tvm 5 | 6 | 7 | class TestCore(TVMTest): 8 | def test_get_handle(self): 9 | shape = 8 10 | x = torch.rand(shape) 11 | y = torch.rand(shape) 12 | z = torch.rand(shape) 13 | 14 | def add(a, b, c): 15 | return a + b + c 16 | 17 | @torch.jit.script 18 | def mul(a, b, c): 19 | return a * b * c 20 | 21 | inputs = [x, y, z] 22 | 23 | torch_tvm.enable() 24 | 25 | trace_tvm = torch.jit.trace(add, inputs) 26 | 27 | relay_graph = torch_tvm.to_relay(trace_tvm, inputs) 28 | relay_graph = torch_tvm.to_relay(add, inputs) 29 | relay_graph = torch_tvm.to_relay(mul, inputs) 30 | 31 | torch_tvm.disable() 32 | 33 | @TVMTest.given(shape=TVMTest.rand_shape(rank=1)) 34 | @unittest.skip("causing segfaults, need to fix operator registration before enable it") 35 | def test_registry(self, shape): 36 | x = torch.rand(shape) 37 | y0 = torch.ops.tvm.relu(x) 38 | y1 = torch.relu(x) 39 | 40 | torch.testing.assert_allclose(y0, y1) 41 | 42 | @TVMTest.given(shape=TVMTest.rand_shape(rank=1)) 43 | def test_core(self, shape): 44 | x = torch.rand(shape) 45 | y = torch.rand(shape) 46 | z = torch.rand(shape) 47 | 48 | def add(a, b, c): 49 | return a + b + c 50 | 51 | inputs = [x, y, z] 52 | 53 | trace_jit = torch.jit.trace(add, inputs) 54 | jit_out = trace_jit(*inputs) 55 | 56 | torch_tvm.enable() 57 | trace_tvm = torch.jit.trace(add, inputs) 58 | tvm_out = trace_tvm(*inputs) 59 | torch_tvm.disable() 60 | torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) 61 | 62 | torch_tvm.enable(opt_level=1) 63 | trace_tvm = torch.jit.trace(add, inputs) 64 | tvm_out = trace_tvm(*inputs) 65 | torch_tvm.disable() 66 | torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) 67 | 68 | torch_tvm.enable(opt_level=3) 69 | trace_tvm = torch.jit.trace(add, inputs) 70 | tvm_out = trace_tvm(*inputs) 71 | torch_tvm.disable() 72 | torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) 73 | 74 | torch_tvm.enable(device_type="cpu", device="llvm", host="llvm") 75 | trace_tvm = torch.jit.trace(add, inputs) 76 | tvm_out = trace_tvm(*inputs) 77 | torch_tvm.disable() 78 | torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) 79 | 80 | @TVMTest.given( 81 | shape=TVMTest.rand_shape(rank=4, min_dim=4), 82 | examples=1 83 | ) 84 | def test_fall_back(self, shape): 85 | inputs = torch.rand(shape) 86 | 87 | def add(input): 88 | return torch.add(input, 1, 2) 89 | 90 | jit_script_reshape = torch.jit.script(add) 91 | jit_out = jit_script_reshape(inputs) 92 | 93 | with self.assertRaises(RuntimeError): 94 | tvm_strict_script_reshape = torch.jit.script(add) 95 | torch_tvm.enable(strict=True) 96 | tvm_out = tvm_strict_script_reshape(inputs) 97 | torch_tvm.disable() 98 | 99 | torch_tvm.enable(strict=False) 100 | tvm_script_reshape = torch.jit.script(add) 101 | tvm_out = tvm_script_reshape(inputs) 102 | torch_tvm.disable() 103 | 104 | torch.testing.assert_allclose(jit_out, tvm_out, rtol=0.01, atol=0.01) 105 | 106 | @TVMTest.given( 107 | shape=TVMTest.rand_shape(rank=4, min_dim=4), 108 | examples=1 109 | ) 110 | def test_dropout_removal(self, shape): 111 | input_a = torch.rand(shape) 112 | input_b = torch.rand(shape) 113 | input_c = torch.rand(shape) 114 | 115 | def dropout_training(a, b, c): 116 | t = a + b 117 | s = torch.dropout(t, 0.1, True) 118 | return s + c 119 | 120 | def dropout_inference(a, b, c): 121 | t = a + b 122 | s = torch.dropout(t, 0.1, False) 123 | return s + c 124 | 125 | torch_tvm.enable() 126 | tvm_graph_training = torch.jit.trace(dropout_training, \ 127 | (input_a, input_b, input_c)) 128 | tvm_graph_inference = torch.jit.trace(dropout_inference, \ 129 | (input_a, input_b, input_c)) 130 | torch_tvm.disable() 131 | assert "aten::dropout" in \ 132 | str(tvm_graph_training.graph_for(input_a, input_b, input_c)), \ 133 | "dropout must not be removed during training." 134 | assert "aten::dropout" not in \ 135 | str(tvm_graph_inference.graph_for(input_a, input_b, input_c)), \ 136 | "dropout must be removed during inference." 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch TVM Extension 2 | [![CircleCI](https://circleci.com/gh/pytorch/tvm.svg?style=svg)](https://circleci.com/gh/pytorch/tvm) 3 | 4 | 5 | 6 | ## Build 7 | 8 | Install the latest Nightly build of PyTorch. 9 | 10 | Then, build this repo 11 | ``` 12 | # Make sure the right llvm-config is in your PATH 13 | python setup.py install 14 | ``` 15 | 16 | ## Test 17 | 18 | ``` 19 | python setup.py test 20 | ``` 21 | 22 | ## Usage 23 | 24 | This package transparently hooks into PyTorch's JIT, so the same tooling is applicable (see `@torch.jit.script`, `torch.jit.trace` and `graph_for`). See below for an example. 25 | 26 | ``` 27 | import torch 28 | import torch_tvm 29 | 30 | torch_tvm.enable() 31 | 32 | # The following function will be compiled with TVM 33 | @torch.jit.script 34 | def my_func(a, b, c): 35 | return a * b + c 36 | ``` 37 | 38 | To disable the JIT hooks, use `torch_tvm.disable()`. 39 | 40 | ## Code Layout 41 | 42 | - `register.cpp`: Sets up pybind bindings and invokes the registration of a TVM backend. 43 | - `compiler.{h,cpp}`: Main logic to compile a PyTorch JIT graph with TVM. 44 | - `operators.{h,cpp}`: Location of mapping from JIT IR to TVM operators. 45 | 46 | ![TVM Integration](https://github.com/pytorch/tvm/blob/master/pt_execution.png?raw=true) 47 | 48 | ## FAQ 49 | 50 | ### How do I configure TVM compilation? 51 | 52 | All options are available as keyword arguments in the `enable` function exposed by `torch_tvm`. 53 | The optimization level, device type, device and host compilation targets are all exposed directly from TVM. 54 | 55 | ``` 56 | torch_tvm.enable( 57 | opt_level=3, 58 | device_type="cpu", 59 | device="llvm", 60 | host="llvm") 61 | ``` 62 | 63 | ### How do I register a new TVM operator? 64 | 65 | First, ensure the operator is [registered with Relay](https://docs.tvm.ai/dev/relay_add_op.html#registering-an-operator). 66 | 67 | Then, register a map from PyTorch symbols to a Relay `CallNode` with `RegisterTVMOperator`. 68 | This can be done in any compilation unit provided it is linked into the final `torch_tvm` library. 69 | See [`torch_tvm/operators.cpp`](https://github.com/pytorch/tvm/blob/master/torch_tvm/operators.cpp) for examples. 70 | 71 | ``` 72 | RegisterTVMOperator reg_relu({ 73 | {Symbol::fromQualString("aten::relu"), 74 | [](Node* node, tvm::Array inputs) { 75 | auto op = tvm::relay::Op::Get("nn.relu"); 76 | return tvm::relay::CallNode::make(op, inputs, tvm::Attrs(), {}); 77 | }}, 78 | }); 79 | ``` 80 | 81 | ### How do I extract the Relay expression associated with a PyTorch Graph? 82 | 83 | If the PyTorch function can be fully converted to Relay, it is possible to extract the expression itself 84 | using `torch_tvm.to_relay(func, inputs)`. Example inputs must be passed in to calculate type information. 85 | 86 | ``` 87 | def add(a, b, c): 88 | return a + b + c 89 | 90 | # via tracing 91 | relay_graph = torch_tvm.to_relay(add, inputs) 92 | 93 | @torch.jit.script 94 | def mul(a, b, c): 95 | return a * b * c 96 | 97 | # via script 98 | relay_graph = torch_tvm.to_relay(mul, inputs) 99 | ``` 100 | 101 | Note that not all functions can be converted to Relay in their entirety and will raise exceptions 102 | if expression extraction is attempted. To solve this isse, simply refactor the function. 103 | 104 | ## v0.1 Roadmap 105 | 106 | Below, in order, is a prioritized list of tasks for this repository. 107 | 108 | - [x] End to end build and runtime 109 | - [x] Operator translation 110 | - [x] Add 111 | - [x] Multiply 112 | - [x] Convolution 113 | - [x] BatchNorm 114 | - [x] Relu 115 | - [x] AveragePool 116 | - [x] MaxPool 117 | - [x] Linear 118 | - [x] Reshape 119 | - [x] AdaptiveAveragePool 120 | - [x] Tooling 121 | - [x] Model coverage checks 122 | - [x] Benchmarks for master 123 | - [x] User exposed configurations 124 | - [x] Backend selection (CPU/Cuda/OpenCL) 125 | - [x] Optimization level 126 | - [ ] Custom TVM operator registration 127 | - [ ] Enable Python/C++ mechanism to use custom TVM operators and schedules 128 | - [x] Enable Relay op registration 129 | - [x] Bail-out mechanism 130 | - When TVM cannot compile a subgraph, invoke PyTorch JIT fallback 131 | - [x] Extract Relay expression 132 | - [x] Enable exposure of ops registered in eager mode under `torch.ops.tvm.*` 133 | 134 | ### v0.2 Plan 135 | 136 | - [ ] View support 137 | - [x] Zero copy `set_input` 138 | - [ ] Subsystem integration 139 | - [ ] Threadpool integration 140 | - [ ] Allocator integration 141 | - `tvm/include/tvm/runtime/device_api.h` 142 | - [ ] Distributed communication 143 | - [ ] IR integration 144 | - [ ] Control flow 145 | - [ ] Aliasing 146 | - [ ] Operators 147 | - [ ] transpose 148 | - [ ] chunk 149 | - [ ] repeat 150 | - [ ] cat 151 | - [ ] unsqueeze 152 | - [ ] slice 153 | - [ ] softmax 154 | - [ ] bmm 155 | - [ ] layernorm 156 | 157 | 158 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/relay/quantize.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "quantize_attrs.h" 11 | 12 | #include 13 | namespace tvm { 14 | namespace relay { 15 | 16 | Expr MakeDataInt8Quantization(Expr data, Expr zero_point, Expr scale, bool is_signed, int precision) { 17 | static const Op& op = Op::Get("nn.quantize_data_int8_quantize"); 18 | auto attrs = make_node(); 19 | attrs->precision = precision; 20 | attrs->is_signed = is_signed; 21 | return CallNode::make(op, {data, zero_point, scale}, Attrs(attrs), {}); 22 | } 23 | 24 | Expr MakeDataInt8RowOffset(Expr quantized_data) { 25 | static const Op& op = Op::Get("nn.quantize_data_int8_row_offset"); 26 | return CallNode::make(op, {quantized_data}, Attrs(), {}); 27 | } 28 | 29 | bool DataInt8QuantizationRel(const Array& types, 30 | int num_inputs, 31 | const Attrs& attrs, 32 | const TypeReporter& reporter) { 33 | // todo: add axis to decide which dim to do the accumulation 34 | CHECK_EQ(types.size(), 4); 35 | const QuantizeSchemeAttrs* param = attrs.as(); 36 | const auto* data = types[0].as(); 37 | // unchnaged shape 38 | Array oshape = data->shape; 39 | Array acc_oshape = {oshape[0]}; 40 | 41 | DataType out_dtype; 42 | if(param->is_signed) { 43 | out_dtype = Int(param->precision); 44 | } else { 45 | out_dtype = UInt(param->precision); 46 | } 47 | reporter->Assign(types[3], TensorTypeNode::make(oshape, out_dtype)); 48 | return true; 49 | } 50 | 51 | bool DataInt8RowOffsetRel(const Array& types, 52 | int num_inputs, 53 | const Attrs& attrs, 54 | const TypeReporter& reporter) { 55 | // todo: add axis to decide which dim to do the accumulation 56 | CHECK_EQ(types.size(), 2); 57 | const auto* data = types[0].as(); 58 | // unchnaged shape 59 | Array oshape = data->shape; 60 | Array acc_oshape = {oshape[0]}; 61 | 62 | reporter->Assign(types[1], TensorTypeNode::make(acc_oshape, Int(32))); 63 | return true; 64 | } 65 | 66 | Expr MakeFindMinMax(Expr data) { 67 | static const Op& op = Op::Get("nn.quantize_findminmax"); 68 | return CallNode::make(op, {data}, Attrs(), {}); 69 | } 70 | 71 | bool FindMinMaxRel(const Array& types, 72 | int num_inputs, 73 | const Attrs& attrs, 74 | const TypeReporter& reporter) { 75 | CHECK_EQ(types.size(), 2); 76 | const auto* data = types[0].as(); 77 | std::vector oshape = {1}; 78 | std::vector fields; 79 | fields.push_back(TensorTypeNode::make(oshape, data->dtype)); 80 | fields.push_back(TensorTypeNode::make(oshape, data->dtype)); 81 | reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); 82 | return true; 83 | } 84 | 85 | Expr MakeDataMMDequantize(Expr data, 86 | Expr weight, 87 | Expr weight_acc, 88 | Expr data_acc, 89 | Expr data_scale, 90 | Expr data_zero_point, 91 | const double w_scale, 92 | const int w_zp, 93 | const int N) { 94 | auto attrs = make_node(); 95 | attrs->w_scale = w_scale; 96 | attrs->w_zp = w_zp; 97 | attrs->N = N; 98 | static const Op& op = Op::Get("nn.quantize_data_mm_dequantize"); 99 | return CallNode::make(op, {data, weight, weight_acc, data_acc, data_scale, data_zero_point}, Attrs(attrs), {}); 100 | } 101 | 102 | bool DataMMDequantizeRel(const Array& types, 103 | int num_inputs, 104 | const Attrs& attrs, 105 | const TypeReporter& reporter) { 106 | CHECK_EQ(types.size(), 7); 107 | const auto* data = types[0].as(); 108 | const auto* weight = types[1].as(); 109 | auto* quantized_params = attrs.as(); 110 | // TODO: check the acc shape 111 | // Assume acc32 input 112 | Array wshape = weight->shape; 113 | Array oshape = data->shape; 114 | oshape.Set((oshape.size() - 1), quantized_params->N); 115 | reporter->Assign(types[6], TensorTypeNode::make(oshape, Float(32))); 116 | return true; 117 | } 118 | 119 | Expr MakeChooseQuantizeParams(Expr data_min, Expr data_max, bool is_signed, int precision) { 120 | auto attrs = make_node(); 121 | attrs->precision = precision; 122 | attrs->is_signed = is_signed; 123 | static const Op& op = Op::Get("nn.choose_quantize_params"); 124 | return CallNode::make(op, {data_min, data_max}, Attrs(attrs), {}); 125 | } 126 | 127 | bool ChooseQuantizeParamsRel(const Array& types, 128 | int num_inputs, 129 | const Attrs& attrs, 130 | const TypeReporter& reporter) { 131 | CHECK_EQ(types.size(), 3); 132 | const auto* data = types[0].as(); 133 | std::vector oshape = {1}; 134 | std::vector fields; 135 | fields.push_back(TensorTypeNode::make(oshape, Int(32))); 136 | fields.push_back(TensorTypeNode::make(oshape, data->dtype)); 137 | reporter->Assign(types[2], TupleTypeNode::make(Array(fields))); 138 | return true; 139 | } 140 | 141 | } // namespace relay 142 | } // namespace tvm 143 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/custom_layer_norm.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "topi/reduction.h" 7 | #include "topi/tags.h" 8 | 9 | #include "custom_layer_norm.h" 10 | 11 | namespace topi { 12 | using namespace tvm; 13 | 14 | inline Array calculate_mean_and_variance( 15 | const Tensor& data, 16 | const Array& normalized_axis) { 17 | auto ndim = data->shape.size(); 18 | auto real_axis = GetRealAxis(static_cast(ndim), normalized_axis); 19 | auto reduce_axes = MakeReduceAxes(real_axis, data); 20 | // Gets the shape of result after reduction. 21 | auto target_shape = MakeReduceTargetShape(real_axis, data, false, false); 22 | 23 | auto fidentity = [](std::vector types) { 24 | Array result; 25 | CHECK(types.size() == 2); 26 | CHECK(types[0] == types[1]); 27 | result.push_back(tvm::make_const(types[0], 0)); // Mean sum 28 | result.push_back(tvm::make_const(types[0], 0)); // Variance sum 29 | return result; 30 | }; 31 | auto fcombine = [](Array lhs, Array rhs) { 32 | Array result; 33 | result.push_back(lhs[0] + rhs[0]); // mean 34 | result.push_back(lhs[1] + rhs[1]); // variance 35 | return result; 36 | }; 37 | auto squared_data = tvm::compute(data->shape, [&data](const Array& indices) { 38 | return data(indices) * data(indices); 39 | }, data->op->name + "_squared"); 40 | auto reducer = MakeCommReducer(fcombine, fidentity, "mean_variance_sum"); 41 | auto compute = [ndim, &real_axis, &reduce_axes, &reducer, &data, &squared_data]( 42 | const Array& indices) { 43 | Array eval_range; 44 | int arg_counter = 0; 45 | int red_counter = 0; 46 | 47 | // eval_range takes index value from indices for thenon reduction axis. 48 | // And for the reduction axis adds reduce_axes which is a Range axis. 49 | // Thus for some 2 dim tensor [5, 5] with dim 1 to reduce, with dim 0 50 | // index of 3 eval_range would (3, ReduceAxis(0, 4)) 51 | for (size_t i = 0; i < ndim; ++i) { 52 | if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { 53 | // real_axis contains i 54 | eval_range.push_back(reduce_axes[red_counter]); 55 | red_counter++; 56 | } else { 57 | eval_range.push_back(indices[i]); 58 | } 59 | } 60 | return reducer({data(eval_range), squared_data(eval_range)}, reduce_axes, nullptr); 61 | }; 62 | 63 | tvm::relay::IndexExpr num_elements = make_const(data->dtype, 1); 64 | for (int64_t i = 0; i < normalized_axis.size(); ++i) { 65 | num_elements *= data->shape[normalized_axis[i]->value]; 66 | } 67 | auto reduce_outputs = tvm::compute( 68 | target_shape, compute, data->op->name + "mean_var_sum", kCommReduce); 69 | return {reduce_outputs[0] / num_elements, reduce_outputs[1] / num_elements}; 70 | } 71 | 72 | inline Tensor custom_layer_norm_impl( 73 | const Tensor& data, 74 | const Array& normalized_axis) { 75 | auto ndim = data->shape.size(); 76 | auto normalized_axis_dim = normalized_axis.size(); 77 | CHECK(ndim > normalized_axis_dim); 78 | auto mean_variance = calculate_mean_and_variance(data, normalized_axis); 79 | auto layer_norm_compute = [&mean_variance, &data, ndim, normalized_axis_dim]( 80 | const Array& indices) { 81 | Array mean_variance_indices( 82 | indices.begin(), indices.begin() + (ndim - normalized_axis_dim)); 83 | auto mean = mean_variance[0]; 84 | auto variance = mean_variance[1]; 85 | auto epsilon = tvm::make_const(Float(32), 1e-9); 86 | auto var_0 = tvm::max( 87 | (variance(mean_variance_indices) - 88 | mean(mean_variance_indices) * mean(mean_variance_indices)), 89 | epsilon); 90 | Expr one = make_const(data->dtype, 1); 91 | auto var_rsqrt = one / tvm::sqrt(var_0); 92 | return ( 93 | data(indices) * var_rsqrt - var_rsqrt * mean(mean_variance_indices)); 94 | }; 95 | return tvm::compute( 96 | data->shape, layer_norm_compute, data->op->name, "custom_layer_norm_tag"); 97 | } 98 | 99 | inline Tensor custom_layer_norm_impl_affine( 100 | const Tensor& data, 101 | const Tensor& gamma, 102 | const Tensor& beta, 103 | const Array& normalized_axis) { 104 | auto ndim = data->shape.size(); 105 | auto normalized_axis_dim = normalized_axis.size(); 106 | CHECK(ndim > normalized_axis_dim); 107 | auto mean_variance = calculate_mean_and_variance(data, normalized_axis); 108 | auto layer_norm_compute = [&mean_variance, 109 | &data, 110 | ndim, 111 | normalized_axis_dim, 112 | &gamma, 113 | &beta](const Array& indices) { 114 | Array mean_variance_indices( 115 | indices.begin(), indices.begin() + (ndim - normalized_axis_dim)); 116 | Array affine_indices( 117 | indices.begin() + (ndim - normalized_axis_dim), indices.end()); 118 | auto mean = mean_variance[0]; 119 | auto variance = mean_variance[1]; 120 | auto epsilon = tvm::make_const(Float(32), 1e-9); 121 | auto var_0 = tvm::max( 122 | (variance(mean_variance_indices) - 123 | mean(mean_variance_indices) * mean(mean_variance_indices)), 124 | epsilon); 125 | Expr one = make_const(data->dtype, 1); 126 | auto var_rsqrt = one / tvm::sqrt(var_0); 127 | return ( 128 | (data(indices) * var_rsqrt - var_rsqrt * mean(mean_variance_indices)) * 129 | gamma(affine_indices) + 130 | beta(affine_indices)); 131 | }; 132 | return tvm::compute( 133 | data->shape, layer_norm_compute, data->op->name, "custom_layer_norm_tag"); 134 | } 135 | 136 | Tensor custom_layer_norm( 137 | const Tensor& data, 138 | const Tensor& gamma, 139 | const Tensor& beta, 140 | const int num_axis_to_normalize, 141 | const bool affine, 142 | const float eps) { 143 | int data_num_dims = data->shape.size(); 144 | CHECK(num_axis_to_normalize < data_num_dims); 145 | int index = (data_num_dims - num_axis_to_normalize); 146 | Array normalized_axis; 147 | for (int i = 0; index < data_num_dims; ++i, ++index) { 148 | normalized_axis.push_back(Integer(index)); 149 | } 150 | if (affine) { 151 | return custom_layer_norm_impl_affine(data, gamma, beta, normalized_axis); 152 | } else { 153 | return custom_layer_norm_impl(data, normalized_axis); 154 | } 155 | } 156 | 157 | } // namespace topi 158 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/x86/quantize_data_mm_dequantize.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "topi/detail/array_utils.h" 5 | #include "topi/detail/fuse.h" 6 | #include "topi/tags.h" 7 | #include "tvm/buffer.h" 8 | #include "tvm/operation.h" 9 | #include "tvm/tensor_intrin.h" 10 | 11 | namespace topi { 12 | using namespace tvm; 13 | 14 | namespace x86 { 15 | 16 | // for avx2 17 | TensorIntrin dot_1x4x16_int8_int8_int32_avx2() { 18 | tvm::Tensor data = tvm::placeholder({4}, UInt(8), "data"); 19 | tvm::Tensor kernel = tvm::placeholder({16, 4}, Int(8), "kernel"); 20 | auto k = tvm::reduce_axis(tvm::Range{0, 4}, "k"); 21 | auto C = tvm::compute( 22 | {16}, 23 | [&](Var i) { 24 | return tvm::sum( 25 | tvm::cast(Int(32), data(k)) * tvm::cast(Int(32), kernel(i, k)), 26 | {k}); 27 | }, 28 | "tensor", 29 | "dense"); 30 | auto a_buf = BufferNode::make( 31 | /*Var ptr*/ Var("a_buf", Handle()), 32 | /*Type dtype*/ UInt(8), 33 | /*Array shape*/ {4}, 34 | /*Array strides*/ {1}, 35 | /*Expr elem_offset*/ Var("a_buf_elem_offset"), 36 | /*std::string name*/ "a_buf", 37 | /*std::string scope*/ "", 38 | /*int data_alignment*/ -1, 39 | /*int offset_factor*/ 1, 40 | /*BufferType buffer_type*/ kDefault); 41 | auto b_buf = BufferNode::make( 42 | Var("b_buf", Handle()), 43 | Int(8), 44 | {16, 4}, 45 | {Var("ldw"), 1}, 46 | Var("b_buf_elem_offset"), 47 | "b_buf", 48 | "", 49 | -1, 50 | 1, 51 | kDefault); 52 | auto c_buf = BufferNode::make( 53 | Var("c_buf", Handle()), 54 | Int(32), 55 | {16}, 56 | {1}, 57 | Var("c_buf_elem_offset"), 58 | "c_buf", 59 | "", 60 | -1, 61 | 1, 62 | kDefault); 63 | Expr a_int8 = a_buf.vload({0}, UInt(8, 4)); 64 | Expr re_int32 = tvm::reinterpret(Int(32), a_int8); 65 | Expr vec_ai32 = tvm::cast(Int(32, 8), re_int32); 66 | Expr vec_a = tvm::reinterpret(Int(8, 32), vec_ai32); 67 | Expr vec_b_0 = b_buf.vload({0, 0}, Int(8, 32)); 68 | Expr vec_b_1 = b_buf.vload({8, 0}, Int(8, 32)); 69 | Expr vec_one = make_const(Int(16, 16), 1); 70 | Expr vec_zero = make_const(UInt(32), 0); 71 | 72 | constexpr auto pair = "llvm.x86.avx2.pmadd.ub.sw"; 73 | constexpr auto quad = "llvm.x86.avx2.pmadd.wd"; 74 | 75 | Expr llvm_pair = 76 | make_const(UInt(32), llvm::Function::lookupIntrinsicID(pair)); 77 | Expr llvm_quad = 78 | make_const(UInt(32), llvm::Function::lookupIntrinsicID(quad)); 79 | 80 | Expr pair_reduction_0 = ir::Call::make( 81 | /*Type type*/ Int(16, 16), 82 | /*std::string name*/ "llvm_intrin", 83 | /*Array args*/ {llvm_pair, vec_zero, vec_a, vec_b_0}, 84 | /*CallType call_type*/ ir::Call::PureIntrinsic 85 | /*FunctionRef func = FunctionRef()*/ 86 | /*int value_index = 0*/ 87 | ); 88 | Expr quad_reduction_0 = ir::Call::make( 89 | Int(32, 8), 90 | "llvm_intrin", 91 | {llvm_quad, vec_zero, pair_reduction_0, vec_one}, 92 | ir::Call::PureIntrinsic); 93 | Expr pair_reduction_1 = ir::Call::make( 94 | Int(16, 16), 95 | "llvm_intrin", 96 | {llvm_pair, vec_zero, vec_a, vec_b_1}, 97 | ir::Call::PureIntrinsic); 98 | Expr quad_reduction_1 = ir::Call::make( 99 | Int(32, 8), 100 | "llvm_intrin", 101 | {llvm_quad, vec_zero, pair_reduction_1, vec_one}, 102 | ir::Call::PureIntrinsic); 103 | 104 | Stmt reduce_init = c_buf.vstore({0}, make_const(Int(32, 16), 0)); 105 | Stmt body_0 = c_buf.vstore({0}, quad_reduction_0); 106 | Stmt body_1 = c_buf.vstore({8}, quad_reduction_1); 107 | Stmt body = ir::Block::make(body_0, body_1); 108 | Stmt reduce_update_0 = 109 | c_buf.vstore({0}, quad_reduction_0 + c_buf.vload({0}, Int(32, 8))); 110 | Stmt reduce_update_1 = 111 | c_buf.vstore({8}, quad_reduction_1 + c_buf.vload({8}, Int(32, 8))); 112 | Stmt reduce_update = ir::Block::make(reduce_update_0, reduce_update_1); 113 | 114 | return TensorIntrinNode::make( 115 | /*std::string name*/ "tensor_intrin", 116 | /*Operation op*/ C->op, 117 | /*Array inputs*/ C->op->InputTensors(), 118 | /*Array buffers*/ {a_buf, b_buf, c_buf}, 119 | /*Array scalar_params*/ {}, 120 | /*Stmt body*/ body, 121 | /*Stmt body*/ reduce_init, 122 | /*Stmt body*/ reduce_update); 123 | } 124 | 125 | // for avx512 126 | TensorIntrin dot_16x1x16_int8_int8_int32_avx512() { 127 | tvm::Tensor data = tvm::placeholder({4}, UInt(8), "data"); 128 | tvm::Tensor kernel = tvm::placeholder({16, 4}, Int(8), "kernel"); 129 | auto k = tvm::reduce_axis(tvm::Range{0, 4}, "k"); 130 | auto C = tvm::compute( 131 | {16}, 132 | [&](Var i) { 133 | return tvm::sum( 134 | tvm::cast(Int(32), data(k)) * tvm::cast(Int(32), kernel(i, k)), 135 | {k}); 136 | }, 137 | "tensor", 138 | "dense"); 139 | auto a_buf = BufferNode::make( 140 | Var("a_buf", Handle()), 141 | UInt(8), 142 | {4}, 143 | {1}, 144 | Var("a_buf_elem_offset"), 145 | "a_buf", 146 | "", 147 | -1, 148 | 1, 149 | kDefault); 150 | auto b_buf = BufferNode::make( 151 | Var("b_buf", Handle()), 152 | Int(8), 153 | {16, 4}, 154 | {Var("ldw"), 1}, 155 | Var("b_buf_elem_offset"), 156 | "b_buf", 157 | "", 158 | -1, 159 | 1, 160 | kDefault); 161 | auto c_buf = BufferNode::make( 162 | Var("c_buf", Handle()), 163 | Int(32), 164 | {16}, 165 | {1}, 166 | Var("c_buf_elem_offset"), 167 | "c_buf", 168 | "", 169 | -1, 170 | 1, 171 | kDefault); 172 | 173 | Expr a_int8 = a_buf.vload({0}, UInt(8, 4)); 174 | Expr vec_b = b_buf.vload({0, 0}, Int(8, 64)); 175 | Expr vec_one = make_const(Int(16, 32), 1); 176 | Expr re_int32 = tvm::reinterpret(Int(32), a_int8); 177 | Expr vec_ai32 = tvm::cast(Int(32, 16), re_int32); 178 | Expr vec_a = tvm::reinterpret(Int(8, 64), vec_ai32); 179 | Expr vec_zero = make_const(UInt(32), 0); 180 | constexpr auto pair = "llvm.x86.avx512.pmaddubs.w.512"; 181 | constexpr auto quad = "llvm.x86.avx512.pmaddw.d.512"; 182 | Expr llvm_pair = 183 | make_const(UInt(32), llvm::Function::lookupIntrinsicID(pair)); 184 | Expr llvm_quad = 185 | make_const(UInt(32), llvm::Function::lookupIntrinsicID(quad)); 186 | Expr pair_reduction = ir::Call::make( 187 | Int(16, 32), 188 | "llvm_intrin", 189 | {llvm_pair, vec_zero, vec_a, vec_b}, 190 | ir::Call::PureIntrinsic); 191 | Expr quad_reduction = ir::Call::make( 192 | Int(32, 16), 193 | "llvm_intrin", 194 | {llvm_quad, vec_zero, pair_reduction, vec_one}, 195 | ir::Call::PureIntrinsic); 196 | 197 | Stmt reduce_init = c_buf.vstore({0}, make_const(Int(32, 16), 0)); 198 | Stmt body = c_buf.vstore({0}, quad_reduction); 199 | Stmt reduce_update = 200 | c_buf.vstore({0}, quad_reduction + c_buf.vload({0}, Int(32, 16))); 201 | 202 | return TensorIntrinNode::make( 203 | "tensor_intrin", 204 | C->op, 205 | C->op->InputTensors(), 206 | {a_buf, b_buf, c_buf}, 207 | {}, 208 | body, 209 | reduce_init, 210 | reduce_update); 211 | } 212 | 213 | inline Schedule schedule_quantized_mm_dequantize( 214 | const Target& target, 215 | const Array& outs) { 216 | Array out_ops; 217 | for (auto t : outs) { 218 | out_ops.push_back(t->op); 219 | } 220 | auto s = create_schedule(out_ops); 221 | 222 | auto _schedule_quantized_mm = [&](const Tensor& input) { 223 | auto axis = s[input]->op.as()->axis; 224 | CHECK_EQ(axis.size(), 2); 225 | auto y = axis[0]; 226 | auto x = axis[1]; 227 | auto reduce_axis = s[input]->op.as()->reduce_axis; 228 | CHECK_EQ(reduce_axis.size(), 1); 229 | auto k = reduce_axis[0]; 230 | auto x_dim_size = input->shape[1]; 231 | if (*as_const_int(x_dim_size) >= 16) { 232 | IterVar xo, xi; 233 | IterVar ko, ki; 234 | s[input].split(x, 16, &xo, &xi); 235 | s[input].split(k, 4, &ko, &ki); 236 | s[input].reorder({xo, ko, y, xi, ki}); 237 | s[input].unroll(y); 238 | if (target->options_array[0].as()->value == 239 | "-mcpu=skylake-avx512") { 240 | auto pc = dot_16x1x16_int8_int8_int32_avx512(); 241 | s[input].tensorize(xi, pc); 242 | } else { 243 | auto pc = dot_1x4x16_int8_int8_int32_avx2(); 244 | s[input].tensorize(xi, pc); 245 | } 246 | } else { 247 | s[input].reorder({y, x}); 248 | s[input].unroll(y); 249 | s[input].vectorize(x); 250 | } 251 | }; 252 | 253 | auto _schedule_mm_dequantize = [&](const Tensor& output) { 254 | for (auto tensor : output->op->InputTensors()) { 255 | if (tensor->op->tag.rfind("quantized_mm", 0) == 0) { 256 | _schedule_quantized_mm(tensor); 257 | } 258 | } 259 | }; 260 | 261 | std::function traverse; 262 | traverse = [&](const Operation& op) { 263 | // Inline all one-to-one-mapping operators except the last stage (output) 264 | if (is_broadcast(op->tag)) { 265 | if (!detail::contains(s->outputs, op)) { 266 | s[op].compute_inline(); 267 | } 268 | for (auto tensor : op->InputTensors()) { 269 | if (tensor->op->InputTensors().size() > 0) { 270 | traverse(tensor->op); 271 | } 272 | } 273 | } 274 | if (op->tag.rfind("mm_dequantize", 0) == 0) { 275 | auto output = op.output(0); 276 | _schedule_mm_dequantize(output); 277 | } 278 | }; 279 | 280 | traverse(outs[0]->op); 281 | return s; 282 | } 283 | } // namespace x86 284 | } // namespace topi 285 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from distutils.spawn import find_executable 7 | from distutils import sysconfig, log 8 | import setuptools 9 | import setuptools.command.build_py 10 | import setuptools.command.develop 11 | import setuptools.command.build_ext 12 | import setuptools.command.install 13 | 14 | from collections import namedtuple 15 | from contextlib import contextmanager 16 | import glob 17 | import os 18 | import shlex 19 | import subprocess 20 | import sys 21 | from textwrap import dedent 22 | import multiprocessing 23 | 24 | try: 25 | import torch 26 | except ImportError as e: 27 | print('Unable to import torch. Error:') 28 | print('\t', e) 29 | print('You need to install pytorch first.') 30 | sys.exit(1) 31 | 32 | 33 | TOP_DIR = os.path.realpath(os.path.dirname(__file__)) 34 | SRC_DIR = os.path.join(TOP_DIR, 'torch_tvm') 35 | CMAKE_BUILD_DIR = os.path.join(TOP_DIR, 'build') 36 | 37 | CMAKE = find_executable('cmake') or find_executable('cmake3') 38 | if not CMAKE: 39 | print('Could not find "cmake". Make sure it is in your PATH.') 40 | sys.exit(1) 41 | 42 | llvm_config_versions = ['', '-7', '-6.0'] 43 | for ver in llvm_config_versions: 44 | LLVM_CONFIG = find_executable('llvm-config{}'.format(ver)) 45 | if LLVM_CONFIG: 46 | break 47 | else: 48 | print('Could not find "llvm-config". Make sure it is in your PATH.') 49 | sys.exit(1) 50 | 51 | install_requires = [] 52 | setup_requires = [] 53 | tests_require = [] 54 | extras_require = {} 55 | 56 | ################################################################################ 57 | # Global flags and environment variables 58 | ################################################################################ 59 | 60 | DEBUG = bool(os.getenv('DEBUG')) 61 | RERUN_CMAKE = False 62 | filtered_args = [] 63 | for i, arg in enumerate(sys.argv): 64 | if arg == '--cmake': 65 | RERUN_CMAKE = True 66 | continue 67 | filtered_args.append(arg) 68 | sys.argv = filtered_args 69 | 70 | ################################################################################ 71 | # Version 72 | ################################################################################ 73 | 74 | try: 75 | git_version = subprocess.check_output(['git', 'rev-parse', 'HEAD'], 76 | cwd=TOP_DIR).decode('ascii').strip() 77 | except (OSError, subprocess.CalledProcessError): 78 | git_version = None 79 | 80 | with open(os.path.join(TOP_DIR, 'VERSION_NUMBER')) as version_file: 81 | VersionInfo = namedtuple('VersionInfo', ['version', 'git_version'])( 82 | version=version_file.read().strip(), 83 | git_version=git_version 84 | ) 85 | 86 | ################################################################################ 87 | # Utilities 88 | ################################################################################ 89 | 90 | 91 | @contextmanager 92 | def cd(path): 93 | if not os.path.isabs(path): 94 | raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) 95 | orig_path = os.getcwd() 96 | os.chdir(path) 97 | try: 98 | yield 99 | finally: 100 | os.chdir(orig_path) 101 | 102 | 103 | ################################################################################ 104 | # Customized commands 105 | ################################################################################ 106 | 107 | class build_py(setuptools.command.build_py.build_py): 108 | def run(self): 109 | with open(os.path.join(SRC_DIR, 'version.py'), 'w') as f: 110 | f.write(dedent('''\ 111 | # This file is generated by setup.py. DO NOT EDIT! 112 | 113 | from __future__ import absolute_import 114 | from __future__ import division 115 | from __future__ import print_function 116 | from __future__ import unicode_literals 117 | 118 | version = '{version}' 119 | git_version = '{git_version}' 120 | '''.format(**dict(VersionInfo._asdict())))) 121 | return setuptools.command.build_py.build_py.run(self) 122 | 123 | 124 | class cmake_build(setuptools.Command): 125 | """ 126 | Compiles everything when `python setup.py develop` is run using cmake. 127 | 128 | Custom args can be passed to cmake by specifying the `CMAKE_ARGS` 129 | environment variable. 130 | """ 131 | def initialize_options(self): 132 | pass 133 | 134 | def finalize_options(self): 135 | pass 136 | 137 | def _run_cmake(self): 138 | with cd(CMAKE_BUILD_DIR): 139 | cmake_args = [ 140 | CMAKE, 141 | '-DCMAKE_EXPORT_COMPILE_COMMANDS=ON', 142 | '-DCMAKE_BUILD_TYPE={}'.format('Debug' if DEBUG else 'Release'), 143 | '-DPYTHON_EXECUTABLE={}'.format(sys.executable), 144 | # PyTorch cmake args 145 | '-DPYTORCH_DIR={}'.format( 146 | os.path.dirname(os.path.realpath(torch.__file__))), 147 | # TVM cmake args 148 | '-DUSE_LLVM={}'.format(LLVM_CONFIG), 149 | '-DTVM_DIR={}'.format(os.path.join(TOP_DIR, 'tvm')), 150 | ] 151 | if 'CMAKE_ARGS' in os.environ: 152 | extra_cmake_args = shlex.split(os.environ['CMAKE_ARGS']) 153 | log.info('Extra cmake args: {}'.format(extra_cmake_args)) 154 | cmake_args.extend(extra_cmake_args) 155 | cmake_args.append(TOP_DIR) 156 | subprocess.check_call(cmake_args) 157 | 158 | def _run_build(self): 159 | with cd(CMAKE_BUILD_DIR): 160 | build_args = [ 161 | CMAKE, 162 | '--build', os.curdir, 163 | '--', '-j', str(multiprocessing.cpu_count()), 164 | ] 165 | subprocess.check_call(build_args) 166 | 167 | 168 | def run(self): 169 | if not os.path.exists(CMAKE_BUILD_DIR): 170 | os.makedirs(CMAKE_BUILD_DIR) 171 | 172 | is_initial_build = not os.path.exists(os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt")) 173 | if is_initial_build or RERUN_CMAKE: 174 | self._run_cmake() 175 | 176 | self._run_build() 177 | 178 | 179 | class develop(setuptools.command.develop.develop): 180 | def run(self): 181 | self.run_command('build_ext') 182 | setuptools.command.develop.develop.run(self) 183 | 184 | 185 | class build_ext(setuptools.command.build_ext.build_ext): 186 | def run(self): 187 | self.run_command('cmake_build') 188 | setuptools.command.build_ext.build_ext.run(self) 189 | 190 | def build_extensions(self): 191 | for ext in self.extensions: 192 | fullname = self.get_ext_fullname(ext.name) 193 | filename = os.path.basename(self.get_ext_filename(fullname)) 194 | 195 | src = os.path.join(CMAKE_BUILD_DIR, filename) 196 | dst = os.path.join(os.path.realpath(self.build_lib), 'torch_tvm', filename) 197 | if not os.path.exists(os.path.dirname(dst)): 198 | os.makedirs(os.path.dirname(dst)) 199 | self.copy_file(src, dst) 200 | 201 | class install(setuptools.command.install.install): 202 | def run(self): 203 | setuptools.command.install.install.run(self) 204 | mk_tvm_dir = [ 205 | 'mkdir', 206 | '-p', 207 | '{}'.format(os.path.join(TOP_DIR, 'tvm', 'build')), 208 | ] 209 | subprocess.check_call(mk_tvm_dir) 210 | copy_tvm_files = [ 211 | 'cp', 212 | '-a', 213 | '{}'.format(os.path.join(CMAKE_BUILD_DIR, 'tvm', '.')), 214 | '{}'.format(os.path.join(TOP_DIR, 'tvm', 'build')), 215 | ] 216 | subprocess.check_call(copy_tvm_files) 217 | with cd(os.path.join(TOP_DIR, 'tvm', 'python')): 218 | subprocess.check_call("{} setup.py install".format(sys.executable), shell=True) 219 | with cd(os.path.join(TOP_DIR, 'tvm', 'topi', 'python')): 220 | subprocess.check_call("{} setup.py install".format(sys.executable), shell=True) 221 | 222 | cmdclass = { 223 | 'cmake_build': cmake_build, 224 | 'build_py': build_py, 225 | 'develop': develop, 226 | 'build_ext': build_ext, 227 | 'install': install, 228 | } 229 | 230 | ################################################################################ 231 | # Extensions 232 | ################################################################################ 233 | 234 | ext_modules = [ 235 | setuptools.Extension( 236 | name=str('torch_tvm._torch_tvm'), 237 | sources=[]) 238 | ] 239 | 240 | ################################################################################ 241 | # Packages 242 | ################################################################################ 243 | 244 | # no need to do fancy stuff so far 245 | packages = setuptools.find_packages() 246 | 247 | ################################################################################ 248 | # Test 249 | ################################################################################ 250 | 251 | setup_requires.append('pytest-runner') 252 | tests_require.append('pytest') 253 | tests_require.append('scikit-image') 254 | 255 | ################################################################################ 256 | # Final 257 | ################################################################################ 258 | 259 | setuptools.setup( 260 | name="torch_tvm", 261 | version=VersionInfo.version, 262 | description="PyTorch + TVM", 263 | ext_modules=ext_modules, 264 | cmdclass=cmdclass, 265 | packages=packages, 266 | include_package_data=True, 267 | install_requires=install_requires, 268 | setup_requires=setup_requires, 269 | tests_require=tests_require, 270 | extras_require=extras_require, 271 | author='bddppq', 272 | author_email='jbai@fb.com', 273 | url='https://github.com/pytorch/tvm', 274 | ) 275 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/test/test_custom_layer_norm.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import unittest 4 | from collections import namedtuple 5 | import itertools 6 | import logging 7 | import tvm 8 | import os 9 | 10 | from torch_tvm._torch_tvm import * 11 | 12 | logger = logging.getLogger() 13 | # Important to import relay after loading the library because 14 | # importing python/tvm/relay/op/nn/_make.py happens during 15 | # import relay and _make.py initializes internal dictionary of 16 | # apis. If we import it before then the registration that happens 17 | # during loading relay lib will not be visible on the python side. 18 | # TODO: Fix this. 19 | import tvm 20 | from tvm import relay 21 | 22 | BuildConfig = namedtuple('BuildConfig', 'ctx target') 23 | 24 | class CustomLayerNormUtils(object): 25 | EPSILON_FLOAT = 1e-9 26 | @staticmethod 27 | def pt_layer_norm(a, shape, normalized_axis, weight, bias): 28 | a_np = a.asnumpy() 29 | weight_pt = bias_pt = None 30 | if weight: 31 | weight_pt = torch.from_numpy(weight.asnumpy()) 32 | if bias: 33 | bias_pt = torch.from_numpy(bias.asnumpy()) 34 | a_pt = torch.from_numpy(a_np) 35 | pt_normalized_axis = [] 36 | for i in range(len(normalized_axis)): 37 | pt_normalized_axis.append(shape[normalized_axis[i]]) 38 | a_out = torch.layer_norm(a_pt, pt_normalized_axis, weight_pt, \ 39 | bias_pt, eps=CustomLayerNormUtils.EPSILON_FLOAT, cudnn_enable=False) 40 | return a_out.numpy() 41 | 42 | @staticmethod 43 | def print_schedule(schedule, input_ph, output_ph): 44 | print(tvm.lower(schedule, [*input_ph, output_ph], simple_mode=True)) 45 | 46 | @staticmethod 47 | def optimize_schedule(schedule, output_tensor): 48 | divide_1 = output_tensor.op.input_tensors[1] 49 | divide_2 = output_tensor.op.input_tensors[2] 50 | mean_var_sum = divide_1.op.input_tensors[0] 51 | squared_data = mean_var_sum.op.input_tensors[1] 52 | schedule[divide_1].compute_inline() 53 | schedule[divide_2].compute_inline() 54 | ko, ki = schedule[mean_var_sum].split(mean_var_sum.op.reduce_axis[0], factor=8) 55 | BF = schedule.rfactor(mean_var_sum, ki) 56 | schedule[mean_var_sum].compute_at(schedule[output_tensor], output_tensor.op.axis[0]) 57 | schedule[BF[0]].compute_at(schedule[output_tensor], output_tensor.op.axis[0]) 58 | schedule[squared_data].compute_inline() 59 | 60 | @staticmethod 61 | def tvm_layer_norm_via_topi(a, a_out, shape, normalized_axis, \ 62 | build_config, weight=None, bias=None): 63 | ctx = build_config.ctx 64 | target = build_config.target 65 | num_axis_to_normalize = len(normalized_axis) 66 | start_index = len(shape) - num_axis_to_normalize 67 | weight_shape = shape[start_index:] 68 | input_ph = tvm.placeholder(shape, name='input_placeholder') 69 | weights_ph = tvm.placeholder(weight_shape, name='weights_placeholder') 70 | bias_ph = tvm.placeholder(weight_shape, name='bias_placeholder') 71 | affine = False 72 | layer_norm = tvm.get_global_func("nn.custom_layer_norm") 73 | if weight is not None and bias is not None: 74 | affine = True 75 | normalized_output_ph = layer_norm(input_ph, weights_ph, bias_ph, \ 76 | num_axis_to_normalize, affine, CustomLayerNormUtils.EPSILON_FLOAT) 77 | s = tvm.create_schedule([normalized_output_ph.op]) 78 | CustomLayerNormUtils.optimize_schedule(s, normalized_output_ph) 79 | if not affine: 80 | CustomLayerNormUtils.print_schedule(s, [input_ph], normalized_output_ph) 81 | else: 82 | CustomLayerNormUtils.print_schedule(s, [input_ph, weights_ph, bias_ph], normalized_output_ph) 83 | if affine: 84 | layer_norm_func = tvm.build(s, [input_ph, weights_ph, \ 85 | bias_ph, normalized_output_ph], target=target) 86 | layer_norm_func(a, weight, bias, a_out) 87 | else: 88 | layer_norm_func = tvm.build(s, [input_ph, normalized_output_ph], 89 | target=target) 90 | layer_norm_func(a, a_out) 91 | return a_out.asnumpy() 92 | 93 | @staticmethod 94 | def tvm_layer_norm_via_relay(a, shape, normalized_axis, build_config, \ 95 | weight=None, bias=None): 96 | ctx = build_config.ctx 97 | target = build_config.target 98 | # Must clear the compile engine as it caches the functions compiled. 99 | # Have not looked thru code yet but it seems that it probably specializes 100 | # on shapes that are passed in the placeholders below. The issue is that 101 | # there is only one layer_norm relay op, but making this op will return 102 | # different Tensor expression depending on the affine parameter being true 103 | # or not. Now if we dont clear the cache the older function gets used 104 | # which may correspond to affine or may not. 105 | # This is a general problem. Caching function should really account for 106 | # parameters matching as well and not just shapes of input tensors. 107 | compile_engine = tvm.relay.backend.compile_engine.get() 108 | tvm.relay.backend.compile_engine.CompileEngine.clear(compile_engine) 109 | # Assume normalized axis is sorted. 110 | num_axis_to_normalize = len(normalized_axis) 111 | start_index = len(shape) - num_axis_to_normalize 112 | weight_shape = shape[start_index:] 113 | data_ph = relay.var("data", relay.TensorType(shape, "float32")) 114 | weight_ph = relay.var("weight_ph", relay.TensorType(weight_shape, "float32")) 115 | bias_ph = relay.var("bias_ph", relay.TensorType(weight_shape, "float32")) 116 | affine = False 117 | if weight is not None and bias is not None: 118 | affine = True 119 | # TODO: Should really fix this. Do a python wrapper in tvm style 120 | # to hide _make call. 121 | out_ph = relay.op.nn._make.custom_layer_norm(data_ph, weight_ph, bias_ph, \ 122 | num_axis_to_normalize, affine, CustomLayerNormUtils.EPSILON_FLOAT) 123 | func = relay.Function([data_ph, weight_ph, bias_ph], out_ph) 124 | func_mod = relay.module.Module.from_expr(func) 125 | intrp = relay.create_executor("graph", mod=func_mod, ctx=ctx, target=target) 126 | if affine: 127 | tvm_out = intrp.evaluate()(a, weight, bias) 128 | else: 129 | tvm_out = intrp.evaluate()(a) 130 | return tvm_out.asnumpy() 131 | 132 | @staticmethod 133 | def test_custom_layer_norm(shape, normalized_axis, build_config, dtype, affine): 134 | ctx = build_config.ctx 135 | a = tvm.nd.array(numpy.random.rand(*shape).astype(dtype), ctx) 136 | a_out = tvm.nd.array(numpy.zeros(shape, dtype=dtype), ctx) 137 | weight = None 138 | bias = None 139 | if affine: 140 | weight_shape = shape[normalized_axis[0]:] 141 | weight = tvm.nd.array(numpy.random.rand(*weight_shape).astype(dtype), ctx) 142 | bias = tvm.nd.array(numpy.random.rand(*weight_shape).astype(dtype), ctx) 143 | 144 | pt_out = CustomLayerNormUtils.pt_layer_norm(a, shape, normalized_axis, weight, bias) 145 | tvm_out_via_topi = CustomLayerNormUtils.tvm_layer_norm_via_topi(a, \ 146 | a_out, shape, normalized_axis, build_config, weight, bias) 147 | torch.allclose(torch.from_numpy(pt_out), torch.from_numpy(tvm_out_via_topi)) 148 | # Need to comment this out due to libarary loading issues because of 149 | # which registered functions are getting overwritten. 150 | #tvm_out_via_relay = CustomLayerNormUtils.tvm_layer_norm_via_relay(a, \ 151 | # shape, normalized_axis, build_config, weight, bias) 152 | #numpy.testing.assert_array_almost_equal(pt_out, \ 153 | # tvm_out_via_relay, decimal=5) 154 | 155 | @staticmethod 156 | def gen_shapes(shape_list, dims=1): 157 | shapes = [] 158 | for _ in range(dims): 159 | shapes.append([s for s in shape_list]) 160 | return itertools.product(*shapes) 161 | 162 | class TestFBLayerNorm(unittest.TestCase): 163 | def setUp(self): 164 | super(TestFBLayerNorm, self).setUp() 165 | target="llvm -mcpu=broadwell" 166 | ctx = tvm.context(target, 0) 167 | self.build_config = BuildConfig(ctx, target) 168 | self.dtype = "float32" 169 | numpy.random.seed(42) 170 | 171 | def test_twodim_array(self): 172 | dims = [8, 16, 32, 64] 173 | shapes = CustomLayerNormUtils.gen_shapes(dims, 2) 174 | for shape in shapes: 175 | logger.info("Testing shape {},{}".format(*shape)) 176 | normalized_axis = [1] 177 | 178 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 179 | self.build_config, self.dtype, False) 180 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 181 | self.build_config, self.dtype, True) 182 | 183 | def test_threedim_array(self): 184 | dims = [8, 16, 32, 64] 185 | shapes = CustomLayerNormUtils.gen_shapes(dims, 3) 186 | for shape in shapes: 187 | logger.info("Testing shape {},{}".format(*shape)) 188 | normalized_axis = [2] 189 | 190 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 191 | self.build_config, self.dtype, False) 192 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 193 | self.build_config, self.dtype, True) 194 | 195 | def test_threedim_array_2(self): 196 | dims = [8, 16, 32, 64] 197 | shapes = CustomLayerNormUtils.gen_shapes(dims, 3) 198 | for shape in shapes: 199 | logger.info("Testing shape {},{}".format(*shape)) 200 | normalized_axis = [1, 2] 201 | 202 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 203 | self.build_config, self.dtype, False) 204 | CustomLayerNormUtils.test_custom_layer_norm(shape, normalized_axis, \ 205 | self.build_config, self.dtype, True) 206 | 207 | if __name__ == "__main__": 208 | unittest.main() 209 | -------------------------------------------------------------------------------- /torch_tvm/custom_tvm_ops/cpp/topi/custom_topi_ops.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "custom_layer_norm.h" 15 | #include "custom_layer_norm_generic_sched.h" 16 | #include "quantize.h" 17 | #include "contrib/quantize.h" 18 | #include "generic/quantize_generic_sched.h" 19 | #include "x86/quantize_data_mm_dequantize.h" 20 | 21 | namespace tvm { 22 | 23 | using tvm::relay::OpPatternKind; 24 | 25 | class CustomTOPIOpRegisterer { 26 | public: 27 | CustomTOPIOpRegisterer() { 28 | auto reg_ptr = runtime::Registry::Get("relay.op._Register"); 29 | CHECK(reg_ptr) << "Cannot find function for relay.op._Register"; 30 | (*reg_ptr)( 31 | "nn.custom_layer_norm", 32 | "TOpPattern", 33 | static_cast(OpPatternKind::kOutEWiseFusable), 34 | 10); 35 | (*reg_ptr)( 36 | "nn.custom_layer_norm", 37 | "FTVMCompute", 38 | tvm::relay::FTVMCompute( 39 | [](const tvm::Attrs& attrs, 40 | const tvm::Array& inputs, 41 | const tvm::relay::Type& out_type, 42 | const tvm::Target& target) -> tvm::Array { 43 | const relay::CustomLayerNormAttrs* param = 44 | attrs.as(); 45 | auto num_axis_to_normalize = param->num_axis_to_normalize; 46 | auto affine = param->affine; 47 | auto eps = param->eps; 48 | return tvm::Array{topi::custom_layer_norm( 49 | inputs[0], 50 | inputs[1], 51 | inputs[2], 52 | num_axis_to_normalize, 53 | affine, 54 | eps)}; 55 | }), 56 | 10); 57 | (*reg_ptr)( 58 | "nn.custom_layer_norm", 59 | "FTVMSchedule", 60 | tvm::relay::FTVMSchedule( 61 | [](const tvm::Attrs& attrs, 62 | const tvm::Array& outs, 63 | const tvm::Target& target) -> tvm::Schedule { 64 | // dispatches x86 schedule when the target is "llvm" and the 65 | return topi::generic::schedule_custom_layer_norm(outs); 66 | }), 67 | 10); 68 | (*reg_ptr)( 69 | "nn.quantize_findminmax", 70 | "TOpPattern", 71 | static_cast(OpPatternKind::kCommReduce), 72 | 10); 73 | (*reg_ptr)( 74 | "nn.quantize_findminmax", 75 | "FTVMCompute", 76 | tvm::relay::FTVMCompute( 77 | [](const tvm::Attrs& attrs, 78 | const tvm::Array& inputs, 79 | const tvm::relay::Type& out_type, 80 | const tvm::Target& target) -> tvm::Array { 81 | return topi::contrib::quantize_findminmax(inputs[0]); 82 | }), 83 | 10); 84 | (*reg_ptr)( 85 | "nn.quantize_findminmax", 86 | "FTVMSchedule", 87 | tvm::relay::FTVMSchedule( 88 | [](const tvm::Attrs& attrs, 89 | const tvm::Array& outs, 90 | const tvm::Target& target) -> tvm::Schedule { 91 | return topi::generic::schedule_quantize_findminmax(outs); 92 | }), 93 | 10); 94 | (*reg_ptr)( 95 | "nn.choose_quantize_params", 96 | "TOpPattern", 97 | static_cast(OpPatternKind::kOpaque), 98 | 10); 99 | (*reg_ptr)( 100 | "nn.choose_quantize_params", 101 | "FTVMCompute", 102 | tvm::relay::FTVMCompute( 103 | [](const tvm::Attrs& attrs, 104 | const tvm::Array& inputs, 105 | const tvm::relay::Type& out_type, 106 | const tvm::Target& target) -> tvm::Array { 107 | const auto* param = attrs.as(); 108 | auto precision = param->precision; 109 | auto is_signed = param->is_signed; 110 | return topi::contrib::choose_quantize_params(inputs[0], inputs[1], is_signed, precision); 111 | }), 112 | 10); 113 | (*reg_ptr)( 114 | "nn.choose_quantize_params", 115 | "FTVMSchedule", 116 | tvm::relay::FTVMSchedule( 117 | [](const tvm::Attrs& attrs, 118 | const tvm::Array& outs, 119 | const tvm::Target& target) -> tvm::Schedule { 120 | return topi::generic::schedule_choose_quantize_params(outs); 121 | }), 122 | 10); 123 | (*reg_ptr)( 124 | "nn.quantize_data_int8_quantize", 125 | "TOpPattern", 126 | static_cast(OpPatternKind::kOutEWiseFusable), 127 | 10); 128 | (*reg_ptr)( 129 | "nn.quantize_data_int8_quantize", 130 | "FTVMCompute", 131 | tvm::relay::FTVMCompute( 132 | [](const tvm::Attrs& attrs, 133 | const tvm::Array& inputs, 134 | const tvm::relay::Type& out_type, 135 | const tvm::Target& target) -> tvm::Array { 136 | const auto* param = attrs.as(); 137 | auto precision = param->precision; 138 | auto is_signed = param->is_signed; 139 | return topi::data_int8_quantize(inputs[0], inputs[1], inputs[2], is_signed, precision); 140 | }), 141 | 10); 142 | (*reg_ptr)( 143 | "nn.quantize_data_int8_quantize", 144 | "FTVMSchedule", 145 | tvm::relay::FTVMSchedule( 146 | [](const tvm::Attrs& attrs, 147 | const tvm::Array& outs, 148 | const tvm::Target& target) -> tvm::Schedule { 149 | return topi::generic::schedule_quantize_data_int8_quantize(outs); 150 | }), 151 | 10); 152 | (*reg_ptr)( 153 | "nn.quantize_data_int8_row_offset", 154 | "TOpPattern", 155 | static_cast(OpPatternKind::kOutEWiseFusable), 156 | 10); 157 | (*reg_ptr)( 158 | "nn.quantize_data_int8_row_offset", 159 | "FTVMCompute", 160 | tvm::relay::FTVMCompute( 161 | [](const tvm::Attrs& attrs, 162 | const tvm::Array& inputs, 163 | const tvm::relay::Type& out_type, 164 | const tvm::Target& target) -> tvm::Array { 165 | return topi::data_int8_row_offset(inputs[0]); 166 | }), 167 | 10); 168 | (*reg_ptr)( 169 | "nn.quantize_data_int8_row_offset", 170 | "FTVMSchedule", 171 | tvm::relay::FTVMSchedule( 172 | [](const tvm::Attrs& attrs, 173 | const tvm::Array& outs, 174 | const tvm::Target& target) -> tvm::Schedule { 175 | return topi::generic::schedule_quantize_data_int8_row_offset(outs); 176 | }), 177 | 10); 178 | (*reg_ptr)( 179 | "nn.quantize_data_mm_dequantize", 180 | "TOpPattern", 181 | static_cast(OpPatternKind::kOutEWiseFusable), 182 | 10); 183 | (*reg_ptr)( 184 | "nn.quantize_data_mm_dequantize", 185 | "FTVMCompute", 186 | tvm::relay::FTVMCompute( 187 | [](const tvm::Attrs& attrs, 188 | const tvm::Array& inputs, 189 | const tvm::relay::Type& out_type, 190 | const tvm::Target& target) -> tvm::Array { 191 | const auto* param = attrs.as(); 192 | auto w_scale = param->w_scale; 193 | auto w_zp = param->w_zp; 194 | auto N = param->N; 195 | 196 | return topi::data_int8_mm_dequantize( 197 | inputs[0], inputs[1], inputs[2], inputs[3], 198 | inputs[4], inputs[5], w_scale, w_zp, N); 199 | }), 200 | 10); 201 | (*reg_ptr)( 202 | "nn.quantize_data_mm_dequantize", 203 | "FTVMSchedule", 204 | tvm::relay::FTVMSchedule( 205 | [](const tvm::Attrs& attrs, 206 | const tvm::Array& outs, 207 | const tvm::Target& target) -> tvm::Schedule { 208 | auto schedule_quantized_mm_dequantize = 209 | tvm::GenericFunc::Get("schedule_quantized_mm_dequantize"); 210 | return schedule_quantized_mm_dequantize(outs); 211 | }), 212 | 10); 213 | } 214 | }; 215 | 216 | static CustomTOPIOpRegisterer custom_top_op_registerer; 217 | 218 | } // namespace tvm 219 | 220 | namespace topi { 221 | 222 | using namespace tvm; 223 | using namespace tvm::runtime; 224 | 225 | /*! \brief Builder function for instantiating schedules. */ 226 | using FTVMScheduleBuilder = std::function& outs)>; 229 | 230 | /*! 231 | * \brief Helper function for registering generic functions matching the 232 | * FTVMScheduleBuilder signature. The schedule builder function is wrapped 233 | * with a PackedFunc suitable for passing to a tvm::GenericFunc. 234 | * 235 | * \param builder The schedule builder to wrap. 236 | * 237 | * \return The wrapped schedule builder 238 | */ 239 | inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { 240 | return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { 241 | auto target = Target::Current(false); 242 | Array outs; 243 | NodeRef argNodeRef = args[0]; 244 | if (argNodeRef->type_index() == outs->type_index()) { 245 | outs = args[0]; 246 | } else { 247 | outs = Array{args[0]}; 248 | } 249 | 250 | *ret = builder(target, outs); 251 | }); 252 | } 253 | 254 | // For python API 255 | TVM_REGISTER_GLOBAL("nn.custom_layer_norm") 256 | .set_body([](TVMArgs args, TVMRetValue* rv) { 257 | CHECK(args.size() == 6); 258 | *rv = custom_layer_norm( 259 | args[0], 260 | args[1], 261 | args[2], 262 | args[3], 263 | args[4], 264 | static_cast(args[5])); 265 | }); 266 | 267 | TVM_REGISTER_GLOBAL("nn.compute_quantized_mm_dequantize") 268 | .set_body([](TVMArgs args, TVMRetValue* rv) { 269 | CHECK(args.size() == 9); 270 | *rv = data_int8_mm_dequantize( 271 | args[0], args[1], args[2], args[3], 272 | args[4], args[5], args[6], args[7], args[8]); 273 | }); 274 | 275 | TVM_REGISTER_GLOBAL("nn.compute_data_int8_quantize") 276 | .set_body([](TVMArgs args, TVMRetValue* rv) { 277 | CHECK(args.size() == 5); 278 | *rv = data_int8_quantize( 279 | args[0], args[1], args[2], args[3], args[4]); 280 | }); 281 | 282 | TVM_REGISTER_GLOBAL("nn.compute_data_int8_row_offset") 283 | .set_body([](TVMArgs args, TVMRetValue* rv) { 284 | CHECK(args.size() == 1); 285 | *rv = data_int8_row_offset(args[0]); 286 | }); 287 | 288 | TVM_REGISTER_GLOBAL("topi.generic.schedule_quantized_mm_dequantize") 289 | .set_body([](TVMArgs args, TVMRetValue* rv) { 290 | *rv = topi::generic::schedule_quantized_mm_dequantize(args[0], args[1]); 291 | }); 292 | 293 | TVM_REGISTER_GLOBAL("topi.x86.schedule_quantized_mm_dequantize") 294 | .set_body([](TVMArgs args, TVMRetValue* rv) { 295 | *rv = topi::x86::schedule_quantized_mm_dequantize(args[0], args[1]); 296 | }); 297 | 298 | TVM_REGISTER_GENERIC_FUNC(schedule_quantized_mm_dequantize) 299 | .set_default(WrapSchedule(topi::generic::schedule_quantized_mm_dequantize)) 300 | .register_func( 301 | {"cpu"}, 302 | WrapSchedule(topi::x86::schedule_quantized_mm_dequantize)); 303 | 304 | } // namespace topi 305 | -------------------------------------------------------------------------------- /test/test_models.py: -------------------------------------------------------------------------------- 1 | ######################################################################################## 2 | ## Source: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py ## 3 | ######################################################################################## 4 | 5 | from skimage import io, transform 6 | import torch_tvm 7 | import torch.nn.functional as F 8 | import torch 9 | from test.util import TVMTest 10 | import unittest 11 | import torch.nn as nn 12 | import torch.utils.model_zoo as model_zoo 13 | import os 14 | 15 | 16 | __all__ = [ 17 | "ResNet", 18 | "resnet18", 19 | "resnet34", 20 | "resnet50", 21 | "resnet101", 22 | "resnet152", 23 | "resnext50_32x4d", 24 | "resnext101_32x8d", 25 | ] 26 | 27 | 28 | model_urls = { 29 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 30 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 31 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 32 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 33 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 34 | } 35 | 36 | 37 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 38 | """3x3 convolution with padding""" 39 | return nn.Conv2d( 40 | in_planes, 41 | out_planes, 42 | kernel_size=3, 43 | stride=stride, 44 | padding=1, 45 | groups=groups, 46 | bias=False, 47 | ) 48 | 49 | 50 | def conv1x1(in_planes, out_planes, stride=1): 51 | """1x1 convolution""" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 53 | 54 | 55 | class BasicBlock(nn.Module): 56 | expansion = 1 57 | 58 | def __init__( 59 | self, 60 | inplanes, 61 | planes, 62 | stride=1, 63 | downsample=None, 64 | groups=1, 65 | base_width=64, 66 | norm_layer=None, 67 | ): 68 | super(BasicBlock, self).__init__() 69 | if norm_layer is None: 70 | norm_layer = nn.BatchNorm2d 71 | if groups != 1 or base_width != 64: 72 | raise ValueError( 73 | "BasicBlock only supports groups=1 and base_width=64") 74 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv3x3(inplanes, planes, stride) 76 | self.bn1 = norm_layer(planes) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.conv2 = conv3x3(planes, planes) 79 | self.bn2 = norm_layer(planes) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class Bottleneck(nn.Module): 103 | expansion = 4 104 | 105 | def __init__( 106 | self, 107 | inplanes, 108 | planes, 109 | stride=1, 110 | downsample=None, 111 | groups=1, 112 | base_width=64, 113 | norm_layer=None, 114 | ): 115 | super(Bottleneck, self).__init__() 116 | if norm_layer is None: 117 | norm_layer = nn.BatchNorm2d 118 | width = int(planes * (base_width / 64.0)) * groups 119 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 120 | self.conv1 = conv1x1(inplanes, width) 121 | self.bn1 = norm_layer(width) 122 | self.conv2 = conv3x3(width, width, stride, groups) 123 | self.bn2 = norm_layer(width) 124 | self.conv3 = conv1x1(width, planes * self.expansion) 125 | self.bn3 = norm_layer(planes * self.expansion) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def forward(self, x): 131 | identity = x 132 | 133 | out = self.conv1(x) 134 | out = self.bn1(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv2(out) 138 | out = self.bn2(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv3(out) 142 | out = self.bn3(out) 143 | 144 | if self.downsample is not None: 145 | identity = self.downsample(x) 146 | 147 | out += identity 148 | out = self.relu(out) 149 | 150 | return out 151 | 152 | 153 | class ResNet(nn.Module): 154 | def __init__( 155 | self, 156 | block, 157 | layers, 158 | num_classes=1000, 159 | zero_init_residual=False, 160 | groups=1, 161 | width_per_group=64, 162 | norm_layer=None, 163 | ): 164 | super(ResNet, self).__init__() 165 | if norm_layer is None: 166 | norm_layer = nn.BatchNorm2d 167 | 168 | self.inplanes = 64 169 | self.groups = groups 170 | self.base_width = width_per_group 171 | self.conv1 = nn.Conv2d( 172 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 173 | ) 174 | self.bn1 = norm_layer(self.inplanes) 175 | self.relu = nn.ReLU(inplace=True) 176 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 177 | self.layer1 = self._make_layer( 178 | block, 64, layers[0], norm_layer=norm_layer) 179 | self.layer2 = self._make_layer( 180 | block, 128, layers[1], stride=2, norm_layer=norm_layer 181 | ) 182 | self.layer3 = self._make_layer( 183 | block, 256, layers[2], stride=2, norm_layer=norm_layer 184 | ) 185 | self.layer4 = self._make_layer( 186 | block, 512, layers[3], stride=2, norm_layer=norm_layer 187 | ) 188 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 189 | self.fc = nn.Linear(512 * block.expansion, num_classes) 190 | 191 | for m in self.modules(): 192 | if isinstance(m, nn.Conv2d): 193 | nn.init.kaiming_normal_( 194 | m.weight, mode="fan_out", nonlinearity="relu") 195 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 196 | nn.init.constant_(m.weight, 1) 197 | nn.init.constant_(m.bias, 0) 198 | 199 | # Zero-initialize the last BN in each residual branch, 200 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 201 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 202 | if zero_init_residual: 203 | for m in self.modules(): 204 | if isinstance(m, Bottleneck): 205 | nn.init.constant_(m.bn3.weight, 0) 206 | elif isinstance(m, BasicBlock): 207 | nn.init.constant_(m.bn2.weight, 0) 208 | 209 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 210 | if norm_layer is None: 211 | norm_layer = nn.BatchNorm2d 212 | downsample = None 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append( 221 | block( 222 | self.inplanes, 223 | planes, 224 | stride, 225 | downsample, 226 | self.groups, 227 | self.base_width, 228 | norm_layer, 229 | ) 230 | ) 231 | self.inplanes = planes * block.expansion 232 | for _ in range(1, blocks): 233 | layers.append( 234 | block( 235 | self.inplanes, 236 | planes, 237 | groups=self.groups, 238 | base_width=self.base_width, 239 | norm_layer=norm_layer, 240 | ) 241 | ) 242 | 243 | return nn.Sequential(*layers) 244 | 245 | def forward(self, x): 246 | x = self.conv1(x) 247 | x = self.bn1(x) 248 | x = self.relu(x) 249 | x = self.maxpool(x) 250 | 251 | x = self.layer1(x) 252 | x = self.layer2(x) 253 | x = self.layer3(x) 254 | x = self.layer4(x) 255 | 256 | x = self.avgpool(x) 257 | x = torch.reshape(x, (1, -1)) 258 | x = self.fc(x) 259 | 260 | return x 261 | 262 | 263 | def resnet18(pretrained=False, **kwargs): 264 | """Constructs a ResNet-18 model. 265 | 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | """ 269 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 270 | if pretrained: 271 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) 272 | return model 273 | 274 | 275 | def resnet34(pretrained=False, **kwargs): 276 | """Constructs a ResNet-34 model. 277 | 278 | Args: 279 | pretrained (bool): If True, returns a model pre-trained on ImageNet 280 | """ 281 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 282 | if pretrained: 283 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"])) 284 | return model 285 | 286 | 287 | def resnet50(pretrained=False, **kwargs): 288 | """Constructs a ResNet-50 model. 289 | 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | """ 293 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 294 | if pretrained: 295 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 296 | return model 297 | 298 | 299 | def resnet101(pretrained=False, **kwargs): 300 | """Constructs a ResNet-101 model. 301 | 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | """ 305 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 306 | if pretrained: 307 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"])) 308 | return model 309 | 310 | 311 | def resnet152(pretrained=False, **kwargs): 312 | """Constructs a ResNet-152 model. 313 | 314 | Args: 315 | pretrained (bool): If True, returns a model pre-trained on ImageNet 316 | """ 317 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 318 | if pretrained: 319 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"])) 320 | return model 321 | 322 | 323 | def resnext50_32x4d(pretrained=False, **kwargs): 324 | model = ResNet(Bottleneck, [3, 4, 6, 3], 325 | groups=32, width_per_group=4, **kwargs) 326 | # if pretrained: 327 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 328 | return model 329 | 330 | 331 | def resnext101_32x8d(pretrained=False, **kwargs): 332 | model = ResNet(Bottleneck, [3, 4, 23, 3], 333 | groups=32, width_per_group=8, **kwargs) 334 | # if pretrained: 335 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 336 | return model 337 | 338 | 339 | # TVM Tests here: 340 | 341 | 342 | class resnetish(nn.Module): 343 | def __init__(self): 344 | super(resnetish, self).__init__() 345 | block = BasicBlock 346 | layers = [2, 2, 2, 2] 347 | self.groups = 1 348 | width_per_group = 64 349 | self.base_width = width_per_group 350 | self.inplanes = 64 351 | self.conv1 = nn.Conv2d( 352 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 353 | ) 354 | norm_layer = nn.BatchNorm2d 355 | self.bn1 = norm_layer(self.inplanes) 356 | self.relu = nn.ReLU(inplace=True) 357 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 358 | self.layer1 = self._make_layer( 359 | block, 64, layers[0], norm_layer=norm_layer) 360 | self.layer2 = self._make_layer( 361 | block, 128, layers[1], stride=2, norm_layer=norm_layer 362 | ) 363 | self.layer3 = self._make_layer( 364 | block, 256, layers[2], stride=2, norm_layer=norm_layer 365 | ) 366 | self.layer4 = self._make_layer( 367 | block, 512, layers[3], stride=2, norm_layer=norm_layer 368 | ) 369 | 370 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 371 | if norm_layer is None: 372 | norm_layer = nn.BatchNorm2d 373 | downsample = None 374 | if stride != 1 or self.inplanes != planes * block.expansion: 375 | downsample = nn.Sequential( 376 | conv1x1(self.inplanes, planes * block.expansion, stride), 377 | norm_layer(planes * block.expansion), 378 | ) 379 | 380 | layers = [] 381 | layers.append( 382 | block( 383 | self.inplanes, 384 | planes, 385 | stride, 386 | downsample, 387 | self.groups, 388 | self.base_width, 389 | norm_layer, 390 | ) 391 | ) 392 | self.inplanes = planes * block.expansion 393 | for _ in range(1, blocks): 394 | layers.append( 395 | block( 396 | self.inplanes, 397 | planes, 398 | groups=self.groups, 399 | base_width=self.base_width, 400 | norm_layer=norm_layer, 401 | ) 402 | ) 403 | 404 | return nn.Sequential(*layers) 405 | 406 | def forward(self, x): 407 | x = self.conv1(x) 408 | x = self.bn1(x) 409 | x = self.relu(x) 410 | x = self.maxpool(x) 411 | 412 | x = self.layer1(x) 413 | x = self.layer2(x) 414 | 415 | return x 416 | 417 | 418 | class TestModels(TVMTest): 419 | def model_test(self, constructor): 420 | model = constructor(True) 421 | model.eval() 422 | d = os.path.dirname(os.path.abspath(__file__)) 423 | fn = os.path.join(d, "cat.png") 424 | image = io.imread(fn)[:, :, :3].transpose(2, 0, 1) 425 | input_image = torch.unsqueeze(torch.Tensor(image), 0) 426 | ref_out, tvm_out = self.runBoth(model, input_image) 427 | # With CNN model tests we look for top-k similarity rather than 428 | # exact numerical matches post-softmax. 429 | k = 10 430 | torch.testing.assert_allclose( 431 | ref_out.topk(k).indices, tvm_out.topk(k).indices) 432 | 433 | def test_resnet18(self): 434 | self.model_test(resnet18) 435 | 436 | def test_resnet34(self): 437 | self.model_test(resnet34) 438 | 439 | def test_resnet50(self): 440 | self.model_test(resnet50) 441 | 442 | def test_resnet101(self): 443 | self.model_test(resnet101) 444 | 445 | def test_resnet152(self): 446 | self.model_test(resnet152) 447 | 448 | def test_resnext50_32x4d(self): 449 | self.model_test(resnext50_32x4d) 450 | 451 | def test_resnext101_32x8d(self): 452 | self.model_test(resnext101_32x8d) 453 | 454 | def test_resnetish(self): 455 | model = resnetish() 456 | model.eval() 457 | d = os.path.dirname(os.path.abspath(__file__)) 458 | fn = os.path.join(d, "cat.png") 459 | image = io.imread(fn)[:, :, :3].transpose(2, 0, 1) 460 | input_image = torch.unsqueeze(torch.Tensor(image), 0) 461 | ref_out, tvm_out = self.runBoth(model, input_image) 462 | torch.testing.assert_allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 463 | 464 | 465 | if __name__ == "__main__": 466 | unittest.main() 467 | -------------------------------------------------------------------------------- /test/test_operators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from test.util import TVMTest 3 | from torch.testing import FileCheck 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch 8 | 9 | # test jit tvm operators 10 | 11 | 12 | class TestOperators(TVMTest): 13 | @TVMTest.given(shape=TVMTest.rand_shape(rank=1)) 14 | def test_add(self, shape): 15 | x = torch.rand(shape) 16 | y = torch.rand(shape) 17 | z = torch.rand(shape) 18 | 19 | def add(a, b, c): 20 | return a + b + c 21 | 22 | ref_out, tvm_out = self.runBoth(add, x, y, z) 23 | assert torch.allclose(ref_out, tvm_out) 24 | 25 | @TVMTest.given(shape=TVMTest.rand_shape(rank=1)) 26 | def test_mul(self, shape): 27 | x = torch.rand(shape) 28 | y = torch.rand(shape) 29 | z = torch.rand(shape) 30 | 31 | def mul(a, b, c): 32 | return a * b * c 33 | 34 | ref_out, tvm_out = self.runBoth(mul, x, y, z) 35 | assert torch.allclose(ref_out, tvm_out) 36 | 37 | @TVMTest.given( 38 | shape=TVMTest.rand_shape(min_rank=3, max_rank=4, min_dim=4, max_dim=4), 39 | kernel_size=TVMTest.rand_int(3, 3), 40 | num_kernels=TVMTest.rand_int(5, 5), 41 | examples=10, 42 | ) 43 | def test_conv_simple(self, shape, kernel_size, num_kernels): 44 | # NCHW 45 | X = torch.rand(shape) 46 | if (len(shape) == 4): 47 | W = torch.rand((num_kernels, shape[1], kernel_size, kernel_size)) 48 | conv_fn = F.conv2d 49 | else: 50 | # Case of 1D conv 51 | W = torch.rand((num_kernels, shape[1], kernel_size)) 52 | conv_fn = F.conv1d 53 | bias = torch.rand(num_kernels) 54 | 55 | def conv(a, b): 56 | return conv_fn(a + a, b) 57 | 58 | def conv_bias(a, b, c): 59 | return conv_fn(a + a, b, c) 60 | 61 | ref_out, tvm_out = self.runBoth(conv, X, W) 62 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 63 | ref_out, tvm_out = self.runBoth(conv_bias, X, W, bias) 64 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 65 | 66 | @TVMTest.given( 67 | shape=TVMTest.rand_shape(min_rank=3, max_rank=4, min_dim=15), 68 | kernel_size=TVMTest.rand_int(3, 6), 69 | num_kernels=TVMTest.rand_int(), 70 | stride=TVMTest.rand_list(TVMTest.rand_int(1, 2), 2), 71 | padding=TVMTest.rand_list(TVMTest.rand_int(0, 4), 2), 72 | dilation=TVMTest.rand_list(TVMTest.rand_int( 73 | 1, 1), 2), # TODO known broken in TVM 74 | examples=10, 75 | ) 76 | def test_conv_complex( 77 | self, shape, kernel_size, num_kernels, stride, padding, dilation 78 | ): 79 | # NCHW 80 | X = torch.rand(shape) 81 | if (len(shape) == 4): 82 | W = torch.rand(num_kernels, shape[1], kernel_size, kernel_size) 83 | conv_fn = F.conv2d 84 | else: 85 | W = torch.rand(num_kernels, shape[1], kernel_size) 86 | conv_fn = F.conv1d 87 | stride = [stride[0]] 88 | padding = [padding[0]] 89 | dilation = [dilation[0]] 90 | 91 | def conv(a, b): 92 | return conv_fn(a + a, b, stride=stride, padding=padding, dilation=dilation) 93 | 94 | ref_out, tvm_out = self.runBoth(conv, X, W) 95 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 96 | 97 | @TVMTest.given( 98 | shape=TVMTest.rand_shape(min_rank=3, max_rank=4, min_dim=15), 99 | kernel_size=TVMTest.rand_int(3, 8), 100 | stride=TVMTest.rand_list(TVMTest.rand_int(1, 2), 2), 101 | padding=TVMTest.rand_list(TVMTest.rand_int(0, 4), 2), 102 | dilation=TVMTest.rand_list(TVMTest.rand_int(1, 2), 2), 103 | groups=TVMTest.rand_int(4, 8), 104 | in_ch_per_group=TVMTest.rand_int(1, 4), 105 | out_ch_per_group=TVMTest.rand_int(1, 8), 106 | examples=10, 107 | ) 108 | def test_group_conv( 109 | self, shape, kernel_size, stride, padding, dilation, groups, in_ch_per_group, out_ch_per_group 110 | ): 111 | # NCHW 112 | in_channels = in_ch_per_group * groups 113 | out_channels = out_ch_per_group * groups 114 | if (len(shape) == 4): 115 | X = torch.rand(shape[0], in_channels, shape[1], shape[2]) 116 | W = torch.rand(out_channels, in_ch_per_group, kernel_size, kernel_size) 117 | conv_fn = F.conv2d 118 | else: 119 | X = torch.rand(shape[0], in_channels, shape[1]) 120 | W = torch.rand(out_channels, in_ch_per_group, kernel_size) 121 | conv_fn = F.conv1d 122 | stride = [stride[0]] 123 | padding = [padding[0]] 124 | dilation = [dilation[0]] 125 | 126 | def conv(a, b): 127 | return conv_fn(a + a, b, stride=stride, padding=padding, dilation=dilation, groups=groups) 128 | 129 | ref_out, tvm_out = self.runBoth(conv, X, W) 130 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 131 | 132 | @TVMTest.given(shape=TVMTest.rand_shape(rank=2, min_dim=5)) 133 | def test_batch_norm(self, shape): 134 | a = torch.rand(shape) 135 | b = torch.rand(shape[1]) 136 | c = torch.rand(shape[1]) 137 | d = torch.rand(shape) 138 | 139 | def batch_norm(a, b, c, d): 140 | return F.batch_norm(a + d, b, c) 141 | 142 | ref_out, tvm_out = self.runBoth(batch_norm, a, b, c, d) 143 | assert torch.allclose(ref_out, tvm_out, rtol=0.05, atol=0.01) 144 | 145 | @TVMTest.given(shape=TVMTest.rand_shape(rank=2, min_dim=5)) 146 | def test_batch_norm_weighted(self, shape): 147 | a = torch.rand(shape) 148 | b = torch.rand(shape[1]) 149 | c = torch.rand(shape[1]) 150 | d = torch.rand(shape) 151 | 152 | def batch_norm_weighted(a, b, c, d, weight, bias): 153 | return F.batch_norm(a + d, b, c, weight=weight, bias=bias) 154 | 155 | ref_out, tvm_out = self.runBoth(batch_norm_weighted, a, b, c, d, c, b) 156 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 157 | 158 | @TVMTest.given(shape=TVMTest.rand_shape(min_rank=2, max_rank=4, min_dim=8),\ 159 | examples=20) 160 | def test_layer_norm(self, shape): 161 | a = torch.rand(shape) 162 | axis = shape[1:] 163 | d = torch.rand(shape) 164 | 165 | def layer_norm(a, d): 166 | return F.layer_norm(a + d, axis) 167 | 168 | ref_out, tvm_out = self.runBoth(layer_norm, a, d) 169 | assert torch.allclose(ref_out, tvm_out, rtol=0.05, atol=0.01) 170 | 171 | @TVMTest.given(shape=TVMTest.rand_shape(min_rank=2, max_rank=4, min_dim=8),\ 172 | examples=20) 173 | def test_layer_norm_weighted(self, shape): 174 | a = torch.rand(shape) 175 | b = torch.rand(shape[1:]) 176 | c = torch.rand(shape[1:]) 177 | axis = shape[1:] 178 | d = torch.rand(shape) 179 | 180 | def layer_norm(a, b, c, d): 181 | return F.layer_norm(a + d, axis, weight=b, bias=c) 182 | 183 | ref_out, tvm_out = self.runBoth(layer_norm, a, b, c, d) 184 | assert torch.allclose(ref_out, tvm_out, rtol=0.05, atol=0.01) 185 | 186 | @TVMTest.given(shape=TVMTest.rand_shape()) 187 | def test_relu(self, shape): 188 | X = torch.rand(shape) 189 | 190 | def relu(a): 191 | return F.relu(F.relu(a)) 192 | 193 | ref_out, tvm_out = self.runBoth(relu, X) 194 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 195 | 196 | @TVMTest.given(shape=TVMTest.rand_shape(min_rank=2, max_rank=4, min_dim=8, max_dim=32), 197 | examples=20, 198 | ) 199 | def test_max(self, shape): 200 | X = torch.rand(shape) 201 | axis = TVMTest.rand_int(-len(shape), len(shape)-1)() 202 | 203 | def max_fn(a): 204 | return torch.max(a + a, axis=axis) 205 | 206 | ref_out, tvm_out = self.runBoth(max_fn, X) 207 | ref_out_values, ref_out_indices = ref_out 208 | tvm_out_values, tvm_out_indices = tvm_out 209 | assert torch.allclose(ref_out_values, tvm_out_values, rtol=0.01, atol=0.01) 210 | assert torch.allclose(ref_out_values, tvm_out_values, rtol=0.01, atol=0.01) 211 | 212 | # Known bug -- stride > 2 has mismatched padding 213 | @TVMTest.given( 214 | shape=TVMTest.rand_shape(rank=4, min_dim=4), 215 | stride=TVMTest.rand_list(TVMTest.rand_int(2, 2), 2), 216 | ) 217 | def test_avg_pool2d(self, shape, stride): 218 | X = torch.rand(shape) 219 | 220 | def avg_pool2d(a): 221 | return F.avg_pool2d(a, 2) 222 | 223 | def avg_pool2d_strides(a): 224 | return F.avg_pool2d( 225 | a, 2, stride=stride 226 | ) 227 | 228 | ref_out, tvm_out = self.runBoth(avg_pool2d, X) 229 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 230 | ref_out, tvm_out = self.runBoth(avg_pool2d_strides, X) 231 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 232 | 233 | @TVMTest.given( 234 | shape=TVMTest.rand_shape(rank=4, min_dim=4), 235 | ) 236 | def test_adaptive_avg_pool2d(self, shape): 237 | X = torch.rand(shape) 238 | 239 | def adaptive_avg_pool2d(a): 240 | return F.adaptive_avg_pool2d(a, 3) 241 | 242 | ref_out, tvm_out = self.runBoth(adaptive_avg_pool2d, X) 243 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 244 | 245 | # Known bug -- ceil_mode=True sometimes has mismatched shapes 246 | @TVMTest.given( 247 | shape=TVMTest.rand_shape(rank=4, min_dim=4), 248 | stride=TVMTest.rand_list(TVMTest.rand_int(1, 2), 2), 249 | ) 250 | def test_max_pool2d(self, shape, stride): 251 | X = torch.rand(shape) 252 | 253 | def max_pool2d(a): 254 | return F.max_pool2d(a, 3) + 2.0 255 | 256 | def max_pool2d_strides_padding_ceil_mode(a): 257 | return F.max_pool2d( 258 | a, 2, stride=stride, padding=1, ceil_mode=False 259 | ) 260 | 261 | # TODO: fix the unstableness when ceil_mode=True case 262 | 263 | ref_out, tvm_out = self.runBoth(max_pool2d, X) 264 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 265 | ref_out, tvm_out = self.runBoth(max_pool2d_strides_padding_ceil_mode, X) 266 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 267 | 268 | 269 | @TVMTest.given( 270 | shape=TVMTest.rand_shape(rank=2, min_dim=4), 271 | out_features=TVMTest.rand_int(3, 6), 272 | ) 273 | def test_fuse_linear_pattern_match(self, shape, out_features): 274 | input = torch.rand(shape) 275 | weight = torch.rand(out_features, shape[1]) 276 | bias = torch.rand(out_features) 277 | 278 | def linear_addmm(input, weight, bias): 279 | return torch.addmm(bias, input, weight.t()) 280 | 281 | def linear_matmul_add(input, weight, bias): 282 | output = input.matmul(weight.t()) 283 | output += bias 284 | return output 285 | 286 | def linear_matmul(input, weight): 287 | return input.matmul(weight.t()) 288 | 289 | import torch_tvm 290 | torch_tvm.enable() 291 | # test addmm 292 | scripted_addmm = torch.jit.script(linear_matmul_add) 293 | addmm_graph = scripted_addmm.graph_for(input, weight, bias) 294 | FileCheck().check("aten::linear").check_not("addmm").check_not("aten::t").run(str(addmm_graph)) 295 | 296 | # test matmul + add 297 | scripted_matmul_add = torch.jit.script(linear_matmul_add) 298 | matmul_add_graph = scripted_matmul_add.graph_for(input, weight, bias) 299 | FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_add_graph)) 300 | 301 | # test matmul 302 | scripted_matmul = torch.jit.script(linear_matmul) 303 | matmul_graph = scripted_matmul.graph_for(input, weight) 304 | FileCheck().check("aten::linear").check_not("matmul").check_not("aten::t").run(str(matmul_graph)) 305 | torch_tvm.disable() 306 | 307 | 308 | @TVMTest.given( 309 | shape=TVMTest.rand_shape(min_rank=2, max_rank=4, min_dim=4), 310 | out_features=TVMTest.rand_int(3, 128), 311 | examples=20, 312 | ) 313 | def test_linear(self, shape, out_features): 314 | input = torch.rand(shape) 315 | weight = torch.rand(out_features, shape[-1]) 316 | bias = torch.rand(out_features) 317 | 318 | def linear(input, weight, bias): 319 | return F.linear(input + 3.0, weight, bias) + 2.0 320 | 321 | def linear_no_bias(input, weight): 322 | return F.linear(input, weight) + 2.0 323 | 324 | ref_out, tvm_out = self.runBoth(linear, input, weight, bias) 325 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 326 | 327 | ref_out_no_bias, tvm_out_no_bias = self.runBoth(linear_no_bias, input, weight) 328 | assert torch.allclose(ref_out_no_bias, tvm_out_no_bias, rtol=0.01, atol=0.01) 329 | 330 | 331 | # Have to make min_dim large since we may compare non tensorized imlementation 332 | # against fbgemm and that requires K dim to be large enough. 333 | @TVMTest.given( 334 | shape=TVMTest.rand_shape(rank=2, min_dim=32, max_dim=64), 335 | out_features=TVMTest.rand_int(15, 64), 336 | ) 337 | def test_quantized_linear(self, shape, out_features): 338 | # This is necessary since for N of size > 16 we enforce it to be 339 | # multiple of 16. Right this defines packing of weights and that 340 | # specifically makes this requirement necessary. 341 | # On the other hand loosing this requirement needs changes that require 342 | # some changes on relay side which otherwise complains during shape 343 | # propagation in shape inference. 344 | # Same holds for k. 345 | shape[1] = shape[1] * 4 346 | if out_features > 16: 347 | out_features = out_features * 16 348 | input = torch.normal(torch.rand(shape)) 349 | weight = torch.normal(torch.rand(out_features, shape[1])) 350 | bias = torch.normal(torch.rand(out_features)) 351 | q_weight, col_offsets, scale, zero_point = \ 352 | torch.fbgemm_linear_quantize_weight(weight.clone().float()) 353 | packed_weight = torch.fbgemm_pack_quantized_matrix(q_weight.clone()) 354 | 355 | def fbgemm_quantized_linear(input, weight, bias, col_offsets): 356 | return torch.fbgemm_linear_int8_weight( 357 | input.float(), weight, packed_weight, col_offsets, scale, zero_point, bias.float()) 358 | ref_out, tvm_out = self.runBoth(fbgemm_quantized_linear, input, q_weight, bias, col_offsets) 359 | # Too loose a bound. We can make this 0.05 360 | # however this works only when we can tensorize, i.e. use AVX instructions. 361 | # Since without that we cast int8 to int32 before mul we always get more precision 362 | # which can result in mismatch. Fix this test once we have AVX2 impl as well. 363 | assert torch.allclose(ref_out, tvm_out, rtol=0.5, atol=0.5) 364 | 365 | 366 | @TVMTest.given( 367 | shape=TVMTest.rand_shape(rank=2, min_dim=4), 368 | ) 369 | def test_concat_fuse(self, shape): 370 | input1 = torch.rand(shape) 371 | input2 = torch.rand(shape) 372 | 373 | def concat(x1, x2): 374 | return torch.cat((x1, x2), 0) 375 | 376 | import torch_tvm 377 | torch_tvm.enable() 378 | # test concat 379 | scripted_concat = torch.jit.script(concat) 380 | concat_graph = scripted_concat.graph_for(input1, input2) 381 | FileCheck().check("prim::FusedConcat").check_not("prim::ListConstruct").check_not("aten::cat").run(str(concat_graph)) 382 | 383 | @TVMTest.given( 384 | shape=TVMTest.rand_shape(rank=2, min_dim=4), 385 | ) 386 | def test_concat_op(self, shape): 387 | input1 = torch.rand(shape) 388 | input2 = torch.rand(shape) 389 | 390 | # if we didn't use relu, single op will not be lowered 391 | def concat_relu(x1, x2): 392 | return F.relu(torch.cat((x1, x2), 0)) 393 | 394 | ref_out, tvm_out = self.runBoth(concat_relu, input1, input2) 395 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 396 | 397 | @TVMTest.given( 398 | shape=TVMTest.rand_shape(rank=2, min_dim=4), 399 | ) 400 | def test_reshape(self, shape): 401 | input = torch.rand(shape) 402 | 403 | def reshape(input): 404 | return torch.reshape(input, (-1,)) 405 | 406 | ref_out, tvm_out = self.runBoth(reshape, input) 407 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 408 | 409 | def reshape(input): 410 | return torch.reshape(input, (1, 1, *shape)) 411 | 412 | ref_out, tvm_out = self.runBoth(reshape, input) 413 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 414 | 415 | def reshape(input): 416 | return torch.reshape(input, (1, -1)) 417 | 418 | ref_out, tvm_out = self.runBoth(reshape, input) 419 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 420 | 421 | def reshape(input): 422 | return torch.reshape(input, (shape[0], 1, 1, shape[1])) 423 | 424 | ref_out, tvm_out = self.runBoth(reshape, input) 425 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 426 | 427 | @TVMTest.given( 428 | shape=TVMTest.rand_shape(rank=2, min_dim=4), 429 | axis=TVMTest.rand_int(0, 1), 430 | ) 431 | def test_softmax(self, shape, axis): 432 | input = torch.rand(shape) 433 | 434 | def softmax(input): 435 | return torch.softmax(input, axis=axis) 436 | 437 | ref_out, tvm_out = self.runBoth(softmax, input) 438 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 439 | 440 | @TVMTest.given( 441 | shape=TVMTest.rand_shape(rank=1), 442 | ) 443 | def test_clamp(self, shape): 444 | input = torch.rand(shape) 445 | 446 | def clamp(input): 447 | return torch.clamp(input + 3.0, 0.0, 6.0) 448 | 449 | ref_out, tvm_out = self.runBoth(clamp, input) 450 | assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01) 451 | 452 | if __name__ == "__main__": 453 | unittest.main() 454 | -------------------------------------------------------------------------------- /torch_tvm/compiler.cpp: -------------------------------------------------------------------------------- 1 | #include "compiler.h" 2 | #include "operators.h" 3 | #include "register.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | using namespace torch::jit; 16 | 17 | using torch_tvm::utils::DLManagedTensorPtr; 18 | using tvm::runtime::DeviceAPI; 19 | 20 | namespace { 21 | std::vector set_input( 22 | std::unordered_map& value_to_ivalue, 23 | TVMObject& cache) { 24 | std::vector input_tensors; 25 | for (auto& input_value : cache.input_values) { 26 | auto* value = input_value.first; 27 | TVMGraphInputInfo& graph_input = input_value.second; 28 | if (graph_input.is_param) { 29 | continue; 30 | } 31 | if (!value_to_ivalue.count(value)) { 32 | auto optional_ivalue = toIValue(value); 33 | TORCH_INTERNAL_ASSERT(optional_ivalue.has_value()); 34 | value_to_ivalue[value] = optional_ivalue.value(); 35 | } 36 | auto ivalue = value_to_ivalue.at(value); 37 | //auto tensor = ivalue.toTensor().to(at::kFloat); 38 | auto tensor = ivalue.toTensor(); 39 | DLManagedTensor* dl_tensor; 40 | if (tensor.is_contiguous() && ( 41 | tvm::is_training() || torch_tvm::utils::isAligned(tensor.data_ptr(), 42 | tvm::runtime::kAllocAlignment))) { 43 | dl_tensor = at::toDLPack(tensor); 44 | } else { 45 | if (tvm::is_training()) { 46 | auto contig_tensor = tensor.contiguous(); 47 | dl_tensor = at::toDLPack(contig_tensor); 48 | } else { 49 | dl_tensor = 50 | torch_tvm::utils::allocAndCopyData(tensor); 51 | input_tensors.emplace_back( 52 | dl_tensor); 53 | } 54 | } 55 | cache.set_input(graph_input.tvm_var_name, 56 | tvm::runtime::NDArray::FromDLPack(dl_tensor)); 57 | } 58 | return input_tensors; 59 | } 60 | 61 | DLManagedTensorPtr createParamTensor(const IValue& param_val) { 62 | auto tensor = param_val.toTensor(); 63 | DLManagedTensor* dl_tensor = nullptr; 64 | if (tvm::is_training()) { 65 | if (tensor.is_contiguous()) { 66 | dl_tensor = at::toDLPack(tensor); 67 | } else { 68 | auto contig_tensor = tensor.contiguous(); 69 | dl_tensor = at::toDLPack(contig_tensor); 70 | } 71 | } else { 72 | dl_tensor = torch_tvm::utils::allocAndCopyData(tensor); 73 | } 74 | return DLManagedTensorPtr(dl_tensor); 75 | } 76 | 77 | tvm::relay::Constant createParamConstant( 78 | const DLManagedTensorPtr& dl_tensor_ptr) { 79 | auto nd_array = tvm::runtime::NDArray::FromDLPack(dl_tensor_ptr.get()); 80 | return tvm::relay::ConstantNode::make(nd_array); 81 | } 82 | 83 | template 84 | tvm::relay::Expr doubleNode(double d, TVMContext ctx) { 85 | TORCH_CHECK(size == 16 || size == 32); 86 | std::string type; 87 | if (size == 16) { 88 | type = "float16"; 89 | } else { 90 | type = "float32"; 91 | } 92 | 93 | auto x = tvm::runtime::NDArray::Empty( 94 | {}, tvm::runtime::String2TVMType(type), ctx); 95 | TORCH_CHECK(d <= std::numeric_limits::max()); 96 | TORCH_CHECK(d >= std::numeric_limits::lowest()); 97 | T f = static_cast(d); 98 | 99 | DeviceAPI::Get(ctx)->CopyDataFromTo( 100 | &f, 0, 101 | x->data, 0, 102 | size, 103 | cpuContext(), 104 | ctx, 105 | DLDataType{kDLFloat, size, 1}, nullptr); 106 | 107 | return tvm::relay::ConstantNode::make(x); 108 | } 109 | 110 | } // namespace 111 | 112 | void TVMObject::populateParamTVMTensors( 113 | const std::unordered_map& value_to_ivalue) { 114 | for (auto& input_value : input_values) { 115 | auto* jit_value = input_value.first; 116 | auto& graph_input = input_value.second; 117 | if (graph_input.is_param) { 118 | const auto& input_ivalue = value_to_ivalue.at(jit_value); 119 | graph_input.tvm_tensor = createParamTensor(input_ivalue); 120 | } 121 | } 122 | } 123 | 124 | tvm::Map 125 | TVMObject::generateParamConstantMap() { 126 | tvm::Map params_map; 127 | for (const auto& input_value : input_values) { 128 | const auto& graph_input = input_value.second; 129 | if (graph_input.is_param) { 130 | const auto& tvm_var_name = graph_input.tvm_var_name; 131 | params_map.Set(tvm_var_name, createParamConstant(graph_input.tvm_tensor)); 132 | } 133 | } 134 | return params_map; 135 | } 136 | 137 | tvm::relay::DataType scalarTypeToTVMType(at::ScalarType pt_type) { 138 | static const std::unordered_map type_mapping = { 139 | {at::ScalarType::Half, ::tvm::Float(16)}, 140 | {at::ScalarType::Float, ::tvm::Float(32)}, 141 | {at::ScalarType::Double, ::tvm::Float(64)}, 142 | {at::ScalarType::Int, ::tvm::Int(32)}, 143 | {at::ScalarType::Long, ::tvm::Int(64)}, 144 | {at::ScalarType::Bool, ::tvm::Bool()}, 145 | {at::ScalarType::Char, ::tvm::Int(8)}, 146 | {at::ScalarType::Byte, ::tvm::UInt(8)}, 147 | {at::ScalarType::QInt8, ::tvm::Int(8)}, 148 | {at::ScalarType::QUInt8, ::tvm::UInt(8)}, 149 | {at::ScalarType::QInt32, ::tvm::Int(32)}, 150 | }; 151 | 152 | TORCH_CHECK(type_mapping.find(pt_type) != type_mapping.end(), 153 | "could not handle the type ", pt_type, 154 | " when creating tensor type node in TVM"); 155 | return type_mapping.at(pt_type); 156 | } 157 | 158 | tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) { 159 | auto optional_ivalue = toIValue(val); 160 | if (optional_ivalue.has_value()) { 161 | if (optional_ivalue.value().isTensor()) { 162 | auto t = optional_ivalue.value().toTensor(); 163 | val->inferTypeFrom(optional_ivalue.value().toTensor()); 164 | } else { 165 | auto expr = convertToRelay(optional_ivalue.value(), ctx) 166 | .as(); 167 | return tvm::relay::VarNode::make( 168 | val->debugName() + 169 | std::to_string(reinterpret_cast(val)), 170 | expr->tensor_type()); 171 | } 172 | } 173 | if (val->isCompleteTensor()) { 174 | auto pt_t = val->type()->cast(); 175 | TORCH_INTERNAL_ASSERT(pt_t); 176 | auto optional_device_type = pt_t->device(); 177 | TORCH_INTERNAL_ASSERT(optional_device_type); 178 | tvm::Array sizes; 179 | const auto& varying_sizes = pt_t->sizes(); 180 | const auto& optional_sizes = varying_sizes.sizes(); 181 | TORCH_INTERNAL_ASSERT(optional_sizes); 182 | const auto& pt_sizes = optional_sizes.value(); 183 | for (const auto& optional_size : pt_sizes) { 184 | TORCH_INTERNAL_ASSERT(optional_size); 185 | sizes.push_back(tvm::relay::IndexExpr( 186 | static_cast(optional_size.value()))); 187 | } 188 | auto optional_dtype = pt_t->scalarType(); 189 | TORCH_INTERNAL_ASSERT(optional_dtype); 190 | at::ScalarType pt_type = optional_dtype.value(); 191 | auto t = tvm::relay::TensorTypeNode::make(sizes, scalarTypeToTVMType(pt_type)); 192 | auto v = tvm::relay::VarNode::make( 193 | val->debugName() + 194 | std::to_string(reinterpret_cast(val)), 195 | t); 196 | return v; 197 | } 198 | TORCH_INTERNAL_ASSERT(0); 199 | } 200 | 201 | tvm::relay::Expr TVMCompiler::convertToRelay( 202 | const IValue& val, 203 | TVMContext ctx) { 204 | // All doubles are converted to floats/fp16 205 | if (val.isDouble()) { 206 | if (tvm::is_training()) { 207 | return doubleNode(val.toDouble(), ctx); 208 | } else { 209 | return doubleNode(val.toDouble(), ctx); 210 | } 211 | } 212 | // All Ints are converted to int32, which may overflow 213 | if (val.isInt()) { 214 | auto x = tvm::runtime::NDArray::Empty({}, tvm::Int(64), ctx); 215 | auto l = val.toInt(); 216 | DeviceAPI::Get(ctx)->CopyDataFromTo( 217 | &l, 0, 218 | x->data, 0, 219 | sizeof(l), 220 | cpuContext(), 221 | ctx, 222 | DLDataType{kDLInt, 64, 1}, nullptr); 223 | auto v = tvm::relay::ConstantNode::make(x); 224 | return v; 225 | } 226 | if (val.isBool()) { 227 | auto x = tvm::runtime::NDArray::Empty( 228 | {}, tvm::runtime::String2TVMType("bool"), ctx); 229 | auto b = val.toBool(); 230 | DeviceAPI::Get(ctx)->CopyDataFromTo( 231 | &b, 0, 232 | x->data, 0, 233 | sizeof(b), 234 | cpuContext(), 235 | ctx, 236 | DLDataType{kDLUInt, 64, 1}, nullptr); 237 | auto v = tvm::relay::ConstantNode::make(x); 238 | return v; 239 | } 240 | // TODO Add None type to Relay 241 | // HACK sentinel value used for None type 242 | if (val.isNone()) { 243 | auto x = tvm::runtime::NDArray::Empty( 244 | {}, tvm::runtime::String2TVMType("uint64"), ctx); 245 | auto n = getNoneSentinel(); 246 | DeviceAPI::Get(ctx)->CopyDataFromTo( 247 | &n, 0, 248 | x->data, 0, 249 | sizeof(n), 250 | cpuContext(), 251 | ctx, 252 | DLDataType{kDLUInt, 64, 1}, nullptr); 253 | auto v = tvm::relay::ConstantNode::make(x); 254 | return v; 255 | } 256 | if (val.isIntList()) { 257 | tvm::Array tuple_elems; 258 | for (const auto& elem : val.toIntList()) { 259 | auto x = tvm::runtime::NDArray::Empty({}, tvm::Int(64), ctx); 260 | DeviceAPI::Get(ctx)->CopyDataFromTo( 261 | &elem, 0, 262 | x->data, 0, 263 | sizeof(elem), 264 | cpuContext(), 265 | ctx, 266 | DLDataType{kDLInt, 64, 1}, nullptr); 267 | auto v = tvm::relay::ConstantNode::make(x); 268 | tuple_elems.push_back(v); 269 | } 270 | return tvm::relay::TupleNode::make(tuple_elems); 271 | } 272 | TORCH_CHECK( 273 | 0, "Cannot convert value ", val, " to Relay yet. Please file a bug.\n"); 274 | } 275 | 276 | tvm::relay::Function TVMCompiler::convertToRelay( 277 | std::shared_ptr subgraph, 278 | TVMContext ctx, 279 | std::unordered_map* input_values) { 280 | std::unordered_map value_map; 281 | tvm::Array input_vars; 282 | 283 | for (const auto& input : subgraph->inputs()) { 284 | TORCH_INTERNAL_ASSERT(input->isCompleteTensor()); 285 | auto v = convertToRelay(input, ctx); 286 | input_vars.push_back(v); 287 | if (input_values) { 288 | // Primary inputs are always mutable. 289 | input_values->emplace(std::piecewise_construct, 290 | std::forward_as_tuple(input), 291 | std::forward_as_tuple(false, 292 | v.as()->name_hint())); 293 | } 294 | value_map[input] = v; 295 | } 296 | 297 | auto frontier = subgraph->inputs().vec(); 298 | // TODO error handle incorrectly formed graphs (not dominated by frontier) 299 | while (frontier.size()) { 300 | std::vector new_frontier = {}; 301 | for (const auto& value : frontier) { 302 | auto uses = value->uses(); 303 | for (const auto& use : uses) { 304 | tvm::Array relay_inputs; 305 | // Things like prim::Return 306 | // Should we be more explicit here? 307 | // That only prim::Return should be skipped? 308 | if (use.user->outputs().size() < 1) { 309 | continue; 310 | } 311 | auto skip_user = false; 312 | if (std::any_of(use.user->outputs().begin(), use.user->outputs().end(), 313 | [&value_map](Value* const output){return value_map.count(output);})) { 314 | continue; 315 | } 316 | const auto& param_indices = getParamIndices(use.user); 317 | int input_index{0}; 318 | for (const auto& input : use.user->inputs()) { 319 | if (value_map.find(input) == value_map.end()) { 320 | // We may be dealing with a constant, handle that here 321 | auto optional_ivalue = toIValue(input); 322 | if (!optional_ivalue.has_value()) { 323 | skip_user = true; 324 | break; 325 | } else { 326 | if (optional_ivalue.value().isTensor()) { 327 | auto input_var = convertToRelay(input, ctx); 328 | input_vars.push_back(input_var); 329 | value_map[input] = input_var; 330 | if (input_values) { 331 | input_values->emplace(std::piecewise_construct, 332 | std::forward_as_tuple(input), 333 | std::forward_as_tuple(false, 334 | input_var.as()->name_hint())); 335 | } 336 | } else { 337 | value_map[input] = convertToRelay(optional_ivalue.value(), ctx); 338 | } 339 | } 340 | } 341 | // Annotate the value: Whether the Value corresponds to parameter 342 | // and thus is expected to be immutable. 343 | if (!skip_user && input_values && 344 | std::find(param_indices.begin(), 345 | param_indices.end(), input_index) != param_indices.end()) { 346 | auto it = input_values->find(input); 347 | if (it != input_values->end()) { 348 | (*it).second.is_param = true; 349 | } 350 | } 351 | relay_inputs.push_back(value_map[input]); 352 | input_index++; 353 | } 354 | if (skip_user) { 355 | continue; 356 | } 357 | // if there are 2+ outputs, getOperator returns a tuple 358 | if (use.user->outputs().size() == 1) { 359 | value_map[use.user->output()] = getOperator(use.user, relay_inputs); 360 | new_frontier.emplace_back(use.user->output()); 361 | } else { 362 | auto tuple = getOperator(use.user, relay_inputs); 363 | int index = 0; 364 | for (const auto& output : use.user->outputs()) { 365 | auto n = tvm::make_node(); 366 | n->tuple = tuple; 367 | n->index = index; 368 | value_map[output] = tvm::relay::TupleGetItem(n); 369 | index++; 370 | new_frontier.emplace_back(output); 371 | } 372 | } 373 | } 374 | } 375 | frontier = new_frontier; 376 | } 377 | 378 | tvm::NodePtr n = 379 | tvm::make_node(); 380 | tvm::Array fields; 381 | for (const auto& sg_output : subgraph->outputs()) { 382 | TORCH_INTERNAL_ASSERT(value_map.find(sg_output) != value_map.end()); 383 | fields.push_back(value_map[sg_output]); 384 | } 385 | n->fields = std::move(fields); 386 | auto output = tvm::relay::Tuple(n); 387 | 388 | tvm::Array free_vars = tvm::relay::FreeVars(output); 389 | TORCH_CHECK( 390 | free_vars.size() <= input_vars.size(), 391 | "Determined ", 392 | free_vars.size(), 393 | " free vars but only ", 394 | input_vars.size(), 395 | " inputs"); 396 | 397 | return tvm::relay::FunctionNode::make( 398 | input_vars, output, tvm::relay::Type(), {}); 399 | } 400 | 401 | std::string TVMCompiler::getTVMCompilerHandle( 402 | std::shared_ptr subgraph) { 403 | std::ostringstream oss; 404 | oss << "TVM"; 405 | for (const auto& n : subgraph->nodes()) { 406 | oss << "_" << n->kind().toUnqualString(); 407 | } 408 | return oss.str(); 409 | } 410 | 411 | TVMCompiler::TVMCompiler( 412 | const Node* node, 413 | int opt_level, 414 | bool strict, 415 | bool debug, 416 | bool debug_runtime, 417 | std::string device_type, 418 | std::string device, 419 | std::string host, 420 | int device_id) 421 | : opt_level_(opt_level), 422 | strict_(strict), 423 | debug_(debug), 424 | debug_runtime_(debug_runtime), 425 | device_type_(device_type), 426 | device_(device), 427 | host_(host), 428 | device_id_(device_id) { 429 | if (device_type_ == "gpu") { 430 | ctx_.device_type = kDLGPU; 431 | } else { 432 | ctx_.device_type = kDLCPU; 433 | } 434 | ctx_.device_id = device_id_; 435 | subgraph_ = node->g(attr::Subgraph); 436 | handle_str_ = getTVMCompilerHandle(subgraph_); 437 | fallback_interpreter_ = std::unique_ptr( 438 | new torch::jit::InterpreterState(Code(subgraph_))); 439 | auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); 440 | TORCH_INTERNAL_ASSERT(pfb); 441 | build_mod_ = (*pfb)(); 442 | activation_buffer_ = at::zeros({0}, at::kCPU); 443 | } 444 | 445 | void TVMCompiler::run(Stack& stack) { 446 | std::unordered_map value_to_ivalue; 447 | int num_inputs = subgraph_->inputs().size(); 448 | at::ArrayRef inputs = last(stack, num_inputs); 449 | 450 | for (auto i = 0; i < inputs.size(); ++i) { 451 | auto value_input = subgraph_->inputs()[i]; 452 | value_to_ivalue[value_input] = inputs[i]; 453 | } 454 | 455 | CompleteArgumentSpec spec{false, ArrayRef(inputs)}; 456 | if (bad_specs_.count(spec)) { 457 | fallback_interpreter_->run(stack); 458 | return; 459 | } 460 | 461 | if (tvm::is_training()) { 462 | bool is_all_cuda = true; 463 | for (auto i = 0; i < inputs.size(); ++i) { 464 | if (inputs[i].isTensor()) { 465 | if (!inputs[i].toTensor().device().is_cuda()) { 466 | is_all_cuda = false; 467 | } 468 | } 469 | } 470 | 471 | if (!is_all_cuda) { 472 | bad_specs_.insert(spec); 473 | fallback_interpreter_->run(stack); 474 | return; 475 | } 476 | } 477 | 478 | if (cache_.find(spec) == cache_.end()) { 479 | for (auto& kv : value_to_ivalue) { 480 | if (kv.second.isTensor()) { 481 | kv.first->inferTypeFrom(kv.second.toTensor()); 482 | } else if (kv.second.isInt()) { 483 | kv.first->setType(IntType::get()); 484 | } else if (kv.second.isDouble()) { 485 | kv.first->setType(FloatType::get()); 486 | } else { 487 | LOG(ERROR) << "Cannot handle this type yet " 488 | << kv.second 489 | << "\nGraph:\n" 490 | << *subgraph_ 491 | << " (this: " << ((void*)this) << ")"; 492 | fallback_interpreter_->run(stack); 493 | bad_specs_.insert(spec); 494 | return; 495 | } 496 | } 497 | 498 | if (debug_) { 499 | getDebugLogger().printGraph(subgraph_); 500 | } 501 | 502 | // bail out mechanism: try to convert to Relay, if it fails to convert the 503 | // graph by any reason(i.e. op difference), depend on the user preference, 504 | // either throw or fall back to the JIT interpreter for execution 505 | tvm::relay::Function tvm_func; 506 | try { 507 | tvm_func = convertToRelay(subgraph_, ctx_, &cache_[spec].input_values); 508 | } catch (const std::exception& e) { 509 | cache_.erase(spec); 510 | bad_specs_.insert(spec); 511 | if (strict_) { 512 | AT_ERROR( 513 | "Pytorch TVM: fail to convert to relay, exception: ", e.what()); 514 | } 515 | LOG(WARNING) 516 | << "Pytorch TVM: fail to convert to relay, falling back to JIT for execution, exception: " 517 | << e.what() << "\n"; 518 | fallback_interpreter_->run(stack); 519 | return; 520 | } 521 | auto build_f = build_mod_.GetFunction("build", false); 522 | auto json_f = build_mod_.GetFunction("get_graph_json", false); 523 | auto set_params = build_mod_.GetFunction("set_params", false); 524 | auto get_params = build_mod_.GetFunction("get_params", false); 525 | auto mod_f = build_mod_.GetFunction("get_module", false); 526 | tvm::Map target_map = { 527 | {ctx_.device_type, tvm::Target::Create(device_)}}; 528 | cache_[spec].populateParamTVMTensors(value_to_ivalue); 529 | auto params_constant_map = cache_[spec].generateParamConstantMap(); 530 | set_params(params_constant_map); 531 | auto build_config = tvm::BuildConfig::Current(); 532 | // This sets the loop partitioning such that, if loop is 533 | // being partitioned in non-divisible factors it will partition 534 | // it in divisible part and tail that is not divible. 535 | // This help performance. 536 | build_config->partition_const_loop = true; 537 | build_f(tvm_func, target_map, tvm::Target::Create(host_)); 538 | tvm::runtime::Module mod = mod_f(); 539 | std::string json = json_f(); 540 | if (debug_) { 541 | getDebugLogger().printLoweredFuncs(build_mod_); 542 | getDebugLogger().printASM(mod); 543 | } 544 | 545 | #ifdef TVM_USE_FB_GRAPH_RUNTIME 546 | auto pfr = tvm::runtime::Registry::Get("tvm.fb_graph_runtime.create"); 547 | #else 548 | auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); 549 | #endif 550 | 551 | TORCH_INTERNAL_ASSERT(pfr); 552 | if (debug_runtime_) { 553 | pfr = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create"); 554 | TORCH_CHECK(pfr, "TVM must be compiled with debug runtime. " 555 | "Use USE_GRAPH_RUNTIME_DEBUG in TVM CMake file."); 556 | } 557 | tvm::runtime::Module run_mod = 558 | (*pfr)(json, mod, (int)ctx_.device_type, (int)ctx_.device_id); 559 | cache_[spec].set_input = run_mod.GetFunction("set_input_zero_copy", false); 560 | 561 | if (debug_runtime_) { 562 | cache_[spec].kernel = run_mod.GetFunction("run_individual", false); 563 | } else { 564 | cache_[spec].kernel = run_mod.GetFunction("run", false); 565 | } 566 | cache_[spec].get_output = run_mod.GetFunction("get_output", false); 567 | auto get_num_outputs = run_mod.GetFunction("get_num_outputs", false); 568 | cache_[spec].setup_external_storage = run_mod.GetFunction("setup_external_storage",false); 569 | 570 | tvm::Map local_params = get_params(); 571 | 572 | #ifdef TVM_USE_FB_GRAPH_RUNTIME 573 | allocateMemoryAndSetParams(cache_[spec], local_params, json); 574 | #else 575 | for (const auto& param : local_params) { 576 | const auto& param_name = param.first; 577 | const auto& param_ndarray_val = param.second->data; 578 | cache_[spec].set_input(param_name, param_ndarray_val); 579 | } 580 | #endif 581 | 582 | int n = get_num_outputs(); 583 | TORCH_CHECK( 584 | subgraph_->outputs().size() == n, 585 | "Compiled subgraph with mismatching num outputs"); 586 | } 587 | 588 | // Using vector of unique pointers with custom deleter to 589 | // delete allocated memory when gone out of scope. 590 | // Only for those inputs which are not parameters. 591 | // Parameters are managed by cached tvm_param_tensors. 592 | // They get deallocated when cache_ is deleted. 593 | std::vector dl_tensor_list = 594 | set_input(value_to_ivalue, cache_[spec]); 595 | 596 | { 597 | RECORD_FUNCTION(handle_str_, last(stack, num_inputs)); 598 | if (debug_runtime_) { 599 | cache_[spec].kernel(10, 10, 1); 600 | } else { 601 | cache_[spec].kernel(); 602 | } 603 | } 604 | 605 | // clean the stack and add outputs to the stack 606 | drop(stack, num_inputs); 607 | int i = 0; 608 | for (const auto& output : subgraph_->outputs()) { 609 | tvm::runtime::NDArray ret_val = cache_[spec].get_output(i); 610 | auto dl_tensor = ret_val.ToDLPack(); 611 | auto tensor = at::fromDLPack(dl_tensor); 612 | auto var = torch::autograd::make_variable(tensor); 613 | stack.push_back(IValue(var)); 614 | i++; 615 | } 616 | } 617 | --------------------------------------------------------------------------------