├── .gitmodules ├── CMakeLists.txt ├── README.md ├── build.sh ├── src ├── compiler.cpp ├── compiler.h └── register.cpp └── test.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "asmjit"] 2 | path = asmjit 3 | url = https://github.com/asmjit/asmjit.git 4 | [submodule "pybind11"] 5 | path = pybind11 6 | url = https://github.com/pybind/pybind11.git 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | 3 | file(GLOB COMPILER_SRCS 4 | ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp 5 | ) 6 | 7 | set(CMAKE_CXX_STANDARD 11) 8 | SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 9 | set (CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address") 10 | set (CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address") 11 | 12 | # PYTORCH_DIR 13 | IF(DEFINED ENV{PYTORCH_DIR}) 14 | SET(PYTORCH_DIR $ENV{PYTORCH_DIR}) 15 | ENDIF() 16 | 17 | IF ("${PYTORCH_DIR}" STREQUAL "") 18 | message(FATAL_ERROR "Please specify the PyTorch directory with -DPYTORCH_DIR=/path/to/pytorch/dir") 19 | ENDIF() 20 | 21 | message("Using PyTorch directory ${PYTORCH_DIR}") 22 | 23 | link_directories(${PYTORCH_DIR}/lib) 24 | 25 | add_subdirectory(pybind11) 26 | add_subdirectory(asmjit) 27 | 28 | pybind11_add_module(pointwise_compiler SHARED ${COMPILER_SRCS}) 29 | target_link_libraries(pointwise_compiler PUBLIC torch pybind11 asmjit) 30 | 31 | target_include_directories(pointwise_compiler PUBLIC 32 | ${CMAKE_CURRENT_SOURCE_DIR}/src 33 | ${PYTORCH_DIR}/include 34 | ${PYBIND11_INCLUDE_DIR} 35 | asmjit/src 36 | ) 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the codebase associated with the [PyTorch JIT compiler tutorial](https://jott.live/markdown/Writing%20a%20Toy%20Backend%20Compiler%20for%20PyTorch). 2 | 3 | ## Build and Test 4 | 5 | A recent version of PyTorch will likely need to have been installed (either nightly or built from source). 6 | 7 | ``` 8 | git clone https://github.com/bwasti/pytorch_compiler_tutorial.git --recursive 9 | cd pytorch_compiler_tutorial 10 | ./build.sh 11 | PYTHONPATH=build python test.py 12 | ``` 13 | 14 | This test benchmarks the following function with inputs of size 1024, run 100 times: 15 | 16 | ``` 17 | def foo(a, b): 18 | c = a.mul(b) 19 | a = c.mul(c) 20 | a = c.mul(a) 21 | return a 22 | ``` 23 | 24 | We expect to see this output: 25 | 26 | ``` 27 | -- Default IR -- 28 | graph(%a.1 : Float(*), 29 | %b.1 : Float(*)): 30 | %c.1 : Float(*) = aten::mul(%a.1, %b.1) # test.py:20:7 31 | %a.3 : Float(*) = aten::mul(%c.1, %c.1) # test.py:21:7 32 | %a.5 : Float(*) = aten::mul(%c.1, %a.3) # test.py:22:7 33 | return (%a.5) 34 | 35 | Default version took 26.74ms 36 | 37 | -- Transformed IR -- 38 | graph(%a.1 : Float(*), 39 | %b.1 : Float(*)): 40 | %a.5 : Float(*) = pw::CompilationGroup_0(%a.1, %b.1) 41 | return (%a.5) 42 | with pw::CompilationGroup_0 = graph(%4 : Float(*), 43 | %5 : Float(*)): 44 | %c.1 : Float(*) = aten::mul(%4, %5) # test.py:34:7 45 | %a.3 : Float(*) = aten::mul(%c.1, %c.1) # test.py:35:7 46 | %a.5 : Float(*) = aten::mul(%c.1, %a.3) # test.py:36:7 47 | return (%a.5) 48 | 49 | Compiled version took 8.20ms 50 | ``` 51 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | PYTORCH_DIR=$(python -c 'import os, torch; print(os.path.dirname(os.path.realpath(torch.__file__)))') 2 | mkdir -p build && cd build 3 | cmake .. -DPYTORCH_DIR=${PYTORCH_DIR} 4 | make -j 24 5 | -------------------------------------------------------------------------------- /src/compiler.cpp: -------------------------------------------------------------------------------- 1 | #include "compiler.h" 2 | 3 | #include 4 | 5 | using namespace torch::jit; 6 | 7 | // This class will be used to allocate registers 8 | // as we traverse the frontier of the PyTorch IR graph. 9 | class RegisterManager { 10 | public: 11 | RegisterManager() = default; 12 | asmjit::X86Gp getFreeAddrReg() { 13 | TORCH_CHECK(free_addr_regs.size() > 0); 14 | auto reg = free_addr_regs.back(); 15 | free_addr_regs.pop_back(); 16 | return reg; 17 | } 18 | 19 | asmjit::X86Xmm getFreeValueReg() { 20 | TORCH_CHECK(free_value_regs.size() > 0); 21 | auto reg = free_value_regs.back(); 22 | free_value_regs.pop_back(); 23 | return reg; 24 | } 25 | 26 | asmjit::X86Gp getAddrReg(const Value* v) { 27 | TORCH_CHECK(addr_regs.find(v) != addr_regs.end()); 28 | return addr_regs[v]; 29 | } 30 | 31 | asmjit::X86Xmm getValueReg(const Value* v) { 32 | TORCH_CHECK(value_regs.find(v) != value_regs.end()); 33 | return value_regs[v]; 34 | } 35 | 36 | void mapReg(const Value* v, asmjit::X86Gp gp) { 37 | addr_regs[v] = gp; 38 | } 39 | 40 | void mapReg(const Value* v, asmjit::X86Xmm xmm) { 41 | value_regs[v] = xmm; 42 | } 43 | 44 | void free(asmjit::X86Gp reg) { 45 | free_addr_regs.push_back(reg); 46 | } 47 | 48 | void free(asmjit::X86Xmm reg) { 49 | free_value_regs.push_back(reg); 50 | } 51 | 52 | private: 53 | std::unordered_map addr_regs; 54 | std::unordered_map value_regs; 55 | 56 | std::vector free_addr_regs = { 57 | asmjit::x86::rsi, 58 | asmjit::x86::rdx, 59 | asmjit::x86::r8, 60 | asmjit::x86::r9, 61 | asmjit::x86::r10, 62 | asmjit::x86::r11, 63 | }; 64 | 65 | std::vector free_value_regs = { 66 | asmjit::x86::xmm0, 67 | asmjit::x86::xmm1, 68 | asmjit::x86::xmm2, 69 | asmjit::x86::xmm3, 70 | asmjit::x86::xmm4, 71 | asmjit::x86::xmm5, 72 | asmjit::x86::xmm6, 73 | asmjit::x86::xmm7, 74 | }; 75 | }; 76 | 77 | bool PointwiseCompiler::supported(const torch::jit::Node* node) { 78 | switch (node->kind()) { 79 | case aten::mul: 80 | return true; 81 | default: 82 | return false; 83 | } 84 | return false; 85 | } 86 | 87 | void PointwiseCompiler::run(torch::jit::Stack& stack) { 88 | // Get the number of expected inputs to the graph we are compiling 89 | const at::ArrayRef& graph_inputs = subgraph_->inputs(); 90 | const auto num_inputs = graph_inputs.size(); 91 | 92 | // Pop these inputs from the stack. 93 | at::ArrayRef inputs = last(stack, num_inputs); 94 | 95 | // If we haven't compiled for the shape/device of these inputs before, 96 | // do so now. 97 | CompleteArgumentSpec spec{false, ArrayRef(inputs)}; 98 | if (cache_.find(spec) == cache_.end()) { 99 | cache_[spec] = compile(inputs); 100 | } 101 | 102 | // Run the compiled function! 103 | auto outputs = cache_[spec](inputs); 104 | 105 | drop(stack, num_inputs); 106 | for (auto& output : outputs) { 107 | auto var = torch::autograd::make_variable(output.toTensor()); 108 | stack.push_back(IValue(var)); 109 | } 110 | } 111 | 112 | void PointwiseCompiler::emitOperation( 113 | const Node* node, 114 | const std::set& seen, 115 | asmjit::X86Assembler& assembler, 116 | RegisterManager& reg_manager) { 117 | switch (node->kind()) { 118 | case aten::mul: { 119 | auto A = node->inputs()[0]; 120 | auto C_reg = reg_manager.getValueReg(A); 121 | for (auto use : A->uses()) { 122 | if (seen.find(use.user) != seen.end()) { 123 | C_reg = reg_manager.getFreeValueReg(); 124 | assembler.movups(C_reg, reg_manager.getValueReg(A)); 125 | } 126 | } 127 | auto B = node->inputs()[1]; 128 | assembler.mulss(C_reg, reg_manager.getValueReg(B)); 129 | reg_manager.mapReg(node->outputs()[0], C_reg); 130 | } 131 | } 132 | for (auto& input : node->inputs()) { 133 | bool used = true; 134 | for (auto use : input->uses()) { 135 | if (seen.find(use.user) == seen.end()) { 136 | used = false; 137 | } 138 | } 139 | if (used) { 140 | reg_manager.free(reg_manager.getValueReg(input)); 141 | } 142 | } 143 | } 144 | 145 | CompiledCode PointwiseCompiler::compile( 146 | at::ArrayRef& inputs) { 147 | // First we run through some checks to make sure the inputs are Tensors and 148 | // that the implied semantics are pointwise. 149 | TORCH_CHECK(inputs.size(), "Need at least one input."); 150 | for (const auto& input : inputs) { 151 | TORCH_CHECK(input.isTensor(), "Compiler can only handle Tensor inputs."); 152 | } 153 | auto size = inputs[0].toTensor().numel(); 154 | for (const auto& input : inputs) { 155 | TORCH_CHECK( 156 | input.toTensor().numel() == size, 157 | "Compiler can only handle pointwise operations without broadcasting."); 158 | } 159 | 160 | // Then we setup code generation utilities. 161 | auto reg_manager = RegisterManager(); 162 | asmjit::CodeHolder code; 163 | code.init(jit_runtime_.getCodeInfo()); 164 | asmjit::StringLogger asm_logger; 165 | code.setLogger(&asm_logger); 166 | asmjit::X86Assembler assembler(&code); 167 | 168 | const bool isWinOS = static_cast(ASMJIT_OS_WINDOWS); 169 | asmjit::X86Gp pointers = isWinOS ? asmjit::x86::rcx : asmjit::x86::rdi; 170 | 171 | // Move all the input Tensor addresses into registers 172 | for (auto i = 0; i < inputs.size(); ++i) { 173 | auto reg = reg_manager.getFreeAddrReg(); 174 | auto mem_ptr = asmjit::x86::ptr(pointers, i * sizeof(void*)); 175 | reg_manager.mapReg(subgraph_->inputs()[i], reg); 176 | assembler.mov(reg, mem_ptr); 177 | } 178 | 179 | // Do the same with output Tensors 180 | for (auto i = 0; i < subgraph_->outputs().size(); ++i) { 181 | auto reg = reg_manager.getFreeAddrReg(); 182 | auto mem_ptr = 183 | asmjit::x86::ptr(pointers, (i + inputs.size()) * sizeof(void*)); 184 | reg_manager.mapReg(subgraph_->outputs()[i], reg); 185 | assembler.mov(reg, mem_ptr); 186 | } 187 | 188 | // Setup a label for looping 189 | auto iter = reg_manager.getFreeAddrReg(); 190 | assembler.mov(iter, 0); 191 | auto loop_label = assembler.newLabel(); 192 | assembler.bind(loop_label); 193 | 194 | // Now we iterate through the nodes, keeping track of which ones we've 195 | // seen. If the input to a node has been totally consumed (no nodes 196 | // we haven't seen will use it), we free that register. 197 | std::set seen; 198 | 199 | for (auto input : subgraph_->inputs()) { 200 | auto reg = reg_manager.getFreeValueReg(); 201 | assembler.movd( 202 | reg, asmjit::x86::ptr(reg_manager.getAddrReg(input), iter, 2)); 203 | reg_manager.mapReg(input, reg); 204 | } 205 | 206 | // Iterating over graph nodes is guaranteed to be topologically sorted 207 | for (auto node : subgraph_->nodes()) { 208 | seen.insert(node); 209 | emitOperation(node, seen, assembler, reg_manager); 210 | } 211 | 212 | // Store all the output values into memory. 213 | for (auto output : subgraph_->outputs()) { 214 | assembler.movd( 215 | asmjit::x86::ptr(reg_manager.getAddrReg(output), iter, 2), 216 | reg_manager.getValueReg(output)); 217 | } 218 | 219 | assembler.add(iter, 1); 220 | assembler.cmp(iter, size); 221 | assembler.jb(loop_label); 222 | 223 | assembler.ret(); 224 | 225 | // Now we bind a function the assembly we generated. 226 | void (*fn)(void**); 227 | asmjit::Error err = jit_runtime_.add(&fn, &code); 228 | TORCH_CHECK( 229 | !err, 230 | "Couldn't create function, asm:\n", 231 | std::string(asm_logger.getString())); 232 | 233 | // This function wraps the function pointer we bound our assembly to 234 | // Adheres to the CompiledCode interface defined in compiler.h 235 | auto compiled_func = [this, fn, size](at::ArrayRef& inputs) { 236 | std::vector args; 237 | for (auto input : inputs) { 238 | TORCH_CHECK(input.isTensor()); 239 | TORCH_CHECK(input.toTensor().is_contiguous()); 240 | TORCH_CHECK(input.toTensor().device().is_cpu()); 241 | args.emplace_back(input.toTensor().data_ptr()); 242 | } 243 | std::vector outputs; 244 | for (auto output : subgraph_->outputs()) { 245 | outputs.emplace_back(at::empty({size})); 246 | } 247 | for (auto output : outputs) { 248 | args.emplace_back(output.toTensor().data_ptr()); 249 | } 250 | 251 | // Run the function 252 | fn(args.data()); 253 | 254 | return outputs; 255 | }; 256 | 257 | return compiled_func; 258 | } 259 | -------------------------------------------------------------------------------- /src/compiler.h: -------------------------------------------------------------------------------- 1 | // All we need to understand PyTorch 2 | #include 3 | // CompleteArgumentSpec (useful for caching) 4 | #include 5 | // Our assembler 6 | #include 7 | 8 | using CompiledCode = std::function( 9 | at::ArrayRef&)>; 10 | class RegisterManager; 11 | 12 | class PointwiseCompiler { 13 | public: 14 | PointwiseCompiler(const torch::jit::Node* node) 15 | : subgraph_(node->g(torch::jit::attr::Subgraph)) {} 16 | void run(torch::jit::Stack& stack); 17 | static bool supported(const torch::jit::Node* node); 18 | 19 | private: 20 | void emitOperation( 21 | const torch::jit::Node* node, 22 | const std::set& seen, 23 | asmjit::X86Assembler& assembler, 24 | RegisterManager& reg_manager); 25 | CompiledCode compile(at::ArrayRef&); 26 | std::shared_ptr subgraph_; 27 | std::unordered_map cache_; 28 | asmjit::JitRuntime jit_runtime_; 29 | }; 30 | -------------------------------------------------------------------------------- /src/register.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // Register our compiler as handling a 4 | // specific type of operator 5 | #include 6 | 7 | // Register a pass to convert the IR into one with our operator 8 | #include 9 | // CustomFuseGraph is a helper to use simple whitelisting 10 | #include 11 | 12 | #include "compiler.h" 13 | 14 | namespace py = pybind11; 15 | using namespace torch::jit; 16 | 17 | PYBIND11_MODULE(pointwise_compiler, m) { 18 | // PyTorch makes heavy use of interned strings, which are called Symbols 19 | const auto pointwise_compiler_symbol = 20 | Symbol::fromQualString("pw::CompilationGroup"); 21 | 22 | // Let's hook up the compiler! 23 | 24 | // First, register a pass that will coalesce operators we can handle 25 | // into a single operator containing a subgraph. 26 | RegisterPass pass([pointwise_compiler_symbol](std::shared_ptr& g) { 27 | CustomFuseGraph(g, PointwiseCompiler::supported, pointwise_compiler_symbol); 28 | }); 29 | 30 | // We are only dealing with pure operations (no aliasing or in place 31 | // mutation), so our subgraph will always be pure. 32 | auto options = c10::OperatorOptions(); 33 | options.setAliasAnalysis(AliasAnalysisKind::PURE); 34 | 35 | RegisterOperators op({Operator( 36 | pointwise_compiler_symbol, 37 | [](const Node* node) { 38 | auto compiler = std::make_shared(node); 39 | return [compiler](Stack& stack) { 40 | compiler->run(stack); 41 | return 0; 42 | }; 43 | }, 44 | options)}); 45 | } 46 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | def benchmark(f): 5 | A_ = torch.randn(1024) 6 | B_ = torch.randn(1024) 7 | # Warmup 8 | for _ in range(10): 9 | _ = f(A_,B_) 10 | t = time.time() 11 | for _ in range(100): 12 | _ = f(A_,B_) 13 | return time.time() - t 14 | 15 | A = torch.randn(1024) 16 | B = torch.randn(1024) 17 | 18 | @torch.jit.script 19 | def foo_jit(a, b): 20 | c = a.mul(b) 21 | a = c.mul(c) 22 | a = c.mul(a) 23 | return a 24 | 25 | print("-- Default IR --\n", foo_jit.graph_for(A,B)) 26 | C_jit = foo_jit(A,B) 27 | print("Default version took {:.2f}ms".format(1000 * benchmark(foo_jit))) 28 | 29 | import pointwise_compiler 30 | print() 31 | 32 | @torch.jit.script 33 | def foo_compiled(a, b): 34 | c = a.mul(b) 35 | a = c.mul(c) 36 | a = c.mul(a) 37 | return a 38 | 39 | print("-- Transformed IR --\n", foo_compiled.graph_for(A,B)) 40 | C_compiled = foo_compiled(A,B) 41 | print("Compiled version took {:.2f}ms".format(1000 * benchmark(foo_compiled))) 42 | 43 | assert torch.allclose(C_jit, C_compiled) 44 | --------------------------------------------------------------------------------