├── docs ├── .gitignore ├── .DS_Store ├── tilegraph.png ├── design.md └── tilegraph.drawio ├── .gitignore ├── src ├── kernels │ ├── kernel_emitter.cpp │ ├── gemm.cpp │ ├── var.cpp │ ├── cuda │ │ ├── cuda_iteration.cpp │ │ ├── tensor_core.cpp │ │ ├── cuda_var.cpp │ │ ├── function.cpp │ │ └── gemm.cpp │ └── iteration.cpp ├── optimizer │ ├── tilling │ │ └── split.cpp │ └── fusion │ │ ├── subgraph_fusion │ │ ├── gemm_add_relu.cpp │ │ ├── subgraph_fusion_base.cpp │ │ └── gemm_relu_fusion.cpp │ │ └── persistent_kernel_fusion.cpp ├── ir │ └── graphene.cpp ├── core │ ├── graph │ │ ├── graph.cpp │ │ ├── gedge.cpp │ │ ├── gnode.cpp │ │ ├── subgraph_match.cpp │ │ └── graph_base.cpp │ ├── operators │ │ ├── operator.cpp │ │ ├── unary.cpp │ │ ├── binary.cpp │ │ ├── fused.cpp │ │ ├── elementwise.cpp │ │ └── gemm.cpp │ ├── tensor.cpp │ ├── platform.cpp │ └── type.cpp └── codegen │ ├── cuda_compiler.cpp │ └── compiler.cpp ├── include ├── core │ ├── operators │ │ ├── elementwise.hpp │ │ ├── operator.hpp │ │ ├── unary.hpp │ │ ├── fused.hpp │ │ ├── binary.hpp │ │ └── gemm.hpp │ ├── graph │ │ ├── subgraph.hpp │ │ ├── graph.hpp │ │ ├── gedge.hpp │ │ ├── graph_base.hpp │ │ ├── gnode.hpp │ │ └── subgraph_match.hpp │ ├── platform.hpp │ ├── tensor.hpp │ └── type.hpp ├── kernels │ ├── cuda │ │ ├── cutlass │ │ │ └── gemm.hpp │ │ ├── cute │ │ │ └── gemm.hpp │ │ ├── tensor_core.hpp │ │ ├── cuda_kernel_unit.hpp │ │ ├── cuda_iteration.hpp │ │ ├── cuda_var.hpp │ │ ├── header.hpp │ │ ├── sync.hpp │ │ ├── function.hpp │ │ ├── memory.hpp │ │ └── gemm.hpp │ ├── kernel_unit.hpp │ ├── kernel_emiter.hpp │ ├── gemm.hpp │ ├── var.hpp │ └── iteration.hpp ├── optimizer │ ├── fusion │ │ ├── graph_attenetion_fuse.hpp │ │ ├── subgraph_fusion │ │ │ ├── gemm_add_relu.hpp │ │ │ ├── gemm_relu_fusion.hpp │ │ │ └── subgraph_fusion_base.hpp │ │ ├── graph_fusion_base.hpp │ │ └── persistent_kernel_fusion.hpp │ └── tilling │ │ └── split.h ├── common │ ├── common.hpp │ └── error_handler.hpp ├── ir │ └── graphene.hpp └── codegen │ ├── generator.hpp │ ├── cuda_compiler.hpp │ └── compiler.hpp ├── .clang-format ├── examples ├── gemm_kernel.cpp ├── graph_base.cpp ├── fuse.cpp ├── subgraph_match.cpp └── persistent_kernel_fusion.cpp ├── .gitmodules ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── tests ├── operators │ └── test_gemm.cpp ├── codegen │ └── test_simple_codegen.cpp ├── graph │ ├── test_subgraph_match.cpp │ └── test_toposort.cpp └── fusion │ ├── test_persisten_kernel_fusion.cpp │ └── test_gemm_fusion.cpp ├── .github └── workflows │ └── build.yml ├── Makefile ├── .vscode ├── c_cpp_properties.json └── settings.json ├── README.md ├── CMakeLists.txt └── LICENSE /docs/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | .DS_Store -------------------------------------------------------------------------------- /src/kernels/kernel_emitter.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optimizer/tilling/split.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/core/operators/elementwise.hpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/kernels/cuda/cutlass/gemm.hpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/optimizer/fusion/graph_attenetion_fuse.hpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/kernels/gemm.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/gemm.hpp" 2 | -------------------------------------------------------------------------------- /include/optimizer/fusion/subgraph_fusion/gemm_add_relu.hpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optimizer/fusion/subgraph_fusion/gemm_add_relu.cpp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /include/common/common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common/error_handler.hpp" -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KuangjuX/TileGraph/HEAD/docs/.DS_Store -------------------------------------------------------------------------------- /docs/tilegraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KuangjuX/TileGraph/HEAD/docs/tilegraph.png -------------------------------------------------------------------------------- /include/ir/graphene.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace tilegraph::ir {} // namespace tilegraph::ir -------------------------------------------------------------------------------- /src/ir/graphene.cpp: -------------------------------------------------------------------------------- 1 | #include "ir/graphene.hpp" 2 | 3 | namespace tilegraph::ir {} // namespace tilegraph::ir -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | --- 4 | Language: Cpp 5 | PointerAlignment: Right 6 | SortIncludes: false 7 | ColumnLimit: 80 8 | NamespaceIndentation: All 9 | IndentWidth: 4 -------------------------------------------------------------------------------- /include/core/graph/subgraph.hpp: -------------------------------------------------------------------------------- 1 | // #pragma once 2 | // #include "core/graph/graph_base.hpp" 3 | 4 | // namespace tilegraph::graph { 5 | // // class SubGraph : public GraphBase {}; 6 | // } // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/kernels/cuda/cute/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "kernels/gemm.hpp" 4 | 5 | namespace tilegraph::kernel::cuda::cute { 6 | class CuteGEMMKernel : public GEMMKernel {}; 7 | } // namespace tilegraph::kernel::cuda::cute -------------------------------------------------------------------------------- /src/core/graph/graph.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #include "core/graph/graph.hpp" 7 | #include "core/graph/subgraph.hpp" 8 | 9 | namespace tilegraph::graph {} // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/kernels/cuda/tensor_core.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tilegraph::kernel::cuda { 6 | std::string genWmmaSync(int indient, std::string a, std::string b, 7 | std::string c, std::string d); 8 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /src/core/operators/operator.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/operator.hpp" 2 | 3 | namespace tilegraph::operators { 4 | std::vector Operator::inferShape( 5 | std::vector inputs) { 6 | // Empty Implementation. 7 | return {}; 8 | } 9 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /include/optimizer/tilling/split.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace tilegraph::tilling { 4 | class Split { 5 | public: 6 | std::vector split_dims; 7 | 8 | Split(std::vector split_dims); 9 | ~Split() = default; 10 | }; 11 | } // namespace tilegraph::tilling -------------------------------------------------------------------------------- /include/optimizer/fusion/graph_fusion_base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/graph/graph.hpp" 3 | 4 | using namespace tilegraph::graph; 5 | 6 | namespace tilegraph::fusion { 7 | class GraphFusionBase { 8 | public: 9 | virtual bool fusion(Graph::Pointer graph) { return true; } 10 | }; 11 | } // namespace tilegraph::fusion -------------------------------------------------------------------------------- /src/kernels/var.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/var.hpp" 2 | 3 | namespace tilegraph::kernel { 4 | Var::Var(MemoryType memory_level, DataType data_type, uint32_t len, 5 | std::string name) 6 | : memory_level(memory_level), 7 | data_type(data_type), 8 | len(len), 9 | name(name) {} 10 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /include/kernels/kernel_unit.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | namespace tilegraph::kernel { 4 | class KernelUnit {}; 5 | 6 | std::string insertIndient(int indient) { 7 | std::string res; 8 | for (int i = 0; i < indient; i++) { 9 | res += " "; 10 | } 11 | return res; 12 | } 13 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /include/core/graph/graph.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "core/type.hpp" 7 | #include "core/platform.hpp" 8 | #include "core/graph/subgraph.hpp" 9 | #include "core/graph/graph_base.hpp" 10 | 11 | namespace tilegraph::graph { 12 | 13 | using Graph = GraphBase; 14 | 15 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/codegen/generator.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | namespace tilegraph::codegen { 5 | class Generator { 6 | public: 7 | std::string code; 8 | virtual ~Generator() = default; 9 | virtual void generate_head(); 10 | virtual void generate_kernel(); 11 | virtual void generate_host(); 12 | }; 13 | } // namespace tilegraph::codegen -------------------------------------------------------------------------------- /include/core/operators/operator.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/tensor.hpp" 3 | #include "core/type.hpp" 4 | 5 | namespace tilegraph::operators { 6 | class Operator { 7 | public: 8 | virtual std::vector inferShape( 9 | std::vector inputs); 10 | 11 | OperatorType op_type; 12 | 13 | using OpBox = std::shared_ptr; 14 | }; 15 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /include/core/operators/unary.hpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/operator.hpp" 2 | #include "core/type.hpp" 3 | 4 | namespace tilegraph::operators { 5 | class Unary : public Operator { 6 | public: 7 | Unary(OperatorType type); 8 | ~Unary() = default; 9 | std::vector inferShape( 10 | std::vector inputs) override; 11 | 12 | OperatorType type; 13 | }; 14 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /src/core/operators/unary.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/unary.hpp" 2 | #include "common/common.hpp" 3 | 4 | namespace tilegraph::operators { 5 | 6 | Unary::Unary(OperatorType type) : type(type) {} 7 | 8 | std::vector Unary::inferShape( 9 | std::vector inputs) { 10 | ASSERT(inputs.size() == 1, "Unary operator should have 1 input"); 11 | return inputs; 12 | } 13 | 14 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /src/core/operators/binary.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/binary.hpp" 2 | #include "common/common.hpp" 3 | 4 | namespace tilegraph::operators { 5 | 6 | Binary::Binary(OperatorType type) : binary_type(type) {} 7 | 8 | std::vector Binary::inferShape( 9 | std::vector inputs) { 10 | ASSERT(inputs.size() == 2, "Binary operator should have 2 inputs"); 11 | return {inputs[0]}; 12 | } 13 | 14 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /include/codegen/cuda_compiler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "codegen/compiler.hpp" 3 | 4 | namespace tilegraph::codegen { 5 | 6 | class CudaCompiler final : public Compiler { 7 | protected: 8 | std::string_view hardware() const noexcept final; 9 | std::string_view extension() const noexcept final; 10 | void *_compile(std::filesystem::path const &src, 11 | const char *symbol) final; 12 | }; 13 | 14 | } // namespace tilegraph::codegen -------------------------------------------------------------------------------- /include/core/operators/fused.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/operators/operator.hpp" 3 | 4 | namespace tilegraph::operators { 5 | class FusedOp : public Operator { 6 | public: 7 | FusedOp(std::vector ops = {}); 8 | ~FusedOp() = default; 9 | virtual std::vector inferShape( 10 | std::vector inputs) override; 11 | 12 | std::vector ops; 13 | }; 14 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /src/kernels/cuda/cuda_iteration.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/cuda_iteration.hpp" 2 | 3 | namespace tilegraph::kernel::cuda { 4 | CudaIteration::CudaIteration( 5 | std::unique_ptr iter_var, 6 | std::variant> step, 7 | std::variant> start, 8 | std::variant> end) 9 | : Iteration(std::move(iter_var), step, start, end) {} 10 | } // namespace tilegraph::kernel::cuda 11 | -------------------------------------------------------------------------------- /include/core/operators/binary.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/operators/operator.hpp" 3 | #include "core/type.hpp" 4 | 5 | namespace tilegraph::operators { 6 | 7 | class Binary : public Operator { 8 | public: 9 | Binary(OperatorType type); 10 | ~Binary() = default; 11 | virtual std::vector inferShape( 12 | std::vector inputs) override; 13 | 14 | OperatorType binary_type; 15 | }; 16 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /examples/gemm_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/gemm.hpp" 2 | #include "core/type.hpp" 3 | #include 4 | 5 | using namespace tilegraph; 6 | using namespace tilegraph::kernel::cuda; 7 | 8 | int main() { 9 | auto gemm_kernel = std::make_shared( 10 | 5376, 5376, 2048, 128, 128, 32, 64, 64, 16, 16, 16, 16, false, true, 11 | MemoryType::Global, MemoryType::Global); 12 | auto kernel = gemm_kernel->genTCGEMM("matmul"); 13 | fmt::println("GEMM Kernel:\n{}", kernel); 14 | } -------------------------------------------------------------------------------- /include/kernels/cuda/cuda_kernel_unit.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "kernels/kernel_unit.hpp" 3 | #include "kernels/cuda/header.hpp" 4 | #include 5 | 6 | namespace tilegraph::kernel::cuda { 7 | class CudaKernelUnit : public KernelUnit { 8 | public: 9 | CudaKernelUnit() = default; 10 | ~CudaKernelUnit() = default; 11 | 12 | std::set headers; 13 | 14 | void addHeader(CudaHeader header) { headers.insert(header); } 15 | }; 16 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /src/core/tensor.cpp: -------------------------------------------------------------------------------- 1 | #include "core/tensor.hpp" 2 | 3 | namespace tilegraph { 4 | 5 | int64_t Tensor::tensor_count = 0; 6 | 7 | Tensor::Tensor(const std::vector &dimension, 8 | std::string name_value, TensorDatatype dtype, 9 | TensorType type) 10 | : name(name_value), 11 | index(tensor_count++), 12 | tensor_datatype(dtype), 13 | tensor_type(type), 14 | tensor_dimension(dimension) {} 15 | 16 | } // namespace tilegraph -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rd-party/googletest"] 2 | path = 3rd-party/googletest 3 | url = git@github.com:google/googletest.git 4 | [submodule "3rd-party/cutlass"] 5 | path = 3rd-party/cutlass 6 | url = git@github.com:NVIDIA/cutlass.git 7 | [submodule "3rd-party/fmtlog"] 8 | path = 3rd-party/fmtlog 9 | url = git@github.com:MengRao/fmtlog.git 10 | [submodule "3rd-party/fmt"] 11 | path = 3rd-party/fmt 12 | url = git@github.com:fmtlib/fmt.git 13 | [submodule "3rd-party/result"] 14 | path = 3rd-party/result 15 | url = git@github.com:oktal/result.git 16 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM ubuntu:22.04 2 | FROM mcr.microsoft.com/vscode/devcontainers/base:ubuntu-22.04 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | # Install dependencies. 6 | RUN apt update && apt-get install -y git make cmake build-essential python-is-python3 python-dev-is-python3 python3-pip libdw-dev openssh-client 7 | 8 | # Update pip and switch to Tsinghua source. 9 | RUN python -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade pip && pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple 10 | 11 | 12 | -------------------------------------------------------------------------------- /include/kernels/kernel_emiter.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include "core/graph/gnode.hpp" 5 | #include "core/tensor.hpp" 6 | 7 | namespace tilegraph::kernel { 8 | class KernelEmiter { 9 | public: 10 | std::shared_ptr op; 11 | // The input tensor 12 | std::shared_ptr inputs; 13 | // The output tensor 14 | std::shared_ptr outputs; 15 | // The allocated tensor 16 | std::shared_ptr allocated_tensor; 17 | }; 18 | } // namespace tilegraph::kernels 19 | -------------------------------------------------------------------------------- /include/core/operators/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/operators/operator.hpp" 3 | 4 | namespace tilegraph::operators { 5 | 6 | class GEMM : public Operator { 7 | public: 8 | GEMM(float alpha = 1.0f, float beta = 1.0f, bool transA = false, 9 | bool transB = false); 10 | ~GEMM() = default; 11 | virtual std::vector inferShape( 12 | std::vector inputs) override; 13 | 14 | float alpha, beta; 15 | bool transA, transB; 16 | }; 17 | 18 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /src/core/platform.cpp: -------------------------------------------------------------------------------- 1 | #include "core/platform.hpp" 2 | 3 | namespace tilegraph { 4 | 5 | #define CASE(TYPE, STR) \ 6 | case Platform::TYPE: \ 7 | return STR 8 | 9 | const char *Platform::toString() const { 10 | switch (type) { 11 | CASE(CUDA, "CUDA"); 12 | CASE(BANG, "BANG"); 13 | default: 14 | return "Unknown"; 15 | } 16 | } 17 | 18 | bool Platform::isCUDA() const { return type == Platform::CUDA; } 19 | 20 | bool Platform::isBANG() const { return type == Platform::BANG; } 21 | 22 | } // namespace tilegraph -------------------------------------------------------------------------------- /include/kernels/cuda/cuda_iteration.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "kernels/cuda/cuda_var.hpp" 3 | #include "kernels/iteration.hpp" 4 | #include 5 | 6 | namespace tilegraph::kernel::cuda { 7 | class CudaIteration : public Iteration { 8 | public: 9 | CudaIteration(std::unique_ptr iter_var, 10 | std::variant> step, 11 | std::variant> start, 12 | std::variant> end); 13 | }; 14 | 15 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /src/core/type.cpp: -------------------------------------------------------------------------------- 1 | #include "core/type.hpp" 2 | 3 | namespace tilegraph { 4 | std::string toString(OperatorType op) { 5 | switch (op) { 6 | case OperatorType::ADD: 7 | return "Add"; 8 | case OperatorType::SUB: 9 | return "Sub"; 10 | case OperatorType::GEMM: 11 | return "Gemm"; 12 | case OperatorType::RELU: 13 | return "Relu"; 14 | case OperatorType::GEMM_RELU: 15 | return "GemmRelu"; 16 | default: 17 | return "Unknown"; 18 | } 19 | } 20 | } // namespace tilegraph -------------------------------------------------------------------------------- /src/kernels/cuda/tensor_core.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/tensor_core.hpp" 2 | 3 | namespace tilegraph::kernel::cuda { 4 | std::string genWmmaSync(int indient, std::string a, std::string b, 5 | std::string c, std::string d) { 6 | std::string mma_sync; 7 | for (int i = 0; i < indient; i++) { 8 | mma_sync += " "; 9 | } 10 | 11 | // D = A * B + C / C = A * B + C 12 | mma_sync += fmt::format("nvcuda::wmma::mma_sync({}, {}, {}, {});\n;", d, 13 | a, b, c); 14 | 15 | return mma_sync; 16 | } 17 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /tests/operators/test_gemm.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/gemm.hpp" 2 | 3 | #include 4 | 5 | using namespace tilegraph; 6 | using namespace tilegraph::operators; 7 | 8 | TEST(OPERATORS, gemm) { 9 | auto gemm = std::make_shared(); 10 | auto input0 = std::make_shared(std::vector{2, 3}); 11 | auto input1 = std::make_shared(std::vector{3, 4}); 12 | auto output = gemm->inferShape({input0, input1}); 13 | 14 | ASSERT_EQ(output.size(), 1); 15 | ASSERT_EQ(output[0]->tensor_dimension.size(), 2); 16 | ASSERT_EQ(output[0]->tensor_dimension[0], 2); 17 | ASSERT_EQ(output[0]->tensor_dimension[1], 4); 18 | } 19 | -------------------------------------------------------------------------------- /include/kernels/cuda/cuda_var.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "kernels/var.hpp" 3 | #include "kernels/cuda/tensor_core.hpp" 4 | #include 5 | 6 | namespace tilegraph::kernel::cuda { 7 | class CudaVar : public Var { 8 | public: 9 | CudaVar(MemoryType memory_level, DataType data_type, uint32_t len, 10 | std::string name); 11 | ~CudaVar() = default; 12 | 13 | std::string declareVar(int indient) override; 14 | std::string initVar(int indient) override; 15 | std::string getVarIndex(uint32_t index) override; 16 | std::string getVarIndexByVar(std::string index) override; 17 | }; 18 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /include/kernels/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/type.hpp" 3 | 4 | namespace tilegraph::kernel { 5 | class GEMMKernel { 6 | public: 7 | // GEMM parametes. 8 | uint32_t M; 9 | uint32_t N; 10 | uint32_t K; 11 | uint32_t ShardedM; 12 | uint32_t ShardedN; 13 | uint32_t ShardedK; 14 | uint32_t WarpM; 15 | uint32_t WarpN; 16 | uint32_t WarpK; 17 | uint32_t WmmaM; 18 | uint32_t WmmaN; 19 | uint32_t WmmaK; 20 | 21 | bool transpose_a; 22 | bool transpose_b; 23 | 24 | MemoryType memory_level; 25 | MemoryType output_level; 26 | }; 27 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /include/optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/graph/graph.hpp" 3 | #include "core/graph/subgraph_match.hpp" 4 | #include "optimizer/fusion/subgraph_fusion/subgraph_fusion_base.hpp" 5 | 6 | namespace tilegraph::fusion::subgraph { 7 | 8 | class GemmReluFusion : public SubgraphFusionBase { 9 | public: 10 | GemmReluFusion(std::shared_ptr graph); 11 | virtual ~GemmReluFusion() = default; 12 | 13 | virtual void create_subgraphs() override; 14 | virtual Result fuse_subgraph( 15 | graph::SubGraphRecord::Pointer subgraph_record) override; 16 | }; 17 | } // namespace tilegraph::fusion::subgraph -------------------------------------------------------------------------------- /include/kernels/var.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/type.hpp" 3 | 4 | namespace tilegraph::kernel { 5 | class Var { 6 | public: 7 | MemoryType memory_level; 8 | DataType data_type; 9 | uint32_t len; 10 | std::string name; 11 | 12 | Var(MemoryType memory_level, DataType data_type, uint32_t len, 13 | std::string name); 14 | ~Var() = default; 15 | 16 | virtual std::string declareVar(int indient) = 0; 17 | virtual std::string initVar(int indient) = 0; 18 | virtual std::string getVarIndex(uint32_t index) = 0; 19 | virtual std::string getVarIndexByVar(std::string index) = 0; 20 | }; 21 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /src/core/operators/fused.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/fused.hpp" 2 | #include "common/common.hpp" 3 | 4 | #include 5 | 6 | namespace tilegraph::operators { 7 | FusedOp::FusedOp(std::vector ops) : ops(ops) {} 8 | 9 | std::vector FusedOp::inferShape( 10 | std::vector inputs) { 11 | if (ops.empty()) { 12 | loge("[FusedOperator::inferShape] No operators in fused operator."); 13 | return {}; 14 | } else { 15 | loge( 16 | "[FusedOperator::inferShape] Fused operator is not " 17 | "implemented."); 18 | UNREACHABLE(); 19 | return {}; 20 | } 21 | } 22 | 23 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /include/core/platform.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tilegraph { 6 | 7 | struct Platform { 8 | using underlying_t = uint16_t; 9 | 10 | enum : underlying_t { CUDA, BANG } type; 11 | 12 | constexpr Platform(decltype(type) t) : type(t) {} 13 | constexpr explicit Platform(underlying_t val) 14 | : type((decltype(type))val) {} 15 | constexpr underlying_t underlying() const { return type; } 16 | 17 | bool operator==(Platform others) const { return type == others.type; } 18 | bool operator!=(Platform others) const { return type != others.type; } 19 | 20 | const char *toString() const; 21 | bool isCUDA() const; 22 | bool isBANG() const; 23 | }; 24 | 25 | } // namespace tilegraph -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build And Test CI 2 | on: 3 | push: 4 | paths-ignore: 5 | - '**.md' 6 | - 'LICENSE' 7 | pull_request: 8 | paths: 9 | - '**.md' 10 | - 'LICENSE' 11 | 12 | jobs: 13 | build: 14 | name: Build 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | type: [debug, release] 20 | steps: 21 | 22 | - uses: actions/checkout@v3 23 | with: 24 | submodules: recursive 25 | 26 | - name: Build fmt 27 | run: | 28 | cd 3rd-party/fmt && mkdir build && cd build && cmake -DCMAKE_POSITION_INDEPENDENT_CODE=TRUE .. && make 29 | 30 | - name: build 31 | run: make build CUDA=OFF 32 | 33 | - name: examples 34 | run: make examples 35 | 36 | - name: test 37 | run: make test -------------------------------------------------------------------------------- /include/codegen/compiler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace tilegraph::codegen { 8 | 9 | class Compiler { 10 | public: 11 | virtual ~Compiler() = default; 12 | void *compile(const char *dir, const char *code, const char *symbol); 13 | void *fetch(const char *dir); 14 | 15 | protected: 16 | std::unordered_map _dirs; 17 | 18 | virtual std::string_view hardware() const = 0; 19 | virtual std::string_view extension() const = 0; 20 | virtual void *_compile(std::filesystem::path const &src, 21 | const char *symbol) = 0; 22 | 23 | static std::filesystem::path const &repo_path(); 24 | }; 25 | 26 | } // namespace tilegraph::codegen -------------------------------------------------------------------------------- /include/core/tensor.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "core/type.hpp" 7 | 8 | namespace tilegraph { 9 | class Tensor { 10 | private: 11 | static int64_t tensor_count; 12 | 13 | public: 14 | using Pointer = std::shared_ptr; 15 | std::string name; 16 | const int64_t index; 17 | TensorDatatype tensor_datatype; 18 | TensorType tensor_type; 19 | std::vector tensor_dimension; 20 | 21 | public: 22 | Tensor() = delete; 23 | Tensor(const std::vector &dimension, 24 | std::string name_value = "", 25 | TensorDatatype dtype = TensorDatatype::FLOAT, 26 | TensorType type = TensorType::VARIABLE); 27 | 28 | ~Tensor() = default; 29 | }; 30 | } // namespace tilegraph -------------------------------------------------------------------------------- /src/core/operators/elementwise.cpp: -------------------------------------------------------------------------------- 1 | // #include "core/operators/elementwise.hpp" 2 | 3 | // namespace tilegraph::operators { 4 | 5 | // using namespace tilegraph::graph; 6 | 7 | // Binary::Binary(OperatorType type, std::vector inputs_list, 8 | // std::vector outputs_list, std::string name_value, 9 | // int64_t outputs_num_value) 10 | // : Node(inputs_list, outputs_list, name_value, outputs_num_value) { 11 | // operator_type = type; 12 | // } 13 | 14 | // Unary::Unary(OperatorType type, std::vector inputs_list, 15 | // std::vector outputs_list, std::string name_value, 16 | // int64_t outputs_num_value) 17 | // : Node(inputs_list, outputs_list, name_value, outputs_num_value) { 18 | // operator_type = type; 19 | // } 20 | 21 | // } // namespace tilegraph -------------------------------------------------------------------------------- /docs/design.md: -------------------------------------------------------------------------------- 1 | # Design 2 | 3 | ## 设计思想 4 | 5 | TileGraph 是一个实验性的 DNN 静态代码生成框架,主要关注于算子之间的融合与高效率的代码生成。对于计算密集型算子主要依靠子图匹配技术对于一些特定模式的算子进行融合,例如 Bolt 中提到的 Persistent Kernel Fusion 以及 Attention Fusion 等。随后融合后变为优化过的图,这里融合后的节点对于一些模式固定的优化可以使用新的算子类型代替子图。 6 | 7 | 进行一次算子融合 Pass 后变为优化过的子图,随后对于图信息进行下降,下降后主要关注关注算子间的内存分配信息与内存层级。首先进行 tensor 的 tiling,随后对于整个图进行一次 Welder Pass 做启发式的内存算子融合,对于不同层级的内存分别用不同层级的子图结构进行标识。随后使用 Perf profiler 对于一些参数进行选择,最终按照 Kernel Graph 进行代码生成,对于每个子图都要进行递归地代码生成,退出子图后需要插入同步原语,例如 `__syncthreads()`。最终生成 CUDA 代码。 8 | 9 | ## Fusion 10 | 11 | ### Persistent Kernel Fusion 12 | 13 | 想法: 14 | 15 | - 首先使用 SubGraph Match 对形如 GEMM + ADD 以及 GEMM + ADD + 激活函数等进行匹配并融合。 16 | 17 | - 获取图的图拓扑排序,找到所有类型为 GEMM,GEMM_ADD 以及 GEMM_ADD_ACTIVATION 的算子,并对其进行分组,放到 `std::unordered_map` 里面。 18 | 19 | - 接下来遍历所有 GEMM,GEMM_ADD,GEMM_ADD_ACTIVATION 算子并查找其后继,如果后继仍为相同类型算子,则对其进行合并(并查集算法),直到遍历完所有算子,最终生成所有能够进行 persisten kernel fusion 的算子。 20 | -------------------------------------------------------------------------------- /include/kernels/cuda/header.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace tilegraph::kernel::cuda { 5 | enum class CudaHeader { 6 | cuda, 7 | cuda_runtime, 8 | cublas, 9 | cudnn, 10 | cutlass, 11 | }; 12 | 13 | std::string generateHeader(CudaHeader header) { 14 | switch (header) { 15 | case CudaHeader::cuda: 16 | return "#include \n"; 17 | case CudaHeader::cuda_runtime: 18 | return "#include \n"; 19 | case CudaHeader::cublas: 20 | return "#include \n"; 21 | case CudaHeader::cudnn: 22 | return "#include \n"; 23 | case CudaHeader::cutlass: 24 | return "#include \n"; 25 | default: 26 | return ""; 27 | } 28 | } 29 | 30 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /include/kernels/cuda/sync.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/type.hpp" 3 | #include 4 | #include 5 | 6 | namespace tilegraph::kernel::cuda { 7 | std::string insertSyncnorize(int indient, MemoryType memory_level) { 8 | std::string sync; 9 | for (int i = 0; i < indient; i++) { 10 | sync += " "; 11 | } 12 | switch (memory_level) { 13 | case MemoryType::Global: 14 | sync += "cudaDeviceSynchronize();\n"; 15 | break; 16 | case MemoryType::Shared: 17 | sync += "__syncthreads();\n"; 18 | break; 19 | case MemoryType::Register: 20 | fmt::println("Register level not supported"); 21 | break; 22 | default: 23 | fmt::println("Failed to get memory level"); 24 | } 25 | 26 | return sync; 27 | } 28 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /include/kernels/cuda/function.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/type.hpp" 3 | #include 4 | #include 5 | 6 | namespace tilegraph::kernel::cuda { 7 | class CudaFunctionKernelUnit { 8 | public: 9 | std::string declareGlobal( 10 | std::string name, 11 | std::vector> arguments); 12 | std::string declareDevice( 13 | std::string name, 14 | std::vector> arguments); 15 | }; 16 | 17 | class CudaFunction { 18 | public: 19 | std::string name; 20 | FuncType func_type; 21 | std::vector> arguments; 22 | DataType return_type; 23 | 24 | CudaFunction(std::string name, FuncType func_type, 25 | DataType return_type); 26 | 27 | std::string declareFunction(); 28 | }; 29 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /tests/codegen/test_simple_codegen.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "codegen/cuda_compiler.hpp" 4 | 5 | using namespace ::tilegraph::codegen; 6 | 7 | constexpr static const char *code = R"~( 8 | #include 9 | __global__ void kernel(float* a) { 10 | a[threadIdx.x] += 1.0; 11 | } 12 | extern "C" { 13 | void launchKernel(float* a) { 14 | float* dev_a; 15 | cudaMalloc(&dev_a, 100 * sizeof(float)); 16 | cudaMemcpy(dev_a, a, 100 * sizeof(float), cudaMemcpyHostToDevice); 17 | kernel<<<1, 100>>>(dev_a); 18 | cudaDeviceSynchronize(); 19 | cudaMemcpy(a, dev_a, 100 * sizeof(float), cudaMemcpyDeviceToHost); 20 | } 21 | } 22 | )~"; 23 | 24 | TEST(Codegen, simple_cuda) { 25 | CudaCompiler nvcc; 26 | auto function = nvcc.compile("add", code, "launchKernel"); 27 | 28 | float *a = new float[100]; 29 | memset(a, 0, 100 * sizeof(float)); 30 | reinterpret_cast(function)(a); 31 | EXPECT_EQ(a[0], 1.0); 32 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CUDA ?= OFF 2 | 3 | CC := g++ 4 | EXAMPLE := gemm_kernel 5 | EXAMPLE_SRCS := $(wildcard examples/*.cpp) 6 | EXAMPLES := $(patsubst examples/%.cpp, %, $(EXAMPLE_SRCS)) 7 | 8 | LD_FLAGS := -Lbuild/ -ltilegraph -Wl,-rpath=build/ 9 | INC_FLAGS := -Iinclude -I3rd-party/result -I3rd-party/fmt/include -I3rd-party/fmtlog 10 | MACRO_FLAGS := -DFMTLOG_HEADER_ONLY -DFMT_HEADER_ONLY 11 | 12 | CMAKE_OPTS = -DUSE_CUDA=$(CUDA) 13 | 14 | build: 15 | @mkdir build 16 | @cd build && cmake $(CMAKE_OPTS) .. && make -j8 17 | 18 | test: build 19 | @cd build && make test 20 | 21 | example: build 22 | @$(CC) examples/$(EXAMPLE).cpp $(INC_FLAGS) $(LD_FLAGS) $(MACRO_FLAGS) -o build/$(EXAMPLE) 23 | @./build/$(EXAMPLE) 24 | 25 | examples: build 26 | @for example in $(EXAMPLES); do \ 27 | $(CC) examples/$$example.cpp $(INC_FLAGS) $(LD_FLAGS) $(MACRO_FLAGS) -o build/$$example; \ 28 | echo "Running example: $$example"; \ 29 | ./build/$$example; \ 30 | done 31 | 32 | 33 | clean: 34 | @rm -rf build -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/include", 7 | "${workspaceFolder}/3rd-party/cutlass/include", 8 | "${workspaceFolder}/3rd-party/cutlass/tools/util/include", 9 | "${workspaceFolder}/3rd-party/cutlass/examples/common", 10 | "${workspaceFolder}/3rd-party/fmt/include", 11 | "${workspaceFolder}/3rd-party/fmtlog", 12 | "${workspaceFolder}/3rd-party/googletest/googletest/include", 13 | "${workspaceFolder}/3rd-party/googletest/googlemock/include", 14 | "${workspaceFolder}/3rd-party/result", 15 | ], 16 | "intelliSenseMode": "linux-gcc-x64", 17 | "compilerPath": "/usr/bin/gcc", 18 | "cStandard": "c17", 19 | "cppStandard": "gnu++17" 20 | } 21 | ], 22 | "version": 4 23 | } -------------------------------------------------------------------------------- /include/kernels/iteration.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "kernels/var.hpp" 3 | #include 4 | #include 5 | 6 | namespace tilegraph::kernel { 7 | template 8 | class Iteration { 9 | public: 10 | std::unique_ptr iter_var; 11 | // The step of the iteration. 12 | std::variant> step; 13 | // The start and end of the iteration. 14 | std::variant> start; 15 | std::variant> end; 16 | 17 | public: 18 | Iteration(std::unique_ptr iter_var, 19 | std::variant> step, 20 | std::variant> start, 21 | std::variant> end); 22 | 23 | virtual std::string genIter(int indient); 24 | virtual std::string getIterVar(); 25 | }; 26 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /include/optimizer/fusion/subgraph_fusion/subgraph_fusion_base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/graph/graph.hpp" 3 | #include "core/graph/subgraph_match.hpp" 4 | 5 | #include 6 | 7 | namespace tilegraph::fusion::subgraph { 8 | class SubgraphFusionBase { 9 | public: 10 | struct FusionError { 11 | enum class Kind { 12 | UnmatchedSubgraph, 13 | FailedToFuseSubgraph, 14 | }; 15 | 16 | Kind kind; 17 | }; 18 | SubgraphFusionBase(graph::Graph::Pointer graph); 19 | virtual ~SubgraphFusionBase() = default; 20 | 21 | virtual void create_subgraphs() = 0; 22 | virtual Result fuse_subgraph( 23 | graph::SubGraphRecord::Pointer subgraph_record) = 0; 24 | Result match_and_fuse_subgraph(); 25 | 26 | std::vector subgraphs; 27 | std::shared_ptr subgraph_match; 28 | std::shared_ptr graph; 29 | }; 30 | } // namespace tilegraph::fusion::subgraph -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/cpp 3 | { 4 | "name": "C++", 5 | "build": { 6 | "dockerfile": "Dockerfile" 7 | }, 8 | // Features to add to the dev container. More info: https://containers.dev/features. 9 | // "features": {}, 10 | // Configure tool-specific properties. 11 | "customizations": { 12 | // Configure properties specific to VS Code. 13 | "vscode": { 14 | "settings": {}, 15 | "extensions": [ 16 | "streetsidesoftware.code-spell-checker" 17 | ] 18 | } 19 | } 20 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 21 | // "forwardPorts": [], 22 | // Use 'postCreateCommand' to run commands after the container is created. 23 | // "postCreateCommand": "gcc -v", 24 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 25 | // "remoteUser": "root" 26 | } -------------------------------------------------------------------------------- /src/kernels/cuda/cuda_var.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/cuda_var.hpp" 2 | #include "kernels/cuda/memory.hpp" 3 | 4 | namespace tilegraph::kernel::cuda { 5 | CudaVar::CudaVar(MemoryType memory_level, DataType data_type, uint32_t len, 6 | std::string name) 7 | : Var(memory_level, data_type, len, name) {} 8 | 9 | std::string CudaVar::declareVar(int indient) { 10 | std::string var; 11 | if (memory_level != MemoryType::Warp) { 12 | var = declareMemory(indient, name, memory_level, data_type, len); 13 | } else { 14 | } 15 | return var; 16 | } 17 | 18 | std::string CudaVar::initVar(int indient) { 19 | std::string var; 20 | for (int i = 0; i < indient; i++) { 21 | var += " "; 22 | } 23 | return var; 24 | } 25 | 26 | std::string CudaVar::getVarIndex(uint32_t index) { 27 | return fmt::format("{}[{}]", name, index); 28 | } 29 | 30 | std::string CudaVar::getVarIndexByVar(std::string index) { 31 | return fmt::format("{}[{}]", name, index); 32 | } 33 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /src/core/operators/gemm.cpp: -------------------------------------------------------------------------------- 1 | #include "core/operators/gemm.hpp" 2 | #include "common/common.hpp" 3 | 4 | namespace tilegraph::operators { 5 | GEMM::GEMM(float alpha, float beta, bool transA, bool transB) 6 | : alpha(alpha), beta(beta), transA(transA), transB(transB) {} 7 | 8 | std::vector GEMM::inferShape( 9 | std::vector inputs) { 10 | // default Row Major Matrix 11 | ASSERT(inputs.size() == 2, "GEMM should have 2 inputs"); 12 | ASSERT(inputs[0]->tensor_dimension.size() == 2, 13 | "GEMM input 0 should be 2D"); 14 | ASSERT(inputs[1]->tensor_dimension.size() == 2, 15 | "GEMM input 1 should be 2D"); 16 | ASSERT(inputs[0]->tensor_dimension[1] == inputs[1]->tensor_dimension[0], 17 | "GEMM input 0 and input 1 " 18 | "should have same shape " 19 | "except last dimension"); 20 | auto output = std::make_shared(std::vector{ 21 | inputs[0]->tensor_dimension[0], inputs[1]->tensor_dimension[1]}); 22 | return {output}; 23 | } 24 | 25 | } // namespace tilegraph::operators -------------------------------------------------------------------------------- /include/optimizer/fusion/persistent_kernel_fusion.hpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "optimizer/fusion/graph_fusion_base.hpp" 3 | 4 | #include 5 | #include 6 | 7 | using namespace tilegraph::graph; 8 | 9 | namespace tilegraph::fusion { 10 | class PersistentKernelFusion : public GraphFusionBase { 11 | public: 12 | PersistentKernelFusion() = default; 13 | PersistentKernelFusion(Graph::Pointer graph); 14 | bool fusion(Graph::Pointer graph) override; 15 | 16 | private: 17 | std::size_t findRoot( 18 | std::unordered_map node_to_group, 19 | std::size_t node_idx); 20 | 21 | std::unordered_map> findGroups( 22 | std::unordered_map node_to_group, 23 | std::unordered_map nodes, 24 | std::unordered_map node_map); 25 | 26 | std::pair, std::vector> 27 | searchInOut(std::vector group); 28 | }; 29 | } // namespace tilegraph::fusion -------------------------------------------------------------------------------- /src/optimizer/fusion/subgraph_fusion/subgraph_fusion_base.cpp: -------------------------------------------------------------------------------- 1 | #include "optimizer/fusion/subgraph_fusion/subgraph_fusion_base.hpp" 2 | 3 | #include 4 | 5 | namespace tilegraph::fusion::subgraph { 6 | SubgraphFusionBase::SubgraphFusionBase(graph::Graph::Pointer graph) 7 | : graph(graph) { 8 | this->subgraph_match = std::make_shared(graph); 9 | } 10 | 11 | Result 12 | SubgraphFusionBase::match_and_fuse_subgraph() { 13 | for (auto subgraph : this->subgraphs) { 14 | if (this->subgraph_match->Match(subgraph)) { 15 | auto records = this->subgraph_match->get_matched_subgraph(); 16 | logi("Matched record size: {}", records.size()); 17 | for (auto record : records) { 18 | this->fuse_subgraph(record); 19 | } 20 | } else { 21 | return Err(SubgraphFusionBase::FusionError{ 22 | SubgraphFusionBase::FusionError::Kind::UnmatchedSubgraph}); 23 | } 24 | } 25 | return Ok(); 26 | } 27 | } // namespace tilegraph::fusion::subgraph -------------------------------------------------------------------------------- /src/codegen/cuda_compiler.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "codegen/cuda_compiler.hpp" 4 | #include "common/common.hpp" 5 | 6 | namespace tilegraph::codegen { 7 | 8 | auto CudaCompiler::hardware() const noexcept -> std::string_view { 9 | return "CUDA"; 10 | } 11 | auto CudaCompiler::extension() const noexcept -> std::string_view { 12 | return "cu"; 13 | } 14 | void *CudaCompiler::_compile(std::filesystem::path const &src, 15 | const char *symbol) { 16 | auto out = src, so = src; 17 | out.replace_extension("o"); 18 | so.replace_filename("libkernel.so"); 19 | { 20 | std::string command; 21 | command = fmt::format("nvcc -Xcompiler \"-fPIC\" {} -c -o {}", 22 | src.c_str(), out.c_str()); 23 | std::system(command.c_str()); 24 | command = 25 | fmt::format("nvcc -shared {} -o {}", out.c_str(), so.c_str()); 26 | std::system(command.c_str()); 27 | } 28 | 29 | auto handle = dlopen(so.c_str(), RTLD_LAZY); 30 | ASSERT(handle, "Failed to load kernel library: {}", dlerror()); 31 | auto function = dlsym(handle, symbol); 32 | ASSERT(function, "Failed to load kernel function: {}", dlerror()); 33 | return function; 34 | } 35 | 36 | } // namespace tilegraph::codegen -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TileGraph 2 | **Note: TiledKernel is no longer an active project, and the successor is [ThrillerFlow](https://github.com/TiledTensor/ThrillerFlow).** 3 | 4 | TileGraph is an experimental DNN compiler that utilizes static code generation and kernel fusion techniques. 5 | 6 | [设计文档](docs/design.md) | Design Docs 7 | 8 | ## Overview 9 | 10 | ![](docs/tilegraph.png) 11 | 12 | ## Get Started 13 | ### Clone Project 14 | ``` 15 | git clone git@github.com:KuangjuX/TileGraph.git 16 | git submodule update --init --recursive 17 | ``` 18 | 19 | ### Install Dependencies 20 | ``` 21 | cd 3rd-party/fmt && mkdir build && cd build && cmake -DCMAKE_POSITION_INDEPENDENT_CODE=TRUE .. && make 22 | ``` 23 | 24 | ### Test 25 | ``` 26 | make test 27 | ``` 28 | 29 | ### Run Examples 30 | ``` 31 | make examples 32 | ``` 33 | 34 | ## Reference Projects 35 | 36 | - [InfiniTensor/RefactorGraph](https://github.com/InfiniTensor/RefactorGraph): 分层解耦的深度学习推理引擎 37 | - [microsoft/nnfusion](https://github.com/microsoft/nnfusion): A flexible and efficient deep neural network (DNN) compiler that generates high-performance executable from a DNN model description. 38 | 39 | ## Reference Papers 40 | 41 | - BOLT: BRIDGING THE GAP BETWEEN AUTO-TUNERS AND HARDWARE-NATIVE PERFORMANCE(MLSys'22) 42 | - Welder: Scheduling Deep Learning Memory Access via Tile-graph(OSDI'23) 43 | - Graphene: An IR for Optimized Tensor Computations on GPUs(ASPLOS'23) 44 | -------------------------------------------------------------------------------- /include/core/graph/gedge.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "core/graph/gnode.hpp" 7 | #include "core/tensor.hpp" 8 | 9 | namespace tilegraph::graph { 10 | class GNode; 11 | class GEdge { 12 | private: 13 | static int64_t edge_count; 14 | 15 | public: 16 | const int64_t index; 17 | std::string name; 18 | std::shared_ptr producer; 19 | std::vector> consumers; 20 | std::shared_ptr tensor; 21 | 22 | public: 23 | GEdge() = delete; 24 | GEdge(std::shared_ptr tensor_value, 25 | std::string name_value = ""); 26 | GEdge(const std::vector &dimension, 27 | std::string name_value = "", std::string tensor_name_value = "", 28 | TensorDatatype dtype = TensorDatatype::FLOAT, 29 | TensorType type = TensorType::VARIABLE); 30 | 31 | ~GEdge() = default; 32 | void setProducer(std::shared_ptr node); 33 | void addConsumer(std::shared_ptr node); 34 | bool earseConsumer(GNode::Pointer node); 35 | std::shared_ptr getProducer(); 36 | std::vector> getConsumers(); 37 | std::shared_ptr getTensor(); 38 | 39 | using Pointer = std::shared_ptr; 40 | }; 41 | } // namespace tilegraph::graph 42 | -------------------------------------------------------------------------------- /src/codegen/compiler.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "codegen/compiler.hpp" 7 | #include "common/common.hpp" 8 | 9 | namespace tilegraph::codegen { 10 | namespace fs = std::filesystem; 11 | 12 | auto Compiler::repo_path() -> fs::path const & { 13 | static std::once_flag pathFlag; 14 | static fs::path path; 15 | std::call_once(pathFlag, [] { 16 | auto codegenDir = getenv("CODEGEN_DIR"); 17 | path = fs::path(codegenDir ? codegenDir : "build") / "code_repo" / 18 | std::to_string(getpid()); 19 | ASSERT(fs::create_directories(path), 20 | "Failed to create directory \"{}\" for code generation", 21 | path.c_str()); 22 | }); 23 | return path; 24 | } 25 | 26 | void *Compiler::compile(const char *dir_, const char *code, 27 | const char *symbol) { 28 | auto [it, ok] = _dirs.try_emplace(dir_, nullptr); 29 | if (!ok) { 30 | return it->second; 31 | } 32 | auto dir = repo_path() / hardware() / dir_; 33 | auto src = dir / fmt::format("lib.{}", extension()); 34 | fs::create_directories(dir); 35 | std::ofstream(src) << code; 36 | return it->second = _compile(src, symbol); 37 | } 38 | void *Compiler::fetch(const char *dir_) { return _dirs.at(dir_); } 39 | 40 | } // namespace tilegraph::codegen -------------------------------------------------------------------------------- /src/core/graph/gedge.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/gedge.hpp" 2 | 3 | #include 4 | 5 | namespace tilegraph::graph { 6 | 7 | int64_t GEdge::edge_count = 0; 8 | 9 | GEdge::GEdge(std::shared_ptr tensor_value, std::string name_value) 10 | : index(edge_count++), name(name_value), tensor(tensor_value) {} 11 | 12 | GEdge::GEdge(const std::vector &dimension, std::string name_value, 13 | std::string tensor_name_value, TensorDatatype dtype, 14 | TensorType type) 15 | : index(edge_count++), 16 | name(name_value), 17 | tensor(std::make_shared(dimension, tensor_name_value, dtype, 18 | type)) {} 19 | 20 | void GEdge::addConsumer(std::shared_ptr node) { 21 | consumers.push_back(node); 22 | } 23 | 24 | bool GEdge::earseConsumer(GNode::Pointer node) { 25 | auto it = std::find(consumers.begin(), consumers.end(), node); 26 | if (it != consumers.end()) { 27 | consumers.erase(it); 28 | return true; 29 | } 30 | return false; 31 | } 32 | 33 | void GEdge::setProducer(std::shared_ptr node) { producer = node; } 34 | 35 | std::shared_ptr GEdge::getProducer() { return producer; } 36 | 37 | std::vector> GEdge::getConsumers() { 38 | return consumers; 39 | } 40 | 41 | std::shared_ptr GEdge::getTensor() { return tensor; } 42 | 43 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/common/error_handler.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tilegraph { 6 | struct UnimplementError : public std::logic_error { 7 | explicit UnimplementError(std::string msg) 8 | : std::logic_error(std::move(msg)) {} 9 | }; 10 | 11 | struct UnreachableError : public std::logic_error { 12 | explicit UnreachableError(std::string msg) 13 | : std::logic_error(std::move(msg)) {} 14 | }; 15 | } // namespace tilegraph 16 | 17 | #define ERROR_MSG(MSG) fmt::format("{} Source {}:{}", (MSG), __FILE__, __LINE__) 18 | #define RUNTIME_ERROR(MSG) throw std::runtime_error(ERROR_MSG(MSG)) 19 | #define OUT_OF_RANGE(MSG, A, B) \ 20 | throw std::out_of_range(ERROR_MSG(fmt::format("{}/{} {}", (A), (B), (MSG)))) 21 | #define TODO(MSG) throw tilegraph::UnimplementError(ERROR_MSG(MSG)) 22 | 23 | #define UNREACHABLEX(T, F, ...) \ 24 | [&]() -> T { \ 25 | throw tilegraph::UnreachableError( \ 26 | ERROR_MSG(fmt::format("Unreachable: " #F, ##__VA_ARGS__))); \ 27 | }() 28 | #define UNREACHABLE() UNREACHABLEX(void, "no message") 29 | 30 | #ifndef DISABLE_ASSERT 31 | #define ASSERT(CONDITION, F, ...) \ 32 | { \ 33 | if (!(CONDITION)) \ 34 | RUNTIME_ERROR(fmt::format("Assertion: " #F, ##__VA_ARGS__)); \ 35 | } 36 | #else 37 | #define ASSERT(CONDITION, F) 38 | #endif -------------------------------------------------------------------------------- /include/kernels/cuda/memory.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "core/type.hpp" 3 | #include 4 | #include 5 | namespace tilegraph::kernel::cuda { 6 | std::string declareMemory(int indient, std::string name, 7 | MemoryType mem_type, DataType data_type, 8 | uint32_t len) { 9 | std::string memory_declare; 10 | 11 | for (int i = 0; i < indient; i++) { 12 | memory_declare += " "; 13 | } 14 | 15 | switch (mem_type) { 16 | case MemoryType::Global: 17 | memory_declare += ""; 18 | break; 19 | case MemoryType::Shared: 20 | memory_declare += "__shared__"; 21 | break; 22 | case MemoryType::Warp: 23 | break; 24 | case MemoryType::Register: 25 | memory_declare += "register"; 26 | break; 27 | default: 28 | memory_declare += ""; 29 | break; 30 | } 31 | memory_declare += " "; 32 | 33 | switch (data_type) { 34 | case DataType::Float: 35 | memory_declare += "float"; 36 | break; 37 | case DataType::Half: 38 | memory_declare += "half"; 39 | break; 40 | default: 41 | fmt::println("[declareMemory] Invalid data type."); 42 | } 43 | memory_declare += " "; 44 | memory_declare += name; 45 | memory_declare += fmt::format("[{}];\n", len); 46 | return memory_declare; 47 | } 48 | } // namespace tilegraph::kernel::cuda 49 | -------------------------------------------------------------------------------- /include/core/graph/graph_base.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | #include "core/graph/gedge.hpp" 7 | #include "core/graph/gnode.hpp" 8 | 9 | namespace tilegraph::graph { 10 | class GraphBase { 11 | static int64_t graph_count; 12 | 13 | public: 14 | std::string name; 15 | const int64_t index; 16 | std::vector> nodes; 17 | std::vector> inputs; 18 | std::vector> outputs; 19 | std::vector> inter_edges; 20 | 21 | GraphBase(std::vector> operators_list = {}, 22 | std::vector> inputs_list = {}, 23 | std::vector> outputs_list = {}, 24 | std::string name_value = ""); 25 | ~GraphBase() = default; 26 | // Connect the graph based nodes and edges. 27 | void connect(); 28 | // Topological sort the graph. 29 | std::vector> topoSort(); 30 | // Earse the node from the graph. 31 | bool earseNode(GNode::Pointer node); 32 | // Add the node to the graph. 33 | bool addNode(GNode::Pointer node); 34 | // Fuse the subgraph into a single node. 35 | bool fuseNode(std::vector> old_nodes, 36 | std::shared_ptr subgraph_node); 37 | 38 | using Pointer = std::shared_ptr; 39 | 40 | private: 41 | // Disconnect the node from the graph. 42 | bool disconect(GNode::Pointer node); 43 | }; 44 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /examples/graph_base.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "core/graph/graph_base.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "core/graph/gnode.hpp" 7 | #include "core/tensor.hpp" 8 | #include "common/common.hpp" 9 | 10 | using namespace tilegraph; 11 | using namespace tilegraph::graph; 12 | 13 | int main() { 14 | auto tensor_a = std::make_shared(Tensor({5120, 5120})); 15 | auto tensor_b = std::make_shared(Tensor({5120, 5120})); 16 | auto edge_a = std::make_shared(GEdge(tensor_a)); 17 | auto edge_b = std::make_shared(GEdge(tensor_b)); 18 | auto tensor_out_add = std::make_shared(Tensor({5120, 5120})); 19 | auto edge_out_add = std::make_shared(GEdge(tensor_out_add)); 20 | auto node_a = std::make_shared( 21 | GNode({edge_a, edge_b}, {edge_out_add}, OperatorType::ADD)); 22 | 23 | auto tensor_out_relu = std::make_shared(Tensor({5120, 5120})); 24 | auto edge_out_relu = std::make_shared(GEdge(tensor_out_relu)); 25 | auto node_b = std::make_shared( 26 | GNode({edge_out_add}, {edge_out_relu}, OperatorType::RELU)); 27 | 28 | auto graph = std::make_shared( 29 | GraphBase({node_a, node_b}, {edge_a, edge_b}, {edge_out_relu})); 30 | 31 | graph->connect(); 32 | auto sorted = graph->topoSort(); 33 | 34 | ASSERT(sorted.size() == 2, "Graph node size is not 2"); 35 | ASSERT(sorted[0]->getOperatorType() == OperatorType::ADD, 36 | "Graph node type is not ADD"); 37 | ASSERT(sorted[1]->getOperatorType() == OperatorType::RELU, 38 | "Graph node type is not RELU"); 39 | fmt::println("Topo sort test passed!"); 40 | return 0; 41 | } -------------------------------------------------------------------------------- /include/kernels/cuda/gemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "kernels/cuda/function.hpp" 3 | #include "kernels/cuda/cuda_var.hpp" 4 | #include "kernels/cuda/cuda_iteration.hpp" 5 | #include 6 | #include 7 | 8 | namespace tilegraph::kernel::cuda { 9 | class CudaGEMMKernel { 10 | public: 11 | // GEMM parametes. 12 | uint32_t M; 13 | uint32_t N; 14 | uint32_t K; 15 | uint32_t ShardedM; 16 | uint32_t ShardedN; 17 | uint32_t ShardedK; 18 | uint32_t WarpM; 19 | uint32_t WarpN; 20 | uint32_t WarpK; 21 | uint32_t WmmaM; 22 | uint32_t WmmaN; 23 | uint32_t WmmaK; 24 | 25 | bool transpose_a; 26 | bool transpose_b; 27 | 28 | MemoryType memory_level; 29 | MemoryType output_level; 30 | 31 | // Variables 32 | std::set> inputs; 33 | std::set> outputs; 34 | std::set> vars; 35 | // Functions 36 | std::set> functions; 37 | std::set>> iterations; 38 | // Function Unit 39 | std::unique_ptr function_unit; 40 | 41 | CudaGEMMKernel(uint32_t M, uint32_t N, uint32_t K, uint32_t ShardedM, 42 | uint32_t ShardedN, uint32_t ShardedK, uint32_t WarpM, 43 | uint32_t WarpN, uint32_t WarpK, uint32_t WmmaM, 44 | uint32_t WmmaN, uint32_t WmmaK, bool tramspose_a, 45 | bool transpose_b, MemoryType memory_level, 46 | MemoryType output_level); 47 | std::string genTCGEMM(std::string name); 48 | }; 49 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /src/kernels/iteration.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/iteration.hpp" 2 | #include "kernels/cuda/cuda_var.hpp" 3 | #include 4 | 5 | namespace tilegraph::kernel { 6 | template <> 7 | Iteration::Iteration( 8 | std::unique_ptr iter_var, 9 | std::variant> step, 10 | std::variant> start, 11 | std::variant> end) 12 | : iter_var(std::move(iter_var)), step(step), start(start), end(end) {} 13 | 14 | template <> 15 | std::string Iteration::genIter(int indient) { 16 | std::string iter; 17 | for (int i = 0; i < indient; i++) { 18 | iter += " "; 19 | } 20 | auto start_var = 21 | std::get_if(&start) == nullptr 22 | ? std::get>(start)->name 23 | : std::to_string(std::get(start)); 24 | auto end_var = std::get_if(&end) == nullptr 25 | ? std::get>(end)->name 26 | : std::to_string(std::get(end)); 27 | auto step_var = 28 | std::get_if(&step) == nullptr 29 | ? std::get>(step)->name 30 | : std::to_string(std::get(step)); 31 | iter += fmt::format("for (int {} = {}; {} < {}; {} += {}) {{\n", 32 | iter_var->name, start_var, iter_var->name, end_var, 33 | iter_var->name, step_var); 34 | return iter; 35 | } 36 | 37 | template <> 38 | std::string Iteration::getIterVar() { 39 | return iter_var->name; 40 | } 41 | 42 | } // namespace tilegraph::kernel -------------------------------------------------------------------------------- /include/core/graph/gnode.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "core/type.hpp" 8 | #include "core/operators/operator.hpp" 9 | #include "core/graph/subgraph.hpp" 10 | 11 | namespace tilegraph::graph { 12 | using namespace tilegraph::operators; 13 | class GEdge; 14 | class GNode { 15 | public: 16 | struct GNodeError { 17 | enum class Kind { InferError }; 18 | 19 | Kind kind; 20 | }; 21 | using Pointer = std::shared_ptr; 22 | 23 | std::string name; 24 | const int64_t index; 25 | int64_t in_degree; 26 | OperatorType op_type; 27 | // virtual class. 28 | Operator::OpBox op; 29 | std::vector> inputs; 30 | std::vector> outputs; 31 | std::vector> predecessors; 32 | std::vector> successors; 33 | 34 | public: 35 | GNode(std::vector> inputs_list = {}, 36 | std::vector> outputs_list = {}, 37 | OperatorType op_type = OperatorType::ADD, 38 | Operator::OpBox op = nullptr, std::string name_value = ""); 39 | ~GNode() = default; 40 | int64_t getIndex(); 41 | std::shared_ptr getOutput(int64_t index); 42 | std::vector> getInputs(); 43 | std::vector> getOutputs(); 44 | OperatorType getOperatorType(); 45 | 46 | bool earseSuccessor(Pointer node); 47 | bool earsePredecessor(Pointer node); 48 | 49 | Result>, GNodeError> inferShape(); 50 | 51 | private: 52 | static int64_t node_count; 53 | }; 54 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/core/type.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace tilegraph { 6 | // Cacheline 7 | using Cacheline = std::tuple; 8 | // MemoryDispatch 9 | enum class MemoryDispatch { RANDOM, FIFO, LRU, LFU }; 10 | enum class TensorDatatype { HALF, FLOAT, DOUBLE, INT32 }; 11 | enum class TensorLayout { NCHW, NHWC, ARRAY }; 12 | enum class TensorType { CONST, VARIABLE }; 13 | // OperatorType 14 | enum class OperatorType { 15 | // Binary 16 | ADD, 17 | SUB, 18 | MUL, 19 | DIV, 20 | EQ, 21 | GE, 22 | GT, 23 | LE, 24 | LT, 25 | NE, 26 | AND, 27 | OR, 28 | XOR, 29 | FLOORMOD, 30 | FLOORDIV, 31 | 32 | // Unary 33 | SIGMOID, 34 | RELU, 35 | SQRT, 36 | RSQRT, 37 | RECIP, 38 | SIN, 39 | COS, 40 | TANH, 41 | GELU, 42 | // Memory 43 | LOAD, 44 | ALLOCATE, 45 | STORE, 46 | FREE, 47 | // Sync 48 | SYNC, 49 | // GEMM 50 | GEMM, 51 | SOFTMAX, 52 | 53 | GEMM_RELU, 54 | 55 | FUSED 56 | }; 57 | 58 | // enum class MemoryType { Register, Warp, Shared, Global }; 59 | enum class MemoryType { Register, Warp, Shared, Global }; 60 | 61 | enum class FuncType { Device, Global, Host }; 62 | 63 | enum class DataType { Void, Half, Float, Int32 }; 64 | 65 | // KernelType 66 | enum class KernelType { 67 | BINARY, 68 | UNARY, 69 | REDUCE, 70 | BROADCAST, 71 | MEMORY, 72 | FMA, 73 | SYNC 74 | }; 75 | 76 | // CacheType 77 | enum class CacheType { CACHE, LDRAM }; 78 | 79 | // CacheHitLocation 80 | enum class CacheHitLocation { CACHE, LDRAM, NOT_FOUND, ERROR }; 81 | 82 | std::string toString(OperatorType op_type); 83 | 84 | } // namespace tilegraph -------------------------------------------------------------------------------- /tests/graph/test_subgraph_match.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/graph/graph_base.hpp" 4 | #include "core/graph/gnode.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "optimizer/fusion/graph_fusion_base.hpp" 7 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 8 | 9 | #include 10 | #include 11 | 12 | using namespace tilegraph; 13 | using namespace tilegraph::fusion; 14 | using namespace tilegraph::graph; 15 | using namespace tilegraph::fusion::subgraph; 16 | 17 | TEST(SubGraphFuse, gemm_relu) { 18 | // Relu -> GEMM -> Relu -> Softmax 19 | fmtlog::setLogLevel(fmtlog::LogLevel::INF); 20 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 21 | auto edge_out_relu1 = std::make_shared(GEdge({5120, 5120})); 22 | auto relu1 = std::make_shared( 23 | GNode({edge_a}, {edge_out_relu1}, {OperatorType::RELU})); 24 | 25 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 26 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 27 | auto gemm = std::make_shared( 28 | GNode({edge_out_relu1, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 29 | 30 | auto edge_out_relu2 = std::make_shared(GEdge({5120, 5120})); 31 | auto relu2 = std::make_shared( 32 | GNode({edge_out_gemm}, {edge_out_relu2}, {OperatorType::RELU})); 33 | 34 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 35 | auto softmax = std::make_shared( 36 | GNode({edge_out_relu2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 37 | 38 | auto graph = std::make_shared(Graph( 39 | {relu1, gemm, relu2, softmax}, {edge_a, edge_b}, {edge_out_softmax})); 40 | graph->connect(); 41 | 42 | auto gemm_relu_fusion = std::make_shared(graph); 43 | 44 | gemm_relu_fusion->create_subgraphs(); 45 | gemm_relu_fusion->match_and_fuse_subgraph(); 46 | 47 | // Test graph fusion 48 | auto ordered_ops = graph->topoSort(); 49 | EXPECT_EQ(ordered_ops.size(), 3); 50 | 51 | EXPECT_EQ(ordered_ops[0]->getOperatorType(), OperatorType::RELU); 52 | EXPECT_EQ(ordered_ops[1]->getOperatorType(), OperatorType::GEMM_RELU); 53 | EXPECT_EQ(ordered_ops[2]->getOperatorType(), OperatorType::SOFTMAX); 54 | } -------------------------------------------------------------------------------- /src/optimizer/fusion/subgraph_fusion/gemm_relu_fusion.cpp: -------------------------------------------------------------------------------- 1 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 2 | #include "core/graph/gnode.hpp" 3 | #include "core/type.hpp" 4 | 5 | #include 6 | 7 | namespace tilegraph::fusion::subgraph { 8 | 9 | GemmReluFusion::GemmReluFusion(std::shared_ptr graph) 10 | : SubgraphFusionBase(graph) {} 11 | 12 | void GemmReluFusion::create_subgraphs() { 13 | using namespace graph; 14 | auto check_root = [](std::shared_ptr gnode) -> bool { 15 | if (gnode->getOperatorType() != OperatorType::GEMM) { 16 | return false; 17 | } 18 | return true; 19 | }; 20 | 21 | SubGraph::Pointer s_gemm_relu = std::make_shared(); 22 | s_gemm_relu->name = "GEMM_RELU"; 23 | s_gemm_relu->check_starting_node = check_root; 24 | 25 | { 26 | Pattern::Pointer p_gemm_relu = std::make_shared(); 27 | std::vector ops{OperatorType::GEMM, 28 | OperatorType::RELU}; 29 | p_gemm_relu->descriptions.push_back(std::make_pair(ops, 1)); 30 | p_gemm_relu->reverse_order = false; 31 | auto check_gemm_relu = [](const PatternRecord& pr) -> bool { 32 | return true; 33 | }; 34 | p_gemm_relu->check.push_back(check_gemm_relu); 35 | 36 | s_gemm_relu->patterns.push_back(p_gemm_relu); 37 | } 38 | 39 | subgraphs.push_back(s_gemm_relu); 40 | } 41 | 42 | Result GemmReluFusion::fuse_subgraph( 43 | graph::SubGraphRecord::Pointer subgraph_record) { 44 | auto pr_gemm_relu = subgraph_record->pattern_records[0]; 45 | auto gemm = pr_gemm_relu->nodes[0]; 46 | auto relu = pr_gemm_relu->nodes[1]; 47 | 48 | auto fused_node = std::make_shared(graph::GNode( 49 | gemm->getInputs(), relu->getOutputs(), OperatorType::GEMM_RELU)); 50 | if (!graph->fuseNode({gemm, relu}, fused_node)) { 51 | loge("Failed to fuse subgraph"); 52 | return Err(SubgraphFusionBase::FusionError{ 53 | SubgraphFusionBase::FusionError::Kind::FailedToFuseSubgraph}); 54 | } 55 | return Ok(); 56 | } 57 | } // namespace tilegraph::fusion::subgraph -------------------------------------------------------------------------------- /examples/fuse.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/operators/elementwise.hpp" 4 | #include "core/operators/gemm.hpp" 5 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 6 | #include "common/common.hpp" 7 | #include 8 | #include 9 | 10 | using namespace tilegraph; 11 | using namespace tilegraph::graph; 12 | using namespace tilegraph::fusion; 13 | using namespace tilegraph::fusion::subgraph; 14 | 15 | int main() { 16 | fmtlog::setLogLevel(fmtlog::LogLevel::INF); 17 | // Relu -> GEMM -> Relu -> Softmax 18 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 19 | auto edge_out_relu1 = std::make_shared(GEdge({5120, 5120})); 20 | auto relu1 = std::make_shared( 21 | GNode({edge_a}, {edge_out_relu1}, {OperatorType::RELU})); 22 | 23 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 24 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 25 | auto gemm = std::make_shared( 26 | GNode({edge_out_relu1, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 27 | 28 | auto edge_out_relu2 = std::make_shared(GEdge({5120, 5120})); 29 | auto relu2 = std::make_shared( 30 | GNode({edge_out_gemm}, {edge_out_relu2}, {OperatorType::RELU})); 31 | 32 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 33 | auto softmax = std::make_shared( 34 | GNode({edge_out_relu2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 35 | 36 | auto graph = std::make_shared(Graph( 37 | {relu1, gemm, relu2, softmax}, {edge_a, edge_b}, {edge_out_softmax})); 38 | graph->connect(); 39 | 40 | auto gemm_relu_fusion = std::make_shared(graph); 41 | 42 | gemm_relu_fusion->create_subgraphs(); 43 | gemm_relu_fusion->match_and_fuse_subgraph(); 44 | 45 | auto ordered_ops = graph->topoSort(); 46 | 47 | ASSERT(ordered_ops.size() == 3, "Graph node size is not 3"); 48 | ASSERT(ordered_ops[0]->getOperatorType() == OperatorType::RELU, 49 | "Graph node type is not RELU"); 50 | ASSERT(ordered_ops[1]->getOperatorType() == OperatorType::GEMM_RELU, 51 | "Graph node type is not GEMM_RELU"); 52 | ASSERT(ordered_ops[2]->getOperatorType() == OperatorType::SOFTMAX, 53 | "Graph node type is not SOFTMAX"); 54 | 55 | fmt::println("Fuse test passed!"); 56 | } -------------------------------------------------------------------------------- /examples/subgraph_match.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/graph/graph_base.hpp" 4 | #include "core/graph/gnode.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 7 | #include "common/common.hpp" 8 | 9 | #include 10 | 11 | using namespace tilegraph; 12 | using namespace tilegraph::fusion; 13 | using namespace tilegraph::graph; 14 | using namespace tilegraph::fusion::subgraph; 15 | 16 | int main() { 17 | fmtlog::setLogLevel(fmtlog::LogLevel::INF); 18 | // Relu -> GEMM -> Relu -> Softmax 19 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 20 | auto edge_out_relu1 = std::make_shared(GEdge({5120, 5120})); 21 | auto relu1 = std::make_shared( 22 | GNode({edge_a}, {edge_out_relu1}, {OperatorType::RELU})); 23 | 24 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 25 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 26 | auto gemm = std::make_shared( 27 | GNode({edge_out_relu1, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 28 | 29 | auto edge_out_relu2 = std::make_shared(GEdge({5120, 5120})); 30 | auto relu2 = std::make_shared( 31 | GNode({edge_out_gemm}, {edge_out_relu2}, {OperatorType::RELU})); 32 | 33 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 34 | auto softmax = std::make_shared( 35 | GNode({edge_out_relu2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 36 | 37 | auto graph = std::make_shared(Graph( 38 | {relu1, gemm, relu2, softmax}, {edge_a, edge_b}, {edge_out_softmax})); 39 | graph->connect(); 40 | 41 | auto gemm_relu_fusion = std::make_shared(graph); 42 | 43 | gemm_relu_fusion->create_subgraphs(); 44 | gemm_relu_fusion->match_and_fuse_subgraph(); 45 | 46 | auto ordered_ops = graph->topoSort(); 47 | 48 | ASSERT(ordered_ops.size() == 3, "Graph node size is not 3"); 49 | ASSERT(ordered_ops[0]->getOperatorType() == OperatorType::RELU, 50 | "Graph node type is not RELU"); 51 | ASSERT(ordered_ops[1]->getOperatorType() == OperatorType::GEMM_RELU, 52 | "Graph node type is not GEMM_RELU"); 53 | ASSERT(ordered_ops[2]->getOperatorType() == OperatorType::SOFTMAX, 54 | "Graph node type is not SOFTMAX"); 55 | 56 | fmt::println("SubGraph Match test passed!"); 57 | 58 | return 0; 59 | } -------------------------------------------------------------------------------- /tests/fusion/test_persisten_kernel_fusion.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/graph/graph_base.hpp" 4 | #include "core/graph/gnode.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 7 | #include "optimizer/fusion/persistent_kernel_fusion.hpp" 8 | #include "common/common.hpp" 9 | 10 | #include 11 | #include 12 | 13 | using namespace tilegraph; 14 | using namespace tilegraph::fusion; 15 | using namespace tilegraph::graph; 16 | using namespace tilegraph::fusion::subgraph; 17 | 18 | TEST(PersistentKernelFusion, persistent_kernel_fusion_1) { 19 | // GEMM -> GEMM -> SOFTMAX -> GEMM 20 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 21 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 22 | auto edge_c = std::make_shared(GEdge({5120, 5120})); 23 | auto edge_d = std::make_shared(GEdge({5120, 5120})); 24 | 25 | auto edge_out_gemm1 = std::make_shared(GEdge({5120, 5120})); 26 | auto edge_out_gemm2 = std::make_shared(GEdge({5120, 5120})); 27 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 28 | auto edge_out_gemm3 = std::make_shared(GEdge({5120, 5120})); 29 | auto gemm1 = std::make_shared( 30 | GNode({edge_a, edge_b}, {edge_out_gemm1}, {OperatorType::GEMM})); 31 | auto gemm2 = std::make_shared(GNode( 32 | {edge_out_gemm1, edge_c}, {edge_out_gemm2}, {OperatorType::GEMM})); 33 | auto softmax = std::make_shared( 34 | GNode({edge_out_gemm2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 35 | auto gemm3 = std::make_shared(GNode( 36 | {edge_out_softmax, edge_d}, {edge_out_gemm3}, {OperatorType::GEMM})); 37 | 38 | auto graph = std::make_shared(Graph({gemm1, gemm2, softmax, gemm3}, 39 | {edge_a, edge_b, edge_c, edge_d}, 40 | {edge_out_gemm3})); 41 | 42 | graph->connect(); 43 | 44 | auto persistent_kernel_fusion = std::make_shared(); 45 | persistent_kernel_fusion->fusion(graph); 46 | 47 | auto ordered_ops = graph->topoSort(); 48 | 49 | ASSERT_EQ(ordered_ops.size(), 3); 50 | ASSERT_EQ(ordered_ops[0]->getOperatorType(), OperatorType::FUSED); 51 | ASSERT_EQ(ordered_ops[1]->getOperatorType(), OperatorType::SOFTMAX); 52 | ASSERT_EQ(ordered_ops[2]->getOperatorType(), OperatorType::GEMM); 53 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.17) 2 | project(TileGraph C CXX) 3 | 4 | option(USE_CUDA "Support CUDA GPU" OFF) 5 | option(USE_BANG "Support BANG MLU" OFF) 6 | 7 | option(BUILD_TEST "Build test code" ON) 8 | option(BUILD_ASAN "Build code whith ASAN" OFF) 9 | 10 | set(CMAKE_CXX_STANDARD 17) 11 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 12 | set(CMAKE_CXX_EXTENSIONS OFF) 13 | 14 | ################################################################################ 15 | # ASAN Check 16 | ################################################################################ 17 | if(BUILD_ASAN) 18 | set(CMAKE_ASAN_FLAGS "-fsanitize=address -fno-omit-frame-pointer") 19 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_ASAN_FLAGS}") 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CMAKE_ASAN_FLAGS}") 21 | endif() 22 | 23 | include_directories(include) 24 | include_directories(3rd-party/fmt/include) 25 | include_directories(3rd-party/result) 26 | include_directories(3rd-party/fmtlog) 27 | add_definitions(-D FMTLOG_HEADER_ONLY) 28 | add_definitions(-D FMT_HEADER_ONLY) 29 | 30 | if(BUILD_TEST) 31 | include_directories(3rd-party/googletest/googletest/include) 32 | add_subdirectory(3rd-party/googletest) 33 | endif() 34 | 35 | add_compile_options(-Wall) 36 | add_compile_options(-Werror) 37 | 38 | enable_testing() 39 | 40 | 41 | file(GLOB_RECURSE SOURCE 42 | src/core/*.cpp 43 | src/core/graph/*.cpp 44 | src/core/operators/*.cpp 45 | src/optimizer/fusion/*.cpp 46 | srs/optimizer/fusion/subgraph_fusion/*.cpp 47 | src/optimizer/tilling/*.cpp 48 | src/kernels/*.cpp 49 | src/kernel/cuda/*.cpp 50 | src/ir/*.cpp 51 | src/codegen/*.cpp 52 | ) 53 | 54 | add_library(tilegraph SHARED ${SOURCE}) 55 | 56 | add_compile_options(-Wall) 57 | # add_compile_options(-Werror) 58 | 59 | target_link_directories(tilegraph PRIVATE 3rd-party/fmt/build) 60 | target_link_libraries(tilegraph PRIVATE fmt) 61 | 62 | function(build_test files) 63 | # Non-recursive glob for skip failed tests 64 | file(GLOB TEST_SOURCES ${files}) 65 | foreach(testsourcefile ${TEST_SOURCES}) 66 | get_filename_component(testname ${testsourcefile} NAME_WE) 67 | add_executable(${testname} ${testsourcefile}) 68 | target_link_libraries(${testname} tilegraph GTest::gtest_main) 69 | add_test(NAME ${testname} COMMAND ${testname}) 70 | endforeach(testsourcefile ${TEST_SOURCES}) 71 | endfunction() 72 | 73 | if(BUILD_TEST) 74 | build_test(tests/operators/*.cpp) 75 | build_test(tests/graph/*.cpp) 76 | build_test(tests/fusion/*.cpp) 77 | if(USE_CUDA) 78 | build_test(tests/codegen/*.cpp) 79 | endif() 80 | endif() 81 | -------------------------------------------------------------------------------- /examples/persistent_kernel_fusion.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/graph/graph_base.hpp" 4 | #include "core/graph/gnode.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 7 | #include "optimizer/fusion/persistent_kernel_fusion.hpp" 8 | #include "common/common.hpp" 9 | 10 | #include 11 | 12 | using namespace tilegraph; 13 | using namespace tilegraph::fusion; 14 | using namespace tilegraph::graph; 15 | using namespace tilegraph::fusion::subgraph; 16 | 17 | int main() { 18 | // GEMM -> GEMM -> SOFTMAX -> GEMM 19 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 20 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 21 | auto edge_c = std::make_shared(GEdge({5120, 5120})); 22 | auto edge_d = std::make_shared(GEdge({5120, 5120})); 23 | 24 | auto edge_out_gemm1 = std::make_shared(GEdge({5120, 5120})); 25 | auto edge_out_gemm2 = std::make_shared(GEdge({5120, 5120})); 26 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 27 | auto edge_out_gemm3 = std::make_shared(GEdge({5120, 5120})); 28 | auto gemm1 = std::make_shared( 29 | GNode({edge_a, edge_b}, {edge_out_gemm1}, {OperatorType::GEMM})); 30 | auto gemm2 = std::make_shared(GNode( 31 | {edge_out_gemm1, edge_c}, {edge_out_gemm2}, {OperatorType::GEMM})); 32 | auto softmax = std::make_shared( 33 | GNode({edge_out_gemm2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 34 | auto gemm3 = std::make_shared(GNode( 35 | {edge_out_softmax, edge_d}, {edge_out_gemm3}, {OperatorType::GEMM})); 36 | 37 | auto graph = std::make_shared(Graph({gemm1, gemm2, softmax, gemm3}, 38 | {edge_a, edge_b, edge_c, edge_d}, 39 | {edge_out_gemm3})); 40 | 41 | graph->connect(); 42 | 43 | auto persistent_kernel_fusion = std::make_shared(); 44 | persistent_kernel_fusion->fusion(graph); 45 | 46 | auto ordered_ops = graph->topoSort(); 47 | 48 | ASSERT(ordered_ops.size() == 3, "Graph node size unmatched"); 49 | ASSERT(ordered_ops[0]->getOperatorType() == OperatorType::FUSED, 50 | "Graph node type unmatched"); 51 | ASSERT(ordered_ops[1]->getOperatorType() == OperatorType::SOFTMAX, 52 | "Graph node type unmatched"); 53 | ASSERT(ordered_ops[2]->getOperatorType() == OperatorType::GEMM, 54 | "Graph node type unmatched"); 55 | 56 | fmt::println("Persistent Kernel Fusion test passed!"); 57 | } -------------------------------------------------------------------------------- /tests/graph/test_toposort.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/graph/graph_base.hpp" 3 | #include "core/graph/gnode.hpp" 4 | #include "core/graph/gedge.hpp" 5 | #include "core/tensor.hpp" 6 | #include "core/type.hpp" 7 | #include "core/operators/elementwise.hpp" 8 | 9 | #include 10 | 11 | #include 12 | 13 | using namespace tilegraph; 14 | using namespace tilegraph::graph; 15 | 16 | TEST(Graph, graph_base_toposort) { 17 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 18 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 19 | auto edge_out_add = std::make_shared(GEdge({5120, 5120})); 20 | auto node_a = std::make_shared( 21 | GNode({edge_a, edge_b}, {edge_out_add}, OperatorType::ADD)); 22 | 23 | auto edge_out_relu = std::make_shared(GEdge({5120, 5120})); 24 | auto node_b = std::make_shared( 25 | GNode({edge_out_add}, {edge_out_relu}, OperatorType::RELU)); 26 | 27 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 28 | auto node_c = std::make_shared( 29 | GNode({edge_out_relu}, {edge_out_softmax}, OperatorType::SOFTMAX)); 30 | 31 | auto graph = std::make_shared(GraphBase( 32 | {node_a, node_b, node_c}, {edge_a, edge_b}, {edge_out_softmax})); 33 | 34 | graph->connect(); 35 | auto sorted = graph->topoSort(); 36 | EXPECT_EQ(sorted[0].get()->getOperatorType(), OperatorType::ADD); 37 | EXPECT_EQ(sorted[1].get()->getOperatorType(), OperatorType::RELU); 38 | EXPECT_EQ(sorted[2].get()->getOperatorType(), OperatorType::SOFTMAX); 39 | } 40 | 41 | TEST(Graph, graph_base_toposort_2) { 42 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 43 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 44 | auto node_a = 45 | std::make_shared(GNode({edge_a, edge_b}, {}, OperatorType::ADD)); 46 | ASSERT_EQ(node_a->getOutputs().size(), 1); 47 | 48 | auto node_b = std::make_shared( 49 | GNode(node_a->getOutputs(), {}, OperatorType::RELU)); 50 | ASSERT_EQ(node_b->getOutputs().size(), 1); 51 | 52 | auto node_c = std::make_shared( 53 | GNode(node_b->getOutputs(), {}, OperatorType::SOFTMAX)); 54 | ASSERT_EQ(node_c->getOutputs().size(), 1); 55 | 56 | auto graph = std::make_shared(GraphBase( 57 | {node_a, node_b, node_c}, {edge_a, edge_b}, node_c->getOutputs())); 58 | 59 | graph->connect(); 60 | auto sorted = graph->topoSort(); 61 | EXPECT_EQ(sorted[0].get()->getOperatorType(), OperatorType::ADD); 62 | EXPECT_EQ(sorted[1].get()->getOperatorType(), OperatorType::RELU); 63 | EXPECT_EQ(sorted[2].get()->getOperatorType(), OperatorType::SOFTMAX); 64 | } 65 | -------------------------------------------------------------------------------- /src/kernels/cuda/function.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/function.hpp" 2 | #include 3 | 4 | namespace tilegraph::kernel::cuda { 5 | std::string CudaFunctionKernelUnit::declareGlobal( 6 | std::string name, 7 | std::vector> arguments) { 8 | std::string function = fmt::format("__global__ void {} (", name); 9 | for (auto arg : arguments) { 10 | function += fmt::format("{} {}", arg.first, arg.second); 11 | } 12 | function += ")"; 13 | return function; 14 | } 15 | 16 | CudaFunction::CudaFunction(std::string name, FuncType func_type, 17 | DataType return_type) 18 | : name(name), func_type(func_type), return_type(return_type) {} 19 | 20 | std::string CudaFunction::declareFunction() { 21 | std::string func_type_str; 22 | std::string return_type_str; 23 | switch (func_type) { 24 | case FuncType::Global: 25 | func_type_str = "__global__"; 26 | break; 27 | case FuncType::Device: 28 | func_type_str = "__device__"; 29 | break; 30 | case FuncType::Host: 31 | func_type_str = "__host__"; 32 | break; 33 | default: 34 | fmt::println( 35 | "[CudaFunction::declareFunction()] Invalid func_type"); 36 | func_type_str = "__host__"; 37 | } 38 | 39 | switch (return_type) { 40 | case DataType::Void: 41 | return_type_str = "void"; 42 | break; 43 | default: 44 | // fmt::print("Invalid return type: {}\n", return_type); 45 | fmt::println( 46 | "[CudaFunction::declareFunction()] Invalid return type"); 47 | } 48 | 49 | std::string function = 50 | fmt::format("{} {} {}(", func_type_str, return_type_str, name); 51 | 52 | for (auto arg : arguments) { 53 | switch (arg.first) { 54 | case DataType::Float: 55 | function += fmt::format("float* {} ", arg.second); 56 | break; 57 | case DataType::Half: 58 | function += fmt::format("half* {} ", arg.second); 59 | break; 60 | default: 61 | fmt::println( 62 | "[CudaFunction::declareFunction()] Invalid argument " 63 | "type"); 64 | function += fmt::format("int* {} ", arg.second); 65 | } 66 | } 67 | function += ")"; 68 | function += ";\n"; 69 | return function; 70 | } 71 | 72 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "unordered_set": "cpp", 4 | "unordered_map": "cpp", 5 | "memory": "cpp", 6 | "__nullptr": "cpp", 7 | "any": "cpp", 8 | "exception": "cpp", 9 | "initializer_list": "cpp", 10 | "new": "cpp", 11 | "optional": "cpp", 12 | "stdexcept": "cpp", 13 | "type_traits": "cpp", 14 | "typeinfo": "cpp", 15 | "variant": "cpp", 16 | "array": "cpp", 17 | "atomic": "cpp", 18 | "bit": "cpp", 19 | "*.tcc": "cpp", 20 | "bitset": "cpp", 21 | "cctype": "cpp", 22 | "cfenv": "cpp", 23 | "chrono": "cpp", 24 | "cinttypes": "cpp", 25 | "clocale": "cpp", 26 | "cmath": "cpp", 27 | "codecvt": "cpp", 28 | "compare": "cpp", 29 | "complex": "cpp", 30 | "concepts": "cpp", 31 | "condition_variable": "cpp", 32 | "csignal": "cpp", 33 | "cstdarg": "cpp", 34 | "cstddef": "cpp", 35 | "cstdint": "cpp", 36 | "cstdio": "cpp", 37 | "cstdlib": "cpp", 38 | "cstring": "cpp", 39 | "ctime": "cpp", 40 | "cwchar": "cpp", 41 | "cwctype": "cpp", 42 | "deque": "cpp", 43 | "forward_list": "cpp", 44 | "list": "cpp", 45 | "map": "cpp", 46 | "set": "cpp", 47 | "string": "cpp", 48 | "vector": "cpp", 49 | "algorithm": "cpp", 50 | "functional": "cpp", 51 | "iterator": "cpp", 52 | "memory_resource": "cpp", 53 | "numeric": "cpp", 54 | "random": "cpp", 55 | "ratio": "cpp", 56 | "string_view": "cpp", 57 | "system_error": "cpp", 58 | "tuple": "cpp", 59 | "utility": "cpp", 60 | "fstream": "cpp", 61 | "iomanip": "cpp", 62 | "iosfwd": "cpp", 63 | "iostream": "cpp", 64 | "istream": "cpp", 65 | "limits": "cpp", 66 | "mutex": "cpp", 67 | "numbers": "cpp", 68 | "ostream": "cpp", 69 | "semaphore": "cpp", 70 | "span": "cpp", 71 | "sstream": "cpp", 72 | "stop_token": "cpp", 73 | "streambuf": "cpp", 74 | "thread": "cpp", 75 | "typeindex": "cpp", 76 | "__functional_03": "cpp", 77 | "__split_buffer": "cpp", 78 | "scoped_allocator": "cpp", 79 | "hash_map": "cpp", 80 | "hash_set": "cpp", 81 | "barrier": "cpp", 82 | "charconv": "cpp", 83 | "coroutine": "cpp", 84 | "csetjmp": "cpp", 85 | "cuchar": "cpp", 86 | "regex": "cpp", 87 | "source_location": "cpp", 88 | "future": "cpp", 89 | "latch": "cpp", 90 | "ranges": "cpp", 91 | "shared_mutex": "cpp", 92 | "syncstream": "cpp", 93 | "valarray": "cpp", 94 | "__bit_reference": "cpp", 95 | "__hash_table": "cpp", 96 | "__tree": "cpp", 97 | "queue": "cpp", 98 | "stack": "cpp", 99 | "__node_handle": "cpp" 100 | } 101 | } -------------------------------------------------------------------------------- /tests/fusion/test_gemm_fusion.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph.hpp" 2 | #include "core/type.hpp" 3 | #include "core/graph/graph_base.hpp" 4 | #include "core/graph/gnode.hpp" 5 | #include "core/graph/gedge.hpp" 6 | #include "optimizer/fusion/subgraph_fusion/gemm_relu_fusion.hpp" 7 | 8 | #include 9 | 10 | using namespace tilegraph; 11 | using namespace tilegraph::graph; 12 | using namespace tilegraph::fusion; 13 | using namespace tilegraph::fusion::subgraph; 14 | 15 | TEST(GemmFusion, gemm_relu) { 16 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 17 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 18 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 19 | auto gemm = std::make_shared( 20 | GNode({edge_a, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 21 | 22 | auto edge_out_relu = std::make_shared(GEdge({5120, 5120})); 23 | auto relu = std::make_shared( 24 | GNode({edge_out_gemm}, {edge_out_relu}, {OperatorType::RELU})); 25 | 26 | auto graph = std::make_shared( 27 | Graph({gemm, relu}, {edge_a, edge_b}, {edge_out_relu})); 28 | graph->connect(); 29 | 30 | auto gemm_relu_fusion = std::make_shared(graph); 31 | 32 | gemm_relu_fusion->create_subgraphs(); 33 | gemm_relu_fusion->match_and_fuse_subgraph(); 34 | 35 | auto ordered_ops = graph->topoSort(); 36 | EXPECT_EQ(ordered_ops.size(), 1); 37 | EXPECT_EQ(ordered_ops[0]->getOperatorType(), OperatorType::GEMM_RELU); 38 | } 39 | 40 | TEST(GemmFusion, gemm_relu_softmax) { 41 | // GEMM -> Relu -> Softmax 42 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 43 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 44 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 45 | auto gemm = std::make_shared( 46 | GNode({edge_a, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 47 | 48 | auto edge_out_relu = std::make_shared(GEdge({5120, 5120})); 49 | auto relu = std::make_shared( 50 | GNode({edge_out_gemm}, {edge_out_relu}, {OperatorType::RELU})); 51 | 52 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 53 | auto softmax = std::make_shared( 54 | GNode({edge_out_relu}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 55 | 56 | auto graph = std::make_shared( 57 | Graph({gemm, relu, softmax}, {edge_a, edge_b}, {edge_out_softmax})); 58 | graph->connect(); 59 | 60 | auto gemm_relu_fusion = std::make_shared(graph); 61 | 62 | gemm_relu_fusion->create_subgraphs(); 63 | gemm_relu_fusion->match_and_fuse_subgraph(); 64 | 65 | auto ordered_ops = graph->topoSort(); 66 | EXPECT_EQ(ordered_ops.size(), 2); 67 | EXPECT_EQ(ordered_ops[0]->getOperatorType(), OperatorType::GEMM_RELU); 68 | EXPECT_EQ(ordered_ops[1]->getOperatorType(), OperatorType::SOFTMAX); 69 | } 70 | 71 | TEST(Fusion, relu_gemm_relu_softmax) { 72 | // Relu -> GEMM -> Relu -> Softmax 73 | auto edge_a = std::make_shared(GEdge({5120, 5120})); 74 | auto edge_out_relu1 = std::make_shared(GEdge({5120, 5120})); 75 | auto relu1 = std::make_shared( 76 | GNode({edge_a}, {edge_out_relu1}, {OperatorType::RELU})); 77 | 78 | auto edge_b = std::make_shared(GEdge({5120, 5120})); 79 | auto edge_out_gemm = std::make_shared(GEdge({5120, 5120})); 80 | auto gemm = std::make_shared( 81 | GNode({edge_out_relu1, edge_b}, {edge_out_gemm}, {OperatorType::GEMM})); 82 | 83 | auto edge_out_relu2 = std::make_shared(GEdge({5120, 5120})); 84 | auto relu2 = std::make_shared( 85 | GNode({edge_out_gemm}, {edge_out_relu2}, {OperatorType::RELU})); 86 | 87 | auto edge_out_softmax = std::make_shared(GEdge({5120, 5120})); 88 | auto softmax = std::make_shared( 89 | GNode({edge_out_relu2}, {edge_out_softmax}, {OperatorType::SOFTMAX})); 90 | 91 | auto graph = std::make_shared(Graph( 92 | {relu1, gemm, relu2, softmax}, {edge_a, edge_b}, {edge_out_softmax})); 93 | graph->connect(); 94 | 95 | auto gemm_relu_fusion = std::make_shared(graph); 96 | 97 | gemm_relu_fusion->create_subgraphs(); 98 | gemm_relu_fusion->match_and_fuse_subgraph(); 99 | 100 | auto ordered_ops = graph->topoSort(); 101 | ASSERT_EQ(ordered_ops.size(), 3); 102 | ASSERT_EQ(ordered_ops[0]->getOperatorType(), OperatorType::RELU); 103 | ASSERT_EQ(ordered_ops[1]->getOperatorType(), OperatorType::GEMM_RELU); 104 | EXPECT_EQ(ordered_ops[2]->getOperatorType(), OperatorType::SOFTMAX); 105 | } 106 | -------------------------------------------------------------------------------- /src/core/graph/gnode.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/gnode.hpp" 2 | #include "core/graph/gedge.hpp" 3 | #include "core/tensor.hpp" 4 | #include "core/operators/gemm.hpp" 5 | #include "core/operators/binary.hpp" 6 | #include "core/operators/unary.hpp" 7 | #include "core/operators/fused.hpp" 8 | 9 | #include 10 | #include 11 | 12 | namespace tilegraph::graph { 13 | using namespace tilegraph::operators; 14 | 15 | int64_t GNode::node_count = 0; 16 | 17 | GNode::GNode(std::vector> inputs_list, 18 | std::vector> outputs_list, 19 | OperatorType op_type, Operator::OpBox op, 20 | std::string name_value) 21 | : name(name_value), 22 | index(node_count++), 23 | in_degree(0), 24 | op_type(op_type), 25 | op(op), 26 | inputs(inputs_list), 27 | outputs(outputs_list) { 28 | name = (name == "" ? "Operator_" + std::to_string(index) : name); 29 | if (op == nullptr) { 30 | switch (op_type) { 31 | // Default GEMM operator. 32 | case OperatorType::GEMM: 33 | this->op = std::make_shared(); 34 | break; 35 | // Unary operators. 36 | case OperatorType::SIN: 37 | case OperatorType::COS: 38 | case OperatorType::SQRT: 39 | case OperatorType::RELU: 40 | case OperatorType::SOFTMAX: 41 | case OperatorType::SIGMOID: 42 | case OperatorType::TANH: 43 | this->op = std::make_shared(op_type); 44 | break; 45 | // Binary operators. 46 | case OperatorType::ADD: 47 | case OperatorType::SUB: 48 | case OperatorType::MUL: 49 | case OperatorType::DIV: 50 | this->op = std::make_shared(op_type); 51 | break; 52 | case OperatorType::FUSED: 53 | this->op = std::make_shared(); 54 | break; 55 | default: 56 | loge("[GNode::GNode] Operator type is not supported."); 57 | break; 58 | } 59 | } 60 | 61 | if (outputs.empty()) { 62 | if (this->inferShape().isErr()) { 63 | loge("[GNode::GNode] Failed to infer node shape."); 64 | } else { 65 | this->outputs = this->inferShape().unwrap(); 66 | } 67 | } 68 | } 69 | 70 | int64_t GNode::getIndex() { return index; } 71 | 72 | std::shared_ptr GNode::getOutput(int64_t index) { 73 | return outputs[index]; 74 | } 75 | 76 | std::vector> GNode::getInputs() { return inputs; } 77 | 78 | std::vector> GNode::getOutputs() { return outputs; } 79 | 80 | OperatorType GNode::getOperatorType() { return op_type; } 81 | 82 | bool GNode::earseSuccessor(Pointer node) { 83 | auto it = std::find(successors.begin(), successors.end(), node); 84 | if (it != successors.end()) { 85 | successors.erase(it); 86 | return true; 87 | } 88 | return false; 89 | } 90 | 91 | bool GNode::earsePredecessor(Pointer node) { 92 | auto it = std::find(predecessors.begin(), predecessors.end(), node); 93 | if (it != predecessors.end()) { 94 | predecessors.erase(it); 95 | return true; 96 | } 97 | return false; 98 | } 99 | 100 | Result>, GNode::GNodeError> 101 | GNode::inferShape() { 102 | if (this->op != nullptr) { 103 | // Get tensors from inputs. 104 | std::vector> input_tensors; 105 | for (auto &input : this->inputs) { 106 | input_tensors.push_back(input->getTensor()); 107 | } 108 | auto output_tensors = this->op->inferShape(input_tensors); 109 | 110 | std::vector> outputs; 111 | for (auto &output_tensor : output_tensors) { 112 | auto output = std::make_shared(output_tensor); 113 | outputs.push_back(output); 114 | } 115 | // this->outputs = outputs; 116 | return Ok(outputs); 117 | } else { 118 | loge("[GNode::inferShape] Operator is nullptr."); 119 | return Err(GNode::GNodeError{GNode::GNodeError::Kind::InferError}); 120 | } 121 | } 122 | 123 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /src/optimizer/fusion/persistent_kernel_fusion.cpp: -------------------------------------------------------------------------------- 1 | #include "optimizer/fusion/persistent_kernel_fusion.hpp" 2 | #include "core/type.hpp" 3 | 4 | #include 5 | 6 | using namespace tilegraph::graph; 7 | 8 | namespace tilegraph::fusion { 9 | bool PersistentKernelFusion::fusion(Graph::Pointer graph) { 10 | // GEMM, GEMM + RELU, GEMM + ADD + RELU 11 | // 12 | std::unordered_map nodes; 13 | // 14 | std::unordered_map node_map; 15 | // 16 | std::unordered_map node_to_group; 17 | // std::vector gemm_nodes; 18 | 19 | // Step 1. Fuse GEMM_ADD, GEMM_ADD_RELU, GEMM_ADD_GELU and so on. 20 | // TODO: 21 | 22 | // Step 2. Find all GEMM and fused operators. 23 | auto check_node = [](std::shared_ptr gnode) -> bool { 24 | auto op_type = gnode->getOperatorType(); 25 | if (op_type == OperatorType::GEMM || 26 | op_type == OperatorType::GEMM_RELU) { 27 | return true; 28 | } 29 | return false; 30 | }; 31 | 32 | size_t counter = 0; 33 | auto ordered_ops = graph->topoSort(); 34 | for (auto op : ordered_ops) { 35 | if (check_node(op)) { 36 | nodes[counter] = op; 37 | node_map[op] = counter; 38 | node_to_group[counter] = counter; 39 | counter++; 40 | } 41 | } 42 | 43 | // Step 3. Union GEMM and group them. 44 | for (auto op : nodes) { 45 | auto successors = op.second->successors; 46 | for (auto successor : successors) { 47 | if (check_node(successor)) { 48 | node_to_group[node_map[successor]] = op.first; 49 | } 50 | } 51 | } 52 | 53 | // Step 4. Find all groups. 54 | auto groups = findGroups(node_to_group, nodes, node_map); 55 | 56 | // Step 5. Fuse all groups. 57 | for (auto group : groups) { 58 | auto group_idx = group.first; 59 | auto group_nodes = group.second; 60 | for (auto node : group_nodes) { 61 | logi( 62 | "[PersistentKernelFusion::fusion] group idx {}, node: {} " 63 | "{}", 64 | group_idx, node->name, toString(node->getOperatorType())); 65 | } 66 | 67 | if (group_nodes.size() > 1) { 68 | auto in_out_edges = searchInOut(group_nodes); 69 | auto in_edges = in_out_edges.first; 70 | auto out_edges = in_out_edges.second; 71 | 72 | // Create fused nodes 73 | auto fused_node = std::make_shared( 74 | GNode(in_edges, out_edges, OperatorType::FUSED)); 75 | 76 | graph->fuseNode(group_nodes, fused_node); 77 | } 78 | } 79 | 80 | return true; 81 | } 82 | 83 | std::size_t PersistentKernelFusion::findRoot( 84 | std::unordered_map node_to_group, 85 | std::size_t node_idx) { 86 | while (node_to_group[node_idx] != node_idx) { 87 | node_idx = node_to_group[node_idx]; 88 | } 89 | return node_idx; 90 | } 91 | 92 | std::unordered_map> 93 | PersistentKernelFusion::findGroups( 94 | std::unordered_map node_to_group, 95 | std::unordered_map nodes, 96 | std::unordered_map node_map) { 97 | std::unordered_map> groups; 98 | for (auto node : nodes) { 99 | auto root = findRoot(node_to_group, node.first); 100 | groups[root].push_back(node.second); 101 | } 102 | return groups; 103 | } 104 | 105 | std::pair, std::vector> 106 | PersistentKernelFusion::searchInOut(std::vector group) { 107 | std::vector in_edges; 108 | std::vector out_edges; 109 | for (auto node : group) { 110 | // Search in edges by producer. 111 | auto inputs = node->getInputs(); 112 | for (auto input : inputs) { 113 | // Check producer if in this group. 114 | auto producer = input->getProducer(); 115 | if (std::find(group.begin(), group.end(), producer) == 116 | group.end()) { 117 | in_edges.push_back(input); 118 | } 119 | } 120 | // Search out edges by consumers. 121 | auto outputs = node->getOutputs(); 122 | for (auto output : outputs) { 123 | // Check consumers if in this group. 124 | auto consumers = output->getConsumers(); 125 | for (auto consumer : consumers) { 126 | if (std::find(group.begin(), group.end(), consumer) == 127 | group.end()) { 128 | out_edges.push_back(output); 129 | } 130 | } 131 | } 132 | } 133 | return std::make_pair(in_edges, out_edges); 134 | } 135 | } // namespace tilegraph::fusion -------------------------------------------------------------------------------- /src/core/graph/subgraph_match.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "core/graph/subgraph_match.hpp" 4 | 5 | using namespace tilegraph; 6 | using namespace tilegraph::graph; 7 | 8 | bool SubGraphMatch::Match(SubGraph::Pointer subgraph) { 9 | // check starting node 10 | for (auto node : m_graph->topoSort()) { 11 | if (subgraph->check_starting_node(node) && 12 | m_starting_nodes.find(node) == m_starting_nodes.end()) { 13 | SubGraphRecord::Pointer subgraph_record = 14 | std::make_shared(node, subgraph); 15 | if (FindSubGraph(subgraph_record, subgraph, node)) { 16 | m_matched_records.push_back(subgraph_record); 17 | m_starting_nodes.insert(node); 18 | } 19 | } 20 | } 21 | 22 | return !m_matched_records.empty(); 23 | } 24 | 25 | bool SubGraphMatch::FindSubGraph(SubGraphRecord::Pointer subgraph_record, 26 | SubGraph::Pointer subgraph, 27 | std::shared_ptr start) { 28 | ASSERT(subgraph->patterns.size() > 0, "Subgraph patterns is empty!"); 29 | auto init_pattern = subgraph->patterns[0]; 30 | std::vector init_pattern_records; 31 | if (FindPattern(init_pattern, init_pattern_records, start)) { 32 | for (auto pr : init_pattern_records) { 33 | if (SearchSubGraph(subgraph_record, subgraph, pr, 1) && 34 | subgraph_record->is_valid()) { 35 | return true; // return true when we find the first subgraph 36 | } else if (!subgraph_record->is_valid()) { 37 | loge("Subgraph Invalid!"); 38 | return false; 39 | } 40 | } 41 | } 42 | 43 | return false; 44 | } 45 | 46 | bool SubGraphMatch::SearchSubGraph(SubGraphRecord::Pointer subgraph_record, 47 | SubGraph::Pointer subgraph, 48 | PatternRecord::Pointer cur_pr, size_t idx) { 49 | std::stack s; 50 | std::vector matched_pattern_records; 51 | std::unordered_set pr_symbols; 52 | s.push(cur_pr); 53 | while (!s.empty()) { 54 | cur_pr = s.top(); 55 | s.pop(); 56 | subgraph_record->pattern_records.push_back(cur_pr); 57 | pr_symbols.insert(cur_pr->get_symbol()); 58 | 59 | if (idx == subgraph->patterns.size() && 60 | subgraph_record->pattern_records.size() == 61 | subgraph->patterns.size()) { 62 | return true; 63 | } else { 64 | auto start = cur_pr->get_next_start_node(); 65 | auto next_pattern = subgraph->patterns[idx]; 66 | // to ensure that subgraoh_record->pattern_records does not contain 67 | // the same record. 68 | bool is_valid = false; 69 | if (FindPattern(next_pattern, matched_pattern_records, start)) { 70 | for (auto pr : matched_pattern_records) { 71 | auto symbol = pr->get_symbol(); 72 | if (pr_symbols.find(symbol) == pr_symbols.end()) { 73 | s.push(pr); 74 | is_valid = true; 75 | } 76 | } 77 | 78 | idx++; 79 | } 80 | 81 | if (!is_valid) { 82 | auto back_pr = subgraph_record->pattern_records.back(); 83 | subgraph_record->pattern_records.pop_back(); 84 | pr_symbols.erase(back_pr->get_symbol()); 85 | } 86 | } 87 | } 88 | 89 | return false; 90 | } 91 | 92 | bool SubGraphMatch::FindPattern( 93 | Pattern::Pointer pattern, 94 | std::vector& pattern_records, 95 | std::shared_ptr start) { 96 | pattern_records.clear(); 97 | std::vector> pattern_nodes; 98 | pattern_nodes.push_back(start); 99 | for (size_t i = 0; i < pattern->descriptions.size(); i++) { 100 | SearchPattern(start, i, 1, pattern_records, pattern_nodes, pattern); 101 | } 102 | 103 | if (pattern_records.empty()) { 104 | return false; 105 | } 106 | return true; 107 | } 108 | 109 | void SubGraphMatch::SearchPattern( 110 | std::shared_ptr cur_node, size_t description_idx, size_t idx, 111 | std::vector& pattern_records, 112 | std::vector>& pattern_nodes, 113 | Pattern::Pointer pattern) { 114 | auto description_ops = pattern->descriptions[description_idx].first; 115 | if (idx == description_ops.size() && 116 | pattern_nodes.size() == description_ops.size()) { 117 | PatternRecord::Pointer pr = std::make_shared(pattern); 118 | pr->nodes = pattern_nodes; 119 | pr->set_pattern_description_idx(description_idx); 120 | if (pr->is_valid()) 121 | pattern_records.push_back(pr); 122 | else { 123 | loge("PatternRecord Invalid!"); 124 | } 125 | } else { 126 | std::vector> edges; 127 | if (pattern->reverse_order) { 128 | edges = cur_node->getInputs(); 129 | } else { 130 | edges = cur_node->getOutputs(); 131 | } 132 | 133 | for (auto edge : edges) { 134 | std::vector> sub_nodes = {}; 135 | if (pattern->reverse_order) { 136 | sub_nodes.push_back(edge->getProducer()); 137 | 138 | } else { 139 | sub_nodes = edge->getConsumers(); 140 | } 141 | 142 | for (auto sub_node : sub_nodes) { 143 | if (sub_node->getOperatorType() == description_ops[idx]) { 144 | logi("[SubGraphMatch::SearchPattern] Find Pattern: {}", 145 | toString(sub_node->getOperatorType())); 146 | pattern_nodes.push_back(sub_node); 147 | SearchPattern(sub_node, description_idx, idx + 1, 148 | pattern_records, pattern_nodes, pattern); 149 | pattern_nodes.pop_back(); 150 | } else { 151 | loge("[SubGraphMatch::SearchPattern] Not Find Pattern: {}", 152 | toString(sub_node->getOperatorType())); 153 | } 154 | } 155 | } 156 | } 157 | } -------------------------------------------------------------------------------- /src/core/graph/graph_base.cpp: -------------------------------------------------------------------------------- 1 | #include "core/graph/graph_base.hpp" 2 | 3 | #include 4 | #include 5 | 6 | namespace tilegraph::graph { 7 | int64_t GraphBase::graph_count = 0; 8 | 9 | GraphBase::GraphBase(std::vector> operators_list, 10 | std::vector> inputs_list, 11 | std::vector> outputs_list, 12 | std::string name_value) 13 | : name(name_value), 14 | index(graph_count++), 15 | nodes(operators_list), 16 | inputs(inputs_list), 17 | outputs(outputs_list) { 18 | name = (name == "" ? "Graph_" + std::to_string(index) : name); 19 | } 20 | 21 | void GraphBase::connect() { 22 | for (auto node : nodes) { 23 | auto outputs = node->getOutputs(); 24 | if (outputs.empty()) { 25 | if (node->inferShape().isErr()) { 26 | loge("[GraphBase::connect] Failed to infer node shape."); 27 | } else { 28 | node->outputs = node->inferShape().unwrap(); 29 | } 30 | } 31 | for (auto edge : node->inputs) { 32 | auto it = edge.get(); 33 | it->addConsumer(node); 34 | if (it->producer != NULL) { 35 | node.get()->predecessors.push_back(it->producer); 36 | it->producer.get()->successors.push_back(node); 37 | } 38 | } 39 | for (auto it : node.get()->outputs) { 40 | it->setProducer(node); 41 | } 42 | for (auto it : node.get()->inputs) { 43 | node->in_degree += it->producer == NULL ? 0 : 1; 44 | } 45 | } 46 | } 47 | 48 | bool GraphBase::earseNode(GNode::Pointer node) { 49 | if (!disconect(node)) { 50 | loge("Failed to disconect node."); 51 | return false; 52 | } 53 | // Remove node form operators. 54 | auto operators_iter = std::find(nodes.begin(), nodes.end(), node); 55 | if (operators_iter != nodes.end()) { 56 | nodes.erase(operators_iter); 57 | return true; 58 | } 59 | 60 | loge("Failed to remove node from operators."); 61 | return false; 62 | } 63 | 64 | bool GraphBase::addNode(GNode::Pointer node) { 65 | auto input_edges = node->getInputs(); 66 | auto output_edges = node->getOutputs(); 67 | for (auto input : input_edges) { 68 | // Add node to input edges' consumer list. 69 | input->addConsumer(node); 70 | if (input->producer != NULL) { 71 | node->predecessors.push_back(input->producer); 72 | input->producer->successors.push_back(node); 73 | } 74 | } 75 | 76 | for (auto output : output_edges) { 77 | output->setProducer(node); 78 | if (output->consumers.size() > 0) { 79 | for (auto consumer : output->consumers) { 80 | node->successors.push_back(consumer); 81 | consumer->predecessors.push_back(node); 82 | } 83 | } 84 | } 85 | 86 | for (auto input : input_edges) { 87 | node->in_degree += input->producer == NULL ? 0 : 1; 88 | } 89 | 90 | nodes.push_back(node); 91 | return true; 92 | } 93 | 94 | bool GraphBase::disconect(GNode::Pointer node) { 95 | auto input_edges = node->getInputs(); 96 | auto output_edges = node->getOutputs(); 97 | 98 | for (auto input : input_edges) { 99 | // Remove node from consumer list. 100 | if (!input->earseConsumer(node)) { 101 | loge("Failed to remove node from consumer list."); 102 | return false; 103 | } 104 | // Remove node from input edges' producer list. 105 | if (input->producer != NULL) { 106 | if (!input->producer->earseSuccessor(node)) { 107 | loge( 108 | "Failed to remove node from inputs' producer node's " 109 | "successor list."); 110 | return false; 111 | } 112 | } 113 | } 114 | 115 | for (auto output : output_edges) { 116 | // Remove node from producer list. 117 | output->producer = NULL; 118 | if (output->consumers.size() != 0) { 119 | for (auto consumer : output->consumers) { 120 | if (!consumer->earsePredecessor(node)) { 121 | loge( 122 | "Failed to remove node from outputs' consumer " 123 | "node's predecessor list."); 124 | return false; 125 | } 126 | } 127 | } 128 | } 129 | return true; 130 | } 131 | 132 | std::vector> GraphBase::topoSort() { 133 | std::unordered_map, int64_t> operators_indegree; 134 | for (auto op : nodes) { 135 | operators_indegree[op] = op->in_degree; 136 | } 137 | std::vector> result; 138 | while (!operators_indegree.empty()) { 139 | for (auto op = operators_indegree.begin(); 140 | op != operators_indegree.end(); ++op) { 141 | if (op->second == 0) { 142 | result.push_back(op->first); 143 | for (auto successor : (op->first)->successors) { 144 | --operators_indegree[successor]; 145 | } 146 | operators_indegree.erase(op->first); 147 | break; 148 | } 149 | } 150 | } 151 | return result; 152 | } 153 | 154 | bool GraphBase::fuseNode(std::vector> old_nodes, 155 | std::shared_ptr subgraph_node) { 156 | subgraph_node->in_degree = 0; 157 | subgraph_node->predecessors.clear(); 158 | subgraph_node->successors.clear(); 159 | 160 | for (auto old_node : old_nodes) { 161 | earseNode(old_node); 162 | } 163 | addNode(subgraph_node); 164 | return true; 165 | } 166 | 167 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /include/core/graph/subgraph_match.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "core/graph/gnode.hpp" 10 | #include "core/graph/graph.hpp" 11 | #include "common/common.hpp" 12 | 13 | namespace tilegraph::graph { 14 | 15 | struct PatternRecord; 16 | struct Pattern; 17 | struct SubGraph; 18 | struct SubGraphRecord; 19 | class SubGraphMatch; 20 | 21 | struct Pattern { 22 | // pair 23 | // eg. { {{"Add", "Mul"}, 1}, {{"Add", "Convert","Div"}, 2}} 24 | std::vector, size_t>> descriptions; 25 | // pattern search order 26 | bool reverse_order = false; 27 | 28 | /* 29 | cutomized check for the pattern 30 | for example, if we want to ensure that the second node of the pattern 31 | has 2 outputs, since such information is not included in description, 32 | we could add check function to check. 33 | */ 34 | std::vector> check; 35 | // PatternRecord::Pointer pattern_record; 36 | using Pointer = std::shared_ptr; 37 | }; 38 | 39 | struct PatternRecord { 40 | public: 41 | PatternRecord(Pattern::Pointer p) : pattern(p) {} 42 | 43 | std::shared_ptr get_next_start_node() { 44 | ASSERT(pattern_description_idx < pattern->descriptions.size(), 45 | "Pattern description index out of range."); 46 | size_t idx = pattern->descriptions[pattern_description_idx].second; 47 | ASSERT(idx < nodes.size(), "Index out of range."); 48 | return nodes[idx]; 49 | } 50 | 51 | bool is_valid() { 52 | if (pattern == nullptr || nodes.empty()) return false; 53 | 54 | for (auto func : pattern->check) { 55 | if (!func(*this)) return false; 56 | } 57 | 58 | return true; 59 | } 60 | 61 | std::string get_symbol() { 62 | std::string identity; 63 | for (auto node : nodes) { 64 | auto id = node->getIndex(); 65 | identity += std::to_string(id) + "_"; 66 | } 67 | return identity; 68 | } 69 | 70 | /* 71 | pair< pattern nodes, pattern index> 72 | Pattern.descriptions may contain several descriptions with equal logic 73 | meaning, PatternRecord.nodes contains all nodes of matched pattern. For 74 | each pattern descrpition in Pattern.descriptions, there may exist 75 | several pattern in the graph, pair.second is used to indentify which 76 | description pair.first matches. 77 | */ 78 | std::vector> nodes; 79 | void set_pattern_description_idx(size_t idx) { 80 | pattern_description_idx = idx; 81 | } 82 | size_t get_pattern_description_idx() const { 83 | return pattern_description_idx; 84 | } 85 | Pattern::Pointer pattern; 86 | using Pointer = std::shared_ptr; 87 | 88 | private: 89 | size_t pattern_description_idx; 90 | }; 91 | 92 | /* 93 | SubGraph consists of many Pattern. 94 | for a subgraph like 95 | A 96 | / \ 97 | B C 98 | | | 99 | D F 100 | | / 101 | E / 102 | \ / 103 | G 104 | 105 | , we can description it as pattern(A->B->D->E->G) with starting node A 106 | in non-reverse order followed by pattern(G->F->C) in reverse order. 107 | */ 108 | struct SubGraph { 109 | std::string name; 110 | std::function)> check_starting_node; 111 | std::vector patterns; 112 | std::vector> check; 113 | using Pointer = std::shared_ptr; 114 | }; 115 | 116 | struct SubGraphRecord { 117 | public: 118 | SubGraphRecord(std::shared_ptr sn, SubGraph::Pointer sg) 119 | : subgraph(sg), starting_node(sn) {} 120 | 121 | bool is_valid() { 122 | if (subgraph == nullptr || pattern_records.empty()) return false; 123 | 124 | if (!subgraph->check_starting_node(starting_node)) return false; 125 | 126 | for (auto pr : pattern_records) { 127 | if (!pr->is_valid()) return false; 128 | } 129 | 130 | for (auto func : subgraph->check) { 131 | if (!func(*this)) return false; 132 | } 133 | 134 | return true; 135 | } 136 | 137 | void set_starting_node(std::shared_ptr node) { 138 | starting_node = node; 139 | } 140 | const std::shared_ptr& get_starting_node() const { 141 | return starting_node; 142 | } 143 | std::vector pattern_records; 144 | SubGraph::Pointer subgraph; 145 | using Pointer = std::shared_ptr; 146 | 147 | private: 148 | std::shared_ptr starting_node; 149 | }; 150 | 151 | class SubGraphMatch { 152 | public: 153 | SubGraphMatch(std::shared_ptr g) : m_graph(g) {} 154 | 155 | bool Match(SubGraph::Pointer subgraph); 156 | bool FindSubGraph(SubGraphRecord::Pointer subgraph_record, 157 | SubGraph::Pointer subgraph, 158 | std::shared_ptr start); 159 | bool SearchSubGraph(SubGraphRecord::Pointer subgraph_record, 160 | SubGraph::Pointer subgraph, 161 | PatternRecord::Pointer cur_pr, size_t idx); 162 | bool FindPattern(Pattern::Pointer pattern, 163 | std::vector& pattern_records, 164 | std::shared_ptr start); 165 | void SearchPattern(std::shared_ptr cur_node, 166 | size_t description_idx, size_t idx, 167 | std::vector& pattern_records, 168 | std::vector>& pattern_nodes, 169 | Pattern::Pointer pattern); 170 | const std::vector& get_matched_subgraph() 171 | const { 172 | return m_matched_records; 173 | } 174 | void clear_matched_records() { m_matched_records.clear(); }; 175 | 176 | private: 177 | std::shared_ptr m_graph; 178 | std::vector m_matched_records; 179 | std::unordered_set> m_starting_nodes; 180 | }; 181 | } // namespace tilegraph::graph -------------------------------------------------------------------------------- /src/kernels/cuda/gemm.cpp: -------------------------------------------------------------------------------- 1 | #include "kernels/cuda/gemm.hpp" 2 | #include "kernels/cuda/sync.hpp" 3 | #include "kernels/cuda/tensor_core.hpp" 4 | #include "kernels/kernel_unit.hpp" 5 | #include 6 | 7 | namespace tilegraph::kernel::cuda { 8 | CudaGEMMKernel::CudaGEMMKernel(uint32_t M, uint32_t N, uint32_t K, 9 | uint32_t ShardedM, uint32_t ShardedN, 10 | uint32_t ShardedK, uint32_t WarpM, 11 | uint32_t WarpN, uint32_t WarpK, 12 | uint32_t WmmaM, uint32_t WmmaN, 13 | uint32_t WmmaK, bool transpose_a, 14 | bool transpose_b, MemoryType memory_level, 15 | MemoryType output_level) 16 | : M(M), 17 | N(N), 18 | K(K), 19 | ShardedM(ShardedM), 20 | ShardedN(ShardedN), 21 | ShardedK(ShardedK), 22 | WarpM(WarpM), 23 | WarpN(WarpN), 24 | WarpK(WarpK), 25 | WmmaM(WmmaM), 26 | WmmaN(WmmaN), 27 | WmmaK(WmmaK), 28 | transpose_a(transpose_a), 29 | transpose_b(transpose_b), 30 | memory_level(memory_level), 31 | output_level(output_level) {} 32 | 33 | std::string CudaGEMMKernel::genTCGEMM(std::string name) { 34 | std::string function; 35 | // Define Functions; 36 | // auto load_smem_a = std::make_shared( 37 | // "loadSmemA", FuncType::Device, DataType::Void); 38 | // auto load_smem_b = std::make_shared( 39 | // "loadSmemB", FuncType::Device, DataType::Void); 40 | // auto load_smem_c = std::make_shared( 41 | // "loadSmemC", FuncType::Device, DataType::Void); 42 | // auto store_smem_c = std::make_shared( 43 | // "StoreSmemC", FuncType::Device, DataType::Void); 44 | // auto load_frag_a = std::make_shared( 45 | // "LoadFragA", FuncType::Device, DataType::Void); 46 | // auto load_frag_b = std::make_shared( 47 | // "LoadFragB", FuncType::Device, DataType::Void); 48 | // auto store_accum = std::make_shared( 49 | // "StoreAccum", FuncType::Device, DataType::Void); 50 | 51 | // functions.insert(load_smem_a); 52 | // functions.insert(load_smem_b); 53 | // functions.insert(load_smem_c); 54 | // functions.insert(store_smem_c); 55 | // functions.insert(load_frag_a); 56 | // functions.insert(load_frag_b); 57 | // functions.insert(store_accum); 58 | 59 | // for (auto func : functions) { 60 | // function += func->declareFunction(); 61 | // } 62 | 63 | // uint16_t indient = 0; 64 | // std::vector> arguments; 65 | // function += function_unit->declareGlobal(name, arguments); 66 | // function += "{\n"; 67 | // indient += 4; 68 | 69 | int indient = 4; 70 | 71 | // Fuse level | Output Level 72 | // Global | Global 73 | // Global | Shared 74 | // Global | Warp 75 | // Shared | Global 76 | // Shared | Shared 77 | // Shared | Warp 78 | // Warp | Global 79 | // Warp | Shared 80 | // Warp | Warp 81 | // Generate tensor core gemm cuda implementation. 82 | if (memory_level == MemoryType::Global && 83 | output_level == MemoryType::Global) { 84 | // inputs global memory A, B 85 | // output global memory C 86 | // compute and store result into global memory C. 87 | // Decalre Sharded Memory. 88 | auto smem_a = std::make_shared( 89 | MemoryType::Shared, DataType::Half, ShardedM * ShardedK, "SA"); 90 | auto smem_b = std::make_shared( 91 | MemoryType::Shared, DataType::Half, ShardedK * ShardedN, "SB"); 92 | auto smem_c = std::make_shared( 93 | MemoryType::Shared, DataType::Half, ShardedM * ShardedN, "SC"); 94 | 95 | // Declare Warp variable. 96 | auto frag_a = std::make_shared( 97 | MemoryType::Warp, DataType::Half, WarpM * WarpK, "FA"); 98 | auto frag_b = std::make_shared( 99 | MemoryType::Warp, DataType::Half, WarpK * WarpN, "FB"); 100 | auto accum = std::make_shared( 101 | MemoryType::Warp, DataType::Half, WarpM * WarpN, "AC"); 102 | 103 | vars.insert(smem_a); 104 | vars.insert(smem_b); 105 | vars.insert(smem_c); 106 | vars.insert(frag_a); 107 | vars.insert(frag_b); 108 | vars.insert(accum); 109 | 110 | for (auto var : vars) { 111 | function += var->declareVar(indient); 112 | } 113 | 114 | // accum.initVar(indient); 115 | 116 | auto iter_k = std::make_unique( 117 | std::make_unique(MemoryType::Shared, DataType::Int32, 118 | 0, "ki"), 119 | std::variant>(1), 120 | std::variant>(0), 121 | std::variant>( 122 | (int)(K / ShardedK))); 123 | function += iter_k->genIter(indient); 124 | 125 | indient += 4; 126 | // TODO: load sharded memory 127 | function += insertSyncnorize(indient, MemoryType::Shared); 128 | 129 | auto iter_m = std::make_unique( 130 | std::make_unique(MemoryType::Shared, DataType::Int32, 131 | 0, "mii"), 132 | std::variant>(1), 133 | std::variant>(0), 134 | std::variant>( 135 | (int)(WarpM / WmmaM))); 136 | 137 | auto iter_n = std::make_unique( 138 | std::make_unique(MemoryType::Shared, DataType::Int32, 139 | 0, "nii"), 140 | std::variant>(1), 141 | std::variant>(0), 142 | std::variant>( 143 | (int)(WarpN / WmmaN))); 144 | 145 | function += iter_m->genIter(indient); 146 | indient += 4; 147 | function += iter_n->genIter(indient); 148 | indient += 4; 149 | 150 | // Insert mma sync to compute Matrix Mul. 151 | auto warp_a = frag_a->getVarIndexByVar(iter_m->getIterVar()); 152 | auto warp_b = frag_b->getVarIndexByVar(iter_n->getIterVar()); 153 | auto warp_c = accum->getVarIndexByVar( 154 | fmt::format("mii * {} + nii", WarpN / WmmaN)); 155 | function += genWmmaSync(indient, warp_a, warp_b, warp_c, warp_c); 156 | 157 | indient -= 4; 158 | function += insertIndient(indient); 159 | function += "}\n"; 160 | 161 | indient -= 4; 162 | function += insertIndient(indient); 163 | function += "}\n"; 164 | 165 | indient -= 4; 166 | function += insertIndient(indient); 167 | function += "}\n"; 168 | 169 | // TODO: Store accum into smem_c; 170 | function += insertSyncnorize(indient, MemoryType::Shared); 171 | } else if (memory_level == MemoryType::Shared && 172 | memory_level == MemoryType::Shared) { 173 | // Fuse GEMM in shared memory, 174 | // input shared memory A, global memory B, 175 | // Load global memory B into shared memory and 176 | // compute and store result into shared memory C. 177 | } else if (memory_level == MemoryType::Warp && 178 | memory_level == MemoryType::Warp) { 179 | // Fuse GEMM in warp level memory, 180 | // input warp tile A, global memory B, 181 | // Load global memory B into warp tile and compute 182 | // and store result into accumulator 183 | } 184 | 185 | // TODO: Store smem_c into C; 186 | 187 | // Function End; 188 | // indient -= 4; 189 | // function += insertIndient(indient); 190 | // function += "}\n"; 191 | 192 | return function; 193 | } 194 | } // namespace tilegraph::kernel::cuda -------------------------------------------------------------------------------- /docs/tilegraph.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------