├── .gitignore ├── .gitmodules ├── test ├── CMakeLists.txt ├── include │ ├── test_query.h │ └── test_base.h └── src │ ├── basic_tests.cpp │ ├── test_query.cpp │ └── test_base.cpp ├── CMakeLists.txt ├── .github └── workflows │ ├── cpplint.yml │ └── cmake.yml ├── include └── tilt │ ├── base │ ├── log.h │ ├── ctype.h │ └── type.h │ ├── pass │ ├── codegen │ │ ├── vinstr.h │ │ ├── loopgen.h │ │ └── llvmgen.h │ ├── visitor.h │ ├── printer.h │ └── irgen.h │ ├── ir │ ├── op.h │ ├── node.h │ ├── lstream.h │ ├── loop.h │ └── expr.h │ ├── engine │ └── engine.h │ └── builder │ └── tilder.h ├── scripts └── gen_vinstr.sh ├── src ├── CMakeLists.txt ├── engine │ └── engine.cpp ├── ir │ └── ir.cpp ├── builder │ └── tilder.cpp └── pass │ ├── codegen │ ├── vinstr.cpp │ ├── loopgen.cpp │ └── llvmgen.cpp │ └── printer.cpp ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | build 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/googletest"] 2 | path = third_party/googletest 3 | url = https://github.com/google/googletest.git -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(TEST_FILES 2 | src/basic_tests.cpp 3 | src/test_base.cpp 4 | src/test_query.cpp 5 | ) 6 | 7 | add_executable(tilt_test ${TEST_FILES}) 8 | target_include_directories(tilt_test PUBLIC include) 9 | target_link_libraries(tilt_test gtest_main tilt) 10 | 11 | include(GoogleTest) 12 | gtest_discover_tests(tilt_test) 13 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.13.4) 2 | set(CMAKE_C_COMPILER clang) 3 | set(CMAKE_CXX_COMPILER clang++) 4 | 5 | project(tilt) 6 | 7 | set(CMAKE_CXX_STANDARD 17) 8 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 9 | set(CMAKE_CXX_EXTENSIONS OFF) 10 | 11 | add_subdirectory(third_party/googletest) 12 | add_subdirectory(src) 13 | 14 | enable_testing() 15 | add_subdirectory(test) 16 | -------------------------------------------------------------------------------- /.github/workflows/cpplint.yml: -------------------------------------------------------------------------------- 1 | name: cpplint 2 | on: [pull_request] 3 | jobs: 4 | cpplint: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v1 8 | - uses: actions/setup-python@v1 9 | - run: pip install cpplint 10 | - run: cpplint --exclude=third_party --exclude=main.cpp --recursive --filter=-legal/copyright,-whitespace/braces,-whitespace/indent,-build/namespaces,-build/include_subdir --linelength=120 . 11 | 12 | -------------------------------------------------------------------------------- /include/tilt/base/log.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_BASE_LOG_H_ 2 | #define INCLUDE_TILT_BASE_LOG_H_ 3 | 4 | #include 5 | 6 | #define LOG(severity) \ 7 | std::cerr << "[" << __FILE__ << ":" << __LINE__ << "] [" #severity "] " 8 | 9 | #define CHECK(EXPR, MSG) \ 10 | if (!(EXPR)) { LOG(FATAL) << "Check failed: `" #EXPR "` " MSG << std::endl; std::abort(); } 11 | 12 | #define ASSERT(EXPR) CHECK(EXPR, "") 13 | 14 | #endif // INCLUDE_TILT_BASE_LOG_H_ 15 | -------------------------------------------------------------------------------- /scripts/gen_vinstr.sh: -------------------------------------------------------------------------------- 1 | CMAKE_CXX_COMPILER=$1 2 | CMAKE_CURRENT_SOURCE_DIR=$2 3 | CMAKE_CURRENT_BINARY_DIR=$3 4 | 5 | ${CMAKE_CXX_COMPILER} -emit-llvm -S ${CMAKE_CURRENT_SOURCE_DIR}/pass/codegen/vinstr.cpp \ 6 | -I ${CMAKE_CURRENT_SOURCE_DIR}/../include/ \ 7 | -o ${CMAKE_CURRENT_BINARY_DIR}/vinstr.ll 8 | 9 | VINSTR_IR=$(cat ${CMAKE_CURRENT_BINARY_DIR}/vinstr.ll) 10 | 11 | echo "const char* vinstr_str = R\"( 12 | ${VINSTR_IR} 13 | )\"; 14 | " > ${CMAKE_CURRENT_BINARY_DIR}/vinstr_str.cpp 15 | -------------------------------------------------------------------------------- /include/tilt/base/ctype.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_BASE_CTYPE_H_ 2 | #define INCLUDE_TILT_BASE_CTYPE_H_ 3 | 4 | #include 5 | 6 | typedef int64_t ts_t; 7 | typedef int64_t idx_t; 8 | typedef uint32_t dur_t; 9 | 10 | extern "C" { 11 | 12 | struct ival_t { 13 | ts_t t; 14 | dur_t d; 15 | }; 16 | 17 | struct region_t { 18 | ts_t st; 19 | ts_t et; 20 | idx_t head; 21 | idx_t count; 22 | uint32_t mask; 23 | ival_t* tl; 24 | char* data; 25 | }; 26 | 27 | } // extern "C" 28 | 29 | #endif // INCLUDE_TILT_BASE_CTYPE_H_ 30 | -------------------------------------------------------------------------------- /test/include/test_query.h: -------------------------------------------------------------------------------- 1 | #ifndef TEST_INCLUDE_TEST_QUERY_H_ 2 | #define TEST_INCLUDE_TEST_QUERY_H_ 3 | 4 | #include 5 | 6 | #include "tilt/builder/tilder.h" 7 | 8 | using namespace std; 9 | using namespace tilt; 10 | using namespace tilt::tilder; 11 | 12 | Op _Select(_sym, function); 13 | Op _MovingSum(_sym, int64_t, int64_t); 14 | Op _Join(_sym, _sym); 15 | Op _WindowAvg(string, _sym, int64_t); 16 | Op _Norm(string, _sym, int64_t); 17 | Op _Resample(string, _sym, int64_t, int64_t); 18 | 19 | Expr _Count(_sym); 20 | Expr _Sum(_sym); 21 | Expr _Average(_sym); 22 | Expr _StdDev(_sym); 23 | 24 | #endif // TEST_INCLUDE_TEST_QUERY_H_ 25 | -------------------------------------------------------------------------------- /test/src/basic_tests.cpp: -------------------------------------------------------------------------------- 1 | #include "test_base.h" 2 | 3 | TEST(MathOpTests, AddOpTest) { add_test(); } 4 | TEST(MathOpTests, SubOpTest) { sub_test(); } 5 | TEST(MathOpTests, MulOpTest) { mul_test(); } 6 | TEST(MathOpTests, DivOpTest) { div_test(); } 7 | TEST(MathOpTests, ModOpTest) { mod_test(); } 8 | TEST(MathOpTests, MaxOpTest) { max_test(); } 9 | TEST(MathOpTests, MinOpTest) { min_test(); } 10 | TEST(MathOpTests, NegOpTest) { neg_test(); } 11 | TEST(MathOpTests, SqrtOpTest) { sqrt_test(); } 12 | TEST(MathOpTests, PowOPTest) { pow_test(); } 13 | TEST(MathOpTests, CeilOPTest) { ceil_test(); } 14 | TEST(MathOpTests, FloorOPTest) { floor_test(); } 15 | TEST(MathOpTests, AbsOPTest) { abs_test(); } 16 | TEST(CastOpTests, CastTest) { cast_test(); } 17 | TEST(QuiltTest, MovingSumTest) { moving_sum_test(); } 18 | TEST(QuiltTest, NormTest) { norm_test(); } 19 | TEST(QuiltTest, ResampleTest) { resample_test(); } 20 | -------------------------------------------------------------------------------- /include/tilt/pass/codegen/vinstr.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_CODEGEN_VINSTR_H_ 2 | #define INCLUDE_TILT_PASS_CODEGEN_VINSTR_H_ 3 | 4 | #include "tilt/base/ctype.h" 5 | 6 | #define TILT_VINSTR_ATTR __attribute__((always_inline)) 7 | 8 | namespace tilt { 9 | extern "C" { 10 | 11 | TILT_VINSTR_ATTR uint32_t get_buf_size(idx_t); 12 | TILT_VINSTR_ATTR idx_t get_start_idx(region_t*); 13 | TILT_VINSTR_ATTR idx_t get_end_idx(region_t*); 14 | TILT_VINSTR_ATTR ts_t get_start_time(region_t*); 15 | TILT_VINSTR_ATTR ts_t get_end_time(region_t*); 16 | TILT_VINSTR_ATTR ts_t get_ckpt(region_t*, ts_t, idx_t); 17 | TILT_VINSTR_ATTR idx_t advance(region_t*, idx_t, ts_t); 18 | TILT_VINSTR_ATTR char* fetch(region_t*, ts_t, idx_t, uint32_t); 19 | TILT_VINSTR_ATTR region_t* make_region(region_t*, region_t*, ts_t, idx_t, ts_t, idx_t); 20 | TILT_VINSTR_ATTR region_t* init_region(region_t*, ts_t, uint32_t, ival_t*, char*); 21 | TILT_VINSTR_ATTR region_t* commit_data(region_t*, ts_t); 22 | TILT_VINSTR_ATTR region_t* commit_null(region_t*, ts_t); 23 | 24 | } // extern "C" 25 | } // namespace tilt 26 | 27 | #endif // INCLUDE_TILT_PASS_CODEGEN_VINSTR_H_ 28 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(SRC_FILES 2 | ir/ir.cpp 3 | builder/tilder.cpp 4 | pass/printer.cpp 5 | pass/codegen/loopgen.cpp 6 | pass/codegen/llvmgen.cpp 7 | pass/codegen/vinstr.cpp 8 | engine/engine.cpp 9 | ) 10 | 11 | find_package(LLVM 15 REQUIRED CONFIG) 12 | message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") 13 | message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") 14 | 15 | add_definitions(${LLVM_DEFINITIONS}) 16 | llvm_map_components_to_libnames(llvm_libs native orcjit mcjit objcarcopts) 17 | 18 | # Generate vinstr IR for JIT 19 | # 20 | # We have two commands that run scripts/gen_vinstr.sh because 21 | # execute_process is for configure (cmake) and target is for build (make) 22 | add_custom_command( 23 | OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/vinstr_str.cpp 24 | COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/gen_vinstr.sh ${CMAKE_CXX_COMPILER} ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} 25 | DEPENDS pass/codegen/vinstr.cpp 26 | ) 27 | 28 | add_library(tilt STATIC ${SRC_FILES} ${CMAKE_CURRENT_BINARY_DIR}/vinstr_str.cpp) 29 | 30 | target_link_libraries(tilt ${llvm_libs}) 31 | target_include_directories(tilt PUBLIC ${LLVM_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/../include) 32 | target_compile_options(tilt PRIVATE -Wall -Wextra -pedantic -Werror -Wno-unused-parameter) 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TiLT: A Temporal Query Compiler 2 | TiLT is a query compiler and execution engine for temporal stream processing applications. 3 | 4 | ## Building TiLT from source 5 | 6 | ### Prerequisites 7 | 1. CMake 3.13.4 8 | 2. LLVM 15 9 | 3. Clang++ 15 10 | 11 | ### Build and install LLVM and Clang 12 | Download and unpack [llvm-project-15.0.7.tar.xz](https://github.com/llvm/llvm-project/releases/download/llvmorg-15.0.7/llvm-project-15.0.7.src.tar.xz) 13 | 14 | cd llvm-project-15.0.7 15 | mkdir build 16 | cd build 17 | cmake -DCMAKE_BUILD_TYPE=Release \ 18 | -DLLVM_ENABLE_RTTI=ON \ 19 | -DLLVM_TARGETS_TO_BUILD="X86" \ 20 | -DLLVM_ENABLE_PROJECTS="clang;clang-tools-extra" \ 21 | -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi" \ 22 | -DLLVM_ENABLE_ZLIB=OFF \ 23 | -DLLVM_ENABLE_ZSTD=OFF \ 24 | -DLLVM_ENABLE_TERMINFO=OFF \ 25 | -DLLVM_BUILD_LLVM_DYLIB=ON \ 26 | -DLLVM_LINK_LLVM_DYLIB=ON \ 27 | -DCMAKE_INSTALL_PREFIX= ../llvm 28 | cmake --build . 29 | cmake --build . --target install 30 | 31 | ### Build TiLT 32 | Clone TiLT repository along with the submodules 33 | 34 | git clone https://github.com/ampersand-projects/tilt.git --recursive 35 | mkdir build 36 | cd build 37 | cmake -DLLVM_DIR=/lib/cmake/llvm .. 38 | cmake --build . 39 | -------------------------------------------------------------------------------- /test/include/test_base.h: -------------------------------------------------------------------------------- 1 | #ifndef TEST_INCLUDE_TEST_BASE_H_ 2 | #define TEST_INCLUDE_TEST_BASE_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "tilt/ir/op.h" 8 | 9 | #include "test_query.h" 10 | 11 | #include "gtest/gtest.h" 12 | 13 | using namespace std; 14 | using namespace tilt; 15 | 16 | template 17 | struct Event { 18 | int64_t st; 19 | int64_t et; 20 | T payload; 21 | }; 22 | 23 | template 24 | using QueryFn = function>(vector>)>; 25 | 26 | template 27 | void op_test(Op, QueryFn, vector>); 28 | 29 | template 30 | void select_test(function, function); 31 | 32 | namespace { 33 | 34 | template 35 | void assert_eq(T exp, T act) { ASSERT_EQ(exp, act); } 36 | 37 | template<> 38 | void assert_eq(float exp, float act) { ASSERT_FLOAT_EQ(exp, act); } 39 | 40 | template<> 41 | void assert_eq(double exp, double act) { ASSERT_DOUBLE_EQ(exp, act); } 42 | 43 | } // namespace 44 | 45 | // Math ops tests 46 | void add_test(); 47 | void sub_test(); 48 | void mul_test(); 49 | void div_test(); 50 | void mod_test(); 51 | void max_test(); 52 | void min_test(); 53 | void neg_test(); 54 | void sqrt_test(); 55 | void pow_test(); 56 | void ceil_test(); 57 | void floor_test(); 58 | void abs_test(); 59 | 60 | // Cast op test 61 | void cast_test(); 62 | 63 | // quilt tests 64 | void moving_sum_test(); 65 | void norm_test(); 66 | void resample_test(); 67 | 68 | #endif // TEST_INCLUDE_TEST_BASE_H_ 69 | -------------------------------------------------------------------------------- /include/tilt/ir/op.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_IR_OP_H_ 2 | #define INCLUDE_TILT_IR_OP_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "tilt/ir/expr.h" 10 | #include "tilt/ir/lstream.h" 11 | 12 | using namespace std; 13 | 14 | namespace tilt { 15 | 16 | struct OpNode : public LStream { 17 | Iter iter; 18 | Params inputs; 19 | SymTable syms; 20 | Expr pred; 21 | Sym output; 22 | Aux aux; 23 | 24 | OpNode(Iter iter, Params inputs, SymTable syms, Expr pred, Sym output, Aux aux = {}) : 25 | LStream(Type(output->type.dtype, iter)), iter(iter), 26 | inputs(std::move(inputs)), syms(std::move(syms)), pred(pred), output(output), aux(std::move(aux)) 27 | {} 28 | 29 | void Accept(Visitor&) const final; 30 | }; 31 | typedef shared_ptr Op; 32 | 33 | // Accumulate function type (state, st, et, data) -> state 34 | typedef function AccTy; 35 | 36 | struct Reduce : public ValNode { 37 | Sym lstream; 38 | Val state; 39 | AccTy acc; 40 | 41 | Reduce(Sym lstream, Val state, AccTy acc) : 42 | ValNode(state->type.dtype), lstream(lstream), state(state), acc(acc) 43 | { 44 | auto st = make_shared("st", Type(types::TIME)); 45 | auto et = make_shared("et", Type(types::TIME)); 46 | auto data = make_shared("data", Type(lstream->type.dtype)); 47 | ASSERT(acc(state, st, et, data)->type == Type(state->type.dtype)); 48 | } 49 | 50 | void Accept(Visitor&) const final; 51 | }; 52 | 53 | } // namespace tilt 54 | 55 | 56 | #endif // INCLUDE_TILT_IR_OP_H_ 57 | -------------------------------------------------------------------------------- /include/tilt/ir/node.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_IR_NODE_H_ 2 | #define INCLUDE_TILT_IR_NODE_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "tilt/base/type.h" 11 | 12 | using namespace std; 13 | 14 | namespace tilt { 15 | 16 | class Visitor; 17 | struct ExprNode; 18 | typedef shared_ptr Expr; 19 | struct Symbol; 20 | typedef shared_ptr Sym; 21 | typedef vector Params; 22 | typedef map SymTable; 23 | typedef map Aux; 24 | 25 | struct ExprNode { 26 | const Type type; 27 | 28 | explicit ExprNode(Type type) : type(type) {} 29 | 30 | virtual ~ExprNode() {} 31 | 32 | virtual void Accept(Visitor&) const = 0; 33 | }; 34 | 35 | struct Symbol : public ExprNode { 36 | const string name; 37 | 38 | Symbol(string name, Type type) : ExprNode(type), name(name) {} 39 | Symbol(string name, Expr expr) : Symbol(name, expr->type) {} 40 | 41 | void Accept(Visitor&) const override; 42 | }; 43 | 44 | struct FuncNode : public ExprNode { 45 | string name; 46 | Params inputs; 47 | Sym output; 48 | SymTable syms; 49 | 50 | FuncNode(string name, Params inputs, Sym output, SymTable syms) : 51 | ExprNode(output->type), name(name), inputs(std::move(inputs)), output(output), syms(std::move(syms)) 52 | {} 53 | 54 | virtual const string get_name() const = 0; 55 | 56 | protected: 57 | FuncNode(string name, Type type) : ExprNode(std::move(type)), name(name) {} 58 | }; 59 | typedef shared_ptr Func; 60 | 61 | struct ValNode : public ExprNode { 62 | explicit ValNode(DataType dtype) : ExprNode(Type(dtype)) {} 63 | }; 64 | typedef shared_ptr Val; 65 | 66 | } // namespace tilt 67 | 68 | #endif // INCLUDE_TILT_IR_NODE_H_ 69 | -------------------------------------------------------------------------------- /.github/workflows/cmake.yml: -------------------------------------------------------------------------------- 1 | name: CMake 2 | 3 | on: 4 | pull_request: 5 | branches: [ "master" ] 6 | 7 | env: 8 | # Customize the CMake build type here (Release, Debug, RelWithDebInfo, etc.) 9 | BUILD_TYPE: Release 10 | 11 | jobs: 12 | build: 13 | # The CMake configure and build commands are platform agnostic and should work equally well on Windows or Mac. 14 | # You can convert this to a matrix build if you need cross-platform coverage. 15 | # See: https://docs.github.com/en/free-pro-team@latest/actions/learn-github-actions/managing-complex-workflows#using-a-build-matrix 16 | 17 | runs-on: ubuntu-22.04 18 | 19 | steps: 20 | - name: Install LLVM and Clang 21 | run: | 22 | wget https://apt.llvm.org/llvm.sh 23 | chmod +x llvm.sh 24 | sudo ./llvm.sh 15 25 | sudo ln -f /usr/bin/clang-15 /usr/bin/clang 26 | sudo ln -f /usr/bin/clang++-15 /usr/bin/clang++ 27 | 28 | - uses: actions/checkout@v3 29 | with: 30 | submodules: true 31 | 32 | - name: Configure CMake 33 | # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. 34 | # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type 35 | run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} 36 | 37 | - name: Build 38 | # Build your program with the given configuration 39 | run: cmake --build ${{github.workspace}}/build --config ${{env.BUILD_TYPE}} 40 | 41 | - name: Test 42 | working-directory: ${{github.workspace}}/build 43 | # Execute tests defined by the CMake configuration. 44 | # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail 45 | run: ctest -C ${{env.BUILD_TYPE}} 46 | -------------------------------------------------------------------------------- /include/tilt/pass/visitor.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_VISITOR_H_ 2 | #define INCLUDE_TILT_PASS_VISITOR_H_ 3 | 4 | #include "tilt/ir/expr.h" 5 | #include "tilt/ir/lstream.h" 6 | #include "tilt/ir/op.h" 7 | #include "tilt/ir/loop.h" 8 | 9 | namespace tilt { 10 | 11 | class Visitor { 12 | public: 13 | /** 14 | * TiLT IR 15 | */ 16 | virtual void Visit(const Symbol&) = 0; 17 | virtual void Visit(const Out&) = 0; 18 | virtual void Visit(const Beat&) = 0; 19 | virtual void Visit(const Call&) = 0; 20 | virtual void Visit(const Read&) = 0; 21 | virtual void Visit(const IfElse&) = 0; 22 | virtual void Visit(const Select&) = 0; 23 | virtual void Visit(const Get&) = 0; 24 | virtual void Visit(const New&) = 0; 25 | virtual void Visit(const Exists&) = 0; 26 | virtual void Visit(const ConstNode&) = 0; 27 | virtual void Visit(const Cast&) = 0; 28 | virtual void Visit(const NaryExpr&) = 0; 29 | virtual void Visit(const SubLStream&) = 0; 30 | virtual void Visit(const Element&) = 0; 31 | virtual void Visit(const OpNode&) = 0; 32 | virtual void Visit(const Reduce&) = 0; 33 | 34 | /** 35 | * Loop IR 36 | */ 37 | virtual void Visit(const Fetch&) = 0; 38 | virtual void Visit(const Write&) = 0; 39 | virtual void Visit(const Advance&) = 0; 40 | virtual void Visit(const GetCkpt&) = 0; 41 | virtual void Visit(const GetStartIdx&) = 0; 42 | virtual void Visit(const GetEndIdx&) = 0; 43 | virtual void Visit(const GetStartTime&) = 0; 44 | virtual void Visit(const GetEndTime&) = 0; 45 | virtual void Visit(const CommitData&) = 0; 46 | virtual void Visit(const CommitNull&) = 0; 47 | virtual void Visit(const AllocRegion&) = 0; 48 | virtual void Visit(const MakeRegion&) = 0; 49 | virtual void Visit(const LoopNode&) = 0; 50 | }; 51 | 52 | } // namespace tilt 53 | 54 | #endif // INCLUDE_TILT_PASS_VISITOR_H_ 55 | -------------------------------------------------------------------------------- /include/tilt/ir/lstream.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_IR_LSTREAM_H_ 2 | #define INCLUDE_TILT_IR_LSTREAM_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "tilt/ir/node.h" 8 | 9 | using namespace std; 10 | 11 | namespace tilt { 12 | 13 | struct LStream : public ExprNode { 14 | explicit LStream(Type type) : ExprNode(std::move(type)) { ASSERT(!this->type.is_val()); } 15 | }; 16 | 17 | struct Out : public Symbol { 18 | explicit Out(DataType dtype) : Symbol("", Type(dtype, Iter(0, -2))) {} 19 | 20 | void Accept(Visitor&) const final; 21 | }; 22 | 23 | struct Beat : public Symbol { 24 | explicit Beat(Iter iter) : Symbol(iter.str(), Type(types::TIME, iter)) 25 | { 26 | ASSERT(this->type.is_beat()); 27 | } 28 | 29 | void Accept(Visitor&) const final; 30 | }; 31 | 32 | struct Point { 33 | const int64_t offset; 34 | 35 | explicit Point(int64_t offset) : offset(offset) { ASSERT(offset <= 0); } 36 | Point() : Point(0) {} 37 | 38 | bool operator<(const Point& o) const { return offset < o.offset; } 39 | }; 40 | typedef shared_ptr Pointer; 41 | 42 | struct Window { 43 | Point start; 44 | Point end; 45 | 46 | Window(Point start, Point end) : start(start), end(end) { ASSERT(start < end); } 47 | Window(int64_t start, int64_t end) : Window(Point(start), Point(end)) {} 48 | }; 49 | 50 | struct SubLStream : public LStream { 51 | Sym lstream; 52 | const Window win; 53 | 54 | SubLStream(Sym lstream, Window win) : 55 | LStream(lstream->type), lstream(lstream), win(win) 56 | {} 57 | 58 | void Accept(Visitor&) const final; 59 | }; 60 | 61 | struct Element : public ValNode { 62 | Sym lstream; 63 | const Point pt; 64 | 65 | Element(Sym lstream, Point pt) : 66 | ValNode(lstream->type.dtype), lstream(lstream), pt(pt) 67 | { 68 | ASSERT(!lstream->type.is_val()); 69 | } 70 | 71 | void Accept(Visitor&) const final; 72 | }; 73 | 74 | } // namespace tilt 75 | 76 | #endif // INCLUDE_TILT_IR_LSTREAM_H_ 77 | -------------------------------------------------------------------------------- /src/engine/engine.cpp: -------------------------------------------------------------------------------- 1 | #include "llvm/IR/Verifier.h" 2 | #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" 3 | 4 | #include "tilt/engine/engine.h" 5 | 6 | using namespace tilt; 7 | using namespace std::placeholders; 8 | 9 | ExecEngine* ExecEngine::Get() 10 | { 11 | static unique_ptr engine; 12 | 13 | if (!engine) { 14 | InitializeNativeTarget(); 15 | InitializeNativeTargetAsmPrinter(); 16 | 17 | auto jtmb = cantFail(JITTargetMachineBuilder::detectHost()); 18 | auto dl = cantFail(jtmb.getDefaultDataLayoutForTarget()); 19 | 20 | engine = make_unique(std::move(jtmb), std::move(dl)); 21 | } 22 | 23 | return engine.get(); 24 | } 25 | 26 | void ExecEngine::AddModule(unique_ptr m) 27 | { 28 | raw_fd_ostream r(fileno(stdout), false); 29 | verifyModule(*m, &r); 30 | 31 | cantFail(optimizer.add(jd, ThreadSafeModule(std::move(m), ctx))); 32 | } 33 | 34 | LLVMContext& ExecEngine::GetCtx() { return *ctx.getContext(); } 35 | 36 | intptr_t ExecEngine::Lookup(StringRef name) 37 | { 38 | auto fn_sym = cantFail(es->lookup({ &jd }, mangler(name.str()))); 39 | return (intptr_t) fn_sym.getAddress(); 40 | } 41 | 42 | Expected ExecEngine::optimize_module(ThreadSafeModule tsm, const MaterializationResponsibility &r) 43 | { 44 | tsm.withModuleDo([](Module &m) { 45 | unsigned opt_level = 3; 46 | unsigned opt_size = 0; 47 | 48 | llvm::PassManagerBuilder builder; 49 | builder.OptLevel = opt_level; 50 | builder.Inliner = createFunctionInliningPass(opt_level, opt_size, false); 51 | 52 | llvm::legacy::PassManager mpm; 53 | builder.populateModulePassManager(mpm); 54 | mpm.run(m); 55 | }); 56 | 57 | return std::move(tsm); 58 | } 59 | 60 | unique_ptr ExecEngine::createExecutionSession() { 61 | unique_ptr epc = llvm::cantFail(SelfExecutorProcessControl::Create()); 62 | return std::make_unique(std::move(epc)); 63 | } 64 | -------------------------------------------------------------------------------- /src/ir/ir.cpp: -------------------------------------------------------------------------------- 1 | #include "tilt/ir/expr.h" 2 | #include "tilt/ir/lstream.h" 3 | #include "tilt/ir/op.h" 4 | #include "tilt/ir/loop.h" 5 | #include "tilt/pass/visitor.h" 6 | 7 | using namespace tilt; 8 | 9 | void Symbol::Accept(Visitor& v) const { v.Visit(*this); } 10 | void Out::Accept(Visitor& v) const { v.Visit(*this); } 11 | void Beat::Accept(Visitor& v) const { v.Visit(*this); } 12 | void Call::Accept(Visitor& v) const { v.Visit(*this); } 13 | void IfElse::Accept(Visitor& v) const { v.Visit(*this); } 14 | void Select::Accept(Visitor& v) const { v.Visit(*this); } 15 | void Get::Accept(Visitor& v) const { v.Visit(*this); } 16 | void New::Accept(Visitor& v) const { v.Visit(*this); } 17 | void Exists::Accept(Visitor& v) const { v.Visit(*this); } 18 | void ConstNode::Accept(Visitor& v) const { v.Visit(*this); } 19 | void Cast::Accept(Visitor& v) const { v.Visit(*this); } 20 | void NaryExpr::Accept(Visitor& v) const { v.Visit(*this); } 21 | void SubLStream::Accept(Visitor& v) const { v.Visit(*this); } 22 | void Element::Accept(Visitor& v) const { v.Visit(*this); } 23 | void OpNode::Accept(Visitor& v) const { v.Visit(*this); } 24 | void Reduce::Accept(Visitor& v) const { v.Visit(*this); } 25 | void Fetch::Accept(Visitor& v) const { v.Visit(*this); } 26 | void Read::Accept(Visitor& v) const { v.Visit(*this); } 27 | void Write::Accept(Visitor& v) const { v.Visit(*this); } 28 | void Advance::Accept(Visitor& v) const { v.Visit(*this); } 29 | void GetCkpt::Accept(Visitor& v) const { v.Visit(*this); } 30 | void GetStartIdx::Accept(Visitor& v) const { v.Visit(*this); } 31 | void GetEndIdx::Accept(Visitor& v) const { v.Visit(*this); } 32 | void GetStartTime::Accept(Visitor& v) const { v.Visit(*this); } 33 | void GetEndTime::Accept(Visitor& v) const { v.Visit(*this); } 34 | void CommitData::Accept(Visitor& v) const { v.Visit(*this); } 35 | void CommitNull::Accept(Visitor& v) const { v.Visit(*this); } 36 | void AllocRegion::Accept(Visitor& v) const { v.Visit(*this); } 37 | void MakeRegion::Accept(Visitor& v) const { v.Visit(*this); } 38 | void LoopNode::Accept(Visitor& v) const { v.Visit(*this); } 39 | -------------------------------------------------------------------------------- /src/builder/tilder.cpp: -------------------------------------------------------------------------------- 1 | #include "tilt/base/type.h" 2 | #include "tilt/builder/tilder.h" 3 | 4 | namespace tilt::tilder { 5 | 6 | _expr _expr_add(Expr a, Expr b) { return _add(a, b); } 7 | _expr _expr_sub(Expr a, Expr b) { return _sub(a, b); } 8 | _expr _expr_mul(Expr a, Expr b) { return _mul(a, b); } 9 | _expr
_expr_div(Expr a, Expr b) { return _div(a, b); } 10 | _expr _expr_neg(Expr a) { return _neg(a); } 11 | _expr _expr_mod(Expr a, Expr b) { return _mod(a, b); } 12 | _expr _expr_lt(Expr a, Expr b) { return _lt(a, b); } 13 | _expr _expr_lte(Expr a, Expr b) { return _lte(a, b); } 14 | _expr _expr_gt(Expr a, Expr b) { return _gt(a, b); } 15 | _expr _expr_gte(Expr a, Expr b) { return _gte(a, b); } 16 | _expr _expr_eq(Expr a, Expr b) { return _eq(a, b); } 17 | _expr _expr_not(Expr a) { return _not(a); } 18 | _expr _expr_and(Expr a, Expr b) { return _and(a, b); } 19 | _expr _expr_or(Expr a, Expr b) { return _or(a, b); } 20 | _expr _expr_get(Expr a, size_t n) { return _get(a, n); } 21 | _expr _expr_elem(Sym a, Point pt) { return _elem(a, pt); } 22 | _expr _expr_subls(Sym a, Window win) { return _subls(a, win); } 23 | 24 | _expr _i8(int8_t v) { return _const(BaseType::INT8, v); } 25 | _expr _i16(int16_t v) { return _const(BaseType::INT16, v); } 26 | _expr _i32(int32_t v) { return _const(BaseType::INT32, v); } 27 | _expr _i64(int64_t v) { return _const(BaseType::INT64, v); } 28 | _expr _u8(uint8_t v) { return _const(BaseType::UINT8, v); } 29 | _expr _u16(uint16_t v) { return _const(BaseType::UINT16, v); } 30 | _expr _u32(uint32_t v) { return _const(BaseType::UINT32, v); } 31 | _expr _u64(uint64_t v) { return _const(BaseType::UINT64, v); } 32 | _expr _f32(float v) { return _const(BaseType::FLOAT32, v); } 33 | _expr _f64(double v) { return _const(BaseType::FLOAT64, v); } 34 | _expr _ts(int64_t v) { return _const(BaseType::TIME, v); } 35 | _expr _idx(idx_t v) { return _const(BaseType::INDEX, v); } 36 | _expr _true() { return _const(BaseType::BOOL, 1); } 37 | _expr _false() { return _const(BaseType::BOOL, 0); } 38 | 39 | } // namespace tilt::tilder 40 | -------------------------------------------------------------------------------- /include/tilt/engine/engine.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_ENGINE_ENGINE_H_ 2 | #define INCLUDE_TILT_ENGINE_ENGINE_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "llvm/ADT/StringRef.h" 8 | #include "llvm/Support/TargetSelect.h" 9 | #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 10 | #include "llvm/ExecutionEngine/Orc/Core.h" 11 | #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" 12 | #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 13 | #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 14 | #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 15 | #include "llvm/ExecutionEngine/SectionMemoryManager.h" 16 | #include "llvm/ExecutionEngine/JITSymbol.h" 17 | #include "llvm/IR/DataLayout.h" 18 | #include "llvm/IR/LLVMContext.h" 19 | #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 20 | #include "llvm/IR/LegacyPassManager.h" 21 | #include "llvm/Transforms/IPO/PassManagerBuilder.h" 22 | #include "llvm/Transforms/IPO.h" 23 | 24 | using namespace std; 25 | using namespace llvm; 26 | using namespace llvm::orc; 27 | 28 | namespace tilt { 29 | 30 | class ExecEngine { 31 | public: 32 | ExecEngine(JITTargetMachineBuilder jtmb, DataLayout dl) : 33 | es(createExecutionSession()), 34 | linker(*es, []() { return make_unique(); }), 35 | compiler(*es, linker, make_unique(std::move(jtmb))), 36 | optimizer(*es, compiler, optimize_module), 37 | dl(std::move(dl)), mangler(*es, this->dl), 38 | ctx(make_unique()), 39 | jd(es->createBareJITDylib("__tilt_dylib")) 40 | { 41 | jd.addGenerator(cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(dl.getGlobalPrefix()))); 42 | } 43 | 44 | static ExecEngine* Get(); 45 | void AddModule(unique_ptr); 46 | LLVMContext& GetCtx(); 47 | intptr_t Lookup(StringRef); 48 | 49 | private: 50 | static Expected optimize_module(ThreadSafeModule, const MaterializationResponsibility&); 51 | static unique_ptr createExecutionSession(); 52 | 53 | unique_ptr es; 54 | RTDyldObjectLinkingLayer linker; 55 | IRCompileLayer compiler; 56 | IRTransformLayer optimizer; 57 | 58 | DataLayout dl; 59 | MangleAndInterner mangler; 60 | ThreadSafeContext ctx; 61 | 62 | JITDylib& jd; 63 | }; 64 | 65 | } // namespace tilt 66 | 67 | #endif // INCLUDE_TILT_ENGINE_ENGINE_H_ 68 | -------------------------------------------------------------------------------- /src/pass/codegen/vinstr.cpp: -------------------------------------------------------------------------------- 1 | #include "tilt/pass/codegen/vinstr.h" 2 | 3 | namespace tilt { 4 | extern "C" { 5 | 6 | uint32_t get_buf_size(idx_t len) 7 | { 8 | uint32_t ring = 1; 9 | while (len) { len >>= 1; ring <<= 1; } 10 | return ring; 11 | } 12 | 13 | idx_t get_start_idx(region_t* reg) 14 | { 15 | auto size = reg->mask + 1; 16 | auto count = (reg->count < size) ? reg->count : size; 17 | return reg->head - count + 1; 18 | } 19 | 20 | idx_t get_end_idx(region_t* reg) { return reg->head; } 21 | 22 | ts_t get_start_time(region_t* reg) { return reg->st; } 23 | 24 | ts_t get_end_time(region_t* reg) { return reg->et; } 25 | 26 | int64_t get_ckpt(region_t* reg, ts_t t, idx_t i) 27 | { 28 | auto civl = reg->tl[i & reg->mask]; 29 | return (t <= civl.t) ? civl.t : (civl.t + civl.d); 30 | } 31 | 32 | idx_t advance(region_t* reg, idx_t i, ts_t t) 33 | { 34 | while ((reg->tl[i & reg->mask].t + reg->tl[i & reg->mask].d) < t) { i++; } 35 | return i; 36 | } 37 | 38 | char* fetch(region_t* reg, ts_t t, idx_t i, uint32_t bytes) 39 | { 40 | auto ivl = reg->tl[i & reg->mask]; 41 | return (t <= ivl.t) ? nullptr : (reg->data + ((i & reg->mask) * bytes)); 42 | } 43 | 44 | region_t* make_region(region_t* out_reg, region_t* in_reg, ts_t st, idx_t si, ts_t et, idx_t ei) 45 | { 46 | out_reg->st = st; 47 | out_reg->et = et; 48 | out_reg->head = ei; 49 | out_reg->count = ei - si + 1; 50 | out_reg->mask = in_reg->mask; 51 | out_reg->tl = in_reg->tl; 52 | out_reg->data = in_reg->data; 53 | 54 | return out_reg; 55 | } 56 | 57 | region_t* init_region(region_t* reg, ts_t t, uint32_t size, ival_t* tl, char* data) 58 | { 59 | reg->st = t; 60 | reg->et = t; 61 | reg->head = -1; 62 | reg->count = 0; 63 | reg->mask = size - 1; 64 | reg->tl = tl; 65 | reg->data = data; 66 | commit_null(reg, t); 67 | return reg; 68 | } 69 | 70 | region_t* commit_data(region_t* reg, ts_t t) 71 | { 72 | auto last_ckpt = reg->et; 73 | reg->et = t; 74 | reg->head++; 75 | reg->count++; 76 | 77 | reg->tl[reg->head & reg->mask].t = last_ckpt; 78 | reg->tl[reg->head & reg->mask].d = t - last_ckpt; 79 | 80 | return reg; 81 | } 82 | 83 | region_t* commit_null(region_t* reg, ts_t t) 84 | { 85 | reg->et = t; 86 | reg->tl[(reg->head + 1) & reg->mask].t = t; 87 | reg->tl[(reg->head + 1) & reg->mask].d = 0; 88 | return reg; 89 | } 90 | 91 | } // extern "C" 92 | } // namespace tilt 93 | -------------------------------------------------------------------------------- /include/tilt/pass/codegen/loopgen.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_CODEGEN_LOOPGEN_H_ 2 | #define INCLUDE_TILT_PASS_CODEGEN_LOOPGEN_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "tilt/pass/irgen.h" 8 | 9 | using namespace std; 10 | 11 | namespace tilt { 12 | 13 | class LoopGenCtx : public IRGenCtx { 14 | public: 15 | LoopGenCtx(Sym sym, const OpNode* op, Loop loop) : 16 | IRGenCtx(sym, &op->syms, &loop->syms), op(op), loop(loop) 17 | {} 18 | 19 | private: 20 | const OpNode* op; 21 | Loop loop; 22 | 23 | map> pt_idx_maps; 24 | map idx_diff_map; 25 | map sym_ref; 26 | 27 | friend class LoopGen; 28 | }; 29 | 30 | class LoopGen : public IRGen { 31 | public: 32 | explicit LoopGen(LoopGenCtx ctx) : _ctx(std::move(ctx)) {} 33 | 34 | static Loop Build(Sym, const OpNode*); 35 | 36 | private: 37 | LoopGenCtx& ctx() override { return _ctx; } 38 | 39 | Expr get_timer(const Point, bool); 40 | Index& get_idx(const Sym, const Point); 41 | Sym get_ref(const Sym sym) { return ctx().sym_ref.at(sym); } 42 | void set_ref(Sym sym, Sym ref) { ctx().sym_ref[sym] = ref; } 43 | void build_tloop(function, function); 44 | void build_loop(); 45 | 46 | Expr visit(const Symbol&) final; 47 | Expr visit(const Out&) final; 48 | Expr visit(const Beat&) final; 49 | Expr visit(const Call&) final; 50 | Expr visit(const IfElse&) final; 51 | Expr visit(const Select&) final; 52 | Expr visit(const Get&) final; 53 | Expr visit(const New&) final; 54 | Expr visit(const Exists&) final; 55 | Expr visit(const ConstNode&) final; 56 | Expr visit(const Cast&) final; 57 | Expr visit(const NaryExpr&) final; 58 | Expr visit(const SubLStream&) final; 59 | Expr visit(const Element&) final; 60 | Expr visit(const OpNode&) final; 61 | Expr visit(const Reduce&) final; 62 | Expr visit(const Fetch&) final { throw runtime_error("Invalid expression"); }; 63 | Expr visit(const Read&) final { throw runtime_error("Invalid expression"); }; 64 | Expr visit(const Write&) final { throw runtime_error("Invalid expression"); }; 65 | Expr visit(const Advance&) final { throw runtime_error("Invalid expression"); }; 66 | Expr visit(const GetCkpt&) final { throw runtime_error("Invalid expression"); }; 67 | Expr visit(const GetStartIdx&) final { throw runtime_error("Invalid expression"); }; 68 | Expr visit(const GetEndIdx&) final { throw runtime_error("Invalid expression"); }; 69 | Expr visit(const GetStartTime&) final { throw runtime_error("Invalid expression"); }; 70 | Expr visit(const GetEndTime&) final { throw runtime_error("Invalid expression"); }; 71 | Expr visit(const CommitData&) final { throw runtime_error("Invalid expression"); }; 72 | Expr visit(const CommitNull&) final { throw runtime_error("Invalid expression"); }; 73 | Expr visit(const AllocRegion&) final { throw runtime_error("Invalid expression"); }; 74 | Expr visit(const MakeRegion&) final { throw runtime_error("Invalid expression"); }; 75 | Expr visit(const LoopNode&) final { throw runtime_error("Invalid expression"); }; 76 | 77 | LoopGenCtx _ctx; 78 | }; 79 | 80 | } // namespace tilt 81 | 82 | #endif // INCLUDE_TILT_PASS_CODEGEN_LOOPGEN_H_ 83 | -------------------------------------------------------------------------------- /include/tilt/pass/printer.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_PRINTER_H_ 2 | #define INCLUDE_TILT_PASS_PRINTER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "llvm/IR/Module.h" 10 | #include "tilt/pass/visitor.h" 11 | 12 | using namespace std; 13 | 14 | namespace tilt { 15 | 16 | class IRPrinterCtx { 17 | public: 18 | IRPrinterCtx() : indent(0), nesting(0) {} 19 | 20 | private: 21 | size_t indent; 22 | size_t nesting; 23 | 24 | friend class IRPrinter; 25 | }; 26 | 27 | class IRPrinter : public Visitor { 28 | public: 29 | IRPrinter() : IRPrinter(IRPrinterCtx()) {} 30 | explicit IRPrinter(IRPrinterCtx ctx) : IRPrinter(std::move(ctx), 2) {} 31 | 32 | IRPrinter(IRPrinterCtx ctx, size_t tabstop) : 33 | ctx(std::move(ctx)), tabstop(tabstop) 34 | {} 35 | 36 | static string Build(const Expr); 37 | static string Build(const llvm::Module*); 38 | 39 | void Visit(const Symbol&) override; 40 | void Visit(const Out&) override; 41 | void Visit(const Beat&) override; 42 | void Visit(const Call&) override; 43 | void Visit(const IfElse&) override; 44 | void Visit(const Select&) override; 45 | void Visit(const Get&) override; 46 | void Visit(const New&) override; 47 | void Visit(const Exists&) override; 48 | void Visit(const ConstNode&) override; 49 | void Visit(const Cast&) override; 50 | void Visit(const NaryExpr&) override; 51 | void Visit(const SubLStream&) override; 52 | void Visit(const Element&) override; 53 | void Visit(const OpNode&) override; 54 | void Visit(const Reduce&) override; 55 | void Visit(const Fetch&) override; 56 | void Visit(const Read&) override; 57 | void Visit(const Write&) override; 58 | void Visit(const Advance&) override; 59 | void Visit(const GetCkpt&) override; 60 | void Visit(const GetStartIdx&) override; 61 | void Visit(const GetEndIdx&) override; 62 | void Visit(const GetStartTime&) override; 63 | void Visit(const GetEndTime&) override; 64 | void Visit(const CommitData&) override; 65 | void Visit(const CommitNull&) override; 66 | void Visit(const AllocRegion&) override; 67 | void Visit(const MakeRegion&) override; 68 | void Visit(const LoopNode&) override; 69 | 70 | private: 71 | void enter_op() { ctx.nesting++; } 72 | void exit_op() { ctx.nesting--; } 73 | void enter_block() { ctx.indent++; emitnewline(); } 74 | void exit_block() { ctx.indent--; emitnewline(); } 75 | 76 | void emittab() { ostr << string(1 << tabstop, ' '); } 77 | void emitnewline() { ostr << endl << string(ctx.indent << tabstop, ' '); } 78 | void emit(string str) { ostr << str; } 79 | void emitcomment(string comment) { ostr << "/* " << comment << " */"; } 80 | 81 | void emitunary(const string op, const Expr a) 82 | { 83 | ostr << op; 84 | a->Accept(*this); 85 | } 86 | 87 | void emitbinary(const Expr a, const string op, const Expr b) 88 | { 89 | ostr << "("; 90 | a->Accept(*this); 91 | ostr << " " << op << " "; 92 | b->Accept(*this); 93 | ostr << ")"; 94 | } 95 | 96 | void emitassign(const Expr lhs, const Expr rhs) 97 | { 98 | lhs->Accept(*this); 99 | ostr << " = "; 100 | rhs->Accept(*this); 101 | ostr << ";"; 102 | } 103 | 104 | void emitfunc(const string name, const vector args) 105 | { 106 | ostr << name << "("; 107 | for (size_t i = 0; i < args.size()-1; i++) { 108 | args[i]->Accept(*this); 109 | ostr << ", "; 110 | } 111 | args.back()->Accept(*this); 112 | ostr << ")"; 113 | } 114 | 115 | IRPrinterCtx ctx; 116 | size_t tabstop; 117 | ostringstream ostr; 118 | }; 119 | 120 | } // namespace tilt 121 | 122 | #endif // INCLUDE_TILT_PASS_PRINTER_H_ 123 | -------------------------------------------------------------------------------- /include/tilt/pass/codegen/llvmgen.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_CODEGEN_LLVMGEN_H_ 2 | #define INCLUDE_TILT_PASS_CODEGEN_LLVMGEN_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "tilt/pass/irgen.h" 12 | 13 | #include "llvm/IR/LLVMContext.h" 14 | #include "llvm/IR/IRBuilder.h" 15 | #include "llvm/IR/Module.h" 16 | #include "llvm/IR/Verifier.h" 17 | #include "llvm/Linker/Linker.h" 18 | #include "llvm/IRReader/IRReader.h" 19 | #include "llvm/Support/SourceMgr.h" 20 | #include "llvm/Support/MemoryBuffer.h" 21 | 22 | using namespace std; 23 | 24 | extern const char* vinstr_str; 25 | 26 | namespace tilt { 27 | 28 | class LLVMGenCtx : public IRGenCtx { 29 | public: 30 | LLVMGenCtx(const LoopNode* loop, llvm::LLVMContext* llctx) : 31 | IRGenCtx(nullptr, &loop->syms, new map()), 32 | loop(loop), llctx(llctx), map_backup(unique_ptr>(out_sym_tbl)) 33 | {} 34 | 35 | private: 36 | const LoopNode* loop; 37 | llvm::LLVMContext* llctx; 38 | unique_ptr> map_backup; 39 | friend class LLVMGen; 40 | }; 41 | 42 | class LLVMGen : public IRGen { 43 | public: 44 | explicit LLVMGen(LLVMGenCtx llgenctx) : 45 | _ctx(std::move(llgenctx)), _llctx(*ctx().llctx), 46 | _llmod(make_unique(ctx().loop->name, _llctx)), 47 | _builder(make_unique>(_llctx)) 48 | { 49 | register_vinstrs(); 50 | } 51 | 52 | static unique_ptr Build(const Loop, llvm::LLVMContext&); 53 | 54 | private: 55 | LLVMGenCtx& ctx() override { return _ctx; } 56 | 57 | llvm::Value* visit(const Symbol&) final; 58 | llvm::Value* visit(const Out&) final { throw std::runtime_error("Invalid expression"); } 59 | llvm::Value* visit(const Beat&) final { throw std::runtime_error("Invalid expression"); } 60 | llvm::Value* visit(const Call&) final; 61 | llvm::Value* visit(const IfElse&) final; 62 | llvm::Value* visit(const Select&) final; 63 | llvm::Value* visit(const Get&) final; 64 | llvm::Value* visit(const New&) final; 65 | llvm::Value* visit(const Exists&) final; 66 | llvm::Value* visit(const ConstNode&) final; 67 | llvm::Value* visit(const Cast&) final; 68 | llvm::Value* visit(const NaryExpr&) final; 69 | llvm::Value* visit(const SubLStream&) final { throw std::runtime_error("Invalid expression"); } 70 | llvm::Value* visit(const Element&) final { throw std::runtime_error("Invalid expression"); } 71 | llvm::Value* visit(const OpNode&) final { throw std::runtime_error("Invalid expression"); } 72 | llvm::Value* visit(const Reduce&) final { throw std::runtime_error("Invalid expression"); } 73 | llvm::Value* visit(const Fetch&) final; 74 | llvm::Value* visit(const Read&) final; 75 | llvm::Value* visit(const Write&) final; 76 | llvm::Value* visit(const Advance&) final; 77 | llvm::Value* visit(const GetCkpt&) final; 78 | llvm::Value* visit(const GetStartIdx&) final; 79 | llvm::Value* visit(const GetEndIdx&) final; 80 | llvm::Value* visit(const GetStartTime&) final; 81 | llvm::Value* visit(const GetEndTime&) final; 82 | llvm::Value* visit(const CommitData&) final; 83 | llvm::Value* visit(const CommitNull&) final; 84 | llvm::Value* visit(const AllocRegion&) final; 85 | llvm::Value* visit(const MakeRegion&) final; 86 | llvm::Value* visit(const LoopNode&) final; 87 | 88 | void set_expr(const Sym& sym_ptr, llvm::Value* val) override 89 | { 90 | IRGen::set_expr(sym_ptr, val); 91 | val->setName(sym_ptr->name); 92 | } 93 | 94 | void register_vinstrs(); 95 | 96 | llvm::Function* llfunc(const string, llvm::Type*, vector); 97 | llvm::Value* llcall(const string, llvm::Type*, vector); 98 | llvm::Value* llcall(const string, llvm::Type*, vector); 99 | 100 | llvm::Value* llsizeof(llvm::Type*); 101 | 102 | llvm::Type* lltype(const DataType&); 103 | llvm::Type* lltype(const Type&); 104 | llvm::Type* lltype(const ExprNode& expr) { return lltype(expr.type); } 105 | llvm::Type* lltype(const Expr& expr) { return lltype(expr->type); } 106 | 107 | llvm::Type* llregtype() { return llvm::StructType::getTypeByName(llctx(), "struct.region_t"); } 108 | llvm::Type* llregptrtype() { return llvm::PointerType::get(llregtype(), 0); } 109 | 110 | llvm::Module* llmod() { return _llmod.get(); } 111 | llvm::LLVMContext& llctx() { return _llctx; } 112 | llvm::IRBuilder<>* builder() { return _builder.get(); } 113 | 114 | LLVMGenCtx _ctx; 115 | llvm::LLVMContext& _llctx; 116 | unique_ptr _llmod; 117 | unique_ptr> _builder; 118 | }; 119 | 120 | } // namespace tilt 121 | 122 | #endif // INCLUDE_TILT_PASS_CODEGEN_LLVMGEN_H_ 123 | -------------------------------------------------------------------------------- /include/tilt/builder/tilder.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_BUILDER_TILDER_H_ 2 | #define INCLUDE_TILT_BUILDER_TILDER_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tilt/ir/expr.h" 9 | #include "tilt/ir/lstream.h" 10 | #include "tilt/ir/op.h" 11 | #include "tilt/ir/loop.h" 12 | 13 | namespace tilt::tilder { 14 | 15 | template 16 | struct _expr; 17 | 18 | _expr _expr_add(Expr, Expr); 19 | _expr _expr_sub(Expr, Expr); 20 | _expr _expr_mul(Expr, Expr); 21 | _expr
_expr_div(Expr, Expr); 22 | _expr _expr_neg(Expr); 23 | _expr _expr_mod(Expr, Expr); 24 | _expr _expr_lt(Expr, Expr); 25 | _expr _expr_lte(Expr, Expr); 26 | _expr _expr_gt(Expr, Expr); 27 | _expr _expr_gte(Expr, Expr); 28 | _expr _expr_eq(Expr, Expr); 29 | _expr _expr_not(Expr); 30 | _expr _expr_and(Expr, Expr); 31 | _expr _expr_or(Expr, Expr); 32 | _expr _expr_get(Expr, size_t); 33 | _expr _expr_elem(Sym, Point); 34 | _expr _expr_subls(Sym, Window); 35 | 36 | template 37 | struct _expr : public shared_ptr { 38 | explicit _expr(shared_ptr&& ptr) : shared_ptr(std::move(ptr)) {} 39 | 40 | _expr operator+(Expr o) const { return _expr_add(*this, o); } 41 | _expr operator-(Expr o) const { return _expr_sub(*this, o); } 42 | _expr operator*(Expr o) const { return _expr_mul(*this, o); } 43 | _expr
operator/(Expr o) const { return _expr_div(*this, o); } 44 | _expr operator-() const { return _expr_neg(*this); } 45 | _expr operator%(Expr o) const { return _expr_mod(*this, o); } 46 | _expr operator<(Expr o) const { return _expr_lt(*this, o); } 47 | _expr operator<=(Expr o) const { return _expr_lte(*this, o); } 48 | _expr operator>(Expr o) const { return _expr_gt(*this, o); } 49 | _expr operator>=(Expr o) const { return _expr_gte(*this, o); } 50 | _expr operator==(Expr o) const { return _expr_eq(*this, o); } 51 | _expr operator!() const { return _expr_not(*this); } 52 | _expr operator&&(Expr o) const { return _expr_and(*this, o); } 53 | _expr operator||(Expr o) const { return _expr_or(*this, o); } 54 | _expr operator<<(size_t n) const { return _expr_get(*this, n); } 55 | }; 56 | 57 | // Symbol 58 | struct _sym : public _expr { 59 | _sym(string name, Type type) : _expr(make_shared(name, type)) {} 60 | _sym(string name, Expr expr) : _expr(make_shared(name, expr)) {} 61 | explicit _sym(const Symbol& symbol) : _sym(symbol.name, symbol.type) {} 62 | 63 | _expr operator[](Point pt) const { return _expr_elem(*this, pt); } 64 | _expr operator[](Window win) const { return _expr_subls(*this, win); } 65 | }; 66 | 67 | struct _out : public _expr { 68 | explicit _out(DataType dtype) : _expr(make_shared(dtype)) {} 69 | explicit _out(const Out& out) : _out(out.type.dtype) {} 70 | 71 | _expr operator[](Point pt) const { return _expr_elem(*this, pt); } 72 | _expr operator[](Window win) const { return _expr_subls(*this, win); } 73 | }; 74 | 75 | struct _beat : public _expr { 76 | explicit _beat(Iter iter) : _expr(make_shared(iter)) {} 77 | explicit _beat(const Beat& beat) : _beat(beat.type.iter) {} 78 | 79 | _expr operator[](Point pt) const { return _expr_elem(*this, pt); } 80 | _expr operator[](Window win) const { return _expr_subls(*this, win); } 81 | }; 82 | 83 | #define REGISTER_EXPR(NAME, EXPR) \ 84 | template \ 85 | struct NAME : public _expr { \ 86 | explicit NAME(Args... args) : \ 87 | _expr(std::move(make_shared(std::forward(args)...))) \ 88 | {} \ 89 | }; 90 | 91 | // Arithmetic expressions 92 | REGISTER_EXPR(_add, Add) 93 | REGISTER_EXPR(_sub, Sub) 94 | REGISTER_EXPR(_mul, Mul) 95 | REGISTER_EXPR(_div, Div) 96 | REGISTER_EXPR(_max, Max) 97 | REGISTER_EXPR(_min, Min) 98 | REGISTER_EXPR(_abs, Abs) 99 | REGISTER_EXPR(_neg, Neg) 100 | REGISTER_EXPR(_mod, Mod) 101 | REGISTER_EXPR(_sqrt, Sqrt) 102 | REGISTER_EXPR(_pow, Pow) 103 | REGISTER_EXPR(_ceil, Ceil) 104 | REGISTER_EXPR(_floor, Floor) 105 | REGISTER_EXPR(_lt, LessThan) 106 | REGISTER_EXPR(_lte, LessThanEqual) 107 | REGISTER_EXPR(_gt, GreaterThan) 108 | REGISTER_EXPR(_gte, GreaterThanEqual) 109 | REGISTER_EXPR(_eq, Equals) 110 | 111 | // Logical expressions 112 | REGISTER_EXPR(_exists, Exists) 113 | REGISTER_EXPR(_not, Not) 114 | REGISTER_EXPR(_and, And) 115 | REGISTER_EXPR(_or, Or) 116 | 117 | // Constant expressions 118 | REGISTER_EXPR(_const, ConstNode) 119 | 120 | // LStream operations 121 | REGISTER_EXPR(_subls, SubLStream) 122 | REGISTER_EXPR(_elem, Element) 123 | REGISTER_EXPR(_op, OpNode) 124 | 125 | // Misc expressions 126 | REGISTER_EXPR(_call, Call) 127 | REGISTER_EXPR(_read, Read) 128 | REGISTER_EXPR(_get, Get) 129 | REGISTER_EXPR(_new, New) 130 | REGISTER_EXPR(_ifelse, IfElse) 131 | REGISTER_EXPR(_sel, Select) 132 | REGISTER_EXPR(_red, Reduce) 133 | REGISTER_EXPR(_cast, Cast) 134 | 135 | // Loop IR expressions 136 | REGISTER_EXPR(_time, TimeNode) 137 | REGISTER_EXPR(_index, IndexNode) 138 | REGISTER_EXPR(_fetch, Fetch) 139 | REGISTER_EXPR(_write, Write) 140 | REGISTER_EXPR(_adv, Advance) 141 | REGISTER_EXPR(_get_ckpt, GetCkpt) 142 | REGISTER_EXPR(_get_start_idx, GetStartIdx) 143 | REGISTER_EXPR(_get_end_idx, GetEndIdx) 144 | REGISTER_EXPR(_get_start_time, GetStartTime) 145 | REGISTER_EXPR(_get_end_time, GetEndTime) 146 | REGISTER_EXPR(_commit_data, CommitData) 147 | REGISTER_EXPR(_commit_null, CommitNull) 148 | REGISTER_EXPR(_alloc_reg, AllocRegion) 149 | REGISTER_EXPR(_make_reg, MakeRegion) 150 | REGISTER_EXPR(_loop, LoopNode) 151 | 152 | #undef REGISTER_EXPR 153 | 154 | 155 | _expr _i8(int8_t); 156 | _expr _i16(int16_t); 157 | _expr _i32(int32_t); 158 | _expr _i64(int64_t); 159 | _expr _u8(uint8_t); 160 | _expr _u16(uint16_t); 161 | _expr _u32(uint32_t); 162 | _expr _u64(uint64_t); 163 | _expr _f32(float); 164 | _expr _f64(double); 165 | _expr _ch(char); 166 | _expr _ts(ts_t); 167 | _expr _idx(idx_t); 168 | _expr _true(); 169 | _expr _false(); 170 | 171 | using _iter = Iter; 172 | using _pt = Point; 173 | using _win = Window; 174 | 175 | } // namespace tilt::tilder 176 | 177 | #endif // INCLUDE_TILT_BUILDER_TILDER_H_ 178 | -------------------------------------------------------------------------------- /include/tilt/ir/loop.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_IR_LOOP_H_ 2 | #define INCLUDE_TILT_IR_LOOP_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "tilt/ir/node.h" 11 | #include "tilt/ir/expr.h" 12 | 13 | using namespace std; 14 | 15 | namespace tilt { 16 | 17 | struct TimeNode : public Symbol { 18 | explicit TimeNode(string name) : Symbol(name, Type(types::TIME)) {} 19 | }; 20 | typedef shared_ptr Time; 21 | 22 | struct IndexNode : public Symbol { 23 | explicit IndexNode(string name) : Symbol(name, Type(types::INDEX)) {} 24 | }; 25 | typedef shared_ptr Index; 26 | 27 | struct Fetch : public ValNode { 28 | Expr reg; 29 | Expr time; 30 | Expr idx; 31 | 32 | Fetch(Expr reg, Expr time, Expr idx) : 33 | ValNode(reg->type.dtype.ptr()), reg(reg), time(time), idx(idx) 34 | { 35 | ASSERT(!reg->type.is_val()); 36 | ASSERT(time->type.dtype == types::TIME); 37 | ASSERT(idx->type.dtype == types::INDEX); 38 | } 39 | 40 | void Accept(Visitor&) const final; 41 | }; 42 | 43 | struct Read : public ValNode { 44 | Expr ptr; 45 | 46 | explicit Read(Expr ptr) : ValNode(ptr->type.dtype.deref()), ptr(ptr) {} 47 | 48 | void Accept(Visitor&) const final; 49 | }; 50 | 51 | struct Write : public ExprNode { 52 | Expr reg; 53 | Expr ptr; 54 | Expr data; 55 | 56 | Write(Expr reg, Expr ptr, Expr data) : 57 | ExprNode(reg->type), reg(reg), ptr(ptr), data(data) 58 | { 59 | ASSERT(ptr->type.dtype.is_ptr()); 60 | } 61 | 62 | void Accept(Visitor&) const final; 63 | }; 64 | 65 | struct Advance : public ValNode { 66 | Expr reg; 67 | Expr idx; 68 | Expr time; 69 | 70 | Advance(Expr reg, Expr idx, Expr time) : 71 | ValNode(types::INDEX), reg(reg), idx(idx), time(time) 72 | { 73 | ASSERT(!reg->type.is_val()); 74 | ASSERT(idx->type.dtype == types::INDEX); 75 | ASSERT(time->type.dtype == types::TIME); 76 | } 77 | 78 | void Accept(Visitor&) const final; 79 | }; 80 | 81 | struct GetCkpt : public ValNode { 82 | Expr reg; 83 | Expr time; 84 | Expr idx; 85 | 86 | GetCkpt(Expr reg, Expr time, Expr idx) : 87 | ValNode(types::TIME), reg(reg), time(time), idx(idx) 88 | { 89 | ASSERT(!reg->type.is_val()); 90 | ASSERT(time->type.dtype == types::TIME); 91 | ASSERT(idx->type.dtype == types::INDEX); 92 | } 93 | 94 | void Accept(Visitor&) const final; 95 | }; 96 | 97 | struct GetStartIdx : public ValNode { 98 | Expr reg; 99 | 100 | explicit GetStartIdx(Expr reg) : ValNode(types::INDEX), reg(reg) 101 | { 102 | ASSERT(!reg->type.is_val()); 103 | } 104 | 105 | void Accept(Visitor&) const final; 106 | }; 107 | 108 | struct GetEndIdx : public ValNode { 109 | Expr reg; 110 | 111 | explicit GetEndIdx(Expr reg) : ValNode(types::INDEX), reg(reg) 112 | { 113 | ASSERT(!reg->type.is_val()); 114 | } 115 | 116 | void Accept(Visitor&) const final; 117 | }; 118 | 119 | struct GetStartTime : public ValNode { 120 | Expr reg; 121 | 122 | explicit GetStartTime(Expr reg) : ValNode(types::TIME), reg(reg) 123 | { 124 | ASSERT(!reg->type.is_val()); 125 | } 126 | 127 | void Accept(Visitor&) const final; 128 | }; 129 | 130 | struct GetEndTime : public ValNode { 131 | Expr reg; 132 | 133 | explicit GetEndTime(Expr reg) : ValNode(types::TIME), reg(reg) 134 | { 135 | ASSERT(!reg->type.is_val()); 136 | } 137 | 138 | void Accept(Visitor&) const final; 139 | }; 140 | 141 | struct CommitData : public ExprNode { 142 | Expr reg; 143 | Expr time; 144 | 145 | CommitData(Expr reg, Expr time) : 146 | ExprNode(reg->type), reg(reg), time(time) 147 | { 148 | ASSERT(!reg->type.is_val()); 149 | ASSERT(time->type.dtype == types::TIME); 150 | } 151 | 152 | void Accept(Visitor&) const final; 153 | }; 154 | 155 | struct CommitNull : public ExprNode { 156 | Expr reg; 157 | Expr time; 158 | 159 | CommitNull(Expr reg, Expr time) : 160 | ExprNode(reg->type), reg(reg), time(time) 161 | { 162 | ASSERT(!reg->type.is_val()); 163 | ASSERT(time->type.dtype == types::TIME); 164 | } 165 | 166 | void Accept(Visitor&) const final; 167 | }; 168 | 169 | struct AllocRegion : public ExprNode { 170 | Val size; 171 | Expr start_time; 172 | 173 | AllocRegion(Type type, Val size, Expr start_time) : 174 | ExprNode(type), size(size), start_time(start_time) 175 | { 176 | ASSERT(!type.is_val()); 177 | ASSERT(size->type.dtype == types::INDEX); 178 | ASSERT(start_time->type.dtype == types::TIME); 179 | } 180 | 181 | void Accept(Visitor&) const final; 182 | }; 183 | 184 | struct MakeRegion : public ExprNode { 185 | Expr reg; 186 | Expr st; 187 | Expr si; 188 | Expr et; 189 | Expr ei; 190 | 191 | MakeRegion(Expr reg, Expr st, Expr si, Expr et, Expr ei) : 192 | ExprNode(reg->type), reg(reg), st(st), si(si), et(et), ei(ei) 193 | { 194 | ASSERT(!reg->type.is_val()); 195 | ASSERT(st->type.dtype == types::TIME); 196 | ASSERT(si->type.dtype == types::INDEX); 197 | ASSERT(et->type.dtype == types::TIME); 198 | ASSERT(ei->type.dtype == types::INDEX); 199 | } 200 | 201 | void Accept(Visitor&) const final; 202 | }; 203 | 204 | struct IfElse : public ExprNode { 205 | Expr cond; 206 | Expr true_body; 207 | Expr false_body; 208 | 209 | IfElse(Expr cond, Expr true_body, Expr false_body) : 210 | ExprNode(true_body->type), cond(cond), true_body(true_body), false_body(false_body) 211 | { 212 | ASSERT(cond->type.dtype == types::BOOL); 213 | ASSERT(true_body->type.dtype == false_body->type.dtype); 214 | } 215 | 216 | void Accept(Visitor&) const final; 217 | }; 218 | 219 | struct LoopNode : public FuncNode { 220 | // Loop counter 221 | Time t; 222 | 223 | // Indices 224 | vector idxs; 225 | 226 | // States 227 | map state_bases; 228 | 229 | // loop condition 230 | Expr exit_cond; 231 | 232 | // Inner loops 233 | vector> inner_loops; 234 | 235 | LoopNode(string name, Type type) : FuncNode(name, std::move(type)) {} 236 | explicit LoopNode(Sym sym) : LoopNode(sym->name, sym->type) {} 237 | 238 | const string get_name() const override { return "loop_" + this->name; } 239 | 240 | void Accept(Visitor&) const final; 241 | }; 242 | typedef shared_ptr Loop; 243 | 244 | } // namespace tilt 245 | 246 | #endif // INCLUDE_TILT_IR_LOOP_H_ 247 | -------------------------------------------------------------------------------- /include/tilt/pass/irgen.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_PASS_IRGEN_H_ 2 | #define INCLUDE_TILT_PASS_IRGEN_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tilt/pass/visitor.h" 9 | #include "tilt/builder/tilder.h" 10 | 11 | using namespace std; 12 | 13 | namespace tilt { 14 | 15 | template 16 | class IRGen; 17 | 18 | template 19 | class IRGenCtx { 20 | protected: 21 | IRGenCtx(Sym sym, const map* in_sym_tbl, map* out_sym_tbl) : 22 | sym(sym), in_sym_tbl(in_sym_tbl), out_sym_tbl(out_sym_tbl) 23 | {} 24 | 25 | Sym sym; 26 | const map* in_sym_tbl; 27 | map* out_sym_tbl; 28 | map sym_map; 29 | OutExprTy val; 30 | 31 | template 32 | friend class IRGen; 33 | }; 34 | 35 | template 36 | class IRGen : public Visitor { 37 | protected: 38 | virtual CtxTy& ctx() = 0; 39 | 40 | virtual OutExprTy visit(const Symbol&) = 0; 41 | virtual OutExprTy visit(const Out&) = 0; 42 | virtual OutExprTy visit(const Beat&) = 0; 43 | virtual OutExprTy visit(const IfElse&) = 0; 44 | virtual OutExprTy visit(const Select&) = 0; 45 | virtual OutExprTy visit(const Get&) = 0; 46 | virtual OutExprTy visit(const New&) = 0; 47 | virtual OutExprTy visit(const Exists&) = 0; 48 | virtual OutExprTy visit(const ConstNode&) = 0; 49 | virtual OutExprTy visit(const Cast&) = 0; 50 | virtual OutExprTy visit(const NaryExpr&) = 0; 51 | virtual OutExprTy visit(const SubLStream&) = 0; 52 | virtual OutExprTy visit(const Element&) = 0; 53 | virtual OutExprTy visit(const OpNode&) = 0; 54 | virtual OutExprTy visit(const Reduce&) = 0; 55 | virtual OutExprTy visit(const Fetch&) = 0; 56 | virtual OutExprTy visit(const Read&) = 0; 57 | virtual OutExprTy visit(const Write&) = 0; 58 | virtual OutExprTy visit(const Advance&) = 0; 59 | virtual OutExprTy visit(const GetCkpt&) = 0; 60 | virtual OutExprTy visit(const GetStartIdx&) = 0; 61 | virtual OutExprTy visit(const GetEndIdx&) = 0; 62 | virtual OutExprTy visit(const GetStartTime&) = 0; 63 | virtual OutExprTy visit(const GetEndTime&) = 0; 64 | virtual OutExprTy visit(const CommitData&) = 0; 65 | virtual OutExprTy visit(const CommitNull&) = 0; 66 | virtual OutExprTy visit(const AllocRegion&) = 0; 67 | virtual OutExprTy visit(const MakeRegion&) = 0; 68 | virtual OutExprTy visit(const Call&) = 0; 69 | virtual OutExprTy visit(const LoopNode&) = 0; 70 | 71 | void Visit(const Out& expr) final { val() = visit(expr); } 72 | void Visit(const Beat& expr) final { val() = visit(expr); } 73 | void Visit(const IfElse& expr) final { val() = visit(expr); } 74 | void Visit(const Select& expr) final { val() = visit(expr); } 75 | void Visit(const Get& expr) final { val() = visit(expr); } 76 | void Visit(const New& expr) final { val() = visit(expr); } 77 | void Visit(const Exists& expr) final { val() = visit(expr); } 78 | void Visit(const ConstNode& expr) final { val() = visit(expr); } 79 | void Visit(const Cast& expr) final { val() = visit(expr); } 80 | void Visit(const NaryExpr& expr) final { val() = visit(expr); } 81 | void Visit(const SubLStream& expr) final { val() = visit(expr); } 82 | void Visit(const Element& expr) final { val() = visit(expr); } 83 | void Visit(const OpNode& expr) final { val() = visit(expr); } 84 | void Visit(const Reduce& expr) final { val() = visit(expr); } 85 | void Visit(const Fetch& expr) final { val() = visit(expr); } 86 | void Visit(const Read& expr) final { val() = visit(expr); } 87 | void Visit(const Write& expr) final { val() = visit(expr); } 88 | void Visit(const Advance& expr) final { val() = visit(expr); } 89 | void Visit(const GetCkpt& expr) final { val() = visit(expr); } 90 | void Visit(const GetStartIdx& expr) final { val() = visit(expr); } 91 | void Visit(const GetEndIdx& expr) final { val() = visit(expr); } 92 | void Visit(const GetStartTime& expr) final { val() = visit(expr); } 93 | void Visit(const GetEndTime& expr) final { val() = visit(expr); } 94 | void Visit(const CommitData& expr) final { val() = visit(expr); } 95 | void Visit(const CommitNull& expr) final { val() = visit(expr); } 96 | void Visit(const AllocRegion& expr) final { val() = visit(expr); } 97 | void Visit(const MakeRegion& expr) final { val() = visit(expr); } 98 | void Visit(const Call& expr) final { val() = visit(expr); } 99 | void Visit(const LoopNode& expr) final { val() = visit(expr); } 100 | 101 | CtxTy& switch_ctx(CtxTy& new_ctx) { swap(new_ctx, ctx()); return new_ctx; } 102 | 103 | Sym tmp_sym(const Symbol& symbol) 104 | { 105 | shared_ptr tmp_sym(const_cast(&symbol), [](Symbol*) {}); 106 | return tmp_sym; 107 | } 108 | 109 | OutExprTy get_expr(const Sym& sym) { auto& m = *(ctx().out_sym_tbl); return m.at(sym); } 110 | OutExprTy get_expr(const Symbol& symbol) { return get_expr(tmp_sym(symbol)); } 111 | 112 | virtual void set_expr(const Sym& sym, OutExprTy val) 113 | { 114 | set_sym(sym, sym); 115 | auto& m = *(ctx().out_sym_tbl); 116 | m[sym] = val; 117 | } 118 | void set_expr(const Symbol& symbol, OutExprTy val) { set_expr(tmp_sym(symbol), val); } 119 | 120 | Sym& get_sym(const Sym& in_sym) { return ctx().sym_map.at(in_sym); } 121 | Sym& get_sym(const Symbol& symbol) { return get_sym(tmp_sym(symbol)); } 122 | void set_sym(const Sym& in_sym, const Sym out_sym) { ctx().sym_map[in_sym] = out_sym; } 123 | void set_sym(const Symbol& in_symbol, const Sym out_sym) { set_sym(tmp_sym(in_symbol), out_sym); } 124 | 125 | OutExprTy& val() { return ctx().val; } 126 | 127 | OutExprTy eval(const InExprTy expr) 128 | { 129 | OutExprTy val = nullptr; 130 | 131 | swap(val, ctx().val); 132 | expr->Accept(*this); 133 | swap(ctx().val, val); 134 | 135 | return val; 136 | } 137 | 138 | void Visit(const Symbol& symbol) final 139 | { 140 | auto tmp = tmp_sym(symbol); 141 | 142 | if (ctx().sym_map.find(tmp) == ctx().sym_map.end()) { 143 | auto expr = ctx().in_sym_tbl->at(tmp); 144 | 145 | swap(ctx().sym, tmp); 146 | auto value = eval(expr); 147 | swap(tmp, ctx().sym); 148 | 149 | auto sym_clone = tilder::_sym(symbol); 150 | set_sym(tmp, sym_clone); 151 | this->set_expr(sym_clone, value); 152 | } 153 | 154 | val() = visit(symbol); 155 | } 156 | }; 157 | 158 | } // namespace tilt 159 | 160 | #endif // INCLUDE_TILT_PASS_IRGEN_H_ 161 | -------------------------------------------------------------------------------- /include/tilt/ir/expr.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_IR_EXPR_H_ 2 | #define INCLUDE_TILT_IR_EXPR_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "tilt/base/type.h" 11 | #include "tilt/base/log.h" 12 | #include "tilt/ir/node.h" 13 | 14 | using namespace std; 15 | 16 | namespace tilt { 17 | 18 | struct Call : public ExprNode { 19 | string name; 20 | vector args; 21 | 22 | Call(string name, Type type, vector args) : 23 | ExprNode(type), name(name), args(std::move(args)) 24 | {} 25 | 26 | void Accept(Visitor&) const final; 27 | }; 28 | 29 | struct Select : public ValNode { 30 | Expr cond; 31 | Expr true_body; 32 | Expr false_body; 33 | 34 | Select(Expr cond, Expr true_body, Expr false_body) : 35 | ValNode(true_body->type.dtype), cond(cond), true_body(true_body), false_body(false_body) 36 | { 37 | ASSERT(cond->type.dtype == types::BOOL); 38 | ASSERT(true_body->type.dtype == false_body->type.dtype); 39 | } 40 | 41 | void Accept(Visitor&) const final; 42 | }; 43 | 44 | struct Get : public ValNode { 45 | Expr input; 46 | size_t n; 47 | 48 | Get(Expr input, size_t n) : 49 | ValNode(input->type.dtype.dtypes[n]), input(input), n(n) 50 | { 51 | ASSERT(input->type.dtype.is_struct()); 52 | } 53 | 54 | void Accept(Visitor&) const final; 55 | }; 56 | 57 | struct New : public ValNode { 58 | vector inputs; 59 | 60 | explicit New(vector inputs) : 61 | ValNode(get_new_type(inputs)), inputs(inputs) 62 | {} 63 | 64 | void Accept(Visitor&) const final; 65 | 66 | private: 67 | static DataType get_new_type(vector inputs) 68 | { 69 | vector dtypes; 70 | for (const auto& input : inputs) { 71 | dtypes.push_back(input->type.dtype); 72 | } 73 | return DataType(BaseType::STRUCT, (dtypes)); 74 | } 75 | }; 76 | 77 | struct ConstNode : public ValNode { 78 | const double val; 79 | 80 | ConstNode(BaseType btype, double val) : 81 | ValNode(DataType(btype)), val(val) 82 | {} 83 | 84 | void Accept(Visitor&) const final; 85 | }; 86 | typedef shared_ptr Const; 87 | 88 | struct Exists : public ValNode { 89 | Sym sym; 90 | 91 | explicit Exists(Sym sym) : ValNode(types::BOOL), sym(sym) {} 92 | 93 | void Accept(Visitor&) const final; 94 | }; 95 | 96 | struct Cast : public ValNode { 97 | Expr arg; 98 | 99 | Cast(DataType dtype, Expr arg) : ValNode(dtype), arg(arg) 100 | { 101 | ASSERT(!arg->type.dtype.is_struct() && !dtype.is_struct()); 102 | } 103 | 104 | void Accept(Visitor&) const final; 105 | }; 106 | 107 | struct NaryExpr : public ValNode { 108 | MathOp op; 109 | vector args; 110 | 111 | NaryExpr(DataType dtype, MathOp op, vector args) : 112 | ValNode(dtype), op(op), args(std::move(args)) 113 | { 114 | ASSERT(!arg(0)->type.dtype.is_ptr() && !arg(0)->type.dtype.is_struct()); 115 | } 116 | 117 | Expr arg(size_t i) const { return args[i]; } 118 | 119 | size_t size() const { return args.size(); } 120 | 121 | void Accept(Visitor&) const final; 122 | }; 123 | 124 | struct UnaryExpr : public NaryExpr { 125 | UnaryExpr(DataType dtype, MathOp op, Expr input) 126 | : NaryExpr(dtype, op, vector{input}) 127 | {} 128 | }; 129 | 130 | struct BinaryExpr : public NaryExpr { 131 | BinaryExpr(DataType dtype, MathOp op, Expr left, Expr right) 132 | : NaryExpr(dtype, op, vector{left, right}) 133 | { 134 | ASSERT(left->type == right->type); 135 | } 136 | }; 137 | 138 | struct Not : public UnaryExpr { 139 | explicit Not(Expr a) : UnaryExpr(types::BOOL, MathOp::NOT, a) 140 | { 141 | ASSERT(a->type.dtype == types::BOOL); 142 | } 143 | }; 144 | 145 | struct Abs : public UnaryExpr { 146 | explicit Abs(Expr a) : UnaryExpr(a->type.dtype, MathOp::ABS, a) {} 147 | }; 148 | 149 | struct Neg : public UnaryExpr { 150 | explicit Neg(Expr a) : UnaryExpr(a->type.dtype, MathOp::NEG, a) {} 151 | }; 152 | 153 | struct Sqrt : public UnaryExpr { 154 | explicit Sqrt(Expr a) : UnaryExpr(a->type.dtype, MathOp::SQRT, a) {} 155 | }; 156 | 157 | struct Ceil : public UnaryExpr { 158 | explicit Ceil(Expr a) : UnaryExpr(a->type.dtype, MathOp::CEIL, a) { 159 | ASSERT(a->type.dtype.is_float()); 160 | } 161 | }; 162 | 163 | struct Floor : public UnaryExpr { 164 | explicit Floor(Expr a) : UnaryExpr(a->type.dtype, MathOp::FLOOR, a) { 165 | ASSERT(a->type.dtype.is_float()); 166 | } 167 | }; 168 | 169 | struct Equals : public BinaryExpr { 170 | Equals(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::EQ, a, b) {} 171 | }; 172 | 173 | struct And : public BinaryExpr { 174 | And(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::AND, a, b) 175 | { 176 | ASSERT(a->type.dtype == types::BOOL); 177 | } 178 | }; 179 | 180 | struct Or : public BinaryExpr { 181 | Or(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::OR, a, b) 182 | { 183 | ASSERT(a->type.dtype == types::BOOL); 184 | } 185 | }; 186 | 187 | struct LessThan : public BinaryExpr { 188 | LessThan(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::LT, a, b) {} 189 | }; 190 | 191 | struct GreaterThan : public BinaryExpr { 192 | GreaterThan(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::GT, a, b) {} 193 | }; 194 | 195 | struct LessThanEqual : public BinaryExpr { 196 | LessThanEqual(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::LTE, a, b) {} 197 | }; 198 | 199 | struct GreaterThanEqual : public BinaryExpr { 200 | GreaterThanEqual(Expr a, Expr b) : BinaryExpr(types::BOOL, MathOp::GTE, a, b) {} 201 | }; 202 | 203 | struct Add : public BinaryExpr { 204 | Add(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::ADD, a, b) {} 205 | }; 206 | 207 | struct Sub : public BinaryExpr { 208 | Sub(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::SUB, a, b) {} 209 | }; 210 | 211 | struct Mul : public BinaryExpr { 212 | Mul(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::MUL, a, b) {} 213 | }; 214 | 215 | struct Div : public BinaryExpr { 216 | Div(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::DIV, a, b) {} 217 | }; 218 | 219 | struct Max : public BinaryExpr { 220 | Max(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::MAX, a, b) {} 221 | }; 222 | 223 | struct Min : public BinaryExpr { 224 | Min(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::MIN, a, b) {} 225 | }; 226 | 227 | struct Mod : public BinaryExpr { 228 | Mod(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::MOD, a, b) 229 | { 230 | ASSERT(!a->type.dtype.is_float()); 231 | } 232 | }; 233 | 234 | struct Pow : public BinaryExpr { 235 | Pow(Expr a, Expr b) : BinaryExpr(a->type.dtype, MathOp::POW, a, b) { 236 | ASSERT(a->type.dtype.is_float()); 237 | } 238 | }; 239 | 240 | } // namespace tilt 241 | 242 | #endif // INCLUDE_TILT_IR_EXPR_H_ 243 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /include/tilt/base/type.h: -------------------------------------------------------------------------------- 1 | #ifndef INCLUDE_TILT_BASE_TYPE_H_ 2 | #define INCLUDE_TILT_BASE_TYPE_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "tilt/base/ctype.h" 10 | #include "tilt/base/log.h" 11 | 12 | using namespace std; 13 | 14 | namespace tilt { 15 | 16 | enum class BaseType { 17 | BOOL, 18 | INT8, 19 | INT16, 20 | INT32, 21 | INT64, 22 | UINT8, 23 | UINT16, 24 | UINT32, 25 | UINT64, 26 | FLOAT32, 27 | FLOAT64, 28 | STRUCT, 29 | PTR, 30 | UNKNOWN, 31 | 32 | // Loop IR types 33 | TIME, 34 | INDEX, 35 | IVAL, 36 | }; 37 | 38 | struct DataType { 39 | const BaseType btype; 40 | const vector dtypes; 41 | const size_t size; 42 | 43 | DataType(BaseType btype, vector dtypes, size_t size = 0) : 44 | btype(btype), dtypes(dtypes), size(size) 45 | { 46 | switch (btype) { 47 | case BaseType::STRUCT: ASSERT(dtypes.size() > 0); break; 48 | case BaseType::PTR: ASSERT(dtypes.size() == 1); break; 49 | default: ASSERT(dtypes.size() == 0); break; 50 | } 51 | } 52 | 53 | explicit DataType(BaseType btype, size_t size = 0) : 54 | DataType(btype, {}, size) 55 | {} 56 | 57 | bool operator==(const DataType& o) const 58 | { 59 | return (this->btype == o.btype) 60 | && (this->dtypes == o.dtypes) 61 | && (this->size == o.size); 62 | } 63 | 64 | bool is_struct() const { return btype == BaseType::STRUCT; } 65 | bool is_ptr() const { return btype == BaseType::PTR; } 66 | bool is_arr() const { return size > 0; } 67 | 68 | bool is_float() const 69 | { 70 | return (this->btype == BaseType::FLOAT32) 71 | || (this->btype == BaseType::FLOAT64); 72 | } 73 | 74 | bool is_int() const 75 | { 76 | return (this->btype == BaseType::INT8) 77 | || (this->btype == BaseType::INT16) 78 | || (this->btype == BaseType::INT32) 79 | || (this->btype == BaseType::INT64) 80 | || (this->btype == BaseType::UINT8) 81 | || (this->btype == BaseType::UINT16) 82 | || (this->btype == BaseType::UINT32) 83 | || (this->btype == BaseType::UINT64); 84 | } 85 | 86 | bool is_signed() const 87 | { 88 | return (this->btype == BaseType::INT8) 89 | || (this->btype == BaseType::INT16) 90 | || (this->btype == BaseType::INT32) 91 | || (this->btype == BaseType::INT64) 92 | || (this->btype == BaseType::FLOAT32) 93 | || (this->btype == BaseType::FLOAT64) 94 | || (this->btype == BaseType::TIME); 95 | } 96 | 97 | DataType ptr() const { return DataType(BaseType::PTR, {*this}); } 98 | 99 | DataType deref() const 100 | { 101 | ASSERT(this->is_ptr()); 102 | return this->dtypes[0]; 103 | } 104 | 105 | string str() const 106 | { 107 | switch (btype) { 108 | case BaseType::BOOL: return "b"; 109 | case BaseType::INT8: return "i8"; 110 | case BaseType::UINT8: return "u8"; 111 | case BaseType::INT16: return "i16"; 112 | case BaseType::UINT16: return "u16"; 113 | case BaseType::INT32: return "i32"; 114 | case BaseType::UINT32: return "u32"; 115 | case BaseType::INT64: return "i64"; 116 | case BaseType::UINT64: return "u64"; 117 | case BaseType::FLOAT32: return "f32"; 118 | case BaseType::FLOAT64: return "f64"; 119 | case BaseType::TIME: return "t"; 120 | case BaseType::INDEX: return "x"; 121 | case BaseType::PTR: return "*" + dtypes[0].str(); 122 | case BaseType::STRUCT: { 123 | string res = ""; 124 | for (const auto& dtype : dtypes) { 125 | res += dtype.str() + ", "; 126 | } 127 | res.resize(res.size() - 2); 128 | return "{" + res + "}"; 129 | } 130 | case BaseType::IVAL: 131 | case BaseType::UNKNOWN: 132 | default: throw std::runtime_error("Invalid type"); 133 | } 134 | } 135 | }; 136 | 137 | struct Iter { 138 | int64_t offset; 139 | int64_t period; 140 | 141 | Iter(int64_t offset, int64_t period) : 142 | offset(offset), period(period) 143 | {} 144 | 145 | Iter() : offset(0), period(0) {} 146 | 147 | bool operator==(const Iter& o) const 148 | { 149 | return (this->offset == o.offset) 150 | && (this->period == o.period); 151 | } 152 | 153 | string str() const { return "(" + to_string(offset) + ", " + to_string(period) + ")"; } 154 | }; 155 | 156 | struct Type { 157 | const DataType dtype; 158 | const Iter iter; 159 | 160 | Type(DataType dtype, Iter iter) : 161 | dtype(std::move(dtype)), iter(iter) 162 | {} 163 | 164 | explicit Type(DataType dtype) : Type(std::move(dtype), Iter()) {} 165 | 166 | bool is_val() const { return iter.period == 0; } 167 | bool is_beat() const { return iter.period > 0 && dtype.btype == BaseType::TIME; } 168 | bool is_out() const { return iter.period == -2; } 169 | 170 | bool operator==(const Type& o) const 171 | { 172 | return (this->dtype == o.dtype) 173 | && (this->iter == o.iter); 174 | } 175 | 176 | string str() const { return iter.str() + " " + dtype.str(); } 177 | }; 178 | 179 | enum class MathOp { 180 | ADD, SUB, MUL, DIV, MAX, MIN, 181 | MOD, SQRT, POW, 182 | ABS, NEG, CEIL, FLOOR, 183 | LT, LTE, GT, GTE, EQ, 184 | NOT, AND, OR, 185 | }; 186 | 187 | } // namespace tilt 188 | 189 | namespace tilt::types { 190 | 191 | static const DataType BOOL(BaseType::BOOL); 192 | static const DataType INT8(BaseType::INT8); 193 | static const DataType INT16(BaseType::INT16); 194 | static const DataType INT32(BaseType::INT32); 195 | static const DataType INT64(BaseType::INT64); 196 | static const DataType UINT8(BaseType::UINT8); 197 | static const DataType UINT16(BaseType::UINT16); 198 | static const DataType UINT32(BaseType::UINT32); 199 | static const DataType UINT64(BaseType::UINT64); 200 | static const DataType FLOAT32(BaseType::FLOAT32); 201 | static const DataType FLOAT64(BaseType::FLOAT64); 202 | static const DataType CHAR_PTR = DataType(BaseType::PTR, {types::INT8}); 203 | static const DataType TIME(BaseType::TIME); 204 | static const DataType INDEX(BaseType::INDEX); 205 | static const DataType IVAL(BaseType::IVAL); 206 | 207 | template struct Converter { static const BaseType btype = BaseType::UNKNOWN; }; 208 | template<> struct Converter { static const BaseType btype = BaseType::BOOL; }; 209 | template<> struct Converter { static const BaseType btype = BaseType::INT8; }; 210 | template<> struct Converter { static const BaseType btype = BaseType::INT8; }; 211 | template<> struct Converter { static const BaseType btype = BaseType::INT16; }; 212 | template<> struct Converter { static const BaseType btype = BaseType::INT32; }; 213 | template<> struct Converter { static const BaseType btype = BaseType::INT64; }; 214 | template<> struct Converter { static const BaseType btype = BaseType::UINT8; }; 215 | template<> struct Converter { static const BaseType btype = BaseType::UINT16; }; 216 | template<> struct Converter { static const BaseType btype = BaseType::UINT32; }; 217 | template<> struct Converter { static const BaseType btype = BaseType::UINT64; }; 218 | template<> struct Converter { static const BaseType btype = BaseType::FLOAT32; }; 219 | template<> struct Converter { static const BaseType btype = BaseType::FLOAT64; }; 220 | 221 | template 222 | static void convert(BaseType* btypes) {} 223 | 224 | template 225 | static void convert(BaseType* btypes) 226 | { 227 | btypes[n - sizeof...(Ts) - 1] = Converter::btype; 228 | convert(btypes); 229 | } 230 | 231 | template 232 | DataType STRUCT() 233 | { 234 | vector btypes(sizeof...(Ts)); 235 | convert(btypes.data()); 236 | 237 | vector dtypes; 238 | for (const auto& btype : btypes) { 239 | dtypes.push_back(DataType(btype)); 240 | } 241 | return DataType(BaseType::STRUCT, dtypes); 242 | } 243 | 244 | } // namespace tilt::types 245 | 246 | #endif // INCLUDE_TILT_BASE_TYPE_H_ 247 | -------------------------------------------------------------------------------- /test/src/test_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "test_query.h" 7 | 8 | Op _Select(_sym in, function sel_expr) 9 | { 10 | auto e = in[_pt(0)]; 11 | auto e_sym = _sym("e", e); 12 | auto sel = sel_expr(_get(e_sym, 0)); 13 | auto sel_sym = _sym("sel", sel); 14 | auto sel_op = _op( 15 | _iter(0, 1), 16 | Params{ in }, 17 | SymTable{ {e_sym, e}, {sel_sym, sel} }, 18 | _exists(e_sym), 19 | sel_sym); 20 | return sel_op; 21 | } 22 | 23 | Expr _Count(_sym win) 24 | { 25 | auto acc = [](Expr s, Expr st, Expr et, Expr d) { return _add(s, _f32(1)); }; 26 | return _red(win, _f32(0), acc); 27 | } 28 | 29 | Expr _Sum(_sym win) 30 | { 31 | auto acc = [](Expr s, Expr st, Expr et, Expr d) { return _add(s, d); }; 32 | return _red(win, _f32(0), acc); 33 | } 34 | 35 | Op _WindowAvg(string query_name, _sym in, int64_t w) 36 | { 37 | auto window = in[_win(-w, 0)]; 38 | auto window_sym = _sym("win", window); 39 | auto count = _Count(window_sym); 40 | auto count_sym = _sym(query_name + "_count", count); 41 | auto sum = _Sum(window_sym); 42 | auto sum_sym = _sym(query_name + "_sum", sum); 43 | auto avg = sum_sym / count_sym; 44 | auto avg_sym = _sym("avg", avg); 45 | auto wc_op = _op( 46 | _iter(0, w), 47 | Params{ in }, 48 | SymTable{ {window_sym, window}, {count_sym, count}, {sum_sym, sum}, {avg_sym, avg} }, 49 | _true(), 50 | avg_sym); 51 | return wc_op; 52 | } 53 | 54 | Op _Join(_sym left, _sym right) 55 | { 56 | auto e_left = left[_pt(0)]; 57 | auto e_left_sym = _sym("left", e_left); 58 | auto e_right = right[_pt(0)]; 59 | auto e_right_sym = _sym("right", e_right); 60 | auto norm = e_left_sym - e_right_sym; 61 | auto norm_sym = _sym("norm", norm); 62 | auto left_exist = _exists(e_left_sym); 63 | auto right_exist = _exists(e_right_sym); 64 | auto join_cond = left_exist && right_exist; 65 | auto join_op = _op( 66 | _iter(0, 1), 67 | Params{ left, right }, 68 | SymTable{ 69 | {e_left_sym, e_left}, 70 | {e_right_sym, e_right}, 71 | {norm_sym, norm}, 72 | }, 73 | join_cond, 74 | norm_sym); 75 | return join_op; 76 | } 77 | 78 | Op _SelectSub(_sym in, _sym avg) 79 | { 80 | auto e = in[_pt(0)]; 81 | auto e_sym = _sym("e", e); 82 | auto res = e_sym - avg; 83 | auto res_sym = _sym("res", res); 84 | auto sel_op = _op( 85 | _iter(0, 1), 86 | Params{in, avg}, 87 | SymTable{{e_sym, e}, {res_sym, res}}, 88 | _exists(e_sym), 89 | res_sym); 90 | return sel_op; 91 | } 92 | 93 | Op _SelectDiv(_sym in, _sym std) 94 | { 95 | auto e = in[_pt(0)]; 96 | auto e_sym = _sym("e", e); 97 | auto res = e_sym / std; 98 | auto res_sym = _sym("res", res); 99 | auto sel_op = _op( 100 | _iter(0, 1), 101 | Params{in, std}, 102 | SymTable{{e_sym, e}, {res_sym, res}}, 103 | _exists(e_sym), 104 | res_sym); 105 | return sel_op; 106 | } 107 | 108 | Expr _Average(_sym win) 109 | { 110 | auto acc = [](Expr s, Expr st, Expr et, Expr d) { 111 | auto sum = _get(s, 0); 112 | auto count = _get(s, 1); 113 | return _new(vector{_add(sum, d), _add(count, _f32(1))}); 114 | }; 115 | return _red(win, _new(vector{_f32(0), _f32(0)}), acc); 116 | } 117 | 118 | Expr _StdDev(_sym win) 119 | { 120 | auto acc = [](Expr s, Expr st, Expr et, Expr d) { 121 | auto sum = _get(s, 0); 122 | auto count = _get(s, 1); 123 | return _new(vector{_add(sum, _mul(d, d)), _add(count, _f32(1))}); 124 | }; 125 | return _red(win, _new(vector{_f32(0), _f32(0)}), acc); 126 | } 127 | 128 | Op _Norm(string query_name, _sym in, int64_t len) 129 | { 130 | auto inwin = in[_win(-len, 0)]; 131 | auto inwin_sym = _sym("inwin", inwin); 132 | 133 | // avg state 134 | auto avg_state = _Average(inwin_sym); 135 | auto avg_state_sym = _sym(query_name + "_avg_state", avg_state); 136 | 137 | // avg value 138 | auto avg = _div(_get(avg_state_sym, 0), _get(avg_state_sym, 1)); 139 | auto avg_sym = _sym("avg", avg); 140 | 141 | // avg join 142 | auto avg_op = _SelectSub(inwin_sym, avg_sym); 143 | auto avg_op_sym = _sym(query_name + "_avgop", avg_op); 144 | 145 | // stddev state 146 | auto std_state = _StdDev(avg_op_sym); 147 | auto std_state_sym = _sym(query_name + "_stddev_state", std_state); 148 | 149 | // stddev value 150 | auto std = _sqrt(_div(_get(std_state_sym, 0), _get(std_state_sym, 1))); 151 | auto std_sym = _sym("std", std); 152 | 153 | // std join 154 | auto std_op = _SelectDiv(avg_op_sym, std_sym); 155 | auto std_op_sym = _sym(query_name + "_stdop", std_op); 156 | 157 | // query operation 158 | auto query_op = _op( 159 | _iter(0, len), 160 | Params{ in }, 161 | SymTable{ 162 | {inwin_sym, inwin}, 163 | {avg_state_sym, avg_state}, 164 | {avg_sym, avg}, 165 | {avg_op_sym, avg_op}, 166 | {std_state_sym, std_state}, 167 | {std_sym, std}, 168 | {std_op_sym, std_op} 169 | }, 170 | _true(), 171 | std_op_sym); 172 | 173 | return query_op; 174 | } 175 | 176 | Op _MovingSum(_sym in, int64_t dur, int64_t w) 177 | { 178 | auto e = in[_pt(0)]; 179 | auto e_sym = _sym("e", e); 180 | auto p = in[_pt(-w)]; 181 | auto p_sym = _sym("p", p); 182 | auto out = _out(types::INT32); 183 | auto o = out[_pt(-dur)]; 184 | auto o_sym = _sym("o", o); 185 | auto p_val = _ifelse(_exists(p_sym), p_sym, _i32(0)); 186 | auto p_val_sym = _sym("p_val", p_val); 187 | auto o_val = _ifelse(_exists(o_sym), o_sym, _i32(0)); 188 | auto o_val_sym = _sym("o_val", o_val); 189 | auto res = (e_sym + o_val_sym) - p_val_sym; 190 | auto res_sym = _sym("res", res); 191 | auto sel_op = _op( 192 | _iter(0, dur), 193 | Params{in}, 194 | SymTable{ 195 | {e_sym, e}, 196 | {p_sym, p}, 197 | {o_sym, o}, 198 | {p_val_sym, p_val}, 199 | {o_val_sym, o_val}, 200 | {res_sym, res}, 201 | }, 202 | _exists(e_sym), 203 | res_sym); 204 | return sel_op; 205 | } 206 | 207 | Op _Pair(_sym in, int64_t iperiod) 208 | { 209 | auto ev = in[_pt(0)]; 210 | auto ev_sym = _sym("ev", ev); 211 | auto sv = in[_pt(-iperiod)]; 212 | auto sv_sym = _sym("sv", sv); 213 | auto beat = _beat(_iter(0, iperiod)); 214 | auto et = _cast(types::FLOAT32, beat[_pt(0)]); 215 | auto et_sym = _sym("et", et); 216 | auto st = et_sym - _f32(iperiod); 217 | auto st_sym = _sym("st", st); 218 | auto res = _new(vector{st_sym, sv_sym, et_sym, ev_sym}); 219 | auto res_sym = _sym("res", res); 220 | auto pair_op = _op( 221 | _iter(0, iperiod), 222 | Params{in, beat}, 223 | SymTable{ 224 | {st_sym, st}, 225 | {sv_sym, sv}, 226 | {et_sym, et}, 227 | {ev_sym, ev}, 228 | {res_sym, res}, 229 | }, 230 | _exists(sv_sym) && _exists(ev_sym), 231 | res_sym); 232 | return pair_op; 233 | } 234 | 235 | Op _Interpolate(_sym in, int64_t operiod) 236 | { 237 | auto e = in[_pt(0)]; 238 | auto e_sym = _sym("e", e); 239 | auto beat = _beat(_iter(0, operiod)); 240 | auto t = _cast(types::FLOAT32, beat[_pt(0)]); 241 | auto t_sym = _sym("t", t); 242 | auto st = e_sym << 0; 243 | auto sv = e_sym << 1; 244 | auto et = e_sym << 2; 245 | auto ev = e_sym << 3; 246 | auto res = (((ev - sv) * (t_sym - st)) / (et - st)) + sv; 247 | auto res_sym = _sym("res", res); 248 | auto inter_op = _op( 249 | _iter(0, operiod), 250 | Params{in, beat}, 251 | SymTable{ 252 | {e_sym, e}, 253 | {t_sym, t}, 254 | {res_sym, res}, 255 | }, 256 | _exists(e_sym), 257 | res_sym); 258 | return inter_op; 259 | } 260 | 261 | Op _Resample(string query_name, _sym in, int64_t iperiod, int64_t operiod) 262 | { 263 | auto win_size = lcm(iperiod, operiod); 264 | auto win = in[_win(-win_size, 0)]; 265 | auto win_sym = _sym("win", win); 266 | auto pair = _Pair(win_sym, iperiod); 267 | auto pair_sym = _sym(query_name + "_pair", pair); 268 | auto inter = _Interpolate(pair_sym, operiod); 269 | auto inter_sym = _sym(query_name + "_inter", inter); 270 | auto resample_op = _op( 271 | _iter(0, win_size), 272 | Params{in}, 273 | SymTable{ 274 | {win_sym, win}, 275 | {pair_sym, pair}, 276 | {inter_sym, inter}, 277 | }, 278 | _true(), 279 | inter_sym); 280 | return resample_op; 281 | } 282 | -------------------------------------------------------------------------------- /src/pass/printer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "tilt/pass/printer.h" 4 | #include "tilt/builder/tilder.h" 5 | 6 | using namespace tilt; 7 | using namespace tilt::tilder; 8 | using namespace std; 9 | 10 | static const auto EXISTS = "\u2203"; 11 | static const auto FORALL = "\u2200"; 12 | static const auto IN = "\u2208"; 13 | static const auto PHI = "\u0278"; 14 | 15 | string idx_str(int64_t idx) 16 | { 17 | ostringstream ostr; 18 | if (idx > 0) { 19 | ostr << " + " << idx; 20 | } else if (idx < 0) { 21 | ostr << " - " << -idx; 22 | } 23 | return ostr.str(); 24 | } 25 | 26 | void IRPrinter::Visit(const Symbol& sym) 27 | { 28 | if (!sym.type.is_val()) { ostr << "~"; } 29 | ostr << sym.name; 30 | } 31 | 32 | void IRPrinter::Visit(const Out& out) { Visit(static_cast(out)); } 33 | 34 | void IRPrinter::Visit(const Beat& beat) { Visit(static_cast(beat)); } 35 | 36 | void IRPrinter::Visit(const Exists& exists) 37 | { 38 | emitunary(EXISTS, exists.sym); 39 | } 40 | 41 | void IRPrinter::Visit(const ConstNode& cnst) 42 | { 43 | switch (cnst.type.dtype.btype) { 44 | case BaseType::BOOL: ostr << (cnst.val ? "true" : "false"); break; 45 | case BaseType::INT8: 46 | case BaseType::INT16: 47 | case BaseType::INT32: 48 | case BaseType::INT64: ostr << cnst.val << "i"; break; 49 | case BaseType::UINT8: 50 | case BaseType::UINT16: 51 | case BaseType::UINT32: 52 | case BaseType::UINT64: ostr << cnst.val << "u"; break; 53 | case BaseType::FLOAT32: 54 | case BaseType::FLOAT64: ostr << cnst.val << "f"; break; 55 | case BaseType::TIME: ostr << cnst.val << "t"; break; 56 | case BaseType::INDEX: ostr << cnst.val << "x"; break; 57 | default: throw std::runtime_error("Invalid constant type"); 58 | } 59 | } 60 | 61 | void IRPrinter::Visit(const Cast& e) 62 | { 63 | string destty; 64 | switch (e.type.dtype.btype) { 65 | case BaseType::INT8: destty = "int8"; break; 66 | case BaseType::INT16: destty = "int16"; break; 67 | case BaseType::INT32: destty = "int32"; break; 68 | case BaseType::INT64: destty = "long"; break; 69 | case BaseType::UINT8: destty = "uint8"; break; 70 | case BaseType::UINT16: destty = "uint16"; break; 71 | case BaseType::UINT32: destty = "uint32"; break; 72 | case BaseType::UINT64: destty = "ulong"; break; 73 | case BaseType::FLOAT32: destty = "float"; break; 74 | case BaseType::FLOAT64: destty = "double"; break; 75 | case BaseType::BOOL: destty = "bool"; break; 76 | case BaseType::TIME: destty = "ts"; break; 77 | case BaseType::INDEX: destty = "idx"; break; 78 | default: throw std::runtime_error("Invalid destination type for cast"); 79 | } 80 | 81 | ostr << "(" << destty << ") "; 82 | e.arg->Accept(*this); 83 | } 84 | 85 | void IRPrinter::Visit(const NaryExpr& e) 86 | { 87 | switch (e.op) { 88 | case MathOp::ADD: emitbinary(e.arg(0), "+", e.arg(1)); break; 89 | case MathOp::SUB: emitbinary(e.arg(0), "-", e.arg(1)); break; 90 | case MathOp::MUL: emitbinary(e.arg(0), "*", e.arg(1)); break; 91 | case MathOp::DIV: emitbinary(e.arg(0), "/", e.arg(1)); break; 92 | case MathOp::MAX: emitfunc("max", {e.arg(0), e.arg(1)}); break; 93 | case MathOp::MIN: emitfunc("min", {e.arg(0), e.arg(1)}); break; 94 | case MathOp::MOD: emitbinary(e.arg(0), "%", e.arg(1)); break; 95 | case MathOp::ABS: ostr << "|"; e.arg(0)->Accept(*this); ostr << "|"; break; 96 | case MathOp::NEG: emitunary("-", {e.arg(0)}); break; 97 | case MathOp::SQRT: emitfunc("sqrt", {e.arg(0)}); break; 98 | case MathOp::POW: emitfunc("pow", {e.arg(0), e.arg(1)}); break; 99 | case MathOp::CEIL: emitfunc("ceil", {e.arg(0)}); break; 100 | case MathOp::FLOOR: emitfunc("floor", {e.arg(0)}); break; 101 | case MathOp::EQ: emitbinary(e.arg(0), "==", e.arg(1)); break; 102 | case MathOp::NOT: emitunary("!", e.arg(0)); break; 103 | case MathOp::AND: emitbinary(e.arg(0), "&&", e.arg(1)); break; 104 | case MathOp::OR: emitbinary(e.arg(0), "||", e.arg(1)); break; 105 | case MathOp::LT: emitbinary(e.arg(0), "<", e.arg(1)); break; 106 | case MathOp::LTE: emitbinary(e.arg(0), "<=", e.arg(1)); break; 107 | case MathOp::GT: emitbinary(e.arg(0), ">", e.arg(1)); break; 108 | case MathOp::GTE: emitbinary(e.arg(0), ">=", e.arg(1)); break; 109 | default: throw std::runtime_error("Invalid math operation"); 110 | } 111 | } 112 | 113 | void IRPrinter::Visit(const SubLStream& subls) 114 | { 115 | subls.lstream->Accept(*this); 116 | ostr << "[t" << ctx.nesting << idx_str(subls.win.start.offset) 117 | << " : t" << ctx.nesting << idx_str(subls.win.end.offset) << "]"; 118 | } 119 | 120 | void IRPrinter::Visit(const Element& elem) 121 | { 122 | elem.lstream->Accept(*this); 123 | ostr << "[t" << ctx.nesting; 124 | ostr << idx_str(elem.pt.offset); 125 | ostr << "]"; 126 | } 127 | 128 | void IRPrinter::Visit(const OpNode& op) 129 | { 130 | enter_op(); 131 | ostr << FORALL << "t" << ctx.nesting << " " << IN << " " << op.iter.str() << " "; 132 | 133 | ostr << "["; 134 | for (auto in : op.inputs) { 135 | in->Accept(*this); 136 | ostr << "; "; 137 | } 138 | ostr << "] {"; 139 | 140 | enter_block(); 141 | for (auto& it : op.syms) { 142 | it.first->Accept(*this); 143 | ostr << " = "; 144 | it.second->Accept(*this); 145 | emitnewline(); 146 | } 147 | ostr << "return "; 148 | op.pred->Accept(*this); 149 | ostr << " ? "; 150 | op.output->Accept(*this); 151 | ostr << " : " << PHI; 152 | exit_block(); 153 | 154 | ostr << "}"; 155 | exit_op(); 156 | } 157 | 158 | void IRPrinter::Visit(const Reduce& red) 159 | { 160 | auto state_str = IRPrinter::Build(red.state); 161 | 162 | auto e = _elem(red.lstream, _pt(0)); 163 | auto e_sym = _sym("e", e); 164 | auto state_init_sym = _sym("state = " + state_str, red.state); 165 | auto state_sym = _sym("state", red.state); 166 | auto t_name = "t" + to_string(ctx.nesting + 1); 167 | auto t = _sym(t_name, Type(types::TIME)); 168 | auto t_base = _sym("^" + t_name, Type(types::TIME)); 169 | auto res = red.acc(state_sym, t_base, t, e_sym); 170 | auto red_op = _op( 171 | _iter(0, 1), 172 | Params{red.lstream, state_init_sym}, 173 | SymTable{ 174 | {e_sym, e}, 175 | {state_sym, res}, 176 | }, 177 | _exists(e_sym), 178 | state_sym); 179 | red_op->Accept(*this); 180 | } 181 | 182 | void IRPrinter::Visit(const Fetch& fetch) 183 | { 184 | emitfunc("fetch", { fetch.reg, fetch.time, fetch.idx }); 185 | } 186 | 187 | void IRPrinter::Visit(const Read& read) 188 | { 189 | emitunary("*", read.ptr); 190 | } 191 | 192 | void IRPrinter::Visit(const Write& write) 193 | { 194 | emitfunc("write", { write.reg, write.ptr, write.data }); 195 | } 196 | 197 | void IRPrinter::Visit(const Advance& adv) 198 | { 199 | emitfunc("advance", { adv.reg, adv.idx, adv.time }); 200 | } 201 | 202 | void IRPrinter::Visit(const GetCkpt& next) 203 | { 204 | emitfunc("get_ckpt", { next.reg, next.time, next.idx }); 205 | } 206 | 207 | void IRPrinter::Visit(const GetStartIdx& gsi) 208 | { 209 | emitfunc("get_start_idx", { gsi.reg }); 210 | } 211 | 212 | void IRPrinter::Visit(const GetEndIdx& gei) 213 | { 214 | emitfunc("get_end_idx", { gei.reg }); 215 | } 216 | 217 | void IRPrinter::Visit(const GetStartTime& gst) 218 | { 219 | emitfunc("get_start_time", { gst.reg }); 220 | } 221 | 222 | void IRPrinter::Visit(const GetEndTime& get) 223 | { 224 | emitfunc("get_end_time", { get.reg }); 225 | } 226 | 227 | void IRPrinter::Visit(const CommitData& commit) 228 | { 229 | emitfunc("commit_data", { commit.reg, commit.time }); 230 | } 231 | 232 | void IRPrinter::Visit(const CommitNull& commit) 233 | { 234 | emitfunc("commit_null", { commit.reg, commit.time }); 235 | } 236 | 237 | void IRPrinter::Visit(const AllocRegion& alloc_reg) 238 | { 239 | emitfunc("alloc_region", { alloc_reg.size }); 240 | } 241 | 242 | void IRPrinter::Visit(const MakeRegion& mr) 243 | { 244 | emitfunc("make_region", { mr.reg, mr.st, mr.si, mr.et, mr.ei }); 245 | } 246 | 247 | void IRPrinter::Visit(const Call& call) 248 | { 249 | emitfunc(call.name, call.args); 250 | } 251 | 252 | void IRPrinter::Visit(const IfElse& ifelse) 253 | { 254 | ifelse.cond->Accept(*this); 255 | ostr << " ? "; 256 | ifelse.true_body->Accept(*this); 257 | ostr << " : "; 258 | ifelse.false_body->Accept(*this); 259 | } 260 | 261 | void IRPrinter::Visit(const Select& select) 262 | { 263 | ostr << "("; 264 | select.cond->Accept(*this); 265 | ostr << " ? "; 266 | select.true_body->Accept(*this); 267 | ostr << " : "; 268 | select.false_body->Accept(*this); 269 | ostr << ")"; 270 | } 271 | 272 | void IRPrinter::Visit(const Get& get) 273 | { 274 | emitfunc("get", {get.input, _u64(get.n)}); 275 | } 276 | 277 | void IRPrinter::Visit(const New& _new) 278 | { 279 | emitfunc("new", _new.inputs); 280 | } 281 | 282 | void IRPrinter::Visit(const LoopNode& loop) 283 | { 284 | for (const auto& inner_loop : loop.inner_loops) 285 | { 286 | inner_loop->Accept(*this); 287 | } 288 | 289 | vector args; 290 | args.insert(args.end(), loop.inputs.begin(), loop.inputs.end()); 291 | emitfunc(loop.get_name(), args); 292 | emitnewline(); 293 | 294 | ostr << "{"; 295 | enter_block(); 296 | 297 | emitcomment("initialization"); 298 | emitnewline(); 299 | unordered_set bases; 300 | for (const auto& [_, base] : loop.state_bases) { 301 | emitassign(base, loop.syms.at(base)); 302 | bases.insert(base); 303 | emitnewline(); 304 | } 305 | emitnewline(); 306 | 307 | ostr << "while(1) {"; 308 | enter_block(); 309 | 310 | emitcomment("loop condition check"); 311 | emitnewline(); 312 | ostr << "if ("; 313 | loop.exit_cond->Accept(*this); 314 | ostr << ") break;"; 315 | emitnewline(); 316 | emitnewline(); 317 | 318 | emitcomment("update indices"); 319 | emitnewline(); 320 | for (const auto& idx : loop.idxs) { 321 | emitassign(idx, loop.syms.at(idx)); 322 | emitnewline(); 323 | } 324 | emitnewline(); 325 | 326 | emitcomment("update timer"); 327 | emitnewline(); 328 | emitassign(loop.t, loop.syms.at(loop.t)); 329 | emitnewline(); 330 | emitnewline(); 331 | 332 | emitcomment("set local variables"); 333 | emitnewline(); 334 | for (const auto& [sym, expr] : loop.syms) { 335 | if (bases.find(sym) == bases.end() && 336 | loop.state_bases.find(sym) == loop.state_bases.end()) { 337 | emitassign(sym, expr); 338 | emitnewline(); 339 | } 340 | } 341 | emitnewline(); 342 | 343 | emitcomment("loop body"); 344 | emitnewline(); 345 | emitassign(loop.output, loop.syms.at(loop.output)); 346 | emitnewline(); 347 | emitnewline(); 348 | 349 | emitcomment("Update states"); 350 | for (const auto& [var, base] : loop.state_bases) { 351 | emitnewline(); 352 | emitassign(base, var); 353 | } 354 | 355 | exit_block(); 356 | ostr << "}"; 357 | 358 | emitnewline(); 359 | emitnewline(); 360 | ostr << "return "; 361 | loop.state_bases.at(loop.output)->Accept(*this); 362 | ostr << ";"; 363 | 364 | exit_block(); 365 | ostr << "}"; 366 | emitnewline(); 367 | emitnewline(); 368 | } 369 | 370 | string IRPrinter::Build(const Expr expr) 371 | { 372 | IRPrinter printer; 373 | expr->Accept(printer); 374 | return printer.ostr.str(); 375 | } 376 | 377 | string IRPrinter::Build(const llvm::Module* mod) 378 | { 379 | std::string str; 380 | llvm::raw_string_ostream ostr(str); 381 | ostr << *mod; 382 | ostr.flush(); 383 | return ostr.str(); 384 | } 385 | -------------------------------------------------------------------------------- /src/pass/codegen/loopgen.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "tilt/pass/codegen/loopgen.h" 5 | 6 | using namespace tilt; 7 | using namespace tilt::tilder; 8 | using namespace std; 9 | 10 | Expr LoopGen::get_timer(const Point pt, bool use_base = false) 11 | { 12 | Expr t = nullptr; 13 | if (use_base) { 14 | auto t_base = ctx().loop->state_bases[ctx().loop->t]; 15 | t = _add(t_base, _ts(ctx().op->iter.period)); 16 | } else { 17 | t = ctx().loop->t; 18 | } 19 | return _add(t, _ts(pt.offset)); 20 | } 21 | 22 | Expr get_beat_idx(Sym reg, Expr time) 23 | { 24 | auto period = _ts(reg->type.iter.period); 25 | auto offset = _ts(reg->type.iter.offset + 1); 26 | return _add(_cast(types::INDEX, _div(_sub(time, offset), period)), _idx(1)); 27 | } 28 | 29 | Expr get_beat_time(Sym reg, Expr idx) 30 | { 31 | auto period = _ts(reg->type.iter.period); 32 | return _mul(_cast(types::TIME, idx), period); 33 | } 34 | 35 | Index& LoopGen::get_idx(const Sym reg, const Point pt) 36 | { 37 | auto& pt_idx_map = ctx().pt_idx_maps[reg]; 38 | if (pt_idx_map.find(pt) == pt_idx_map.end()) { 39 | auto time = get_timer(pt, true); 40 | auto idx = _index("i" + to_string(pt.offset) + "_" + reg->name); 41 | Expr next_ckpt = nullptr; 42 | 43 | if (reg->type.is_beat()) { 44 | // Index updater 45 | set_expr(idx, get_beat_idx(reg, time)); 46 | 47 | // Index shift expression 48 | next_ckpt = get_beat_time(reg, idx); 49 | } else { 50 | auto idx_base = _index("i" + to_string(pt.offset) + "_" + reg->name + "_base"); 51 | 52 | // Index initializer 53 | set_expr(idx_base, _get_start_idx(reg)); 54 | ctx().loop->state_bases[idx] = idx_base; 55 | 56 | // Index updater 57 | auto adv_expr = _adv(reg, idx_base, time); 58 | set_expr(idx, adv_expr); 59 | 60 | // Index shift expression 61 | next_ckpt = _get_ckpt(reg, time, idx); 62 | } 63 | 64 | pt_idx_map[pt] = idx; 65 | ctx().loop->idxs.push_back(idx); 66 | if (!reg->type.is_out()) { 67 | ctx().idx_diff_map[idx] = _sub(next_ckpt, time); 68 | } 69 | } 70 | 71 | return pt_idx_map[pt]; 72 | } 73 | 74 | void LoopGen::build_tloop(function true_body, function false_body) 75 | { 76 | auto loop = ctx().loop; 77 | auto name = loop->name; 78 | 79 | // Loop function arguments 80 | auto t_start = _time("t_start"); 81 | auto t_end = _time("t_end"); 82 | auto out_arg = _sym(name, ctx().loop->type); 83 | loop->inputs.push_back(t_start); 84 | loop->inputs.push_back(t_end); 85 | loop->inputs.push_back(out_arg); 86 | for (auto& in : ctx().op->inputs) { 87 | auto in_reg = _sym(in->name, in->type); 88 | set_sym(in, in_reg); 89 | if (!in_reg->type.is_beat()) { 90 | loop->inputs.push_back(in_reg); 91 | } 92 | } 93 | 94 | // Create loop counter 95 | auto t_base = _time("t_base"); 96 | set_expr(t_base, t_start); 97 | loop->t = _time("t"); 98 | loop->state_bases[loop->t] = t_base; 99 | 100 | // Loop exit condition 101 | loop->exit_cond = _eq(t_base, t_end); 102 | 103 | // Create loop return value 104 | auto output_base = _sym("output_base", ctx().loop->type); 105 | set_expr(output_base, out_arg); 106 | loop->output = _sym("output", ctx().loop->type); 107 | loop->state_bases[loop->output] = output_base; 108 | 109 | // Evaluate loop body 110 | auto pred_expr = eval(ctx().op->pred); 111 | eval(ctx().op->output); 112 | 113 | // Loop counter update expression 114 | Expr delta = nullptr; 115 | for (const auto& [idx, diff_expr] : ctx().idx_diff_map) { 116 | if (delta) { 117 | delta = _min(delta, diff_expr); 118 | } else { 119 | delta = diff_expr; 120 | } 121 | } 122 | auto t_period = _ts(ctx().op->iter.period); 123 | auto t_incr = _mul(_div(delta, t_period), t_period); 124 | set_expr(loop->t, _min(t_end, _add(get_timer(_pt(0), true), t_incr))); 125 | 126 | // Create loop output 127 | set_expr(loop->output, _ifelse(pred_expr, true_body(), false_body())); 128 | } 129 | 130 | void LoopGen::build_loop() 131 | { 132 | auto true_body = [&]() -> Expr { 133 | auto loop = ctx().loop; 134 | auto output_base = loop->state_bases[loop->output]; 135 | auto out_expr = get_sym(ctx().op->output); 136 | 137 | // Update loop output: 138 | // 1. Outer loop returns the output region of the inner loop 139 | // 2. Inner loop updates the output region 140 | if (out_expr->type.is_val()) { 141 | auto new_reg = _commit_data(output_base, loop->t); 142 | auto new_reg_sym = _sym("new_reg", new_reg); 143 | set_expr(new_reg_sym, new_reg); 144 | 145 | auto idx = _get_end_idx(new_reg_sym); 146 | auto dptr = _fetch(new_reg_sym, loop->t, idx); 147 | return _write(new_reg_sym, dptr, out_expr); 148 | } else { 149 | return out_expr; 150 | } 151 | }; 152 | 153 | auto false_body = [&]() -> Expr { 154 | auto loop = ctx().loop; 155 | auto output_base = loop->state_bases[loop->output]; 156 | return _commit_null(output_base, loop->t); 157 | }; 158 | 159 | build_tloop(true_body, false_body); 160 | } 161 | 162 | Expr LoopGen::visit(const Symbol& symbol) { return get_sym(symbol); } 163 | 164 | Expr LoopGen::visit(const Out& out) 165 | { 166 | auto out_reg = ctx().loop->inputs[2]; 167 | set_sym(out, out_reg); 168 | return out_reg; 169 | } 170 | 171 | Expr LoopGen::visit(const Beat& beat) { return _beat(beat); } 172 | 173 | Expr LoopGen::visit(const IfElse& ifelse) 174 | { 175 | auto cond = eval(ifelse.cond); 176 | auto true_body = eval(ifelse.true_body); 177 | auto false_body = eval(ifelse.false_body); 178 | return _ifelse(cond, true_body, false_body); 179 | } 180 | 181 | Expr LoopGen::visit(const Select& select) 182 | { 183 | auto cond = eval(select.cond); 184 | auto true_body = eval(select.true_body); 185 | auto false_body = eval(select.false_body); 186 | return _sel(cond, true_body, false_body); 187 | } 188 | 189 | Expr LoopGen::visit(const Call& call) 190 | { 191 | vector args; 192 | for (const auto& arg : call.args) { 193 | args.push_back(eval(arg)); 194 | } 195 | return _call(call.name, call.type, std::move(args)); 196 | } 197 | 198 | Expr LoopGen::visit(const Exists& exists) 199 | { 200 | eval(exists.sym); 201 | auto s = get_sym(exists.sym); 202 | if (s->type.is_beat() || s->type == Type(types::TIME)) { 203 | return _true(); 204 | } else if (s->type.is_val()) { 205 | auto ptr_sym = get_ref(exists.sym); 206 | return _exists(ptr_sym); 207 | } else { 208 | auto si = _get_start_idx(s); 209 | auto ei = _get_end_idx(s); 210 | auto st = _get_start_time(s); 211 | auto win_ptr = _fetch(s, st, si); 212 | auto win_ptr_sym = _sym(s->name + "_ptr", win_ptr); 213 | set_expr(win_ptr_sym, win_ptr); 214 | return _or(_not(_eq(si, ei)), _exists(win_ptr_sym)); 215 | } 216 | } 217 | 218 | Expr LoopGen::visit(const New& new_expr) 219 | { 220 | vector input_vals; 221 | for (const auto& input : new_expr.inputs) { 222 | input_vals.push_back(eval(input)); 223 | } 224 | return _new(std::move(input_vals)); 225 | } 226 | 227 | Expr LoopGen::visit(const Get& get) { return _get(eval(get.input), get.n); } 228 | 229 | Expr LoopGen::visit(const ConstNode& cnst) { return _const(cnst); } 230 | 231 | Expr LoopGen::visit(const Cast& e) { return _cast(e.type.dtype, eval(e.arg)); } 232 | 233 | Expr LoopGen::visit(const NaryExpr& e) 234 | { 235 | vector args; 236 | for (auto arg : e.args) { 237 | args.push_back(eval(arg)); 238 | } 239 | return make_shared(e.type.dtype, e.op, std::move(args)); 240 | } 241 | 242 | Expr LoopGen::visit(const SubLStream& subls) 243 | { 244 | eval(subls.lstream); 245 | auto& reg = get_sym(subls.lstream); 246 | if (reg->type.is_beat()) { 247 | return reg; 248 | } else { 249 | auto st = get_timer(subls.win.start, subls.lstream->type.is_out()); 250 | auto& si = get_idx(reg, subls.win.start); 251 | auto et = get_timer(subls.win.end, subls.lstream->type.is_out()); 252 | auto& ei = get_idx(reg, subls.win.end); 253 | return _make_reg(reg, st, si, et, ei); 254 | } 255 | } 256 | 257 | Expr LoopGen::visit(const Element& elem) 258 | { 259 | eval(elem.lstream); 260 | auto& reg = get_sym(elem.lstream); 261 | auto& idx = get_idx(reg, elem.pt); 262 | 263 | if (reg->type.is_beat()) { 264 | return get_beat_time(reg, idx); 265 | } else { 266 | auto time = get_timer(elem.pt, elem.lstream->type.is_out()); 267 | auto ptr = _fetch(reg, time, idx); 268 | auto ptr_sym = _sym(ctx().sym->name + "_ptr", ptr); 269 | set_expr(ptr_sym, ptr); 270 | set_ref(ctx().sym, ptr_sym); 271 | return _read(ptr_sym); 272 | } 273 | } 274 | 275 | Expr LoopGen::visit(const OpNode& op) 276 | { 277 | auto inner_op = &op; 278 | auto inner_loop = LoopGen::Build(ctx().sym, inner_op); 279 | auto outer_op = ctx().op; 280 | auto outer_loop = ctx().loop; 281 | 282 | outer_loop->inner_loops.push_back(inner_loop); 283 | 284 | auto t_end = get_timer(_pt(0)); 285 | auto t_start = get_timer(_pt(-outer_op->iter.period)); 286 | 287 | vector inputs; 288 | Val size_expr = _idx(1); 289 | for (const auto& input : inner_op->inputs) { 290 | auto input_val = eval(input); 291 | if (!input_val->type.is_beat()) { 292 | inputs.push_back(input_val); 293 | } 294 | if (!input_val->type.is_val()) { 295 | if (input_val->type.is_beat()) { 296 | auto period = _idx(inner_op->iter.period); 297 | auto beat = _idx(input_val->type.iter.period); 298 | size_expr = _add(size_expr, _div(period, beat)); 299 | } else { 300 | auto start = _get_start_idx(input_val); 301 | auto end = _get_end_idx(input_val); 302 | size_expr = _add(size_expr, _sub(end, start)); 303 | } 304 | } 305 | } 306 | 307 | Sym out_sym; 308 | if (outer_op->output == ctx().sym) { 309 | out_sym = outer_loop->state_bases[outer_loop->output]; 310 | } else if (outer_op->aux.find(ctx().sym) != outer_op->aux.end()) { 311 | out_sym = get_sym(outer_op->aux.at(ctx().sym)); 312 | } else { 313 | auto out_reg = _alloc_reg(op.type, size_expr, t_start); 314 | out_sym = _sym(ctx().sym->name + "_reg", out_reg); 315 | set_expr(out_sym, out_reg); 316 | } 317 | 318 | vector args = {t_start, t_end, out_sym}; 319 | for (const auto& input : inputs) { 320 | args.push_back(input); 321 | } 322 | return _call(inner_loop->get_name(), inner_loop->type, std::move(args)); 323 | } 324 | 325 | Expr LoopGen::visit(const Reduce& red) 326 | { 327 | auto e = _elem(red.lstream, _pt(0)); 328 | auto e_sym = _sym("e", e); 329 | auto red_op = _op( 330 | _iter(0, 1), 331 | Params{red.lstream}, 332 | SymTable{ 333 | {e_sym, e} 334 | }, 335 | _exists(e_sym), 336 | e_sym); 337 | 338 | auto red_loop = _loop(ctx().sym); 339 | LoopGenCtx new_ctx(ctx().sym, red_op.get(), red_loop); 340 | 341 | auto& old_ctx = switch_ctx(new_ctx); 342 | 343 | auto true_body = [&]() -> Expr { 344 | auto loop = ctx().loop; 345 | auto output_base = loop->state_bases[loop->output]; 346 | auto t = loop->t; 347 | auto t_base = loop->state_bases[t]; 348 | auto out_sym = get_sym(ctx().op->output); 349 | return eval(red.acc(output_base, t_base, t, out_sym)); 350 | }; 351 | 352 | auto false_body = [&]() -> Expr { 353 | auto loop = ctx().loop; 354 | auto output_base = loop->state_bases[loop->output]; 355 | return output_base; 356 | }; 357 | build_tloop(true_body, false_body); 358 | switch_ctx(old_ctx); 359 | 360 | auto outer_loop = ctx().loop; 361 | outer_loop->inner_loops.push_back(red_loop); 362 | 363 | auto red_input = eval(red.lstream); 364 | auto t_start = _get_start_time(red_input); 365 | auto t_end = _get_end_time(red_input); 366 | vector args = { t_start, t_end, eval(red.state), red_input }; 367 | return _call(red_loop->get_name(), red_loop->type, args); 368 | } 369 | 370 | Loop LoopGen::Build(Sym sym, const OpNode* op) 371 | { 372 | auto loop = _loop(sym); 373 | LoopGenCtx ctx(sym, op, loop); 374 | LoopGen loopgen(std::move(ctx)); 375 | loopgen.build_loop(); 376 | return loopgen.ctx().loop; 377 | } 378 | -------------------------------------------------------------------------------- /test/src/test_base.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "tilt/pass/codegen/loopgen.h" 8 | #include "tilt/pass/codegen/llvmgen.h" 9 | #include "tilt/pass/codegen/vinstr.h" 10 | #include "tilt/engine/engine.h" 11 | 12 | #include "test_base.h" 13 | 14 | using namespace tilt; 15 | using namespace tilt::tilder; 16 | 17 | void run_op(string query_name, Op op, ts_t st, ts_t et, region_t* out_reg, region_t* in_reg) 18 | { 19 | auto op_sym = _sym(query_name, op); 20 | auto loop = LoopGen::Build(op_sym, op.get()); 21 | 22 | auto jit = ExecEngine::Get(); 23 | auto& llctx = jit->GetCtx(); 24 | 25 | auto llmod = LLVMGen::Build(loop, llctx); 26 | jit->AddModule(std::move(llmod)); 27 | 28 | auto loop_addr = (region_t* (*)(ts_t, ts_t, region_t*, region_t*)) jit->Lookup(loop->get_name()); 29 | 30 | loop_addr(st, et, out_reg, in_reg); 31 | } 32 | 33 | template 34 | void op_test(string query_name, Op op, ts_t st, ts_t et, QueryFn query_fn, vector> input) 35 | { 36 | auto in_st = input[0].st; 37 | auto true_out = query_fn(input); 38 | 39 | region_t in_reg; 40 | auto in_tl = vector(input.size()); 41 | auto in_data = vector(input.size()); 42 | auto in_data_ptr = reinterpret_cast(in_data.data()); 43 | init_region(&in_reg, in_st, get_buf_size(input.size()), in_tl.data(), in_data_ptr); 44 | for (size_t i = 0; i < input.size(); i++) { 45 | auto t = input[i].et; 46 | commit_data(&in_reg, t); 47 | auto* ptr = reinterpret_cast(fetch(&in_reg, t, get_end_idx(&in_reg), sizeof(InTy))); 48 | *ptr = input[i].payload; 49 | } 50 | 51 | region_t out_reg; 52 | auto out_tl = vector(true_out.size()); 53 | auto out_data = vector(true_out.size()); 54 | auto out_data_ptr = reinterpret_cast(out_data.data()); 55 | init_region(&out_reg, st, get_buf_size(true_out.size()), out_tl.data(), out_data_ptr); 56 | 57 | run_op(query_name, op, st, et, &out_reg, &in_reg); 58 | 59 | for (size_t i = 0; i < true_out.size(); i++) { 60 | auto true_st = true_out[i].st; 61 | auto true_et = true_out[i].et; 62 | auto true_payload = true_out[i].payload; 63 | auto out_st = out_tl[i].t; 64 | auto out_et = out_st + out_tl[i].d; 65 | auto out_payload = out_data[i]; 66 | 67 | assert_eq(true_st, out_st); 68 | assert_eq(true_et, out_et); 69 | assert_eq(true_payload, out_payload); 70 | } 71 | } 72 | 73 | template 74 | void unary_op_test(string query_name, Op op, ts_t st, ts_t et, QueryFn query_fn, size_t len, int64_t dur) 75 | { 76 | std::srand(time(nullptr)); 77 | 78 | vector> input(len); 79 | for (size_t i = 0; i < len; i++) { 80 | int64_t st = dur * i; 81 | int64_t et = st + dur; 82 | InTy payload = static_cast(std::rand() / static_cast(RAND_MAX / 100000)); 83 | input[i] = {st, et, payload}; 84 | } 85 | 86 | op_test(query_name, op, st, et, query_fn, input); 87 | } 88 | 89 | template 90 | void select_test(string query_name, function sel_expr, function sel_fn) 91 | { 92 | size_t len = 1000; 93 | int64_t dur = 5; 94 | 95 | auto in_sym = _sym("in", tilt::Type(types::STRUCT(), _iter(0, -1))); 96 | auto sel_op = _Select(in_sym, sel_expr); 97 | 98 | auto sel_query_fn = [sel_fn] (vector> in) { 99 | vector> out; 100 | 101 | for (size_t i = 0; i < in.size(); i++) { 102 | out.push_back({in[i].st, in[i].et, sel_fn(in[i].payload)}); 103 | } 104 | 105 | return std::move(out); 106 | }; 107 | 108 | unary_op_test(query_name, sel_op, 0, len * dur, sel_query_fn, len, dur); 109 | } 110 | 111 | void add_test() 112 | { 113 | select_test("iadd", 114 | [] (Expr s) { return _add(s, _i32(10)); }, 115 | [] (int32_t s) { return s + 10; }); 116 | select_test("fadd", 117 | [] (Expr s) { return _add(s, _f32(5)); }, 118 | [] (float s) { return s + 5.0; }); 119 | } 120 | 121 | void sub_test() 122 | { 123 | select_test("isub", 124 | [] (Expr s) { return _sub(s, _i32(10)); }, 125 | [] (int32_t s) { return s - 10; }); 126 | select_test("fsub", 127 | [] (Expr s) { return _sub(s, _f32(15)); }, 128 | [] (float s) { return s - 15.0; }); 129 | } 130 | 131 | void mul_test() 132 | { 133 | select_test("imul", 134 | [] (Expr s) { return _mul(s, _i32(10)); }, 135 | [] (int32_t s) { return s * 10; }); 136 | select_test("fmul", 137 | [] (Expr s) { return _mul(s, _f32(10)); }, 138 | [] (float s) { return s * 10.0f; }); 139 | } 140 | 141 | void div_test() 142 | { 143 | select_test("idiv", 144 | [] (Expr s) { return _div(s, _i32(10)); }, 145 | [] (int32_t s) { return s / 10; }); 146 | select_test("udiv", 147 | [] (Expr s) { return _div(s, _u32(10)); }, 148 | [] (uint32_t s) { return s / 10u; }); 149 | select_test("fdiv", 150 | [] (Expr s) { return _div(s, _f32(10)); }, 151 | [] (float s) { return s / 10.0f; }); 152 | } 153 | 154 | void mod_test() 155 | { 156 | select_test("imod", 157 | [] (Expr s) { return _mod(s, _i32(10)); }, 158 | [] (int32_t s) { return s % 10; }); 159 | select_test("umod", 160 | [] (Expr s) { return _mod(s, _u32(10)); }, 161 | [] (uint32_t s) { return s % 10u; }); 162 | } 163 | 164 | void max_test() 165 | { 166 | select_test("imax", 167 | [] (Expr s) { return _max(s, _i32(10)); }, 168 | [] (int32_t s) { return std::max(s, 10); }); 169 | select_test("umax", 170 | [] (Expr s) { return _max(s, _u32(10)); }, 171 | [] (uint32_t s) { return std::max(s, 10u); }); 172 | select_test("fmax", 173 | [] (Expr s) { return _max(s, _f32(10)); }, 174 | [] (float s) { return std::max(s, 10.0f); }); 175 | } 176 | 177 | void min_test() 178 | { 179 | select_test("imin", 180 | [] (Expr s) { return _min(s, _i32(10)); }, 181 | [] (int32_t s) { return std::min(s, 10); }); 182 | select_test("umin", 183 | [] (Expr s) { return _min(s, _u32(10)); }, 184 | [] (uint32_t s) { return std::min(s, 10u); }); 185 | select_test("fmin", 186 | [] (Expr s) { return _min(s, _f32(10)); }, 187 | [] (float s) { return std::min(s, 10.0f); }); 188 | } 189 | 190 | void neg_test() 191 | { 192 | select_test("ineg", 193 | [] (Expr s) { return _neg(s); }, 194 | [] (int32_t s) { return -s; }); 195 | select_test("fneg", 196 | [] (Expr s) { return _neg(s); }, 197 | [] (float s) { return -s; }); 198 | select_test("dneg", 199 | [] (Expr s) { return _neg(s); }, 200 | [] (double s) { return -s; }); 201 | } 202 | 203 | void sqrt_test() 204 | { 205 | select_test("fsqrt", 206 | [] (Expr s) { return _sqrt(s); }, 207 | [] (float s) { return std::sqrt(s); }); 208 | select_test("dsqrt", 209 | [] (Expr s) { return _sqrt(s); }, 210 | [] (double s) { return std::sqrt(s); }); 211 | } 212 | 213 | void pow_test() 214 | { 215 | select_test("fpow", 216 | [] (Expr s) { return _pow(s, _f32(2)); }, 217 | [] (float s) { return std::pow(s, 2); }); 218 | select_test("dpow", 219 | [] (Expr s) { return _pow(s, _f64(2)); }, 220 | [] (double s) { return std::pow(s, 2); }); 221 | } 222 | 223 | void ceil_test() 224 | { 225 | select_test("fceil", 226 | [] (Expr s) { return _ceil(s); }, 227 | [] (float s) { return std::ceil(s); }); 228 | select_test("dceil", 229 | [] (Expr s) { return _ceil(s); }, 230 | [] (double s) { return std::ceil(s); }); 231 | } 232 | 233 | void floor_test() 234 | { 235 | select_test("ffloor", 236 | [] (Expr s) { return _floor(s); }, 237 | [] (float s) { return std::floor(s); }); 238 | select_test("dfloor", 239 | [] (Expr s) { return _floor(s); }, 240 | [] (double s) { return std::floor(s); }); 241 | } 242 | 243 | void abs_test() 244 | { 245 | select_test("fabs", 246 | [] (Expr s) { return _abs(s); }, 247 | [] (float s) { return std::abs(s); }); 248 | select_test("dabs", 249 | [] (Expr s) { return _abs(s); }, 250 | [] (double s) { return std::abs(s); }); 251 | select_test("iabs", 252 | [] (Expr s) { return _abs(s); }, 253 | [] (int32_t s) { return std::abs(s); }); 254 | } 255 | 256 | void cast_test() 257 | { 258 | select_test("sitofp", 259 | [] (Expr s) { return _cast(types::FLOAT32, s); }, 260 | [] (int32_t s) { return static_cast(s); }); 261 | select_test("uitofp", 262 | [] (Expr s) { return _cast(types::FLOAT32, s); }, 263 | [] (uint32_t s) { return static_cast(s); }); 264 | select_test("fptosi", 265 | [] (Expr s) { return _cast(types::INT32, s); }, 266 | [] (float s) { return static_cast(s); }); 267 | select_test("fptoui", 268 | [] (Expr s) { return _cast(types::UINT32, s); }, 269 | [] (float s) { return static_cast(s); }); 270 | select_test("int8toint32", 271 | [] (Expr s) { return _cast(types::INT32, s); }, 272 | [] (int8_t s) { return static_cast(s); }); 273 | } 274 | 275 | void moving_sum_test() 276 | { 277 | size_t len = 30; 278 | int64_t dur = 1; 279 | int64_t w = 10; 280 | 281 | auto in_sym = _sym("in", tilt::Type(types::INT32, _iter(0, -1))); 282 | auto mov_op = _MovingSum(in_sym, dur, w); 283 | 284 | auto mov_query_fn = [w] (vector> in) { 285 | vector> out(in.size()); 286 | 287 | for (int i = 0; i < in.size(); i++) { 288 | auto out_i = i - 1; 289 | auto tail_i = i - w; 290 | auto payload = in[i].payload 291 | - ((tail_i < 0) ? 0 : in[tail_i].payload) 292 | + ((out_i < 0) ? 0 : out[out_i].payload); 293 | out[i] = {in[i].st, in[i].et, payload}; 294 | } 295 | 296 | return std::move(out); 297 | }; 298 | 299 | unary_op_test("moving_sum", mov_op, 0, len * dur, mov_query_fn, len, dur); 300 | } 301 | 302 | void norm_test() 303 | { 304 | size_t len = 1000; 305 | int64_t dur = 1; 306 | int64_t w = 10; 307 | 308 | auto in_sym = _sym("in", tilt::Type(types::FLOAT32, _iter(0, -1))); 309 | auto norm_op = _Norm("norm", in_sym, w); 310 | 311 | auto norm_query_fn = [w] (vector> in) { 312 | vector> out(in.size()); 313 | size_t num_windows = in.size() / w; 314 | 315 | for (size_t i = 0; i < num_windows; i++) { 316 | float sum = 0.0, mean, variance = 0.0, std_dev; 317 | 318 | for (size_t j = 0; j < w; j++) { 319 | sum += in[i * w + j].payload; 320 | } 321 | mean = sum / w; 322 | for (size_t j = 0; j < w; j++) { 323 | variance += pow(in[i * w + j].payload - mean, 2); 324 | } 325 | std_dev = sqrt(variance / w); 326 | 327 | for (size_t j = 0; j < w; j++) { 328 | size_t idx = i * w + j; 329 | float z_score = (in[idx].payload - mean) / std_dev; 330 | out[idx] = {in[idx].st, in[idx].et, z_score}; 331 | } 332 | } 333 | 334 | return std::move(out); 335 | }; 336 | 337 | unary_op_test("norm", norm_op, 0, len * dur, norm_query_fn, len, dur); 338 | } 339 | 340 | void run_resample(string query_name, int64_t iperiod, int64_t operiod) 341 | { 342 | size_t len = 100; 343 | int64_t dur = iperiod; 344 | 345 | auto in_sym = _sym("in", tilt::Type(types::FLOAT32, _iter(0, -1))); 346 | auto resample_op = _Resample(query_name, in_sym, iperiod, operiod); 347 | 348 | auto resample_query_fn = [iperiod, operiod] (vector> in) { 349 | vector> out; 350 | 351 | for (size_t i = 1; i < in.size(); i++) { 352 | int64_t st = in[i-1].et; 353 | int64_t et = in[i].et; 354 | float sv = in[i-1].payload; 355 | float ev = in[i].payload; 356 | 357 | int64_t out_t = (st / operiod + 1) * operiod; 358 | for (; out_t <= et; out_t += operiod) { 359 | float payload = (((ev - sv) * (out_t - st)) / (et - st)) + sv; 360 | out.push_back({out_t - operiod, out_t, payload}); 361 | } 362 | } 363 | 364 | return std::move(out); 365 | }; 366 | 367 | unary_op_test(query_name, resample_op, 0, len * dur, resample_query_fn, len, dur); 368 | } 369 | 370 | void resample_test() 371 | { 372 | run_resample("up_sample1", 5, 4); 373 | run_resample("up_sample2", 6, 3); 374 | run_resample("down_sample1", 4, 5); 375 | run_resample("down_sample2", 3, 6); 376 | } 377 | -------------------------------------------------------------------------------- /src/pass/codegen/llvmgen.cpp: -------------------------------------------------------------------------------- 1 | #include "tilt/base/type.h" 2 | #include "tilt/pass/codegen/llvmgen.h" 3 | 4 | #include "llvm/IR/Function.h" 5 | #include "llvm/IR/DataLayout.h" 6 | #include "llvm/IR/InstrTypes.h" 7 | 8 | using namespace tilt; 9 | using namespace llvm; 10 | 11 | Function* LLVMGen::llfunc(const string name, llvm::Type* ret_type, vector arg_types) 12 | { 13 | auto fn_type = FunctionType::get(ret_type, arg_types, false); 14 | return Function::Create(fn_type, Function::ExternalLinkage, name, llmod()); 15 | } 16 | 17 | Value* LLVMGen::llcall(const string name, llvm::Type* ret_type, vector arg_vals) 18 | { 19 | vector arg_types; 20 | for (const auto& arg_val : arg_vals) { 21 | arg_types.push_back(arg_val->getType()); 22 | } 23 | 24 | auto fn_type = FunctionType::get(ret_type, arg_types, false); 25 | auto fn = llmod()->getOrInsertFunction(name, fn_type); 26 | return builder()->CreateCall(fn, arg_vals); 27 | } 28 | 29 | Value* LLVMGen::llcall(const string name, llvm::Type* ret_type, vector args) 30 | { 31 | vector arg_vals; 32 | for (const auto& arg : args) { 33 | arg_vals.push_back(eval(arg)); 34 | } 35 | 36 | return llcall(name, ret_type, arg_vals); 37 | } 38 | 39 | Value* LLVMGen::llsizeof(llvm::Type* type) 40 | { 41 | auto size = llmod()->getDataLayout().getTypeSizeInBits(type).getFixedSize(); 42 | ASSERT(size % 8 == 0); 43 | return ConstantInt::get(lltype(types::UINT32), size/8); 44 | } 45 | 46 | llvm::Type* LLVMGen::lltype(const DataType& dtype) 47 | { 48 | switch (dtype.btype) { 49 | case BaseType::BOOL: 50 | return llvm::Type::getInt1Ty(llctx()); 51 | case BaseType::INT8: 52 | case BaseType::UINT8: 53 | return llvm::Type::getInt8Ty(llctx()); 54 | case BaseType::INT16: 55 | case BaseType::UINT16: 56 | return llvm::Type::getInt16Ty(llctx()); 57 | case BaseType::INT32: 58 | case BaseType::UINT32: 59 | return llvm::Type::getInt32Ty(llctx()); 60 | case BaseType::INT64: 61 | case BaseType::UINT64: 62 | return llvm::Type::getInt64Ty(llctx()); 63 | case BaseType::FLOAT32: 64 | return llvm::Type::getFloatTy(llctx()); 65 | case BaseType::FLOAT64: 66 | return llvm::Type::getDoubleTy(llctx()); 67 | case BaseType::TIME: 68 | return lltype(DataType(types::Converter::btype)); 69 | case BaseType::INDEX: 70 | return lltype(DataType(types::Converter::btype)); 71 | case BaseType::IVAL: 72 | return StructType::getTypeByName(llctx(), "struct.ival_t"); 73 | case BaseType::STRUCT: { 74 | vector lltypes; 75 | for (auto dt : dtype.dtypes) { 76 | lltypes.push_back(lltype(dt)); 77 | } 78 | return StructType::get(llctx(), lltypes); 79 | } 80 | case BaseType::PTR: 81 | return PointerType::get(lltype(dtype.dtypes[0]), 0); 82 | case BaseType::UNKNOWN: 83 | default: 84 | throw std::runtime_error("Invalid type"); 85 | } 86 | } 87 | 88 | llvm::Type* LLVMGen::lltype(const Type& type) 89 | { 90 | if (type.is_val()) { 91 | return lltype(type.dtype); 92 | } else { 93 | return llregptrtype(); 94 | } 95 | } 96 | 97 | Value* LLVMGen::visit(const Symbol& symbol) { return get_expr(get_sym(symbol)); } 98 | 99 | Value* LLVMGen::visit(const IfElse& ifelse) 100 | { 101 | auto loop_fn = builder()->GetInsertBlock()->getParent(); 102 | auto then_bb = BasicBlock::Create(llctx(), "then"); 103 | auto else_bb = BasicBlock::Create(llctx(), "else"); 104 | auto merge_bb = BasicBlock::Create(llctx(), "merge"); 105 | 106 | // Condition check 107 | auto cond = eval(ifelse.cond); 108 | builder()->CreateCondBr(cond, then_bb, else_bb); 109 | 110 | // Then block 111 | loop_fn->getBasicBlockList().push_back(then_bb); 112 | builder()->SetInsertPoint(then_bb); 113 | auto true_val = eval(ifelse.true_body); 114 | then_bb = builder()->GetInsertBlock(); 115 | builder()->CreateBr(merge_bb); 116 | 117 | // Else block 118 | loop_fn->getBasicBlockList().push_back(else_bb); 119 | builder()->SetInsertPoint(else_bb); 120 | auto false_val = eval(ifelse.false_body); 121 | else_bb = builder()->GetInsertBlock(); 122 | builder()->CreateBr(merge_bb); 123 | 124 | // Merge block 125 | loop_fn->getBasicBlockList().push_back(merge_bb); 126 | builder()->SetInsertPoint(merge_bb); 127 | auto merge_phi = builder()->CreatePHI(lltype(ifelse), 2); 128 | merge_phi->addIncoming(true_val, then_bb); 129 | merge_phi->addIncoming(false_val, else_bb); 130 | 131 | return merge_phi; 132 | } 133 | 134 | Value* LLVMGen::visit(const Select& select) 135 | { 136 | auto cond = eval(select.cond); 137 | auto true_val = eval(select.true_body); 138 | auto false_val = eval(select.false_body); 139 | return builder()->CreateSelect(cond, true_val, false_val); 140 | } 141 | 142 | Value* LLVMGen::visit(const Get& get) 143 | { 144 | auto input = eval(get.input); 145 | return builder()->CreateExtractValue(input, get.n); 146 | } 147 | 148 | Value* LLVMGen::visit(const New& _new) 149 | { 150 | auto new_type = lltype(_new); 151 | auto ptr = builder()->CreateAlloca(new_type); 152 | 153 | for (size_t i = 0; i < _new.inputs.size(); i++) { 154 | auto val_ptr = builder()->CreateStructGEP(new_type, ptr, i); 155 | builder()->CreateStore(eval(_new.inputs[i]), val_ptr); 156 | } 157 | 158 | return builder()->CreateLoad(new_type, ptr); 159 | } 160 | 161 | Value* LLVMGen::visit(const ConstNode& cnst) 162 | { 163 | switch (cnst.type.dtype.btype) { 164 | case BaseType::BOOL: 165 | case BaseType::INT8: 166 | case BaseType::INT16: 167 | case BaseType::INT32: 168 | case BaseType::INT64: 169 | case BaseType::UINT8: 170 | case BaseType::UINT16: 171 | case BaseType::UINT32: 172 | case BaseType::UINT64: 173 | case BaseType::TIME: 174 | case BaseType::INDEX: return ConstantInt::get(lltype(cnst), cnst.val); 175 | case BaseType::FLOAT32: 176 | case BaseType::FLOAT64: return ConstantFP::get(lltype(cnst), cnst.val); 177 | default: throw std::runtime_error("Invalid constant type"); break; 178 | } 179 | } 180 | 181 | Value* LLVMGen::visit(const Cast& e) 182 | { 183 | auto input_val = eval(e.arg); 184 | auto dest_type = lltype(e); 185 | auto op = CastInst::getCastOpcode(input_val, e.arg->type.dtype.is_signed(), dest_type, e.type.dtype.is_signed()); 186 | return builder()->CreateCast(op, input_val, dest_type); 187 | } 188 | 189 | Value* LLVMGen::visit(const NaryExpr& e) 190 | { 191 | switch (e.op) { 192 | case MathOp::ADD: { 193 | if (e.type.dtype.is_float()) { 194 | return builder()->CreateFAdd(eval(e.arg(0)), eval(e.arg(1))); 195 | } else { 196 | return builder()->CreateAdd(eval(e.arg(0)), eval(e.arg(1))); 197 | } 198 | } 199 | case MathOp::SUB: { 200 | if (e.type.dtype.is_float()) { 201 | return builder()->CreateFSub(eval(e.arg(0)), eval(e.arg(1))); 202 | } else { 203 | return builder()->CreateSub(eval(e.arg(0)), eval(e.arg(1))); 204 | } 205 | } 206 | case MathOp::MUL: { 207 | if (e.type.dtype.is_float()) { 208 | return builder()->CreateFMul(eval(e.arg(0)), eval(e.arg(1))); 209 | } else { 210 | return builder()->CreateMul(eval(e.arg(0)), eval(e.arg(1))); 211 | } 212 | } 213 | case MathOp::DIV: { 214 | if (e.type.dtype.is_float()) { 215 | return builder()->CreateFDiv(eval(e.arg(0)), eval(e.arg(1))); 216 | } else if (e.type.dtype.is_signed()) { 217 | return builder()->CreateSDiv(eval(e.arg(0)), eval(e.arg(1))); 218 | } else { 219 | return builder()->CreateUDiv(eval(e.arg(0)), eval(e.arg(1))); 220 | } 221 | } 222 | case MathOp::MAX: { 223 | auto left = eval(e.arg(0)); 224 | auto right = eval(e.arg(1)); 225 | 226 | Value* cond; 227 | if (e.type.dtype.is_float()) { 228 | cond = builder()->CreateFCmpOGE(left, right); 229 | } else if (e.type.dtype.is_signed()) { 230 | cond = builder()->CreateICmpSGE(left, right); 231 | } else { 232 | cond = builder()->CreateICmpUGE(left, right); 233 | } 234 | return builder()->CreateSelect(cond, left, right); 235 | } 236 | case MathOp::MIN: { 237 | auto left = eval(e.arg(0)); 238 | auto right = eval(e.arg(1)); 239 | 240 | Value* cond; 241 | if (e.type.dtype.is_float()) { 242 | cond = builder()->CreateFCmpOLE(left, right); 243 | } else if (e.type.dtype.is_signed()) { 244 | cond = builder()->CreateICmpSLE(left, right); 245 | } else { 246 | cond = builder()->CreateICmpULE(left, right); 247 | } 248 | return builder()->CreateSelect(cond, left, right); 249 | } 250 | case MathOp::MOD: { 251 | if (e.type.dtype.is_signed()) { 252 | return builder()->CreateSRem(eval(e.arg(0)), eval(e.arg(1))); 253 | } else { 254 | return builder()->CreateURem(eval(e.arg(0)), eval(e.arg(1))); 255 | } 256 | } 257 | case MathOp::ABS: { 258 | auto input = eval(e.arg(0)); 259 | 260 | if (e.type.dtype.is_float()) { 261 | return builder()->CreateIntrinsic(Intrinsic::fabs, {lltype(e.arg(0))}, {input}); 262 | } else { 263 | auto neg = builder()->CreateNeg(input); 264 | 265 | Value* cond; 266 | if (e.type.dtype.is_signed()) { 267 | cond = builder()->CreateICmpSGE(input, ConstantInt::get(lltype(types::INT32), 0)); 268 | } else { 269 | cond = builder()->CreateICmpUGE(input, ConstantInt::get(lltype(types::UINT32), 0)); 270 | } 271 | return builder()->CreateSelect(cond, input, neg); 272 | } 273 | } 274 | case MathOp::NEG: { 275 | if (e.type.dtype.is_float()) { 276 | return builder()->CreateFNeg(eval(e.arg(0))); 277 | } else { 278 | return builder()->CreateNeg(eval(e.arg(0))); 279 | } 280 | } 281 | case MathOp::SQRT: return builder()->CreateIntrinsic(Intrinsic::sqrt, {lltype(e.arg(0))}, {eval(e.arg(0))}); 282 | case MathOp::POW: return builder()->CreateIntrinsic( 283 | Intrinsic::pow, {lltype(e.arg(0))}, {eval(e.arg(0)), eval(e.arg(1))}); 284 | case MathOp::CEIL: return builder()->CreateIntrinsic(Intrinsic::ceil, {lltype(e.arg(0))}, {eval(e.arg(0))}); 285 | case MathOp::FLOOR: return builder()->CreateIntrinsic(Intrinsic::floor, {lltype(e.arg(0))}, {eval(e.arg(0))}); 286 | case MathOp::EQ: { 287 | if (e.arg(0)->type.dtype.is_float()) { 288 | return builder()->CreateFCmpOEQ(eval(e.arg(0)), eval(e.arg(1))); 289 | } else { 290 | return builder()->CreateICmpEQ(eval(e.arg(0)), eval(e.arg(1))); 291 | } 292 | } 293 | case MathOp::LT: { 294 | if (e.arg(0)->type.dtype.is_float()) { 295 | return builder()->CreateFCmpOLT(eval(e.arg(0)), eval(e.arg(1))); 296 | } else if (e.arg(0)->type.dtype.is_signed()) { 297 | return builder()->CreateICmpSLT(eval(e.arg(0)), eval(e.arg(1))); 298 | } else { 299 | return builder()->CreateICmpULT(eval(e.arg(0)), eval(e.arg(1))); 300 | } 301 | } 302 | case MathOp::LTE: { 303 | if (e.arg(0)->type.dtype.is_float()) { 304 | return builder()->CreateFCmpOLE(eval(e.arg(0)), eval(e.arg(1))); 305 | } else if (e.arg(0)->type.dtype.is_signed()) { 306 | return builder()->CreateICmpSLE(eval(e.arg(0)), eval(e.arg(1))); 307 | } else { 308 | return builder()->CreateICmpULE(eval(e.arg(0)), eval(e.arg(1))); 309 | } 310 | } 311 | case MathOp::GT: { 312 | if (e.arg(0)->type.dtype.is_float()) { 313 | return builder()->CreateFCmpOGT(eval(e.arg(0)), eval(e.arg(1))); 314 | } else if (e.arg(0)->type.dtype.is_signed()) { 315 | return builder()->CreateICmpSGT(eval(e.arg(0)), eval(e.arg(1))); 316 | } else { 317 | return builder()->CreateICmpUGT(eval(e.arg(0)), eval(e.arg(1))); 318 | } 319 | } 320 | case MathOp::GTE: { 321 | if (e.arg(0)->type.dtype.is_float()) { 322 | return builder()->CreateFCmpOGE(eval(e.arg(0)), eval(e.arg(1))); 323 | } else if (e.arg(0)->type.dtype.is_signed()) { 324 | return builder()->CreateICmpSGE(eval(e.arg(0)), eval(e.arg(1))); 325 | } else { 326 | return builder()->CreateICmpUGE(eval(e.arg(0)), eval(e.arg(1))); 327 | } 328 | } 329 | case MathOp::NOT: return builder()->CreateNot(eval(e.arg(0))); 330 | case MathOp::AND: return builder()->CreateAnd(eval(e.arg(0)), eval(e.arg(1))); 331 | case MathOp::OR: return builder()->CreateOr(eval(e.arg(0)), eval(e.arg(1))); 332 | default: throw std::runtime_error("Invalid math operation"); break; 333 | } 334 | } 335 | 336 | Value* LLVMGen::visit(const Exists& exists) 337 | { 338 | return builder()->CreateIsNotNull(eval(exists.sym)); 339 | } 340 | 341 | Value* LLVMGen::visit(const Fetch& fetch) 342 | { 343 | auto& dtype = fetch.reg->type.dtype; 344 | auto reg_val = eval(fetch.reg); 345 | auto time_val = eval(fetch.time); 346 | auto idx_val = eval(fetch.idx); 347 | auto size_val = llsizeof(lltype(dtype)); 348 | auto ret_type = lltype(types::CHAR_PTR); 349 | auto addr = llcall("fetch", ret_type, { reg_val, time_val, idx_val, size_val }); 350 | 351 | return builder()->CreateBitCast(addr, lltype(fetch)); 352 | } 353 | 354 | Value* LLVMGen::visit(const Advance& adv) 355 | { 356 | return llcall("advance", lltype(adv), { adv.reg, adv.idx, adv.time }); 357 | } 358 | 359 | Value* LLVMGen::visit(const GetCkpt& next) 360 | { 361 | return llcall("get_ckpt", lltype(next), { next.reg, next.time, next.idx }); 362 | } 363 | 364 | Value* LLVMGen::visit(const GetStartIdx& start_idx) 365 | { 366 | return llcall("get_start_idx", lltype(start_idx), { start_idx.reg }); 367 | } 368 | 369 | Value* LLVMGen::visit(const GetEndIdx& end_idx) 370 | { 371 | return llcall("get_end_idx", lltype(end_idx), { end_idx.reg }); 372 | } 373 | 374 | Value* LLVMGen::visit(const GetStartTime& start_time) 375 | { 376 | return llcall("get_start_time", lltype(start_time), { start_time.reg }); 377 | } 378 | 379 | Value* LLVMGen::visit(const GetEndTime& end_time) 380 | { 381 | return llcall("get_end_time", lltype(end_time), { end_time.reg }); 382 | } 383 | 384 | Value* LLVMGen::visit(const CommitNull& commit) 385 | { 386 | return llcall("commit_null", lltype(commit), { commit.reg, commit.time }); 387 | } 388 | 389 | Value* LLVMGen::visit(const CommitData& commit) 390 | { 391 | return llcall("commit_data", lltype(commit), { commit.reg, commit.time }); 392 | } 393 | 394 | Value* LLVMGen::visit(const Read& read) 395 | { 396 | auto ptr_val = eval(read.ptr); 397 | auto ptr_type = read.ptr->type.dtype; 398 | return builder()->CreateLoad(lltype(ptr_type.deref()), ptr_val); 399 | } 400 | 401 | Value* LLVMGen::visit(const Write& write) 402 | { 403 | auto reg_val = eval(write.reg); 404 | auto ptr_val = eval(write.ptr); 405 | auto data_val = eval(write.data); 406 | builder()->CreateStore(data_val, ptr_val); 407 | return reg_val; 408 | } 409 | 410 | Value* LLVMGen::visit(const AllocRegion& alloc) 411 | { 412 | auto time_val = eval(alloc.start_time); 413 | auto size_val = llcall("get_buf_size", lltype(types::UINT32), { eval(alloc.size) }); 414 | auto tl_arr = builder()->CreateAlloca(lltype(types::IVAL), size_val); 415 | auto data_arr = builder()->CreateAlloca(lltype(alloc.type.dtype), size_val); 416 | auto char_arr = builder()->CreateBitCast(data_arr, lltype(types::CHAR_PTR)); 417 | 418 | auto reg_val = builder()->CreateAlloca(llregtype()); 419 | return llcall("init_region", lltype(alloc), { reg_val, time_val, size_val, tl_arr, char_arr }); 420 | } 421 | 422 | Value* LLVMGen::visit(const MakeRegion& make_reg) 423 | { 424 | auto in_reg_val = eval(make_reg.reg); 425 | auto st_val = eval(make_reg.st); 426 | auto si_val = eval(make_reg.si); 427 | auto et_val = eval(make_reg.et); 428 | auto ei_val = eval(make_reg.ei); 429 | auto out_reg_val = builder()->CreateAlloca(llregtype()); 430 | return llcall("make_region", lltype(make_reg), { out_reg_val, in_reg_val, st_val, si_val, et_val, ei_val }); 431 | } 432 | 433 | Value* LLVMGen::visit(const Call& call) 434 | { 435 | return llcall(call.name, lltype(call), call.args); 436 | } 437 | 438 | Value* LLVMGen::visit(const LoopNode& loop) 439 | { 440 | // Build inner loops 441 | for (const auto& inner_loop : loop.inner_loops) { 442 | LLVMGenCtx new_ctx(inner_loop.get(), &llctx()); 443 | auto& old_ctx = switch_ctx(new_ctx); 444 | inner_loop->Accept(*this); 445 | switch_ctx(old_ctx); 446 | } 447 | 448 | // Build current loop 449 | auto preheader_bb = BasicBlock::Create(llctx(), "preheader"); 450 | auto header_bb = BasicBlock::Create(llctx(), "header"); 451 | auto body_bb = BasicBlock::Create(llctx(), "body"); 452 | auto end_bb = BasicBlock::Create(llctx(), "end"); 453 | auto exit_bb = BasicBlock::Create(llctx(), "exit"); 454 | 455 | // Define function signature 456 | vector args_type; 457 | for (const auto& input : loop.inputs) { 458 | args_type.push_back(lltype(input->type)); 459 | } 460 | auto loop_fn = llfunc(loop.get_name(), lltype(loop.output), args_type); 461 | for (size_t i = 0; i < loop.inputs.size(); i++) { 462 | auto input = loop.inputs[i]; 463 | set_expr(input, loop_fn->getArg(i)); 464 | } 465 | // We add `noalias` attribute to the region parameters to help compiler autovectorize 466 | for (size_t i = 0; i < loop.inputs.size(); i++) { 467 | // If type is not a value, then it should be a region 468 | if (!loop.inputs[i]->type.is_val()) { 469 | loop_fn->addParamAttr(i, Attribute::NoAlias); 470 | } 471 | } 472 | 473 | // Initialization of loop states 474 | loop_fn->getBasicBlockList().push_back(preheader_bb); 475 | builder()->SetInsertPoint(preheader_bb); 476 | map base_inits; 477 | for (const auto& [_, base] : loop.state_bases) { 478 | base_inits[base] = eval(loop.syms.at(base)); 479 | } 480 | builder()->CreateBr(header_bb); 481 | 482 | // Phi nodes for loop states 483 | loop_fn->getBasicBlockList().push_back(header_bb); 484 | builder()->SetInsertPoint(header_bb); 485 | for (const auto& [base_sym, val] : base_inits) { 486 | auto base = builder()->CreatePHI(lltype(base_sym->type), 2, base_sym->name); 487 | set_expr(base_sym, base); 488 | base->addIncoming(val, preheader_bb); 489 | } 490 | 491 | // Check exit condition 492 | builder()->CreateCondBr(eval(loop.exit_cond), exit_bb, body_bb); 493 | 494 | // Loop body 495 | loop_fn->getBasicBlockList().push_back(body_bb); 496 | builder()->SetInsertPoint(body_bb); 497 | auto stack_val = builder()->CreateIntrinsic(Intrinsic::stacksave, {}, {}); 498 | 499 | // Update indices 500 | for (const auto& idx : loop.idxs) { 501 | eval(idx); 502 | } 503 | 504 | // Update loop counter 505 | eval(loop.t); 506 | 507 | // Evaluate loop output 508 | eval(loop.output); 509 | for (const auto& [var, base] : loop.state_bases) { 510 | auto base_phi = dyn_cast(eval(base)); 511 | base_phi->addIncoming(eval(var), end_bb); 512 | } 513 | builder()->CreateBr(end_bb); 514 | 515 | // Jump back to loop header 516 | loop_fn->getBasicBlockList().push_back(end_bb); 517 | builder()->SetInsertPoint(end_bb); 518 | builder()->CreateIntrinsic(Intrinsic::stackrestore, {}, {stack_val}); 519 | builder()->CreateBr(header_bb); 520 | 521 | // Loop exit 522 | loop_fn->getBasicBlockList().push_back(exit_bb); 523 | builder()->SetInsertPoint(exit_bb); 524 | builder()->CreateRet(eval(loop.state_bases.at(loop.output))); 525 | 526 | return loop_fn; 527 | } 528 | 529 | unique_ptr LLVMGen::Build(const Loop loop, llvm::LLVMContext& llctx) 530 | { 531 | LLVMGenCtx ctx(loop.get(), &llctx); 532 | LLVMGen llgen(std::move(ctx)); 533 | loop->Accept(llgen); 534 | return std::move(llgen._llmod); 535 | } 536 | 537 | void LLVMGen::register_vinstrs() { 538 | const auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(vinstr_str)); 539 | 540 | llvm::SMDiagnostic error; 541 | std::unique_ptr vinstr_mod = llvm::parseIR(*buffer, error, llctx()); 542 | if (!vinstr_mod) { 543 | throw std::runtime_error("Failed to parse vinstr bitcode"); 544 | } 545 | if (llvm::verifyModule(*vinstr_mod)) { 546 | throw std::runtime_error("Failed to verify vinstr module"); 547 | } 548 | 549 | // For some reason if we try to set internal linkage before we link 550 | // modules, then the JIT will be unable to find the symbols. 551 | // Instead we collect the function names first, then add internal 552 | // linkage to them after linking the modules 553 | std::vector vinstr_names; 554 | for (const auto& function : vinstr_mod->functions()) { 555 | if (function.isDeclaration()) { 556 | continue; 557 | } 558 | vinstr_names.push_back(function.getName().str()); 559 | } 560 | 561 | llvm::Linker::linkModules(*llmod(), std::move(vinstr_mod)); 562 | for (const auto& name : vinstr_names) { 563 | llmod()->getFunction(name.c_str())->setLinkage(llvm::Function::InternalLinkage); 564 | } 565 | } 566 | --------------------------------------------------------------------------------